diff --git a/CHANGELOG.md b/CHANGELOG.md index 55621b37ecee2..fdd02a9a79009 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,489 @@ # TiDB Changelog All notable changes to this project will be documented in this file. See also [Release Notes](https://github.com/pingcap/docs/blob/master/releases/rn.md), [TiKV Changelog](https://github.com/tikv/tikv/blob/master/CHANGELOG.md) and [PD Changelog](https://github.com/pingcap/pd/blob/master/CHANGELOG.md). +## [3.0.5] 2019-10-23 +## SQL Optimizer +* Support boundary checking on Window Functions [#12404](https://github.com/pingcap/tidb/pull/12404) +* Fix the issue that `IndexJoin` on the partition table returns incorrect results [#12712](https://github.com/pingcap/tidb/pull/12712) +* Fix the issue that the `ifnull` function on the top of the outer join `Apply` operator returns incorrect results [#12694](https://github.com/pingcap/tidb/pull/12694) +* Fix the issue of update failure when a subquery was included in the `where` condition of `UPDATE` [#12597](https://github.com/pingcap/tidb/pull/12597) +* Fix the issue that outer join was incorrectly converted to inner join when the `cast` function was included in the query conditions [#12790](https://github.com/pingcap/tidb/pull/12790) +* Fix incorrect expression passing in the join condition of `AntiSemiJoin` [#12799](https://github.com/pingcap/tidb/pull/12799) +* Fix the statistics error caused by shallow copy when initializing statistics [#12817](https://github.com/pingcap/tidb/pull/12817) +* Fix the issue that the `str_to_date` function in TiDB returns a different result from MySQL when the date string and the format string do not match [#12725](https://github.com/pingcap/tidb/pull/12725) + +## SQL Execution Engine +* Fix the panic issue when the `from_unixtime` function handles null [#12551](https://github.com/pingcap/tidb/pull/12551) +* Fix the `invalid list index` error reported when canceling DDL jobs [#12671](https://github.com/pingcap/tidb/pull/12671) +* Fix the issue that arrays were out of bounds when Window Functions are used [#12660](https://github.com/pingcap/tidb/pull/12660) +* Improve the behavior of the `AutoIncrement` column when it is implicitly allocated, to keep it consistent with the default mode of MySQL auto-increment locking (["consecutive" lock mode](https://dev.mysql.com/doc/refman/5.7/en/innodb-auto-increment-handling.html)): for the implicit allocation of multiple `AutoIncrement` IDs in a single-line `Insert` statement, TiDB guarantees the continuity of the allocated values. This improvement ensures that the JDBC `getGeneratedKeys()` method will get the correct results in any scenario. [#12602](https://github.com/pingcap/tidb/pull/12602) +* Fix the issue that the query is hanged when `HashAgg` serves as a child node of `Apply` [#12766](https://github.com/pingcap/tidb/pull/12766) +* Fix the issue that the `AND` and `OR` logical expressions return incorrect results when it comes to type conversion [#12811](https://github.com/pingcap/tidb/pull/12811) + +## Server +* Implement the interface function that modifies transaction TTL to help support large transactions later [#12397](https://github.com/pingcap/tidb/pull/12397) +* Support extending the transaction TTL as needed (up to 10 minutes) to support pessimistic transactions [#12579](https://github.com/pingcap/tidb/pull/12579) +* Adjust the number of times that TiDB caches schema changes and corresponding changed table information from 100 to 1024, and support modification by using the `tidb_max_delta_schema_count` system variable [#12502](https://github.com/pingcap/tidb/pull/12502) +* Update the behavior of the `kvrpc.Cleanup` protocol to no longer clean locks of transactions that are not overtime [#12417](https://github.com/pingcap/tidb/pull/12417) +* Support logging Partition table information to the `information_schema.tables` table [#12631](https://github.com/pingcap/tidb/pull/12631) +* Support modifying the TTL of Region Cache by configuring `region-cache-ttl` [#12683](https://github.com/pingcap/tidb/pull/12683) +* Support printing the execution plan compression-encoded information in the slow log. This feature is enabled by default and can be controlled by using the `slow-log-plan` configuration or the `tidb_record_plan_in_slow_log` variable. In addition, the `tidb_decode_plan` function can parse the execution plan column encoding information in the slow log into execution plan information. [#12808](https://github.com/pingcap/tidb/pull/12808) +* Support displaying memory usage information in the `information_schema.processlist` table [#12801](https://github.com/pingcap/tidb/pull/12801) +* Fix the issue that an error and an unexpected alarm might occur when the TiKV Client judges an idle connection [#12846](https://github.com/pingcap/tidb/pull/12846) +* Fix the issue that the `INSERT IGNORE` statement performance is decreased because `tikvSnapshot` does not properly cache the KV results of `BatchGet()` [#12872](https://github.com/pingcap/tidb/pull/12872) +* Fix the issue that the TiDB response speed was relatively low because of slow connection to some KV services [#12814](https://github.com/pingcap/tidb/pull/12814) + +## DDL +* Fix the issue that the `Create Table` operation does not correctly set the Int type default value for the Set column [#12267](https://github.com/pingcap/tidb/pull/12267) +* Support multiple `unique`s when creating a unique index using the `Create Table` statement [#12463](https://github.com/pingcap/tidb/pull/12463) +* Fix the issue that populating the default value of this column for existing rows might cause an error when adding a Bit type column using `Alter Table` [#12489](https://github.com/pingcap/tidb/pull/12489) +* Fix the failure of adding a partition when the Range partitioned table uses a Date or Datetime type column as the partitioning key [#12815](https://github.com/pingcap/tidb/pull/12815) +* Support checking the consistency of the partition type and the partition key type when creating a table or adding a partition, for the Range partitioned table with the Date or Datetime type column as the partition key [#12792](https://github.com/pingcap/tidb/pull/12792) +* Add a check that the Unique Key column set needs to be greater than or equal to the partitioned column set when creating a Range partitioned table [#12718](https://github.com/pingcap/tidb/pull/12718) + +## Monitor +* Add the monitoring metrics of Commit and Rollback operations to the `Transaction OPS` dashboard [#12505](https://github.com/pingcap/tidb/pull/12505) +* Add the monitoring metrics of `Add Index` operation progress [#12390](https://github.com/pingcap/tidb/pull/12390) + +## [3.0.4] 2019-10-08 +## New features +* Add system table `performance_schema.events_statements_summary_by_digest` to troubleshoot performance issues at SQL level +* Add the `WHERE` clause in TiDB’s `SHOW TABLE REGIONS` syntax +* Add the `worker-count` and `txn-batch` configuration items in Reparo to control the recovery speed + +## Improvements +* Support `Split` commands in batches and the empty `Split` command in TiKV to split Regions in batches +* Support double linked list for RocksDB in TiKV to improve performance of reverse scan +* Add two perf tools -`iosnoop` and `funcslower` in TiDB Ansible to better diagnose the cluster state +* Optimize the output of slow query logs in TiDB by deleting redundant fields + +## Changed behaviors +* Update the default value of `txn-local-latches.enable` to `false` to disable the default behaviour of checking conflicts of local transactions in TiDB +* Add the `tidb_txn_mode` system variable of global scope in TiDB and allow using pessimistic lock by default. Note that TiDB still adopts optimistic lock by default. +* Replace the `Index_ids` field in TiDB slow query logs with `Index_names` to improve the usability of slow query logs +* Add the `split-region-max-num` parameter in the TiDB configuration file to modify the maximum number of Regions allowed in the `SPLIT TABLE` syntax +* Return the `Out Of Memory Quota` error instead of disconnect links in TiDB when a SQL execution exceeds the memory limit +* Disallow dropping the `AUTO INCREMENT` attribute of columns in TiDB to avoid misoperations. To drop this attribute, change the `tidb_allow_remove_auto_inc` system * variable. + +## Fixed Issues +* Fix the issue in TiDB that the special syntax `PRE_SPLIT_REGIONS` does not replicate data to the downstream by using notes +* Fix the issue in TiDB that the slow query logs are incorrect when getting the result of `PREPARE` + `EXECUTE` by using the cursor +* Fix the issue in PD that adjacent small Regions cannot be merged +* Fix the issue that too many file descriptors are opened in idle clusters +* Fix the issue that + +## Contributors +Our thanks go to the following contributors from the community for helping this release: +* [sduzh](https://github.com/sduzh) +* [lizhenda](https://github.com/lizhenda) + +## SQL Optimizer +* Fix the issue that invalid query ranges might be resulted when splitted by feedback [#12170](https://github.com/pingcap/tidb/pull/12170) +* Display the returned error of the `SHOW STATS_BUCKETS` statement in hexadecimal when the result contains invalid Keys [#12094](https://github.com/pingcap/tidb/pull/12094) +* Fix the issue that when a query contains the `SLEEP` function (for example, `select 1 from (select sleep(1)) t;)`), column pruning causes invalid `sleep(1)` during query [#11953](https://github.com/pingcap/tidb/pull/11953) +* Use index scan to lower IO when a query only concerns the number of columns rather than the table data [#12112](https://github.com/pingcap/tidb/pull/12112) +* Do not use any index when no index is specified in `use index()` to becompatible with MySQL [#12100](https://github.com/pingcap/tidb/pull/12100) +* Strictly limit the number of `TopN` records in the `CMSketch` statistics to fix the issue that the `ANALYZE` statement fails because the statement count exceeds TiDB’s limit on the size of a transaction [#11914](https://github.com/pingcap/tidb/pull/11914) +* Fix the error occurred when converting the subqueries contained in the `Update` statement [#12483](https://github.com/pingcap/tidb/pull/12483) +* Optimize execution performance of the `select ... limit ... offset ...` statement by pushing the Limit operator down to the `IndexLookUpReader` execution logic [#12378](https://github.com/pingcap/tidb/pull/12378) + +## SQL Execution Engine +* Print the SQL statement in the log when the `PREPARED` statement is incorrectly executed [#12191](https://github.com/pingcap/tidb/pull/12191) +* Support partition pruning when the `UNIX_TIMPESTAMP` function is used to implement partitioning [#12169](https://github.com/pingcap/tidb/pull/12169) +* Fix the issue that no error is reported when `AUTO INCREMENT` incorrectly allocates `MAX int64` and `MAX uint64` [#12162](https://github.com/pingcap/tidb/pull/12162) +* Add the `WHERE` clause in the `SHOW TABLE … REGIONS` and `SHOW TABLE .. INDEX … REGIONS` syntaxes [#12123](https://github.com/pingcap/tidb/pull/12123) +* Return the `Out Of Memory Quota` error instead of disconnect the link when a SQL execution exceeds the memory limit [#12127](https://github.com/pingcap/tidb/pull/12127) +* Fix the issue that incorrect result is returned when `JSON_UNQUOTE` function handles JSON text [#11955](https://github.com/pingcap/tidb/pull/11955) +* Fix the issue that `LAST INSERT ID` is incorrect when assigning values to the `AUTO_INCREMENT` column in the first row (for example, `insert into t (pk, c) values (1, 2), (NULL, 3)`) [#12002](https://github.com/pingcap/tidb/pull/12002) +* Fix the issue that the `GROUPBY` parsing rule is incorrect in the `PREPARE` statement [#12351](https://github.com/pingcap/tidb/pull/12351) +* Fix the issue that the privilege check is incorrect in point queries [#12340](https://github.com/pingcap/tidb/pull/12340) +* Fix the issue that the duration by sql_type for the `PREPARE` statement is not shown in the monitoring record [#12331](https://github.com/pingcap/tidb/pull/12331) +* Support using aliases for tables in the point queries (for example, `select * from t tmp where a = "aa"`) [#12282](https://github.com/pingcap/tidb/pull/12282) +* Fix the error occurred when not handling negative values as unsigned when inserting negative numbers into BIT type columns [#12423](https://github.com/pingcap/tidb/pull/12423) +* Fix the incorrectly rounding of time (for example, `2019-09-11 11:17:47.999999666` should be rounded to `2019-09-11 11:17:48`.) [#12258](https://github.com/pingcap/tidb/pull/12258) +* Refine the usage of expression blacklist (for example, `<` is equivalent to `lt`.) [#11975](https://github.com/pingcap/tidb/pull/11975) +* Add the database prefix to the function name in the error message of non-existing functions (for example, `[expression:1305]FUNCTION test.std_samp does not exist`) [#12111](https://github.com/pingcap/tidb/pull/12111) + +## Server +* Add the `Prev_stmt` field in slow query logs to output the previous statement when the last statement is `COMMIT` [#12180](https://github.com/pingcap/tidb/pull/12180) +* Optimize the output of slow query logs by deleting redundant fields [#12144](https://github.com/pingcap/tidb/pull/12144) +* Modify the default value of `txn-local-latches.enable` to `false` to disable the check for conflicts of local transactions in TiDB [#12095](https://github.com/pingcap/tidb/pull/12095) +* Replace the `Index_ids` field in TiDB slow query logs with `Index_names` to improve the usability of slow query logs [#12061](https://github.com/pingcap/tidb/pull/12061) +* Add the global system variable `tidb_txn_mode` in TiDB and allow using pessimistic lock by default [#12049](https://github.com/pingcap/tidb/pull/12049) +* Add the `Backoff` field in the slow query logs to record the Backoff information in the commit phase of 2PC [#12335](https://github.com/pingcap/tidb/pull/12335) +* Fix the issue that the slow query logs are incorrect when getting the result of `PREPARE` + `EXECUTE` by using the cursor (for example, `PREPARE stmt1FROM SELECT * FROM t WHERE a > ?; EXECUTE stmt1 USING @variable`) [#12392](https://github.com/pingcap/tidb/pull/12392) +* Support `tidb_enable_stmt_summary`. When this feature is enabled, TiDB counts the SQL statements and the result can be queried by using the system table `performance_schema.events_statements_summary_by_digest`. [#12308](https://github.com/pingcap/tidb/pull/12308) +* Adjust the level of some logs in tikv-client (for example, change the log level of `batchRecvLoop fails` from `ERROR` to `INFO`) [#12383](https://github.com/pingcap/tidb/pull/12383) + +## DDL +* Add the variable of `tidb_allow_remove_auto_inc`. Dropping the `AUTO INCREMENT` attribute of the column is disabled by default [#12145](https://github.com/pingcap/tidb/pull/12145) +* Fix the issue that the uncommented TiDB-specific syntax `PRE_SPLIT_REGIONS` might cause t errors in the downstream database during data replication [#12120](https://github.com/pingcap/tidb/pull/12120) +* Add the `split-region-max-num` variable in the configuration file so that the maximum allowable number of Regions is adjustable [#12097](https://github.com/pingcap/tidb/pull/12079) +* Support splitting a Region into multiple Regions and fix the timeout issue during Region scattering [#12343](https://github.com/pingcap/tidb/pull/12343) +* Fix the issue that the `drop index` statement fails when the index that contains an `auto_increment` column referenced byanother two indexes [#12344](https://github.com/pingcap/tidb/pull/12344) + +## Monitor +* Add the `connection_transient_failure_count` monitoring metric to count the number of gRPC connection errors in `tikvclient` [#12093](https://github.com/pingcap/tidb/pull/12093) + +## [3.0.3] 2019-08-29 +### SQL Optimizer +* Add the `opt_rule_blacklist` table to disable logic optimization rules such as `aggregation_eliminate` and `column_prune` [#11658](https://github.com/pingcap/tidb/pull/11658) +* Fix the issue that incorrect results might be returned for `Index Join` when the join key uses a prefix index or an unsigned index column that is equal to a negative value [#11759](https://github.com/pingcap/tidb/pull/11759) +* Fix the issue that `"` or `\` in the `SELECT` statements of `create ... binding ...` might result in parsing errors [#11726](https://github.com/pingcap/tidb/pull/11726) + +### SQL Execution Engine +* Fix the issue that type errors in the return value might occur when the `Quote` function handles a null value [#11619](https://github.com/pingcap/tidb/pull/11619) +* Fix the issue that incorrect results for `Ifnull` might be returned when `Max`/`min` is used for type inferring with `NotNullFlag` retained [#11641](https://github.com/pingcap/tidb/pull/11641) +* Fix the potential error that occurs when comparing bit type data in string form [#11660](https://github.com/pingcap/tidb/pull/11660) +* Decrease the concurrency for data that requires sequential read to reduce the possibility of OOM [#11679](https://github.com/pingcap/tidb/pull/11679) +* Fix the issue that incorrect type inferring might be caused when multiple parameters are unsigned for some built-in functions (e.g. `If`, `Coalesce`) [#11621](https://github.com/pingcap/tidb/pull/11621) +* Fix the incompatibility with MySQL when the `Div` function handles unsigned decimal types [#11813](https://github.com/pingcap/tidb/pull/11813) +* Fix the issue that panic might occur when executing SQL statements that modify the status of Pump/Drainer [#11827](https://github.com/pingcap/tidb/pull/11827) +* Fix the issue that panic might occur for `select ... for update` when Autocommit = 1 and there is no `begin` statement [#11736](https://github.com/pingcap/tidb/pull/11736) +* Fix the permission check error that might occur when the `set default role` statement is executed [#11777](https://github.com/pingcap/tidb/pull/11777) +* Fix the permission check error that might occur when `create user` or `drop user` is executed [#11814](https://github.com/pingcap/tidb/pull/11814) +* Fix the issue that the `select ... for update` statement might auto retry when it is constructed into the `PointGetExecutor` function [#11718](https://github.com/pingcap/tidb/pull/11718) +* Fix the boundary error that might occur when the Window function handles partition [#11825](https://github.com/pingcap/tidb/pull/11825) +* Fix the issue that the `Time` function hits EOF errors when handling an incorrectly formatted argument [#11893](https://github.com/pingcap/tidb/pull/11893) +* Fix the issue that the Window function does not check the passed-in parameters [#11705](https://github.com/pingcap/tidb/pull/11705) +* Fix the issue that the plan result viewed via `Explain` is inconsistent with the actually executed plan [#11186](https://github.com/pingcap/tidb/pull/11186) +* Fix the issue that duplicate memory referenced by the Window function might result in a crash or incorrect results [#11823](https://github.com/pingcap/tidb/pull/11823) +* Update the incorrect information in the `Succ` field in the slow log [#11887](https://github.com/pingcap/tidb/pull/11887) + +### Server +* Rename the `tidb_back_off_wexight` variable to `tidb_backoff_weight` [#11665](https://github.com/pingcap/tidb/pull/11665) +* Update the minimum TiKV version compatible with the current TiDB to v3.0.0 [#11618](https://github.com/pingcap/tidb/pull/11618) +* Support `make testSuite` to ensure the suites in the test are correctly used [#11685](https://github.com/pingcap/tidb/pull/11685) + +### DDL +* Skip the execution of unsupported partition-related DDL statements, including statements that modify the partition type while deleting multiple partitions [#11373](https://github.com/pingcap/tidb/pull/11373) +* Disallow a Generated Column to be placed before its dependent columns [#11686](https://github.com/pingcap/tidb/pull/11686) +* Modify the default values of `tidb_ddl_reorg_worker_cnt` and `tidb_ddl_reorg_batch_size` [#11874](https://github.com/pingcap/tidb/pull/11874) + +### Monitor +* Add new backoff monitoring types to record duration for each backoff type; add more backoff metrics to cover previously uncounted types such as commit backoff [#11728](https://github.com/pingcap/tidb/pull/11728) + + +## [3.0.2] 2019-08-06 +### SQL Optimizer +* Fix the issue that the "Can't find column in schema" message is reported when the same table occurs multiple times in a query and logically the query result is always empty [#11247](https://github.com/pingcap/tidb/pull/11247) +* Fix the issue that the query plan does not meet the expectation caused by the `TIDB_INLJ` hint not working correctly in some cases (like `explain select /*+ TIDB_INLJ(t1) */ t1.b, t2.a from t t1, t t2 where t1.b = t2.a`) [#11362](https://github.com/pingcap/tidb/pull/11362) +* Fix the issue that the column name in the query result is wrong in some cases (like `SELECT IF(1,c,c) FROM t`) [#11379](https://github.com/pingcap/tidb/pull/11379) +* Fix the issue that some queries like `SELECT 0 LIKE 'a string'` return `TRUE` becausethe `LIKE` expression is implicitly converted to 0 in some cases [#11411](https://github.com/pingcap/tidb/pull/11411) +* Support sub-queries in the `SHOW` statement, like `SHOW COLUMNS FROM tbl WHERE FIELDS IN (SELECT 'a')` [#11459](https://github.com/pingcap/tidb/pull/11459) +* Fix the issue that the related column of the aggregate function cannot be found and an error is reported caused by the `outerJoinElimination` optimizing rule not correctly handling the column alias; improve alias parsing in the optimizing process to make optimization cover more query types [#11377](https://github.com/pingcap/tidb/pull/11377) +* Fix the issue that no error is reported when the syntax restriction is violated in the Window function (for example, `UNBOUNDED PRECEDING` is not allowed to appear at the end of the Frame definition) [#11543](https://github.com/pingcap/tidb/pull/11543) +* Fix the issue that `FUNCTION_NAME` is in uppercase in the `ERROR 3593 (HY000): You cannot use the window function FUNCTION_NAME in this context` error message, which causes incompatibility with MySQL [#11535](https://github.com/pingcap/tidb/pull/11535) +* Fix the issue that the unimplemented `IGNORE NULLS` syntax in the Window function is used but no error is reported [#11593](https://github.com/pingcap/tidb/pull/11593) +* Fix the issue that the Optimizer does not correctly estimate time equal conditions [#11512](https://github.com/pingcap/tidb/pull/11512) +* Support updating the Top-N statistics based on the feedback information [#11507](https://github.com/pingcap/tidb/pull/11507) + +### SQL Execution Engine +* Fix the issue that the returned value is not `NULL` when the `INSERT` function contains `NULL` in parameters [#11248](https://github.com/pingcap/tidb/pull/11248) +* Fix the issue that the computing result might be wrong when the partitioned table is checked by the `ADMIN CHECKSUM` operation [#11266](https://github.com/pingcap/tidb/pull/11266) +* Fix the issue that the result might be wrong when INDEX JOIN uses the prefix index [#11246](https://github.com/pingcap/tidb/pull/11246) +* Fix the issue that result might be wrong caused by incorrectly aligning fractions when the `DATE_ADD` function does subtraction on date numbers involving microseconds [#11288](https://github.com/pingcap/tidb/pull/11288) +* Fix the wrong result caused by the `DATE_ADD` function incorrectly processing the negative numbers in `INTERVAL` [#11325](https://github.com/pingcap/tidb/pull/11325) +* Fix the issue that the number of fractional digits returned by `Mod`(`%`), `Multiple`(`*`) or `Minus`(`-`) is different from that in MySQL when `Mod`, `Multiple` or `Minus` returns 0 and the number of fractional digits is large (like +`select 0.000 % 0.11234500000000000000`) [#11251](https://github.com/pingcap/tidb/pull/11251) +* Fix the issue that `NULL` with a warning is incorrectly returned when the length of the result returned by `CONCAT` and `CONCAT_WS` functions exceeds `max_allowed_packet` [#11275](https://github.com/pingcap/tidb/pull/11275) +* Fix the issue that `NULL` with a warning is incorrectly returned when parameters in the `SUBTIME` and `ADDTIME` functions are invalid [#11337](https://github.com/pingcap/tidb/pull/11337) +* Fix the issue that `NULL` is incorrectly returned when parameters in the `CONVERT_TZ` function are invalid [#11359](https://github.com/pingcap/tidb/pull/11359) +* Add the `MEMORY` column to the result returned by `EXPLAIN ANALYZE` to show the memory usage of this query [#11418](https://github.com/pingcap/tidb/pull/11418) +* Add `CARTESIAN JOIN` to the result of `EXPLAIN` [#11429](https://github.com/pingcap/tidb/pull/11429) +* Fix the incorrect data of auto-increment columns of the float and double types [#11385](https://github.com/pingcap/tidb/pull/11385) +* Fix the panic issue caused by some `nil` information when pseudo statistics are dumped [#11460](https://github.com/pingcap/tidb/pull/11460) +* Fix the incorrect query result of `SELECT … CASE WHEN … ELSE NULL ...` caused by constant folding optimization [#11441](https://github.com/pingcap/tidb/pull/11441) +* Fix the issue that `floatStrToIntStr` does not correctly parse the input such as `+999.9999e2` [#11473](https://github.com/pingcap/tidb/pull/11473) +* Fix the issue that `NULL` is not returned in some cases when the result of the `DATE_ADD` and `DATE_SUB` function overflows [#11476](https://github.com/pingcap/tidb/pull/11476) +* Fix the issue that the conversion result is different from that in MySQL if the string contains an invalid character when a long string is converted to an integer [#11469](https://github.com/pingcap/tidb/pull/11469) +* Fix the issue that the result of the `REGEXP BINARY` function is incompatible with MySQL caused by case sensitiveness of this function [#11504](https://github.com/pingcap/tidb/pull/11504) +* Fix the issue that an error is reported when the `GRANT ROLE` statement receives `CURRENT_ROLE`; fix the issue that the `REVOKE ROLE` statement does not correctly revoke the `mysql.default_role` privilege [#11356](https://github.com/pingcap/tidb/pull/11356) +* Fix the display format issue of the `Incorrect datetime value` warning information when executing statements like `SELECT ADDDATE('2008-01-34', -1)` [#11447](https://github.com/pingcap/tidb/pull/11447) +* Fix the issue that the error message reports `constant … overflows float` rather than `constant … overflows bigint` if the result overflows when a float field of the JSON data is converted to an integer [#11534](https://github.com/pingcap/tidb/pull/11534) +* Fix the issue that the result might be wrong caused by incorrect type conversion when the `DATE_ADD` function receives `FLOAT`, `DOUBLE` and `DECIMAL` column parameters [#11527](https://github.com/pingcap/tidb/pull/11527) +* Fix the wrong result caused by incorrectly processing the sign of the INTERVAL fraction in the `DATE_ADD` function [#11615](https://github.com/pingcap/tidb/pull/11615) +* Fix the incorrect query result when Index Lookup Join contains the prefix index caused by `Ranger` not correctly handling the prefix index [#11565](https://github.com/pingcap/tidb/pull/11565) +* Fix the issue that the "Incorrect arguments to NAME_CONST" message is reported if the `NAME_CONST` function is executed when the second parameter of `NAME_CONST` is a negative number [#11268](https://github.com/pingcap/tidb/pull/11268) +* Fix the issue that the result is incompatible with MySQL when an SQL statement involves computing the current time and the value is fetched multiple times; use the same value when fetching the current time for the same SQL statement [#11394](https://github.com/pingcap/tidb/pull/11394) +* Fix the issue that `Close` is not called for ChildExecutor when the `Close` of baseExecutor reports an error. This issue might lead to Goroutine leaks when the `KILL` statements do not take effect and ChildExecutor is not closed [#11576](https://github.com/pingcap/tidb/pull/11576) + +### Server +* Fix the issue that the auto-added value is 0 instead of the current timestamp when `LOAD DATA` processes the missing `TIMESTAMP` field in the CSV file [#11250](https://github.com/pingcap/tidb/pull/11250) +* Fix issues that the `SHOW CREATE USER` statement does not correctly check related privileges, and `USER` and `HOST` returned by `SHOW CREATE USER CURRENT_USER()` might be wrong [#11229](https://github.com/pingcap/tidb/pull/11229) +* Fix the issue that the returned result might be wrong when `executeBatch` is used in JDBC [#11290](https://github.com/pingcap/tidb/pull/11290) +* Reduce printing the log information of the streaming client when changing the TiKV server's port [#11370](https://github.com/pingcap/tidb/pull/11370) +* Optimize the logic of reconnecting the streaming client to the TiKV server so that the streaming client will not be blocked for a long time [#11372](https://github.com/pingcap/tidb/pull/11372) +* Add `REGION_ID` in `INFORMATION_SCHEMA.TIDB_HOT_REGIONS` [#11350](https://github.com/pingcap/tidb/pull/11350) +* Cancel the timeout duration of obtaining Region information from the PD API to ensure that obtaining Region information will not end in a failure when TiDB API http://{TiDBIP}:10080/regions/hot is called due to PD timeout when the number of Regions is large [#11383](https://github.com/pingcap/tidb/pull/11383) +* Fix the issue that Region related requests do not return partitioned table-related Regions in the HTTP API [#11466](https://github.com/pingcap/tidb/pull/11466) +* Modify some default parameters related to pessimistic locks, these modifications reduce the probability of locking timeout caused by slow operations when the user manually validates pessimistic locking [#11521](https://github.com/pingcap/tidb/pull/11521) + * Increase the default TTL of pessimistic locking from 30 seconds to 40 seconds + * Increase the maximum TTL from 60 seconds to 120 seconds + * Calculate the pessimistic locking duration from the first `LockKeys` request +* Change the `SendRequest` function logic in the TiKV client: try to immediately connect to another peer instead of keeping waiting when the connect cannot be built [#11531](https://github.com/pingcap/tidb/pull/11531) +* Optimize the Region cache: label the removed store as invalid when a store is moved while another store goes online with a same address, to update the store information in the cache as soon as possible [#11567](https://github.com/pingcap/tidb/pull/11567) +* Add the Region ID to the result returned by the `http://{TiDB_ADDRESS:TIDB_IP}/mvcc/key/{db}/{table}/{handle}` API [#11557](https://github.com/pingcap/tidb/pull/11557) +* Fix the issue that Scatter Table does not work caused by the Scatter Table API not escaping the Range key [#11298](https://github.com/pingcap/tidb/pull/11298) +* Optimize the Region cache: label the store where the Region exists as invalid when the correspondent store is inaccessible to avoid reduced query performance caused by accessing this store [#11498](https://github.com/pingcap/tidb/pull/11498) +* Fix the error that the table schema can still be obtained through the HTTP API after dropping the database with the same name multiple times [#11585](https://github.com/pingcap/tidb/pull/11585) + +### DDL +* Fix the issue that an error occurs when a non-string column with a zero length is being indexed [#11214](https://github.com/pingcap/tidb/pull/11214) +* Disallow modifying the columns with foreign key constraints and full-text indexes (Note: TiDB still supports foreign key constraints and full-text indexes in syntax only) [#11274](https://github.com/pingcap/tidb/pull/11274) +* Fix the issue that the index offset of the column might be wrong because the position changed by the `ALTER TABLE` statement and the default value of the column are used concurrently [#11346](https://github.com/pingcap/tidb/pull/11346) +* Fix two issues that occur when parsing JSON files: + * `int64` is used as the intermediate parsing result of `uint64` in `ConvertJSONToFloat`, which leads to the precision overflow error [#11433](https://github.com/pingcap/tidb/pull/11433) + * `int64` is used as the intermediate parsing result of `uint64` in `ConvertJSONToInt`, which leads to the precision overflow error [#11551](https://github.com/pingcap/tidb/pull/11551) +* Disallow dropping indexes on the auto-increment column to avoid that the auto-increment column might get an incorrect result [#11399](https://github.com/pingcap/tidb/pull/11399) +* Fix the following issues [#11492](https://github.com/pingcap/tidb/pull/11492): + * The character set and the collation of the column are not consistent when explicitly specifying the collation but not the character set + * The error is not correctly reported when there is a conflict between the character set and the collation that are specified by `ALTER TABLE … MODIFY COLUMN` + * Incompatibility with MySQL when using `ALTER TABLE … MODIFY COLUMN` to specify character sets and collations multiple times +* Add the trace details of the subquery to the result of the `TRACE` query [#11458](https://github.com/pingcap/tidb/pull/11458) +* Optimize the performance of executing `ADMIN CHECK TABLE` and greatly reduce its execution time [#11547](https://github.com/pingcap/tidb/pull/11547) +* Add the result returned by `SPLIT TABLE … REGIONS/INDEX` and make `TOTAL_SPLIT_REGION` and `SCATTER_FINISH_RATIO` display the number of Regions that have been split successfully before timeout in the result [#11484](https://github.com/pingcap/tidb/pull/11484) +* Fix the issue that the precision displayed by statements like `SHOW CREATE TABLE` is incomplete when `ON UPDATE CURRENT_TIMESTAMP` is the column attribute and the float precision is specified [#11591](https://github.com/pingcap/tidb/pull/11591) +* Fix the issue that the index result of the column cannot be correctly calculated when the expression of a virtual generated column contains another virtual generated column [#11475](https://github.com/pingcap/tidb/pull/11475) +* Fix the issue that the minus sign cannot be added after `VALUE LESS THAN` in the `ALTER TABLE … ADD PARTITION … ` statement [#11581](https://github.com/pingcap/tidb/pull/11581) + +### Monitor +* Fix the issue that data is not collected and reported because the `TiKVTxnCmdCounter` monitoring metric is not registered [#11316](https://github.com/pingcap/tidb/pull/11316) +* Add the `BindUsageCounter`, `BindTotalGauge` and `BindMemoryUsage` monitoring metrics for the Bind Info [#11467](https://github.com/pingcap/tidb/pull/11467) + + +## [3.0.1] 2019-07-16 +* Add the `tidb_wait_split_region_finish_backoff` session variable to control the backoff time of splitting Regions [#11166](https://github.com/pingcap/tidb/pull/11166) +* Support automatically adjusting the auto-incremental ID allocation step based on the load, and the auto-adjustment scope of the step is 1000~2000000 [#11006](https://github.com/pingcap/tidb/pull/11006) +* Add the `ADMIN PLUGINS ENABLE`/`ADMIN PLUGINS DISABLE` SQL statement to dynamically enable or disable plugins [#11157](https://github.com/pingcap/tidb/pull/11157) +* Add the session connection information in the audit plugin [#11013](https://github.com/pingcap/tidb/pull/11013) +* Add optimizer hint `MAX_EXECUTION_TIME`, which places a limit N (a timeout value in milliseconds) on how long a `SELECT` statement is permitted to execute before the server terminates it: [#11026](https://github.com/pingcap/tidb/pull/11026) +* Change the default behavior during the period of splitting Regions to wait for PD to finish scheduling [#11166](https://github.com/pingcap/tidb/pull/11166) +* Prohibit Window Functions from being cached in Prepare Plan Cache to avoid incorrect results in some cases [#11048](https://github.com/pingcap/tidb/pull/11048) +* Prohibit `ALTER` statements from modifying the definition of stored generated columns [#11068](https://github.com/pingcap/tidb/pull/11068) +* Disallow changing virtual generated columns to stored generated columns [#11068](https://github.com/pingcap/tidb/pull/11068) +* Disallow changing the generated column expression with indexes [#11068](https://github.com/pingcap/tidb/pull/11068) +* Support compiling TiDB on the ARM64 architecture [#11150](https://github.com/pingcap/tidb/pull/11150) +* Support modifying the collation of a database or a table, but the character set of the database/table has to be UTF-8 or utf8mb4 [#11086](https://github.com/pingcap/tidb/pull/11086) +* Fix the issue that an error is reported when the `SELECT` subquery in the `UPDATE … SELECT` statement fails to resolve the column in the `UPDATE` expression [#11252](https://github.com/pingcap/tidb/pull/11252) +* Fix the panic issue that happens when a column is queried on multiple times and the returned result is NULL during point queries [#11226](https://github.com/pingcap/tidb/pull/11226) +* Fix the data race issue caused by non-thread safe `rand.Rand` when using the `RAND` function [#11169](https://github.com/pingcap/tidb/pull/11169) +* Fix the bug that the memory usage of a SQL statement exceeds the threshold but the execution of this statement is not canceled in some cases when `oom-action="cancel"` is configured [#11004](https://github.com/pingcap/tidb/pull/11004) +* Fix the issue that when a query ends, `SHOW PROCESSLIST` shows that the memory usage is not `0` because the memory usage of MemTracker was not correctly cleaned [#10970](https://github.com/pingcap/tidb/pull/10970) +* Fix the bug that the result of comparing integers and non-integers is not correct in some cases [#11194](https://github.com/pingcap/tidb/pull/11194) +* Fix the bug that the query result is not correct when the query on table partitions contains a predicate in explicit transactions [#11196](https://github.com/pingcap/tidb/pull/11196) +* Fix the DDL job panic issue because `infoHandle` might be `NULL` [#11022](https://github.com/pingcap/tidb/pull/11022) +* Fix the issue that the query result is not correct because the queried column is not referenced in the subquery and is then wrongly pruned when running a nested aggregation query [#11020](https://github.com/pingcap/tidb/pull/11020) +* Fix the issue that the `Sleep` function does not respond to the `KILL` statement in time [#11028](https://github.com/pingcap/tidb/pull/11028) +* Fix the issue that the `DB` and `INFO` columns shown by the `SHOW PROCESSLIST` command are incompatible with MySQL [#11003](https://github.com/pingcap/tidb/pull/11003) +* Fix the system panic issue caused by the `FLUSH PRIVILEGES` statement when `skip-grant-table=true` is configured [#11027](https://github.com/pingcap/tidb/pull/11027) +* Fix the issue that the primary key statistics collected by `FAST ANALYZE` are not correct when the table primary key is an `UNSIGNED` integer [#11099](https://github.com/pingcap/tidb/pull/11099) +* Fix the issue that the "invalid key" error is reported by the `FAST ANALYZE` statement in some cases [#11098](https://github.com/pingcap/tidb/pull/11098) +* Fix the issue that the precision shown by the `SHOW CREATE TABLE` statement is incomplete when `CURRENT_TIMESTAMP` is used as the default value of the column and the decimal precision is specified [#11088](https://github.com/pingcap/tidb/pull/11088) +* Fix the issue that the function name is not in lowercase when window functions report an error to make it compatible with MySQL [#11118](https://github.com/pingcap/tidb/pull/11118) +* Fix the issue that TiDB fails to connect to TiKV and thus cannot provide service after the background thread of TiKV Client Batch gRPC panics [#11101](https://github.com/pingcap/tidb/pull/11101) +* Fix the issue that the variable is set incorrectly by `SetVar` because of the shallow copy of the string [#11044](https://github.com/pingcap/tidb/pull/11044) +* Fix the issue that the execution fails and an error is reported when the `INSERT … ON DUPLICATE` statement is applied on table partitions [#11231](https://github.com/pingcap/tidb/pull/11231) +* Pessimistic locking (experimental feature) + - Fix the issue that an incorrect result is returned because of the invalid lock on the row when point queries are run using the pessimistic locking and the returned data is empty [#10976](https://github.com/pingcap/tidb/pull/10976) + - Fix the issue that the query result is not correct because `SELECT … FOR UPDATE` does not use the correct TSO when using the pessimistic locking in the query [#11015](https://github.com/pingcap/tidb/pull/11015) +* Change the detection behavior from immediate conflict detection to waiting when an optimistic transaction meets a pessimistic lock to avoid worsening the lock conflict [#11051](https://github.com/pingcap/tidb/pull/11051) + + +## [3.0.0] 2019-06-28 +## New Features +* Support Window Functions; compatible with all window functions in MySQL 8.0, including `NTILE`, `LEAD`, `LAG`, `PERCENT_RANK`, `NTH_VALUE`, `CUME_DIST`, `FIRST_VALUE` , `LAST_VALUE`, `RANK`, `DENSE_RANK`, and `ROW_NUMBER` +* Support Views (Experimental) +* Improve Table Partition + - Support Range Partition + - Support Hash Partition +* Add the plug-in framework, supporting plugins such as IP Whitelist (Enterprise feature) and Audit Log (Enterprise feature). +* Support the SQL Plan Management function to create SQL execution plan binding to ensure query stability (Experimental) + +## SQL Optimizer +* Optimize the `NOT EXISTS` subquery and convert it to `Anti Semi Join` to improve performance +* Optimize the constant propagation on the `Outer Join`, and add the optimization rule of `Outer Join` elimination to reduce non-effective computations and improve performance +* Optimize the `IN` subquery to execute `Inner Join` after aggregation to improve performance +* Optimize `Index Join` to adapt to more scenarios +* Improve the Partition Pruning optimization rule of Range Partition +* Optimize the query logic for `_tidb_rowid`to avoid full table scan and improve performance +* Match more prefix columns of the indexes when extracting access conditions of composite indexes if there are relevant columns in the filter to improve performance +* Improve the accuracy of cost estimates by using order correlation between columns +* Optimize `Join Reorder` based on the Greedy algorithm and the dynamic planning algorithm to improve accuracy for index selection using `Join` +* Support Skyline Pruning, with some rules to prevent the execution plan from relying too heavily on statistics to improve query stability +* Improve the accuracy of row count estimation for single-column indexes with NULL values +* Support `FAST ANALYZE` that randomly samples in each Region to avoid full table scan and improve performance with statistics collection +* Support the incremental Analyze operation on monotonically increasing index columns to improve performance with statistics collection +* Support using subqueries in the `DO` statement +* Support using `Index Join` in transactions +* Optimize `prepare`/`execute` to support DDL statements with no parameters +* Modify the system behaviour to auto load statistics when the `stats-lease` variable value is 0 +* Support exporting historical statistics +* Support the `dump`/`load` correlation of histograms + +## SQL Execution Engine +* Optimize log output: `EXECUTE` outputs user variables and `COMMIT` outputs slow query logs to facilitate troubleshooting +* Support the `EXPLAIN ANALYZE` function to improve SQL tuning usability +* Support the `admin show next_row_id` command to get the ID of the next row +* Add six built-in functions: `JSON_QUOTE`, `JSON_ARRAY_APPEND`, `JSON_MERGE_PRESERVE`, `BENCHMARK` ,`COALESCE`, and `NAME_CONST` +* Optimize control logics on the chunk size to dynamically adjust based on the query context, to reduce the SQL execution time and resource consumption +* Support tracking and controlling memory usage in three operators - `TableReader`, `IndexReader` and `IndexLookupReader` +* Optimize the Merge Join operator to support an empty `ON` condition +* Optimize write performance for single tables that contains too many columns +* Improve the performance of `admin show ddl jobs` by supporting scanning data in reverse order +* Add the `split table region` statement to manually split the table Region to alleviate the hotspot issue +* Add the `split index region` statement to manually split the index Region to alleviate the hotspot issue +* Add a blacklist to prohibit pushing down expressions to Coprocessor +* Optimize the `Expensive Query` log to print the SQL query in the log when it exceeds the configured limit of execution time or memory + +## DDL +* Support migrating from character set `utf8` to `utf8mb4` +* Change the default character set from`utf8` to `utf8mb4` +* Add the `alter schema` statement to modify the character set and the collation of the database +* Support ALTER algorithm `INPLACE`/`INSTANT` +* Support `SHOW CREATE VIEW` +* Support `SHOW CREATE USER` +* Support fast recovery of mistakenly deleted tables +* Support adjusting the number of concurrencies of ADD INDEX dynamically +* Add the `pre_split_regions` option that pre-allocates Regions when creating the table using the `CREATE TABLE` statement, to relieve write hot Regions caused by lots of writes after the table creation +* Support splitting Regions by the index and range of the table specified using SQL statements to relieve hotspot issues +* Add the `ddl_error_count_limit` global variable to limit the number of DDL task retries +* Add a feature to use `SHARD_ROW_ID_BITS` to scatter row IDs when the column contains an AUTO_INCREMENT attribute to relieve the hotspot issue +* Optimize the lifetime of invalid DDL metadata to speed up recovering the normal execution of DDL operations after upgrading the TiDB cluster + +## Transactions +* Support the pessimistic transaction model (Experimental) +* Optimize transaction processing logics to adapt to more scenarios: + - Change the default value `tidb_disable_txn_auto_retry` to `on`, which means non-auto committed transactions will not be retried + - Add the `tidb_batch_commit` system variable to split a transaction into multiple ones to be executed concurrently + - Add the `tidb_low_resolution_tso` system variable to control the number of TSOs to obtain in batches and reduce the number of times that transactions request for TSOs, to improve performance in scenarios with relatively low requirement of consistency + - Add the `tidb_skip_isolation_level_check` variable to control whether to report errors when the isolation level is set to SERIALIZABLE + - Modify the `tidb_disable_txn_auto_retry` system variable to make it work on all retryable errors + +## Permission Management +* Perform permission check on the `ANALYZE`, `USE`, `SET GLOBAL`, and `SHOW PROCESSLIST` statements +* Support Role Based Access Control (RBAC) (Experimental) + +## Server +* Optimize slow query logs + - Restructure the log format + - Optimize the log content + - Optimize the log query method to support using the `INFORMATION_SCHEMA.SLOW_QUERY` and `ADMIN SHOW SLOW` statements of the memory table to query slow query logs +* Develop a unified log format specification with restructured log system to facilitate collection and analysis by tools +* Support using SQL statements to manage Binlog services, including querying status, enabling Binlog, maintaining and sending Binlog strategies. +* Support using `unix_socket` to connect to the database +* Support `Trace` for SQL statements +* Support getting information for a TiDB instance via the `/debug/zip` HTTP interface to facilitate troubleshooting. +* Optimize monitoring items to facilitate troubleshooting: + - Add the `high_error_rate_feedback_total` monitoring item to monitor the difference between the actual data volume and the estimated data volume based on statistics + - Add a QPS monitoring item in the database dimension +* Optimize the system initialization process to only allow the DDL owner to perform the initialization. This reduces the startup time for initialization or upgrading. +* Optimize the execution logic of `kill query` to improve performance and ensure resource is release properly +* Add a startup option `config-check` to check the validity of the configuration file +* Add the `tidb_back_off_weight` system variable to control the backoff time of internal error retries +* Add the `wait_timeout`and `interactive_timeout` system variables to control the maximum idle connections allowed +* Add the connection pool for TiKV to shorten the connection establishing time + +## Compatibility +* Support the `ALLOW_INVALID_DATES` SQL mode +* Support the MySQL 320 Handshake protocol +* Support manifesting unsigned BIGINT columns as auto-increment columns +* Support the `SHOW CREATE DATABASE IF NOT EXISTS` syntax +* Optimize the fault tolerance of `load data` for CSV files +* Abandon the predicate pushdown operation when the filtering condition contains a user variable to improve the compatibility with MySQL's behavior of using user variables to simulate Window Functions + + +## [3.0.0-rc.3] 2019-06-21 +## SQL Optimizer +* Remove the feature of collecting virtual generated column statistics[#10629](https://github.com/pingcap/tidb/pull/10629) +* Fix the issue that the primary key constant overflows during point queries [#10699](https://github.com/pingcap/tidb/pull/10699) +* Fix the issue that using uninitialized information in `fast analyze` causes panic [#10691](https://github.com/pingcap/tidb/pull/10691) +* Fix the issue that executing the `create view` statement using `prepare` causes panic because of wrong column information [#10713](https://github.com/pingcap/tidb/pull/10713) +* Fix the issue that the column information is not cloned when handling window functions [#10720](https://github.com/pingcap/tidb/pull/10720) +* Fix the wrong estimation for the selectivity rate of the inner table selection in index join [#10854](https://github.com/pingcap/tidb/pull/10854) +* Support automatic loading statistics when the `stats-lease` variable value is 0 [#10811](https://github.com/pingcap/tidb/pull/10811) + +## Execution Engine +* Fix the issue that resources are not correctly released when calling the `Close` function in `StreamAggExec` [#10636](https://github.com/pingcap/tidb/pull/10636) +* Fix the issue that the order of `table_option` and `partition_options` is incorrect in the result of executing the `show create table` statement for partitioned tables [#10689](https://github.com/pingcap/tidb/pull/10689) +* Improve the performance of `admin show ddl jobs` by supporting scanning data in reverse order [#10687](https://github.com/pingcap/tidb/pull/10687) +* Fix the issue that the result of the `show grants` statement in RBAC is incompatible with that of MySQL when this statement has the `current_user` field [#10684](https://github.com/pingcap/tidb/pull/10684) +* Fix the issue that UUIDs might generate duplicate values ​​on multiple nodes [#10712](https://github.com/pingcap/tidb/pull/10712) +* Fix the issue that the `show view` privilege is not considered in `explain` [#10635](https://github.com/pingcap/tidb/pull/10635) +* Add the `split table region` statement to manually split the table Region to alleviate the hotspot issue [#10765](https://github.com/pingcap/tidb/pull/10765) +* Add the `split index region` statement to manually split the index Region to alleviate the hotspot issue [#10764](https://github.com/pingcap/tidb/pull/10764) +* Fix the incorrect execution issue when you execute multiple statements such as `create user`, `grant`, or `revoke` consecutively [#10737](https://github.com/pingcap/tidb/pull/10737) +* Add a blacklist to prohibit pushing down expressions to Coprocessor [#10791](https://github.com/pingcap/tidb/pull/10791) +* Add the feature of printing the `expensive query` log when a query exceeds the memory configuration limit [#10849](https://github.com/pingcap/tidb/pull/10849) +* Add the `bind-info-lease` configuration item to control the update time of the modified binding execution plan [#10727](https://github.com/pingcap/tidb/pull/10727) +* Fix the OOM issue in high concurrent scenarios caused by the failure to quickly release Coprocessor resources, resulted from the `execdetails.ExecDetails` pointer [#10832](https://github.com/pingcap/tidb/pull/10832) +* Fix the panic issue caused by the `kill` statement in some cases [#10876](https://github.com/pingcap/tidb/pull/10876) +## Server +* Fix the issue that goroutine might leak when repairing GC [#10683](https://github.com/pingcap/tidb/pull/10683) +* Support displaying the `host` information in slow queries [#10693](https://github.com/pingcap/tidb/pull/10693) +* Support reusing idle links that interact with TiKV [#10632](https://github.com/pingcap/tidb/pull/10632) +* Fix the support for enabling the `skip-grant-table` option in RBAC [#10738](https://github.com/pingcap/tidb/pull/10738) +* Fix the issue that `pessimistic-txn` configuration goes invalid [#10825](https://github.com/pingcap/tidb/pull/10825) +* Fix the issue that the actively cancelled ticlient requests are still retried [#10850](https://github.com/pingcap/tidb/pull/10850) +* Improve performance in the case where pessimistic transactions conflict with optimistic transactions [#10881](https://github.com/pingcap/tidb/pull/10881) +## DDL +* Fix the issue that modifying charset using `alter table` causes the `blob` type change [#10698](https://github.com/pingcap/tidb/pull/10698) +* Add a feature to use `SHARD_ROW_ID_BITS` to scatter row IDs when the column contains an `AUTO_INCREMENT` attribute to alleviate the hotspot issue [#10794](https://github.com/pingcap/tidb/pull/10794) +* Prohibit adding stored generated columns by using the `alter table` statement [#10808](https://github.com/pingcap/tidb/pull/10808) +* Optimize the invalid survival time of DDL metadata to shorten the period during which the DDL operation is slower after cluster upgrade [#10795](https://github.com/pingcap/tidb/pull/10795) + +## [3.0.0-rc.2] 2019-05-28 +### SQL Optimizer +* Support Index Join in more scenarios +[#10540](https://github.com/pingcap/tidb/pull/10540) +* Support exporting historical statistics [#10291](https://github.com/pingcap/tidb/pull/10291) +* Support the incremental `Analyze` operation on monotonically increasing index columns +[#10355](https://github.com/pingcap/tidb/pull/10355) +* Neglect the NULL value in the `Order By` clause [#10488](https://github.com/pingcap/tidb/pull/10488) +* Fix the wrong schema information calculation of the `UnionAll` logical operator when simplifying the column information [#10384](https://github.com/pingcap/tidb/pull/10384) +* Avoid modifying the original expression when pushing down the `Not` operator [#10363](https://github.com/pingcap/tidb/pull/10363/files) +* Support the `dump`/`load` correlation of histograms [#10573](https://github.com/pingcap/tidb/pull/10573) +### Execution Engine +* Handle virtual columns with a unique index properly when fetching duplicate rows in `batchChecker` [#10370](https://github.com/pingcap/tidb/pull/10370) +* Fix the scanning range calculation issue for the `CHAR` column [#10124](https://github.com/pingcap/tidb/pull/10124) +* Fix the issue of `PointGet` incorrectly processing negative numbers [#10113](https://github.com/pingcap/tidb/pull/10113) +* Merge `Window` functions with the same name to improve execution efficiency [#9866](https://github.com/pingcap/tidb/pull/9866) +* Allow the `RANGE` frame in a `Window` function to contain no `OrderBy` clause [#10496](https://github.com/pingcap/tidb/pull/10496) + +### Server +Fix the issue that TiDB continuously creates a new connection to TiKV when a fault occurs in TiKV [#10301](https://github.com/pingcap/tidb/pull/10301) +Make `tidb_disable_txn_auto_retry` affect all retryable errors instead of only write conflict errors [#10339](https://github.com/pingcap/tidb/pull/10339) +Allow DDL statements without parameters to be executed using `prepare`/`execute` [#10144](https://github.com/pingcap/tidb/pull/10144) +Add the `tidb_back_off_weight` variable to control the backoff time [#10266](https://github.com/pingcap/tidb/pull/10266) +Prohibit TiDB retrying non-automatically committed transactions in default conditions by setting the default value of `tidb_disable_txn_auto_retry` to `on` [#10266](https://github.com/pingcap/tidb/pull/10266) +Fix the database privilege judgment of `role` in `RBAC` [#10261](https://github.com/pingcap/tidb/pull/10261) +Support the pessimistic transaction model (experimental) [#10297](https://github.com/pingcap/tidb/pull/10297) +Reduce the wait time for handling lock conflicts in some cases [#10006](https://github.com/pingcap/tidb/pull/10006) +Make the Region cache able to visit follower nodes when a fault occurs in the leader node [#10256](https://github.com/pingcap/tidb/pull/10256) +Add the `tidb_low_resolution_tso` variable to control the number of TSOs obtained in batches and reduce the times of transactions obtaining TSO to adapt for scenarios where data consistency is not so strictly required [#10428](https://github.com/pingcap/tidb/pull/10428) + +### DDL +Fix the uppercase issue of the charset name in the storage of the old version of TiDB +[#10272](https://github.com/pingcap/tidb/pull/10272) +Support `preSplit` of table partition, which pre-allocates table Regions when creating a table to avoid write hotspots after the table is created +[#10221](https://github.com/pingcap/tidb/pull/10221) +Fix the issue that TiDB incorrectly updates the version information in PD in some cases [#10324](https://github.com/pingcap/tidb/pull/10324) +Support modifying the charset and collation using the `ALTER DATABASE` statement +[#10393](https://github.com/pingcap/tidb/pull/10393) +Support splitting Regions based on the index and range of the specified table to relieve hotspot issues +[#10203](https://github.com/pingcap/tidb/pull/10203) +Prohibit modifying the precision of the decimal column using the `alter table` statement +[#10433](https://github.com/pingcap/tidb/pull/10433) +Fix the restriction for expressions and functions in hash partition +[#10273](https://github.com/pingcap/tidb/pull/10273) +Fix the issue that adding indexes in a table that contains partitions will in some cases cause TiDB panic +[#10475](https://github.com/pingcap/tidb/pull/10475) +Validate table information before executing the DDL to avoid invalid table schemas +[#10464](https://github.com/pingcap/tidb/pull/10464) +Enable hash partition by default; and enable range columns partition when there is only one column in the partition definition +[#9936](https://github.com/pingcap/tidb/pull/9936) + + ## [3.0.0-rc.1] 2019-05-10 ### SQL Optimizer @@ -146,7 +629,7 @@ All notable changes to this project will be documented in this file. See also [R * Support the `show pump status` and `show drainer status` SQL statements to check the Pump or Drainer status [9456](https://github.com/pingcap/tidb/pull/9456) * Support modifying the Pump or Drainer status by using SQL statements [#9789](https://github.com/pingcap/tidb/pull/9789) * Support adding HASH fingerprints to SQL text for easy tracking of slow SQL statements [#9662](https://github.com/pingcap/tidb/pull/9662) -* Add the `log_bin` system variable (“0” by default) to control the enabling state of binlog; only support checking the state currently [#9343](https://github.com/pingcap/tidb/pull/9343) +* Add the `log_bin` system variable ("0" by default) to control the enabling state of binlog; only support checking the state currently [#9343](https://github.com/pingcap/tidb/pull/9343) * Support managing the sending binlog strategy by using the configuration file [#9864](https://github.com/pingcap/tidb/pull/9864) * Support querying the slow log by using the `INFORMATION_SCHEMA.SLOW_QUERY` memory table [#9290](https://github.com/pingcap/tidb/pull/9290) * Change the MySQL version displayed in TiDB from 5.7.10 to 5.7.25 [#9553](https://github.com/pingcap/tidb/pull/9553) @@ -154,7 +637,7 @@ All notable changes to this project will be documented in this file. See also [R * Add the `high_error_rate_feedback_total` monitoring item to record the difference between the actual data volume and the estimated data volume based on statistics [#9209](https://github.com/pingcap/tidb/pull/9209) * Add the QPS monitoring item in the database dimension, which can be enabled by using a configuration item [#9151](https://github.com/pingcap/tidb/pull/9151) ### DDL -* Add the `ddl_error_count_limit` global variable (“512” by default) to limit the number of DDL task retries (If this number exceeds the limit, the DDL task is canceled) [#9295](https://github.com/pingcap/tidb/pull/9295) +* Add the `ddl_error_count_limit` global variable ("512" by default) to limit the number of DDL task retries (If this number exceeds the limit, the DDL task is canceled) [#9295](https://github.com/pingcap/tidb/pull/9295) * Support ALTER ALGORITHM `INPLACE`/`INSTANT` [#8811](https://github.com/pingcap/tidb/pull/8811) * Support the `SHOW CREATE VIEW` statement [#9309](https://github.com/pingcap/tidb/pull/9309) * Support the `SHOW CREATE USER` statement [#9240](https://github.com/pingcap/tidb/pull/9240) @@ -335,7 +818,7 @@ All notable changes to this project will be documented in this file. See also [R * Support the MySQL 320 handshake protocol [#8812](https://github.com/pingcap/tidb/pull/8812) * Support using the unsigned bigint column as the auto-increment column [#8181](https://github.com/pingcap/tidb/pull/8181) * Support the `SHOW CREATE DATABASE IF NOT EXISTS` syntax [#8926](https://github.com/pingcap/tidb/pull/8926) -* Abandon the predicate pushdown operation when the filtering condition contains a user variable to improve the compatibility with MySQL’s behavior of using user variables to mock the Window Function behavior [#8412](https://github.com/pingcap/tidb/pull/8412) +* Abandon the predicate pushdown operation when the filtering condition contains a user variable to improve the compatibility with MySQL's behavior of using user variables to mock the Window Function behavior [#8412](https://github.com/pingcap/tidb/pull/8412) ### DDL @@ -511,7 +994,7 @@ All notable changes to this project will be documented in this file. See also [R * Refactor Latch to avoid misjudgment of transaction conflicts and improve the execution performance of concurrent transactions [#7711](https://github.com/pingcap/tidb/pull/7711) * Fix the panic issue caused by collecting slow queries in some cases [#7874](https://github.com/pingcap/tidb/pull/7847) * Fix the panic issue when `ESCAPED BY` is an empty string in the `LOAD DATA` statement [#8005](https://github.com/pingcap/tidb/pull/8005) -* Complete the “coprocessor error” log information [#8006](https://github.com/pingcap/tidb/pull/8006) +* Complete the "coprocessor error" log information [#8006](https://github.com/pingcap/tidb/pull/8006) ### Compatibility * Set the `Command` field of the `SHOW PROCESSLIST` result to `Sleep` when the query is empty [#7839](https://github.com/pingcap/tidb/pull/7839) ### Expressions @@ -536,7 +1019,7 @@ All notable changes to this project will be documented in this file. See also [R * Optimize the performance of Hash aggregate operators [#7541](https://github.com/pingcap/tidb/pull/7541) * Optimize the performance of Join operators [#7493](https://github.com/pingcap/tidb/pull/7493), [#7433](https://github.com/pingcap/tidb/pull/7433) * Fix the issue that the result of `UPDATE JOIN` is incorrect when the Join order is changed [#7571](https://github.com/pingcap/tidb/pull/7571) -* Improve the performance of Chunk’s iterator [#7585](https://github.com/pingcap/tidb/pull/7585) +* Improve the performance of Chunk's iterator [#7585](https://github.com/pingcap/tidb/pull/7585) ### Statistics * Fix the issue that the auto Analyze work repeatedly analyzes the statistics [#7550](https://github.com/pingcap/tidb/pull/7550) * Fix the statistics update error that occurs when there is no statistics change [#7530](https://github.com/pingcap/tidb/pull/7530) @@ -559,7 +1042,7 @@ All notable changes to this project will be documented in this file. See also [R * Use different labels to filter internal SQL and user SQL in monitoring metrics [#7631](https://github.com/pingcap/tidb/pull/7631) * Store the top 30 slow queries in the last week to the TiDB server [#7646](https://github.com/pingcap/tidb/pull/7646) * Put forward a proposal of setting the global system time zone for the TiDB cluster [#7656](https://github.com/pingcap/tidb/pull/7656) -* Enrich the error message of “GC life time is shorter than transaction duration” [#7658](https://github.com/pingcap/tidb/pull/7658) +* Enrich the error message of "GC life time is shorter than transaction duration" [#7658](https://github.com/pingcap/tidb/pull/7658) * Set the global system time zone when starting the TiDB cluster [#7638](https://github.com/pingcap/tidb/pull/7638) ### Compatibility * Add the unsigned flag for the `Year` type [#7542](https://github.com/pingcap/tidb/pull/7542) diff --git a/Dockerfile b/Dockerfile index 9d303c3ab0440..20a183efed942 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Builder image -FROM golang:1.12-alpine as builder +FROM golang:1.13-alpine as builder RUN apk add --no-cache \ wget \ @@ -34,4 +34,4 @@ WORKDIR / EXPOSE 4000 -ENTRYPOINT ["/usr/local/bin/dumb-init", "/tidb-server"] \ No newline at end of file +ENTRYPOINT ["/usr/local/bin/dumb-init", "/tidb-server"] diff --git a/Makefile b/Makefile index 2af63a547d3b9..a1b9914816172 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,8 @@ path_to_add := $(addsuffix /bin,$(subst :,/bin:,$(GOPATH))):$(PWD)/tools/bin export PATH := $(path_to_add):$(PATH) GO := GO111MODULE=on go -GOBUILD := CGO_ENABLED=1 $(GO) build $(BUILD_FLAG) +GOBUILD := CGO_ENABLED=1 $(GO) build $(BUILD_FLAG) -trimpath +GOBUILDCOVERAGE := GOPATH=$(GOPATH) CGO_ENABLED=1 cd tidb-server; $(GO) test -coverpkg="../..." -c . GOTEST := CGO_ENABLED=1 $(GO) test -p 4 OVERALLS := CGO_ENABLED=1 GO111MODULE=on overalls @@ -34,6 +35,7 @@ LDFLAGS += -X "github.com/pingcap/tidb/util/printer.TiDBGitBranch=$(shell git re LDFLAGS += -X "github.com/pingcap/tidb/util/printer.GoVersion=$(shell go version)" TEST_LDFLAGS = -X "github.com/pingcap/tidb/config.checkBeforeDropLDFlag=1" +COVERAGE_SERVER_LDFLAGS = -X "github.com/pingcap/tidb/tidb-server.isCoverageServer=1" CHECK_LDFLAGS += $(LDFLAGS) ${TEST_LDFLAGS} @@ -61,7 +63,7 @@ build: # Install the check tools. check-setup:tools/bin/revive tools/bin/goword tools/bin/gometalinter tools/bin/gosec -check: fmt errcheck lint tidy check-static vet +check: fmt errcheck lint tidy testSuite check-static vet # These need to be fixed before they can be ran regularly check-fail: goword check-slow @@ -106,6 +108,10 @@ tidy: @echo "go mod tidy" ./tools/check/check-tidy.sh +testSuite: + @echo "testSuite" + ./tools/check/check_testSuite.sh + clean: $(GO) clean -i ./... rm -rf *.out @@ -129,8 +135,8 @@ endif gotest: failpoint-enable ifeq ("$(TRAVIS_COVERAGE)", "1") @echo "Running in TRAVIS_COVERAGE mode." - @export log_level=error; \ $(GO) get github.com/go-playground/overalls + @export log_level=error; \ $(OVERALLS) -project=github.com/pingcap/tidb \ -covermode=count \ -ignore='.git,vendor,cmd,docs,LICENSES' \ @@ -183,6 +189,13 @@ else $(GOBUILD) $(RACE_FLAG) -ldflags '$(CHECK_LDFLAGS)' -o '$(TARGET)' tidb-server/main.go endif +server_coverage: +ifeq ($(TARGET), "") + $(GOBUILDCOVERAGE) $(RACE_FLAG) -ldflags '$(LDFLAGS) $(COVERAGE_SERVER_LDFLAGS) $(CHECK_FLAG)' -o ../bin/tidb-server-coverage +else + $(GOBUILDCOVERAGE) $(RACE_FLAG) -ldflags '$(LDFLAGS) $(COVERAGE_SERVER_LDFLAGS) $(CHECK_FLAG)' -o '$(TARGET)' +endif + benchkv: $(GOBUILD) -ldflags '$(LDFLAGS)' -o bin/benchkv cmd/benchkv/main.go diff --git a/bindinfo/bind_test.go b/bindinfo/bind_test.go index d6e1a12a9a5f2..596422d0b5621 100644 --- a/bindinfo/bind_test.go +++ b/bindinfo/bind_test.go @@ -19,7 +19,6 @@ import ( "fmt" "os" "testing" - "time" . "github.com/pingcap/check" "github.com/pingcap/parser" @@ -27,6 +26,7 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/mockstore/mocktikv" @@ -34,6 +34,7 @@ import ( "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" + dto "github.com/prometheus/client_model/go" ) func TestT(t *testing.T) { @@ -73,7 +74,7 @@ func (s *testSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() } d, err := session.BootstrapSession(s.store) c.Assert(err, IsNil) @@ -133,6 +134,11 @@ func (s *testSuite) TestBindParse(c *C) { c.Check(bindData.Collation, Equals, "utf8mb4_bin") c.Check(bindData.CreateTime, NotNil) c.Check(bindData.UpdateTime, NotNil) + + // Test fields with quotes or slashes. + sql = `CREATE GLOBAL BINDING FOR select * from t where a BETWEEN "a" and "b" USING select * from t use index(idx) where a BETWEEN "a\nb\rc\td\0e" and 'x'` + tk.MustExec(sql) + tk.MustExec(`DROP global binding for select * from t use index(idx) where a BETWEEN "a\nb\rc\td\0e" and "x"`) } func (s *testSuite) TestGlobalBinding(c *C) { @@ -145,13 +151,21 @@ func (s *testSuite) TestGlobalBinding(c *C) { tk.MustExec("create table t1(i int, s varchar(20))") tk.MustExec("create index index_t on t(i,s)") + metrics.BindTotalGauge.Reset() + metrics.BindMemoryUsage.Reset() + _, err := tk.Exec("create global binding for select * from t where i>100 using select * from t use index(index_t) where i>100") c.Assert(err, IsNil, Commentf("err %v", err)) - time.Sleep(time.Second * 1) _, err = tk.Exec("create global binding for select * from t where i>99 using select * from t use index(index_t) where i>99") c.Assert(err, IsNil) + pb := &dto.Metric{} + metrics.BindTotalGauge.WithLabelValues(metrics.ScopeGlobal, bindinfo.Using).Write(pb) + c.Assert(pb.GetGauge().GetValue(), Equals, float64(1)) + metrics.BindMemoryUsage.WithLabelValues(metrics.ScopeGlobal, bindinfo.Using).Write(pb) + c.Assert(pb.GetGauge().GetValue(), Equals, float64(161)) + sql, hash := parser.NormalizeDigest("select * from t where i > ?") bindData := s.domain.BindHandle().GetBindRecord(hash, sql, "test") @@ -167,7 +181,7 @@ func (s *testSuite) TestGlobalBinding(c *C) { rs, err := tk.Exec("show global bindings") c.Assert(err, IsNil) - chk := rs.NewRecordBatch() + chk := rs.NewChunk() err = rs.Next(context.TODO(), chk) c.Check(err, IsNil) c.Check(chk.NumRows(), Equals, 1) @@ -202,6 +216,12 @@ func (s *testSuite) TestGlobalBinding(c *C) { bindData = s.domain.BindHandle().GetBindRecord(hash, sql, "test") c.Check(bindData, IsNil) + metrics.BindTotalGauge.WithLabelValues(metrics.ScopeGlobal, bindinfo.Using).Write(pb) + c.Assert(pb.GetGauge().GetValue(), Equals, float64(0)) + metrics.BindMemoryUsage.WithLabelValues(metrics.ScopeGlobal, bindinfo.Using).Write(pb) + // From newly created global bind handle. + c.Assert(pb.GetGauge().GetValue(), Equals, float64(161)) + bindHandle = bindinfo.NewBindHandle(tk.Se) err = bindHandle.Update(true) c.Check(err, IsNil) @@ -212,7 +232,7 @@ func (s *testSuite) TestGlobalBinding(c *C) { rs, err = tk.Exec("show global bindings") c.Assert(err, IsNil) - chk = rs.NewRecordBatch() + chk = rs.NewChunk() err = rs.Next(context.TODO(), chk) c.Check(err, IsNil) c.Check(chk.NumRows(), Equals, 0) @@ -234,13 +254,21 @@ func (s *testSuite) TestSessionBinding(c *C) { tk.MustExec("create table t1(i int, s varchar(20))") tk.MustExec("create index index_t on t(i,s)") + metrics.BindTotalGauge.Reset() + metrics.BindMemoryUsage.Reset() + _, err := tk.Exec("create session binding for select * from t where i>100 using select * from t use index(index_t) where i>100") c.Assert(err, IsNil, Commentf("err %v", err)) - time.Sleep(time.Second * 1) _, err = tk.Exec("create session binding for select * from t where i>99 using select * from t use index(index_t) where i>99") c.Assert(err, IsNil) + pb := &dto.Metric{} + metrics.BindTotalGauge.WithLabelValues(metrics.ScopeSession, bindinfo.Using).Write(pb) + c.Assert(pb.GetGauge().GetValue(), Equals, float64(1)) + metrics.BindMemoryUsage.WithLabelValues(metrics.ScopeSession, bindinfo.Using).Write(pb) + c.Assert(pb.GetGauge().GetValue(), Equals, float64(161)) + handle := tk.Se.Value(bindinfo.SessionBindInfoKeyType).(*bindinfo.SessionHandle) bindData := handle.GetBindRecord("select * from t where i > ?", "test") c.Check(bindData, NotNil) @@ -255,14 +283,14 @@ func (s *testSuite) TestSessionBinding(c *C) { rs, err := tk.Exec("show global bindings") c.Assert(err, IsNil) - chk := rs.NewRecordBatch() + chk := rs.NewChunk() err = rs.Next(context.TODO(), chk) c.Check(err, IsNil) c.Check(chk.NumRows(), Equals, 0) rs, err = tk.Exec("show session bindings") c.Assert(err, IsNil) - chk = rs.NewRecordBatch() + chk = rs.NewChunk() err = rs.Next(context.TODO(), chk) c.Check(err, IsNil) c.Check(chk.NumRows(), Equals, 1) @@ -282,6 +310,11 @@ func (s *testSuite) TestSessionBinding(c *C) { c.Check(bindData, NotNil) c.Check(bindData.OriginalSQL, Equals, "select * from t where i > ?") c.Check(bindData.Status, Equals, "deleted") + + metrics.BindTotalGauge.WithLabelValues(metrics.ScopeSession, bindinfo.Using).Write(pb) + c.Assert(pb.GetGauge().GetValue(), Equals, float64(0)) + metrics.BindMemoryUsage.WithLabelValues(metrics.ScopeSession, bindinfo.Using).Write(pb) + c.Assert(pb.GetGauge().GetValue(), Equals, float64(0)) } func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) { @@ -317,6 +350,7 @@ func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) { tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id") + metrics.BindUsageCounter.Reset() tk.MustQuery("explain SELECT * from t1,t2 where t1.id = t2.id").Check(testkit.Rows( "MergeJoin_7 12487.50 root inner join, left key:test.t1.id, right key:test.t2.id", "├─Sort_11 9990.00 root test.t1.id:asc", @@ -328,6 +362,9 @@ func (s *testSuite) TestGlobalAndSessionBindingBothExist(c *C) { " └─Selection_13 9990.00 cop not(isnull(test.t2.id))", " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) + pb := &dto.Metric{} + metrics.BindUsageCounter.WithLabelValues(metrics.ScopeGlobal).Write(pb) + c.Assert(pb.GetCounter().GetValue(), Equals, float64(1)) tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id") @@ -423,7 +460,7 @@ func (s *testSuite) TestErrorBind(c *C) { rs, err := tk.Exec("show global bindings") c.Assert(err, IsNil) - chk := rs.NewRecordBatch() + chk := rs.NewChunk() err = rs.Next(context.TODO(), chk) c.Check(err, IsNil) c.Check(chk.NumRows(), Equals, 0) diff --git a/bindinfo/cache.go b/bindinfo/cache.go index a4c2785eb9c64..5b74a8c316832 100644 --- a/bindinfo/cache.go +++ b/bindinfo/cache.go @@ -14,7 +14,10 @@ package bindinfo import ( + "unsafe" + "github.com/pingcap/parser/ast" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" ) @@ -64,3 +67,19 @@ func newBindRecord(row chunk.Row) *BindRecord { Collation: row.GetString(7), } } + +// size calculates the memory size of a bind meta. +func (m *BindRecord) size() float64 { + res := len(m.OriginalSQL) + len(m.BindSQL) + len(m.Db) + len(m.Status) + 2*int(unsafe.Sizeof(m.CreateTime)) + len(m.Charset) + len(m.Collation) + return float64(res) +} + +func (m *BindRecord) updateMetrics(scope string, inc bool) { + if inc { + metrics.BindMemoryUsage.WithLabelValues(scope, m.Status).Add(float64(m.size())) + metrics.BindTotalGauge.WithLabelValues(scope, m.Status).Inc() + } else { + metrics.BindMemoryUsage.WithLabelValues(scope, m.Status).Sub(float64(m.size())) + metrics.BindTotalGauge.WithLabelValues(scope, m.Status).Dec() + } +} diff --git a/bindinfo/handle.go b/bindinfo/handle.go index 2d0712bb5fdb8..739a29726b400 100644 --- a/bindinfo/handle.go +++ b/bindinfo/handle.go @@ -14,7 +14,6 @@ package bindinfo import ( - "bytes" "context" "fmt" "go.uber.org/zap" @@ -25,6 +24,8 @@ import ( "github.com/pingcap/parser" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/types" @@ -72,6 +73,9 @@ type BindHandle struct { lastUpdateTime types.Time } +// Lease influences the duration of loading bind info and handling invalid bind. +var Lease = 3 * time.Second + type invalidBindRecordMap struct { bindRecord *BindRecord droppedTime time.Time @@ -120,9 +124,10 @@ func (h *BindHandle) Update(fullLoad bool) (err error) { continue } - newCache.removeStaleBindMetas(hash, meta) + newCache.removeStaleBindMetas(hash, meta, metrics.ScopeGlobal) if meta.Status == Using { newCache[hash] = append(newCache[hash], meta) + metrics.BindMemoryUsage.WithLabelValues(metrics.ScopeGlobal, meta.Status).Add(meta.size()) } } return nil @@ -183,7 +188,6 @@ func (h *BindHandle) AddBindRecord(record *BindRecord) (err error) { } record.UpdateTime = record.CreateTime record.Status = Using - record.BindSQL = h.getEscapeCharacter(record.BindSQL) // insert the BindRecord to the storage. _, err = exec.Execute(context.TODO(), h.insertBindInfoSQL(record)) @@ -251,6 +255,7 @@ func (h *BindHandle) DropInvalidBindRecord() { if time.Since(invalidBindRecord.droppedTime) > 6*time.Second { delete(invalidBindRecordMap, key) + invalidBindRecord.bindRecord.updateMetrics(metrics.ScopeGlobal, false) } } h.invalidBindRecordMap.Store(invalidBindRecordMap) @@ -272,6 +277,7 @@ func (h *BindHandle) AddDropInvalidBindTask(invalidBindRecord *BindRecord) { bindRecord: invalidBindRecord, } h.invalidBindRecordMap.Store(newMap) + invalidBindRecord.updateMetrics(metrics.ScopeGlobal, true) } // Size return the size of bind info cache. @@ -317,8 +323,9 @@ func newBindMetaWithoutAst(record *BindRecord) (hash string, meta *BindMeta) { // removed from the cache after this operation. func (h *BindHandle) appendBindMeta(hash string, meta *BindMeta) { newCache := h.bindInfo.Value.Load().(cache).copy() - newCache.removeStaleBindMetas(hash, meta) + newCache.removeStaleBindMetas(hash, meta, metrics.ScopeGlobal) newCache[hash] = append(newCache[hash], meta) + meta.updateMetrics(metrics.ScopeGlobal, true) h.bindInfo.Value.Store(newCache) } @@ -331,18 +338,19 @@ func (h *BindHandle) removeBindMeta(hash string, meta *BindMeta) { h.bindInfo.Unlock() }() - newCache.removeDeletedBindMeta(hash, meta) + newCache.removeDeletedBindMeta(hash, meta, metrics.ScopeGlobal) } // removeDeletedBindMeta removes all the BindMeta which originSQL and db are the same with the parameter's meta. -func (c cache) removeDeletedBindMeta(hash string, meta *BindMeta) { +func (c cache) removeDeletedBindMeta(hash string, meta *BindMeta, scope string) { metas, ok := c[hash] if !ok { return } for i := len(metas) - 1; i >= 0; i-- { - if meta.isSame(meta) { + if metas[i].isSame(meta) { + metas[i].updateMetrics(scope, false) metas = append(metas[:i], metas[i+1:]...) if len(metas) == 0 { delete(c, hash) @@ -353,15 +361,15 @@ func (c cache) removeDeletedBindMeta(hash string, meta *BindMeta) { } // removeStaleBindMetas removes all the stale BindMeta in the cache. -func (c cache) removeStaleBindMetas(hash string, meta *BindMeta) { +func (c cache) removeStaleBindMetas(hash string, meta *BindMeta, scope string) { metas, ok := c[hash] if !ok { return } - // remove stale bindMetas. for i := len(metas) - 1; i >= 0; i-- { if metas[i].isStale(meta) { + metas[i].updateMetrics(scope, false) metas = append(metas[:i], metas[i+1:]...) if len(metas) == 0 { delete(c, hash) @@ -411,40 +419,29 @@ func (m *BindMeta) isSame(other *BindMeta) bool { func (h *BindHandle) deleteBindInfoSQL(normdOrigSQL, db string) string { return fmt.Sprintf( - "DELETE FROM mysql.bind_info WHERE original_sql='%s' AND default_db='%s'", - normdOrigSQL, - db, + `DELETE FROM mysql.bind_info WHERE original_sql=%s AND default_db=%s`, + expression.Quote(normdOrigSQL), + expression.Quote(db), ) } func (h *BindHandle) insertBindInfoSQL(record *BindRecord) string { - return fmt.Sprintf(`INSERT INTO mysql.bind_info VALUES ('%s', '%s', '%s', '%s', '%s', '%s','%s', '%s')`, - record.OriginalSQL, - record.BindSQL, - record.Db, - record.Status, - record.CreateTime, - record.UpdateTime, - record.Charset, - record.Collation, + return fmt.Sprintf(`INSERT INTO mysql.bind_info VALUES (%s, %s, %s, %s, %s, %s,%s, %s)`, + expression.Quote(record.OriginalSQL), + expression.Quote(record.BindSQL), + expression.Quote(record.Db), + expression.Quote(record.Status), + expression.Quote(record.CreateTime.String()), + expression.Quote(record.UpdateTime.String()), + expression.Quote(record.Charset), + expression.Quote(record.Collation), ) } func (h *BindHandle) logicalDeleteBindInfoSQL(normdOrigSQL, db string, updateTs types.Time) string { - return fmt.Sprintf(`UPDATE mysql.bind_info SET status='%s',update_time='%s' WHERE original_sql='%s' and default_db='%s'`, - deleted, - updateTs, - normdOrigSQL, - db) -} - -func (h *BindHandle) getEscapeCharacter(str string) string { - var buffer bytes.Buffer - for _, v := range str { - if v == '\'' || v == '"' || v == '\\' { - buffer.WriteString("\\") - } - buffer.WriteString(string(v)) - } - return buffer.String() + return fmt.Sprintf(`UPDATE mysql.bind_info SET status=%s,update_time=%s WHERE original_sql=%s and default_db=%s`, + expression.Quote(deleted), + expression.Quote(updateTs.String()), + expression.Quote(normdOrigSQL), + expression.Quote(db)) } diff --git a/bindinfo/session_handle.go b/bindinfo/session_handle.go index f343b3ca8e24d..f52c7d0f92e22 100644 --- a/bindinfo/session_handle.go +++ b/bindinfo/session_handle.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/parser" "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/types" ) @@ -38,8 +39,9 @@ func NewSessionBindHandle(parser *parser.Parser) *SessionHandle { // removed from the cache after this operation. func (h *SessionHandle) appendBindMeta(hash string, meta *BindMeta) { // Make sure there is only one goroutine writes the cache. - h.ch.removeStaleBindMetas(hash, meta) + h.ch.removeStaleBindMetas(hash, meta, metrics.ScopeSession) h.ch[hash] = append(h.ch[hash], meta) + meta.updateMetrics(metrics.ScopeSession, true) } func (h *SessionHandle) newBindMeta(record *BindRecord) (hash string, meta *BindMeta, err error) { @@ -74,7 +76,7 @@ func (h *SessionHandle) DropBindRecord(record *BindRecord) { meta := &BindMeta{BindRecord: record} meta.Status = deleted hash := parser.DigestHash(record.OriginalSQL) - h.ch.removeDeletedBindMeta(hash, meta) + h.ch.removeDeletedBindMeta(hash, meta, metrics.ScopeSession) h.appendBindMeta(hash, meta) } @@ -100,6 +102,15 @@ func (h *SessionHandle) GetAllBindRecord() (bindRecords []*BindMeta) { return bindRecords } +// Close closes the session handle. +func (h *SessionHandle) Close() { + for _, bindRecords := range h.ch { + for _, bindRecord := range bindRecords { + bindRecord.updateMetrics(metrics.ScopeSession, false) + } + } +} + // sessionBindInfoKeyType is a dummy type to avoid naming collision in context. type sessionBindInfoKeyType int diff --git a/cmd/benchdb/main.go b/cmd/benchdb/main.go index 38d906b44f99a..427e04809612e 100644 --- a/cmd/benchdb/main.go +++ b/cmd/benchdb/main.go @@ -117,7 +117,7 @@ func (ut *benchDB) mustExec(sql string) { if len(rss) > 0 { ctx := context.Background() rs := rss[0] - req := rs.NewRecordBatch() + req := rs.NewChunk() for { err := rs.Next(ctx, req) if err != nil { diff --git a/cmd/explaintest/main.go b/cmd/explaintest/main.go index 34eb979683f68..17f5f40a6dff5 100644 --- a/cmd/explaintest/main.go +++ b/cmd/explaintest/main.go @@ -357,10 +357,9 @@ func (t *tester) execute(query query) error { gotBuf := t.buf.Bytes()[offset:] buf := make([]byte, t.buf.Len()-offset) - if _, err = t.resultFD.ReadAt(buf, int64(offset)); err != nil { + if _, err = t.resultFD.ReadAt(buf, int64(offset)); !(err == nil || err == io.EOF) { return errors.Trace(errors.Errorf("run \"%v\" at line %d err, we got \n%s\nbut read result err %s", st.Text(), query.Line, gotBuf, err)) } - if !bytes.Equal(gotBuf, buf) { return errors.Trace(errors.Errorf("run \"%v\" at line %d err, we need:\n%s\nbut got:\n%s\n", query.Query, query.Line, buf, gotBuf)) } diff --git a/cmd/explaintest/r/access_path_selection.result b/cmd/explaintest/r/access_path_selection.result index 3e857d0b1d028..d178e09f8e03c 100644 --- a/cmd/explaintest/r/access_path_selection.result +++ b/cmd/explaintest/r/access_path_selection.result @@ -15,9 +15,9 @@ IndexReader_6 3323.33 root index:IndexScan_5 └─IndexScan_5 3323.33 cop table:access_path_selection, index:a, b, range:[-inf,3), keep order:false, stats:pseudo explain select a, b from access_path_selection where b < 3; id count task operator info -IndexLookUp_10 3323.33 root -├─IndexScan_8 3323.33 cop table:access_path_selection, index:b, range:[-inf,3), keep order:false, stats:pseudo -└─TableScan_9 3323.33 cop table:access_path_selection, keep order:false, stats:pseudo +IndexLookUp_7 3323.33 root +├─IndexScan_5 3323.33 cop table:access_path_selection, index:b, range:[-inf,3), keep order:false, stats:pseudo +└─TableScan_6 3323.33 cop table:access_path_selection, keep order:false, stats:pseudo explain select a, b from access_path_selection where a < 3 and b < 3; id count task operator info IndexReader_11 1104.45 root index:Selection_10 diff --git a/cmd/explaintest/r/black_list.result b/cmd/explaintest/r/black_list.result new file mode 100644 index 0000000000000..e85d21339a3bf --- /dev/null +++ b/cmd/explaintest/r/black_list.result @@ -0,0 +1,56 @@ +use test; +drop table if exists t; +create table t (a int); +explain select * from t where a < 1; +id count task operator info +TableReader_7 3323.33 root data:Selection_6 +└─Selection_6 3323.33 cop lt(test.t.a, 1) + └─TableScan_5 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo +insert into mysql.opt_rule_blacklist values('predicate_push_down'); +admin reload opt_rule_blacklist; + +explain select * from t where a < 1; +id count task operator info +Selection_5 8000.00 root lt(test.t.a, 1) +└─TableReader_7 10000.00 root data:TableScan_6 + └─TableScan_6 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo +delete from mysql.opt_rule_blacklist where name='predicate_push_down'; +admin reload opt_rule_blacklist; + +explain select * from t where a < 1; +id count task operator info +TableReader_7 3323.33 root data:Selection_6 +└─Selection_6 3323.33 cop lt(test.t.a, 1) + └─TableScan_5 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo +insert into mysql.expr_pushdown_blacklist values('<'); +admin reload expr_pushdown_blacklist; + +explain select * from t where a < 1; +id count task operator info +Selection_5 8000.00 root lt(test.t.a, 1) +└─TableReader_7 10000.00 root data:TableScan_6 + └─TableScan_6 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo +delete from mysql.expr_pushdown_blacklist where name='<'; +admin reload expr_pushdown_blacklist; + +explain select * from t where a < 1; +id count task operator info +TableReader_7 3323.33 root data:Selection_6 +└─Selection_6 3323.33 cop lt(test.t.a, 1) + └─TableScan_5 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo +insert into mysql.expr_pushdown_blacklist values('lt'); +admin reload expr_pushdown_blacklist; + +explain select * from t where a < 1; +id count task operator info +Selection_5 8000.00 root lt(test.t.a, 1) +└─TableReader_7 10000.00 root data:TableScan_6 + └─TableScan_6 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo +delete from mysql.expr_pushdown_blacklist where name='lt'; +admin reload expr_pushdown_blacklist; + +explain select * from t where a < 1; +id count task operator info +TableReader_7 3323.33 root data:Selection_6 +└─Selection_6 3323.33 cop lt(test.t.a, 1) + └─TableScan_5 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo diff --git a/cmd/explaintest/r/explain_complex.result b/cmd/explaintest/r/explain_complex.result index 64da664d9adb6..6f273706d1e94 100644 --- a/cmd/explaintest/r/explain_complex.result +++ b/cmd/explaintest/r/explain_complex.result @@ -156,9 +156,9 @@ Projection_10 0.00 root test.dt.id, test.dt.aid, test.dt.pt, test.dt.dic, test.d ├─TableReader_41 0.00 root data:Selection_40 │ └─Selection_40 0.00 cop eq(test.dt.bm, 0), eq(test.dt.pt, "ios"), gt(test.dt.t, 1478185592), not(isnull(test.dt.dic)) │ └─TableScan_39 10000.00 cop table:dt, range:[0,+inf], keep order:false, stats:pseudo - └─IndexLookUp_18 3.33 root + └─IndexLookUp_18 0.00 root ├─IndexScan_15 10.00 cop table:rr, index:aid, dic, range: decided by [eq(test.rr.aid, test.dt.aid) eq(test.rr.dic, test.dt.dic)], keep order:false, stats:pseudo - └─Selection_17 3.33 cop eq(test.rr.pt, "ios"), gt(test.rr.t, 1478185592) + └─Selection_17 0.00 cop eq(test.rr.pt, "ios"), gt(test.rr.t, 1478185592) └─TableScan_16 10.00 cop table:rr, keep order:false, stats:pseudo explain select pc,cr,count(DISTINCT uid) as pay_users,count(oid) as pay_times,sum(am) as am from pp where ps=2 and ppt>=1478188800 and ppt<1478275200 and pi in ('510017','520017') and uid in ('18089709','18090780') group by pc,cr; id count task operator info @@ -200,3 +200,62 @@ HashAgg_34 72000.00 root group by:col_1, funcs:sum(col_0) │ └─TableScan_58 10000.00 cop table:tbl_008, range:[-inf,+inf], keep order:false, stats:pseudo └─TableReader_62 10000.00 root data:TableScan_61 └─TableScan_61 10000.00 cop table:tbl_009, range:[-inf,+inf], keep order:false, stats:pseudo +CREATE TABLE org_department ( +id int(11) NOT NULL AUTO_INCREMENT, +ctx int(11) DEFAULT '0' COMMENT 'organization id', +name varchar(128) DEFAULT NULL, +left_value int(11) DEFAULT NULL, +right_value int(11) DEFAULT NULL, +depth int(11) DEFAULT NULL, +leader_id bigint(20) DEFAULT NULL, +status int(11) DEFAULT '1000', +created_on datetime DEFAULT NULL, +updated_on datetime DEFAULT NULL, +PRIMARY KEY (id), +UNIQUE KEY org_department_id_uindex (id), +KEY org_department_leader_id_index (leader_id), +KEY org_department_ctx_index (ctx) +); +CREATE TABLE org_position ( +id int(11) NOT NULL AUTO_INCREMENT, +ctx int(11) DEFAULT NULL, +name varchar(128) DEFAULT NULL, +left_value int(11) DEFAULT NULL, +right_value int(11) DEFAULT NULL, +depth int(11) DEFAULT NULL, +department_id int(11) DEFAULT NULL, +status int(2) DEFAULT NULL, +created_on datetime DEFAULT NULL, +updated_on datetime DEFAULT NULL, +PRIMARY KEY (id), +UNIQUE KEY org_position_id_uindex (id), +KEY org_position_department_id_index (department_id) +) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8; +CREATE TABLE org_employee_position ( +hotel_id int(11) DEFAULT NULL, +user_id bigint(20) DEFAULT NULL, +position_id int(11) DEFAULT NULL, +status int(11) DEFAULT NULL, +created_on datetime DEFAULT NULL, +updated_on datetime DEFAULT NULL, +UNIQUE KEY org_employee_position_pk (hotel_id,user_id,position_id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; +explain SELECT d.id, d.ctx, d.name, d.left_value, d.right_value, d.depth, d.leader_id, d.status, d.created_on, d.updated_on FROM org_department AS d LEFT JOIN org_position AS p ON p.department_id = d.id AND p.status = 1000 LEFT JOIN org_employee_position AS ep ON ep.position_id = p.id AND ep.status = 1000 WHERE (d.ctx = 1 AND (ep.user_id = 62 OR d.id = 20 OR d.id = 20) AND d.status = 1000) GROUP BY d.id ORDER BY d.left_value; +id count task operator info +Sort_10 1.00 root test.d.left_value:asc +└─HashAgg_15 1.00 root group by:test.d.id, funcs:firstrow(test.d.id), firstrow(test.d.ctx), firstrow(test.d.name), firstrow(test.d.left_value), firstrow(test.d.right_value), firstrow(test.d.depth), firstrow(test.d.leader_id), firstrow(test.d.status), firstrow(test.d.created_on), firstrow(test.d.updated_on) + └─Selection_20 0.01 root or(eq(test.ep.user_id, 62), or(eq(test.d.id, 20), eq(test.d.id, 20))) + └─HashLeftJoin_21 0.02 root left outer join, inner:TableReader_55, equal:[eq(test.p.id, test.ep.position_id)] + ├─IndexJoin_29 0.01 root left outer join, inner:IndexLookUp_28, outer key:test.d.id, inner key:test.p.department_id + │ ├─IndexLookUp_45 0.01 root + │ │ ├─IndexScan_42 10.00 cop table:d, index:ctx, range:[1,1], keep order:false, stats:pseudo + │ │ └─Selection_44 0.01 cop eq(test.d.status, 1000) + │ │ └─TableScan_43 10.00 cop table:org_department, keep order:false, stats:pseudo + │ └─IndexLookUp_28 0.01 root + │ ├─Selection_26 9.99 cop not(isnull(test.p.department_id)) + │ │ └─IndexScan_24 10.00 cop table:p, index:department_id, range: decided by [eq(test.p.department_id, test.d.id)], keep order:false, stats:pseudo + │ └─Selection_27 0.01 cop eq(test.p.status, 1000) + │ └─TableScan_25 9.99 cop table:org_position, keep order:false, stats:pseudo + └─TableReader_55 9.99 root data:Selection_54 + └─Selection_54 9.99 cop eq(test.ep.status, 1000), not(isnull(test.ep.position_id)) + └─TableScan_53 10000.00 cop table:ep, range:[-inf,+inf], keep order:false, stats:pseudo diff --git a/cmd/explaintest/r/explain_complex_stats.result b/cmd/explaintest/r/explain_complex_stats.result index 32e2eb653c670..419d1ba3834d8 100644 --- a/cmd/explaintest/r/explain_complex_stats.result +++ b/cmd/explaintest/r/explain_complex_stats.result @@ -115,13 +115,13 @@ PRIMARY KEY (aid,dic) load stats 's/explain_complex_stats_rr.json'; explain SELECT ds, p1, p2, p3, p4, p5, p6_md5, p7_md5, count(dic) as install_device FROM dt use index (cm) WHERE (ds >= '2016-09-01') AND (ds <= '2016-11-03') AND (cm IN ('1062', '1086', '1423', '1424', '1425', '1426', '1427', '1428', '1429', '1430', '1431', '1432', '1433', '1434', '1435', '1436', '1437', '1438', '1439', '1440', '1441', '1442', '1443', '1444', '1445', '1446', '1447', '1448', '1449', '1450', '1451', '1452', '1488', '1489', '1490', '1491', '1492', '1493', '1494', '1495', '1496', '1497', '1550', '1551', '1552', '1553', '1554', '1555', '1556', '1557', '1558', '1559', '1597', '1598', '1599', '1600', '1601', '1602', '1603', '1604', '1605', '1606', '1607', '1608', '1609', '1610', '1611', '1612', '1613', '1614', '1615', '1616', '1623', '1624', '1625', '1626', '1627', '1628', '1629', '1630', '1631', '1632', '1709', '1719', '1720', '1843', '2813', '2814', '2815', '2816', '2817', '2818', '2819', '2820', '2821', '2822', '2823', '2824', '2825', '2826', '2827', '2828', '2829', '2830', '2831', '2832', '2833', '2834', '2835', '2836', '2837', '2838', '2839', '2840', '2841', '2842', '2843', '2844', '2845', '2846', '2847', '2848', '2849', '2850', '2851', '2852', '2853', '2854', '2855', '2856', '2857', '2858', '2859', '2860', '2861', '2862', '2863', '2864', '2865', '2866', '2867', '2868', '2869', '2870', '2871', '2872', '3139', '3140', '3141', '3142', '3143', '3144', '3145', '3146', '3147', '3148', '3149', '3150', '3151', '3152', '3153', '3154', '3155', '3156', '3157', '3158', '3386', '3387', '3388', '3389', '3390', '3391', '3392', '3393', '3394', '3395', '3664', '3665', '3666', '3667', '3668', '3670', '3671', '3672', '3673', '3674', '3676', '3677', '3678', '3679', '3680', '3681', '3682', '3683', '3684', '3685', '3686', '3687', '3688', '3689', '3690', '3691', '3692', '3693', '3694', '3695', '3696', '3697', '3698', '3699', '3700', '3701', '3702', '3703', '3704', '3705', '3706', '3707', '3708', '3709', '3710', '3711', '3712', '3713', '3714', '3715', '3960', '3961', '3962', '3963', '3964', '3965', '3966', '3967', '3968', '3978', '3979', '3980', '3981', '3982', '3983', '3984', '3985', '3986', '3987', '4208', '4209', '4210', '4211', '4212', '4304', '4305', '4306', '4307', '4308', '4866', '4867', '4868', '4869', '4870', '4871', '4872', '4873', '4874', '4875')) GROUP BY ds, p1, p2, p3, p4, p5, p6_md5, p7_md5 ORDER BY ds2 DESC; id count task operator info -Projection_7 21.40 root test.dt.ds, test.dt.p1, test.dt.p2, test.dt.p3, test.dt.p4, test.dt.p5, test.dt.p6_md5, test.dt.p7_md5, install_device -└─Sort_8 21.40 root test.dt.ds2:desc - └─HashAgg_16 21.40 root group by:col_10, col_11, col_12, col_13, col_14, col_15, col_16, col_17, funcs:count(col_0), firstrow(col_1), firstrow(col_2), firstrow(col_3), firstrow(col_4), firstrow(col_5), firstrow(col_6), firstrow(col_7), firstrow(col_8), firstrow(col_9) - └─IndexLookUp_17 21.40 root +Projection_7 21.53 root test.dt.ds, test.dt.p1, test.dt.p2, test.dt.p3, test.dt.p4, test.dt.p5, test.dt.p6_md5, test.dt.p7_md5, install_device +└─Sort_8 21.53 root test.dt.ds2:desc + └─HashAgg_16 21.53 root group by:col_10, col_11, col_12, col_13, col_14, col_15, col_16, col_17, funcs:count(col_0), firstrow(col_1), firstrow(col_2), firstrow(col_3), firstrow(col_4), firstrow(col_5), firstrow(col_6), firstrow(col_7), firstrow(col_8), firstrow(col_9) + └─IndexLookUp_17 21.53 root ├─IndexScan_13 128.32 cop table:dt, index:cm, range:[1062,1062], [1086,1086], [1423,1423], [1424,1424], [1425,1425], [1426,1426], [1427,1427], [1428,1428], [1429,1429], [1430,1430], [1431,1431], [1432,1432], [1433,1433], [1434,1434], [1435,1435], [1436,1436], [1437,1437], [1438,1438], [1439,1439], [1440,1440], [1441,1441], [1442,1442], [1443,1443], [1444,1444], [1445,1445], [1446,1446], [1447,1447], [1448,1448], [1449,1449], [1450,1450], [1451,1451], [1452,1452], [1488,1488], [1489,1489], [1490,1490], [1491,1491], [1492,1492], [1493,1493], [1494,1494], [1495,1495], [1496,1496], [1497,1497], [1550,1550], [1551,1551], [1552,1552], [1553,1553], [1554,1554], [1555,1555], [1556,1556], [1557,1557], [1558,1558], [1559,1559], [1597,1597], [1598,1598], [1599,1599], [1600,1600], [1601,1601], [1602,1602], [1603,1603], [1604,1604], [1605,1605], [1606,1606], [1607,1607], [1608,1608], [1609,1609], [1610,1610], [1611,1611], [1612,1612], [1613,1613], [1614,1614], [1615,1615], [1616,1616], [1623,1623], [1624,1624], [1625,1625], [1626,1626], [1627,1627], [1628,1628], [1629,1629], [1630,1630], [1631,1631], [1632,1632], [1709,1709], [1719,1719], [1720,1720], [1843,1843], [2813,2813], [2814,2814], [2815,2815], [2816,2816], [2817,2817], [2818,2818], [2819,2819], [2820,2820], [2821,2821], [2822,2822], [2823,2823], [2824,2824], [2825,2825], [2826,2826], [2827,2827], [2828,2828], [2829,2829], [2830,2830], [2831,2831], [2832,2832], [2833,2833], [2834,2834], [2835,2835], [2836,2836], [2837,2837], [2838,2838], [2839,2839], [2840,2840], [2841,2841], [2842,2842], [2843,2843], [2844,2844], [2845,2845], [2846,2846], [2847,2847], [2848,2848], [2849,2849], [2850,2850], [2851,2851], [2852,2852], [2853,2853], [2854,2854], [2855,2855], [2856,2856], [2857,2857], [2858,2858], [2859,2859], [2860,2860], [2861,2861], [2862,2862], [2863,2863], [2864,2864], [2865,2865], [2866,2866], [2867,2867], [2868,2868], [2869,2869], [2870,2870], [2871,2871], [2872,2872], [3139,3139], [3140,3140], [3141,3141], [3142,3142], [3143,3143], [3144,3144], [3145,3145], [3146,3146], [3147,3147], [3148,3148], [3149,3149], [3150,3150], [3151,3151], [3152,3152], [3153,3153], [3154,3154], [3155,3155], [3156,3156], [3157,3157], [3158,3158], [3386,3386], [3387,3387], [3388,3388], [3389,3389], [3390,3390], [3391,3391], [3392,3392], [3393,3393], [3394,3394], [3395,3395], [3664,3664], [3665,3665], [3666,3666], [3667,3667], [3668,3668], [3670,3670], [3671,3671], [3672,3672], [3673,3673], [3674,3674], [3676,3676], [3677,3677], [3678,3678], [3679,3679], [3680,3680], [3681,3681], [3682,3682], [3683,3683], [3684,3684], [3685,3685], [3686,3686], [3687,3687], [3688,3688], [3689,3689], [3690,3690], [3691,3691], [3692,3692], [3693,3693], [3694,3694], [3695,3695], [3696,3696], [3697,3697], [3698,3698], [3699,3699], [3700,3700], [3701,3701], [3702,3702], [3703,3703], [3704,3704], [3705,3705], [3706,3706], [3707,3707], [3708,3708], [3709,3709], [3710,3710], [3711,3711], [3712,3712], [3713,3713], [3714,3714], [3715,3715], [3960,3960], [3961,3961], [3962,3962], [3963,3963], [3964,3964], [3965,3965], [3966,3966], [3967,3967], [3968,3968], [3978,3978], [3979,3979], [3980,3980], [3981,3981], [3982,3982], [3983,3983], [3984,3984], [3985,3985], [3986,3986], [3987,3987], [4208,4208], [4209,4209], [4210,4210], [4211,4211], [4212,4212], [4304,4304], [4305,4305], [4306,4306], [4307,4307], [4308,4308], [4866,4866], [4867,4867], [4868,4868], [4869,4869], [4870,4870], [4871,4871], [4872,4872], [4873,4873], [4874,4874], [4875,4875], keep order:false - └─HashAgg_11 21.40 cop group by:test.dt.ds, test.dt.p1, test.dt.p2, test.dt.p3, test.dt.p4, test.dt.p5, test.dt.p6_md5, test.dt.p7_md5, funcs:count(test.dt.dic), firstrow(test.dt.ds), firstrow(test.dt.ds2), firstrow(test.dt.p1), firstrow(test.dt.p2), firstrow(test.dt.p3), firstrow(test.dt.p4), firstrow(test.dt.p5), firstrow(test.dt.p6_md5), firstrow(test.dt.p7_md5) - └─Selection_15 21.43 cop ge(test.dt.ds, 2016-09-01 00:00:00.000000), le(test.dt.ds, 2016-11-03 00:00:00.000000) + └─HashAgg_11 21.53 cop group by:test.dt.ds, test.dt.p1, test.dt.p2, test.dt.p3, test.dt.p4, test.dt.p5, test.dt.p6_md5, test.dt.p7_md5, funcs:count(test.dt.dic), firstrow(test.dt.ds), firstrow(test.dt.ds2), firstrow(test.dt.p1), firstrow(test.dt.p2), firstrow(test.dt.p3), firstrow(test.dt.p4), firstrow(test.dt.p5), firstrow(test.dt.p6_md5), firstrow(test.dt.p7_md5) + └─Selection_15 21.56 cop ge(test.dt.ds, 2016-09-01 00:00:00.000000), le(test.dt.ds, 2016-11-03 00:00:00.000000) └─TableScan_14 128.32 cop table:dt, keep order:false explain select gad.id as gid,sdk.id as sid,gad.aid as aid,gad.cm as cm,sdk.dic as dic,sdk.ip as ip, sdk.t as t, gad.p1 as p1, gad.p2 as p2, gad.p3 as p3, gad.p4 as p4, gad.p5 as p5, gad.p6_md5 as p6, gad.p7_md5 as p7, gad.ext as ext, gad.t as gtime from st gad join (select id, aid, pt, dic, ip, t from dd where pt = 'android' and bm = 0 and t > 1478143908) sdk on gad.aid = sdk.aid and gad.ip = sdk.ip and sdk.t > gad.t where gad.t > 1478143908 and gad.bm = 0 and gad.pt = 'android' group by gad.aid, sdk.dic limit 2500; id count task operator info @@ -132,9 +132,9 @@ Projection_13 424.00 root test.gad.id, test.dd.id, test.gad.aid, test.gad.cm, te ├─TableReader_29 424.00 root data:Selection_28 │ └─Selection_28 424.00 cop eq(test.gad.bm, 0), eq(test.gad.pt, "android"), gt(test.gad.t, 1478143908), not(isnull(test.gad.ip)) │ └─TableScan_27 1999.00 cop table:gad, range:[0,+inf], keep order:false - └─IndexLookUp_23 455.80 root + └─IndexLookUp_23 0.23 root ├─IndexScan_20 1.00 cop table:dd, index:aid, dic, range: decided by [eq(test.dd.aid, test.gad.aid)], keep order:false - └─Selection_22 455.80 cop eq(test.dd.bm, 0), eq(test.dd.pt, "android"), gt(test.dd.t, 1478143908), not(isnull(test.dd.ip)), not(isnull(test.dd.t)) + └─Selection_22 0.23 cop eq(test.dd.bm, 0), eq(test.dd.pt, "android"), gt(test.dd.t, 1478143908), not(isnull(test.dd.ip)), not(isnull(test.dd.t)) └─TableScan_21 1.00 cop table:dd, keep order:false explain select gad.id as gid,sdk.id as sid,gad.aid as aid,gad.cm as cm,sdk.dic as dic,sdk.ip as ip, sdk.t as t, gad.p1 as p1, gad.p2 as p2, gad.p3 as p3, gad.p4 as p4, gad.p5 as p5, gad.p6_md5 as p6, gad.p7_md5 as p7, gad.ext as ext from st gad join dd sdk on gad.aid = sdk.aid and gad.dic = sdk.mac and gad.t < sdk.t where gad.t > 1477971479 and gad.bm = 0 and gad.pt = 'ios' and gad.dit = 'mac' and sdk.t > 1477971479 and sdk.bm = 0 and sdk.pt = 'ios' limit 3000; id count task operator info @@ -144,9 +144,9 @@ Projection_10 170.34 root test.gad.id, test.sdk.id, test.gad.aid, test.gad.cm, t ├─TableReader_23 170.34 root data:Selection_22 │ └─Selection_22 170.34 cop eq(test.gad.bm, 0), eq(test.gad.dit, "mac"), eq(test.gad.pt, "ios"), gt(test.gad.t, 1477971479), not(isnull(test.gad.dic)) │ └─TableScan_21 1999.00 cop table:gad, range:[0,+inf], keep order:false - └─IndexLookUp_17 509.04 root + └─IndexLookUp_17 0.25 root ├─IndexScan_14 1.00 cop table:sdk, index:aid, dic, range: decided by [eq(test.sdk.aid, test.gad.aid)], keep order:false - └─Selection_16 509.04 cop eq(test.sdk.bm, 0), eq(test.sdk.pt, "ios"), gt(test.sdk.t, 1477971479), not(isnull(test.sdk.mac)), not(isnull(test.sdk.t)) + └─Selection_16 0.25 cop eq(test.sdk.bm, 0), eq(test.sdk.pt, "ios"), gt(test.sdk.t, 1477971479), not(isnull(test.sdk.mac)), not(isnull(test.sdk.t)) └─TableScan_15 1.00 cop table:dd, keep order:false explain SELECT cm, p1, p2, p3, p4, p5, p6_md5, p7_md5, count(1) as click_pv, count(DISTINCT ip) as click_ip FROM st WHERE (t between 1478188800 and 1478275200) and aid='cn.sbkcq' and pt='android' GROUP BY cm, p1, p2, p3, p4, p5, p6_md5, p7_md5; id count task operator info @@ -164,9 +164,9 @@ Projection_10 428.32 root test.dt.id, test.dt.aid, test.dt.pt, test.dt.dic, test ├─TableReader_41 428.32 root data:Selection_40 │ └─Selection_40 428.32 cop eq(test.dt.bm, 0), eq(test.dt.pt, "ios"), gt(test.dt.t, 1478185592), not(isnull(test.dt.dic)) │ └─TableScan_39 2000.00 cop table:dt, range:[0,+inf], keep order:false - └─IndexLookUp_18 970.00 root + └─IndexLookUp_18 0.48 root ├─IndexScan_15 1.00 cop table:rr, index:aid, dic, range: decided by [eq(test.rr.aid, test.dt.aid) eq(test.rr.dic, test.dt.dic)], keep order:false - └─Selection_17 970.00 cop eq(test.rr.pt, "ios"), gt(test.rr.t, 1478185592) + └─Selection_17 0.48 cop eq(test.rr.pt, "ios"), gt(test.rr.t, 1478185592) └─TableScan_16 1.00 cop table:rr, keep order:false explain select pc,cr,count(DISTINCT uid) as pay_users,count(oid) as pay_times,sum(am) as am from pp where ps=2 and ppt>=1478188800 and ppt<1478275200 and pi in ('510017','520017') and uid in ('18089709','18090780') group by pc,cr; id count task operator info diff --git a/cmd/explaintest/r/explain_easy.result b/cmd/explaintest/r/explain_easy.result index 4806f6be3d3b7..387073b759936 100644 --- a/cmd/explaintest/r/explain_easy.result +++ b/cmd/explaintest/r/explain_easy.result @@ -45,10 +45,10 @@ id count task operator info IndexJoin_12 4166.67 root left outer join, inner:IndexLookUp_11, outer key:test.t1.c2, inner key:test.t2.c1 ├─TableReader_24 3333.33 root data:TableScan_23 │ └─TableScan_23 3333.33 cop table:t1, range:(1,+inf], keep order:false, stats:pseudo -└─IndexLookUp_11 0.00 root - ├─Selection_10 0.00 cop not(isnull(test.t2.c1)) +└─IndexLookUp_11 9.99 root + ├─Selection_10 9.99 cop not(isnull(test.t2.c1)) │ └─IndexScan_8 10.00 cop table:t2, index:c1, range: decided by [eq(test.t2.c1, test.t1.c2)], keep order:false, stats:pseudo - └─TableScan_9 0.00 cop table:t2, keep order:false, stats:pseudo + └─TableScan_9 9.99 cop table:t2, keep order:false, stats:pseudo explain update t1 set t1.c2 = 2 where t1.c1 = 1; id count task operator info Point_Get_1 1.00 root table:t1, handle:1 @@ -90,11 +90,11 @@ explain select sum(t1.c1 in (select c1 from t2)) from t1; id count task operator info StreamAgg_12 1.00 root funcs:sum(col_0) └─Projection_19 10000.00 root cast(5_aux_0) - └─HashLeftJoin_18 10000.00 root left outer semi join, inner:TableReader_17, other cond:eq(test.t1.c1, test.t2.c1) - ├─TableReader_15 10000.00 root data:TableScan_14 - │ └─TableScan_14 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo - └─TableReader_17 10000.00 root data:TableScan_16 - └─TableScan_16 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo + └─HashLeftJoin_18 10000.00 root CARTESIAN left outer semi join, inner:IndexReader_17, other cond:eq(test.t1.c1, test.t2.c1) + ├─IndexReader_15 10000.00 root index:IndexScan_14 + │ └─IndexScan_14 10000.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false, stats:pseudo + └─IndexReader_17 10000.00 root index:IndexScan_16 + └─IndexScan_16 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:false, stats:pseudo explain select c1 from t1 where c1 in (select c2 from t2); id count task operator info Projection_9 9990.00 root test.t1.c1 @@ -122,15 +122,14 @@ MemTableScan_4 10000.00 root explain select c2 = (select c2 from t2 where t1.c1 = t2.c1 order by c1 limit 1) from t1; id count task operator info Projection_12 10000.00 root eq(test.t1.c2, test.t2.c2) -└─Apply_14 10000.00 root left outer join, inner:Limit_21 - ├─TableReader_16 10000.00 root data:TableScan_15 - │ └─TableScan_15 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo - └─Limit_21 1.00 root offset:0, count:1 - └─Projection_41 1.00 root test.t2.c1, test.t2.c2 - └─IndexLookUp_40 1.00 root - ├─Limit_39 1.00 cop offset:0, count:1 - │ └─IndexScan_37 1.00 cop table:t2, index:c1, range: decided by [eq(test.t1.c1, test.t2.c1)], keep order:true, stats:pseudo - └─TableScan_38 1.00 cop table:t2, keep order:false, stats:pseudo +└─Apply_14 10000.00 root CARTESIAN left outer join, inner:Projection_41 + ├─IndexReader_16 10000.00 root index:IndexScan_15 + │ └─IndexScan_15 10000.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false, stats:pseudo + └─Projection_41 1.00 root test.t2.c1, test.t2.c2 + └─IndexLookUp_40 1.00 root limit embedded(offset:0, count:1) + ├─Limit_39 1.00 cop offset:0, count:1 + │ └─IndexScan_37 1.00 cop table:t2, index:c1, range: decided by [eq(test.t1.c1, test.t2.c1)], keep order:true, stats:pseudo + └─TableScan_38 1.00 cop table:t2, keep order:false, stats:pseudo explain select * from t1 order by c1 desc limit 1; id count task operator info Limit_10 1.00 root offset:0, count:1 @@ -155,12 +154,12 @@ Limit_8 1.00 root offset:0, count:1 └─TableScan_11 3.00 cop table:t4, range:(1,+inf], keep order:false, stats:pseudo explain select ifnull(null, t1.c1) from t1; id count task operator info -TableReader_5 10000.00 root data:TableScan_4 -└─TableScan_4 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +IndexReader_5 10000.00 root index:IndexScan_4 +└─IndexScan_4 10000.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false, stats:pseudo explain select if(10, t1.c1, t1.c2) from t1; id count task operator info -TableReader_5 10000.00 root data:TableScan_4 -└─TableScan_4 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +IndexReader_5 10000.00 root index:IndexScan_4 +└─IndexScan_4 10000.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false, stats:pseudo explain select c1 from t2 union select c1 from t2 union all select c1 from t2; id count task operator info Union_17 26000.00 root @@ -174,8 +173,8 @@ Union_17 26000.00 root │ └─IndexReader_50 8000.00 root index:StreamAgg_41 │ └─StreamAgg_41 8000.00 cop group by:test.t2.c1, funcs:firstrow(test.t2.c1), firstrow(test.t2.c1) │ └─IndexScan_48 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:true, stats:pseudo -└─TableReader_55 10000.00 root data:TableScan_54 - └─TableScan_54 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +└─IndexReader_55 10000.00 root index:IndexScan_54 + └─IndexScan_54 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:false, stats:pseudo explain select c1 from t2 union all select c1 from t2 union select c1 from t2; id count task operator info HashAgg_18 24000.00 root group by:c1, funcs:firstrow(join_agg_0) @@ -192,22 +191,51 @@ HashAgg_18 24000.00 root group by:c1, funcs:firstrow(join_agg_0) └─IndexReader_62 8000.00 root index:StreamAgg_53 └─StreamAgg_53 8000.00 cop group by:test.t2.c1, funcs:firstrow(test.t2.c1), firstrow(test.t2.c1) └─IndexScan_60 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:true, stats:pseudo +explain select count(1) from (select count(1) from (select * from t1 where c3 = 100) k) k2; +id count task operator info +StreamAgg_13 1.00 root funcs:count(1) +└─StreamAgg_28 1.00 root funcs:firstrow(col_0) + └─TableReader_29 1.00 root data:StreamAgg_17 + └─StreamAgg_17 1.00 cop funcs:firstrow(1) + └─Selection_27 10.00 cop eq(test.t1.c3, 100) + └─TableScan_26 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +explain select 1 from (select count(c2), count(c3) from t1) k; +id count task operator info +Projection_5 1.00 root 1 +└─StreamAgg_17 1.00 root funcs:firstrow(col_0) + └─IndexReader_18 1.00 root index:StreamAgg_9 + └─StreamAgg_9 1.00 cop funcs:firstrow(1) + └─IndexScan_16 10000.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false, stats:pseudo +explain select count(1) from (select max(c2), count(c3) as m from t1) k; +id count task operator info +StreamAgg_11 1.00 root funcs:count(1) +└─StreamAgg_23 1.00 root funcs:firstrow(col_0) + └─IndexReader_24 1.00 root index:StreamAgg_15 + └─StreamAgg_15 1.00 cop funcs:firstrow(1) + └─IndexScan_22 10000.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false, stats:pseudo +explain select count(1) from (select count(c2) from t1 group by c3) k; +id count task operator info +StreamAgg_11 1.00 root funcs:count(1) +└─HashAgg_23 8000.00 root group by:col_1, funcs:firstrow(col_0) + └─TableReader_24 8000.00 root data:HashAgg_20 + └─HashAgg_20 8000.00 cop group by:test.t1.c3, funcs:firstrow(1) + └─TableScan_15 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo set @@session.tidb_opt_insubq_to_join_and_agg=0; explain select sum(t1.c1 in (select c1 from t2)) from t1; id count task operator info StreamAgg_12 1.00 root funcs:sum(col_0) └─Projection_19 10000.00 root cast(5_aux_0) - └─HashLeftJoin_18 10000.00 root left outer semi join, inner:TableReader_17, other cond:eq(test.t1.c1, test.t2.c1) - ├─TableReader_15 10000.00 root data:TableScan_14 - │ └─TableScan_14 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo - └─TableReader_17 10000.00 root data:TableScan_16 - └─TableScan_16 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo + └─HashLeftJoin_18 10000.00 root CARTESIAN left outer semi join, inner:IndexReader_17, other cond:eq(test.t1.c1, test.t2.c1) + ├─IndexReader_15 10000.00 root index:IndexScan_14 + │ └─IndexScan_14 10000.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false, stats:pseudo + └─IndexReader_17 10000.00 root index:IndexScan_16 + └─IndexScan_16 10000.00 cop table:t2, index:c1, range:[NULL,+inf], keep order:false, stats:pseudo explain select 1 in (select c2 from t2) from t1; id count task operator info Projection_6 10000.00 root 5_aux_0 -└─HashLeftJoin_7 10000.00 root left outer semi join, inner:TableReader_12 - ├─TableReader_9 10000.00 root data:TableScan_8 - │ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +└─HashLeftJoin_7 10000.00 root CARTESIAN left outer semi join, inner:TableReader_12 + ├─IndexReader_9 10000.00 root index:IndexScan_8 + │ └─IndexScan_8 10000.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false, stats:pseudo └─TableReader_12 10.00 root data:Selection_11 └─Selection_11 10.00 cop eq(1, test.t2.c2) └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo @@ -215,9 +243,9 @@ explain select sum(6 in (select c2 from t2)) from t1; id count task operator info StreamAgg_12 1.00 root funcs:sum(col_0) └─Projection_20 10000.00 root cast(5_aux_0) - └─HashLeftJoin_19 10000.00 root left outer semi join, inner:TableReader_18 - ├─TableReader_15 10000.00 root data:TableScan_14 - │ └─TableScan_14 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─HashLeftJoin_19 10000.00 root CARTESIAN left outer semi join, inner:TableReader_18 + ├─IndexReader_15 10000.00 root index:IndexScan_14 + │ └─IndexScan_14 10000.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false, stats:pseudo └─TableReader_18 10.00 root data:Selection_17 └─Selection_17 10.00 cop eq(6, test.t2.c2) └─TableScan_16 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo @@ -231,23 +259,23 @@ color=black label = "root" "StreamAgg_12" -> "Projection_19" "Projection_19" -> "HashLeftJoin_18" -"HashLeftJoin_18" -> "TableReader_15" -"HashLeftJoin_18" -> "TableReader_17" +"HashLeftJoin_18" -> "IndexReader_15" +"HashLeftJoin_18" -> "IndexReader_17" } subgraph cluster14{ node [style=filled, color=lightgrey] color=black label = "cop" -"TableScan_14" +"IndexScan_14" } subgraph cluster16{ node [style=filled, color=lightgrey] color=black label = "cop" -"TableScan_16" +"IndexScan_16" } -"TableReader_15" -> "TableScan_14" -"TableReader_17" -> "TableScan_16" +"IndexReader_15" -> "IndexScan_14" +"IndexReader_17" -> "IndexScan_16" } explain format="dot" select 1 in (select c2 from t2) from t1; @@ -259,14 +287,14 @@ node [style=filled, color=lightgrey] color=black label = "root" "Projection_6" -> "HashLeftJoin_7" -"HashLeftJoin_7" -> "TableReader_9" +"HashLeftJoin_7" -> "IndexReader_9" "HashLeftJoin_7" -> "TableReader_12" } subgraph cluster8{ node [style=filled, color=lightgrey] color=black label = "cop" -"TableScan_8" +"IndexScan_8" } subgraph cluster11{ node [style=filled, color=lightgrey] @@ -274,7 +302,7 @@ color=black label = "cop" "Selection_11" -> "TableScan_10" } -"TableReader_9" -> "TableScan_8" +"IndexReader_9" -> "IndexScan_8" "TableReader_12" -> "Selection_11" } @@ -284,7 +312,7 @@ create table t(a int primary key, b int, c int, index idx(b)); explain select t.c in (select count(*) from t s ignore index(idx), t t1 where s.a = t.a and s.a = t1.a) from t; id count task operator info Projection_11 10000.00 root 9_aux_0 -└─Apply_13 10000.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) +└─Apply_13 10000.00 root CARTESIAN left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) ├─TableReader_15 10000.00 root data:TableScan_14 │ └─TableScan_14 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo └─StreamAgg_20 1.00 root funcs:count(1) @@ -297,7 +325,7 @@ Projection_11 10000.00 root 9_aux_0 explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.a = t1.a) from t; id count task operator info Projection_11 10000.00 root 9_aux_0 -└─Apply_13 10000.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) +└─Apply_13 10000.00 root CARTESIAN left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) ├─TableReader_15 10000.00 root data:TableScan_14 │ └─TableScan_14 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo └─StreamAgg_20 1.00 root funcs:count(1) @@ -309,7 +337,7 @@ Projection_11 10000.00 root 9_aux_0 explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = t.a and s.c = t1.a) from t; id count task operator info Projection_11 10000.00 root 9_aux_0 -└─Apply_13 10000.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) +└─Apply_13 10000.00 root CARTESIAN left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) ├─TableReader_15 10000.00 root data:TableScan_14 │ └─TableScan_14 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo └─StreamAgg_20 1.00 root funcs:count(1) @@ -325,7 +353,7 @@ analyze table t; explain select t.c in (select count(*) from t s, t t1 where s.b = t.a and s.b = 3 and s.a = t1.a) from t; id count task operator info Projection_11 5.00 root 9_aux_0 -└─Apply_13 5.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) +└─Apply_13 5.00 root CARTESIAN left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) ├─TableReader_15 5.00 root data:TableScan_14 │ └─TableScan_14 5.00 cop table:t, range:[-inf,+inf], keep order:false └─StreamAgg_20 1.00 root funcs:count(1) @@ -339,7 +367,7 @@ Projection_11 5.00 root 9_aux_0 explain select t.c in (select count(*) from t s left join t t1 on s.a = t1.a where 3 = t.a and s.b = 3) from t; id count task operator info Projection_10 5.00 root 9_aux_0 -└─Apply_12 5.00 root left outer semi join, inner:StreamAgg_19, other cond:eq(test.t.c, 7_col_0) +└─Apply_12 5.00 root CARTESIAN left outer semi join, inner:StreamAgg_19, other cond:eq(test.t.c, 7_col_0) ├─TableReader_14 5.00 root data:TableScan_13 │ └─TableScan_13 5.00 cop table:t, range:[-inf,+inf], keep order:false └─StreamAgg_19 1.00 root funcs:count(1) @@ -353,7 +381,7 @@ Projection_10 5.00 root 9_aux_0 explain select t.c in (select count(*) from t s right join t t1 on s.a = t1.a where 3 = t.a and t1.b = 3) from t; id count task operator info Projection_10 5.00 root 9_aux_0 -└─Apply_12 5.00 root left outer semi join, inner:StreamAgg_19, other cond:eq(test.t.c, 7_col_0) +└─Apply_12 5.00 root CARTESIAN left outer semi join, inner:StreamAgg_19, other cond:eq(test.t.c, 7_col_0) ├─TableReader_14 5.00 root data:TableScan_13 │ └─TableScan_13 5.00 cop table:t, range:[-inf,+inf], keep order:false └─StreamAgg_19 1.00 root funcs:count(1) @@ -394,9 +422,9 @@ IndexReader_6 10.00 root index:IndexScan_5 └─IndexScan_5 10.00 cop table:t, index:a, b, range:[1,1], keep order:false, stats:pseudo explain select * from t where b in (1, 2) and b in (1, 3); id count task operator info -TableReader_7 10.00 root data:Selection_6 -└─Selection_6 10.00 cop in(test.t.b, 1, 2), in(test.t.b, 1, 3) - └─TableScan_5 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo +IndexReader_7 10.00 root index:Selection_6 +└─Selection_6 10.00 cop eq(test.t.b, 1) + └─IndexScan_5 10000.00 cop table:t, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo explain select * from t where a = 1 and a = 1; id count task operator info IndexReader_6 10.00 root index:IndexScan_5 @@ -410,20 +438,20 @@ TableDual_5 0.00 root rows:0 explain select * from t t1 join t t2 where t1.b = t2.b and t2.b is null; id count task operator info Projection_7 0.00 root test.t1.a, test.t1.b, test.t2.a, test.t2.b -└─HashRightJoin_9 0.00 root inner join, inner:TableReader_12, equal:[eq(test.t2.b, test.t1.b)] - ├─TableReader_12 0.00 root data:Selection_11 +└─HashRightJoin_9 0.00 root inner join, inner:IndexReader_12, equal:[eq(test.t2.b, test.t1.b)] + ├─IndexReader_12 0.00 root index:Selection_11 │ └─Selection_11 0.00 cop isnull(test.t2.b), not(isnull(test.t2.b)) - │ └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo - └─TableReader_15 9990.00 root data:Selection_14 + │ └─IndexScan_10 10000.00 cop table:t2, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo + └─IndexReader_15 9990.00 root index:Selection_14 └─Selection_14 9990.00 cop not(isnull(test.t1.b)) - └─TableScan_13 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_13 10000.00 cop table:t1, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo explain select * from t t1 where not exists (select * from t t2 where t1.b = t2.b); id count task operator info -HashLeftJoin_9 8000.00 root anti semi join, inner:TableReader_13, equal:[eq(test.t1.b, test.t2.b)] -├─TableReader_11 10000.00 root data:TableScan_10 -│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_13 10000.00 root data:TableScan_12 - └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +HashLeftJoin_9 8000.00 root anti semi join, inner:IndexReader_13, equal:[eq(test.t1.b, test.t2.b)] +├─IndexReader_11 10000.00 root index:IndexScan_10 +│ └─IndexScan_10 10000.00 cop table:t1, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_13 10000.00 root index:IndexScan_12 + └─IndexScan_12 10000.00 cop table:t2, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo drop table if exists t; create table t(a bigint primary key); explain select * from t where a = 1 and a = 2; @@ -434,6 +462,9 @@ id count task operator info Projection_3 10000.00 root or(NULL, gt(test.t.a, 1)) └─TableReader_5 10000.00 root data:TableScan_4 └─TableScan_4 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo +explain select * from t where a = 1 for update; +id count task operator info +Point_Get_1 1.00 root table:t, handle:1, lock drop table if exists ta, tb; create table ta (a varchar(20)); create table tb (a varchar(20)); @@ -452,14 +483,12 @@ create table t1(a int, b int, c int, primary key(a, b)); create table t2(a int, b int, c int, primary key(a)); explain select t1.a, t1.b from t1 left outer join t2 on t1.a = t2.a; id count task operator info -TableReader_7 10000.00 root data:TableScan_6 -└─TableScan_6 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo +IndexReader_7 10000.00 root index:IndexScan_6 +└─IndexScan_6 10000.00 cop table:t1, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo explain select distinct t1.a, t1.b from t1 left outer join t2 on t1.a = t2.a; id count task operator info -StreamAgg_18 8000.00 root group by:col_2, col_3, funcs:firstrow(col_0), firstrow(col_1) -└─IndexReader_19 8000.00 root index:StreamAgg_10 - └─StreamAgg_10 8000.00 cop group by:test.t1.a, test.t1.b, funcs:firstrow(test.t1.a), firstrow(test.t1.b) - └─IndexScan_17 10000.00 cop table:t1, index:a, b, range:[NULL,+inf], keep order:true, stats:pseudo +IndexReader_9 10000.00 root index:IndexScan_8 +└─IndexScan_8 10000.00 cop table:t1, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo drop table if exists t; create table t(a int, nb int not null, nc int not null); explain select ifnull(a, 0) from t; @@ -530,7 +559,7 @@ HashRightJoin_9 4166.67 root inner join, inner:TableReader_12, equal:[eq(test.ta explain select ifnull(t.nc, 1) in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t; id count task operator info Projection_12 10000.00 root 9_aux_0 -└─Apply_14 10000.00 root left outer semi join, inner:HashAgg_19, other cond:eq(test.t.nc, 7_col_0) +└─Apply_14 10000.00 root CARTESIAN left outer semi join, inner:HashAgg_19, other cond:eq(test.t.nc, 7_col_0) ├─TableReader_16 10000.00 root data:TableScan_15 │ └─TableScan_15 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo └─HashAgg_19 1.00 root funcs:count(join_agg_0) @@ -563,7 +592,7 @@ HashRightJoin_7 8000.00 root right outer join, inner:TableReader_10, equal:[eq(t explain select ifnull(t.a, 1) in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t; id count task operator info Projection_12 10000.00 root 9_aux_0 -└─Apply_14 10000.00 root left outer semi join, inner:HashAgg_19, other cond:eq(ifnull(test.t.a, 1), 7_col_0) +└─Apply_14 10000.00 root CARTESIAN left outer semi join, inner:HashAgg_19, other cond:eq(ifnull(test.t.a, 1), 7_col_0) ├─TableReader_16 10000.00 root data:TableScan_15 │ └─TableScan_15 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo └─HashAgg_19 1.00 root funcs:count(join_agg_0) @@ -660,3 +689,17 @@ Projection_8 8320.83 root test.t.a, test.t1.a └─TableScan_15 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo rollback; drop table if exists t; +create table t(a time, b date); +insert into t values (1, "1000-01-01"), (2, "1000-01-02"), (3, "1000-01-03"); +analyze table t; +explain select * from t where a = 1; +id count task operator info +TableReader_7 1.00 root data:Selection_6 +└─Selection_6 1.00 cop eq(test.t.a, 00:00:01.000000) + └─TableScan_5 3.00 cop table:t, range:[-inf,+inf], keep order:false +explain select * from t where b = "1000-01-01"; +id count task operator info +TableReader_7 1.00 root data:Selection_6 +└─Selection_6 1.00 cop eq(test.t.b, 1000-01-01 00:00:00.000000) + └─TableScan_5 3.00 cop table:t, range:[-inf,+inf], keep order:false +drop table t; diff --git a/cmd/explaintest/r/explain_easy_stats.result b/cmd/explaintest/r/explain_easy_stats.result index 44f693e962177..4c0329fd824de 100644 --- a/cmd/explaintest/r/explain_easy_stats.result +++ b/cmd/explaintest/r/explain_easy_stats.result @@ -108,15 +108,14 @@ MemTableScan_4 10000.00 root explain select c2 = (select c2 from t2 where t1.c1 = t2.c1 order by c1 limit 1) from t1; id count task operator info Projection_12 1999.00 root eq(test.t1.c2, test.t2.c2) -└─Apply_14 1999.00 root left outer join, inner:Limit_21 - ├─TableReader_16 1999.00 root data:TableScan_15 - │ └─TableScan_15 1999.00 cop table:t1, range:[-inf,+inf], keep order:false - └─Limit_21 1.00 root offset:0, count:1 - └─Projection_41 1.00 root test.t2.c1, test.t2.c2 - └─IndexLookUp_40 1.00 root - ├─Limit_39 1.00 cop offset:0, count:1 - │ └─IndexScan_37 1.25 cop table:t2, index:c1, range: decided by [eq(test.t1.c1, test.t2.c1)], keep order:true - └─TableScan_38 1.00 cop table:t2, keep order:false, stats:pseudo +└─Apply_14 1999.00 root CARTESIAN left outer join, inner:Projection_41 + ├─IndexReader_16 1999.00 root index:IndexScan_15 + │ └─IndexScan_15 1999.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false + └─Projection_41 1.00 root test.t2.c1, test.t2.c2 + └─IndexLookUp_40 1.00 root limit embedded(offset:0, count:1) + ├─Limit_39 1.00 cop offset:0, count:1 + │ └─IndexScan_37 1.25 cop table:t2, index:c1, range: decided by [eq(test.t1.c1, test.t2.c1)], keep order:true + └─TableScan_38 1.00 cop table:t2, keep order:false, stats:pseudo explain select * from t1 order by c1 desc limit 1; id count task operator info Limit_10 1.00 root offset:0, count:1 @@ -127,9 +126,9 @@ set @@session.tidb_opt_insubq_to_join_and_agg=0; explain select 1 in (select c2 from t2) from t1; id count task operator info Projection_6 1999.00 root 5_aux_0 -└─HashLeftJoin_7 1999.00 root left outer semi join, inner:TableReader_12 - ├─TableReader_9 1999.00 root data:TableScan_8 - │ └─TableScan_8 1999.00 cop table:t1, range:[-inf,+inf], keep order:false +└─HashLeftJoin_7 1999.00 root CARTESIAN left outer semi join, inner:TableReader_12 + ├─IndexReader_9 1999.00 root index:IndexScan_8 + │ └─IndexScan_8 1999.00 cop table:t1, index:c2, range:[NULL,+inf], keep order:false └─TableReader_12 0.00 root data:Selection_11 └─Selection_11 0.00 cop eq(1, test.t2.c2) └─TableScan_10 1985.00 cop table:t2, range:[-inf,+inf], keep order:false @@ -142,14 +141,14 @@ node [style=filled, color=lightgrey] color=black label = "root" "Projection_6" -> "HashLeftJoin_7" -"HashLeftJoin_7" -> "TableReader_9" +"HashLeftJoin_7" -> "IndexReader_9" "HashLeftJoin_7" -> "TableReader_12" } subgraph cluster8{ node [style=filled, color=lightgrey] color=black label = "cop" -"TableScan_8" +"IndexScan_8" } subgraph cluster11{ node [style=filled, color=lightgrey] @@ -157,7 +156,7 @@ color=black label = "cop" "Selection_11" -> "TableScan_10" } -"TableReader_9" -> "TableScan_8" +"IndexReader_9" -> "IndexScan_8" "TableReader_12" -> "Selection_11" } @@ -169,18 +168,16 @@ id count task operator info TableDual_5 0.00 root rows:0 explain select * from index_prune WHERE a = 1010010404050976781 AND b = 26467085526790 LIMIT 1, 1; id count task operator info -Limit_9 1.00 root offset:1, count:1 -└─IndexLookUp_14 1.00 root - ├─Limit_13 1.00 cop offset:0, count:2 - │ └─IndexScan_11 1.00 cop table:index_prune, index:a, b, range:[1010010404050976781 26467085526790,1010010404050976781 26467085526790], keep order:false - └─TableScan_12 1.00 cop table:index_prune, keep order:false, stats:pseudo +IndexLookUp_14 1.00 root limit embedded(offset:1, count:1) +├─Limit_13 1.00 cop offset:0, count:2 +│ └─IndexScan_11 1.00 cop table:index_prune, index:a, b, range:[1010010404050976781 26467085526790,1010010404050976781 26467085526790], keep order:false +└─TableScan_12 1.00 cop table:index_prune, keep order:false, stats:pseudo explain select * from index_prune WHERE a = 1010010404050976781 AND b = 26467085526790 LIMIT 1, 0; id count task operator info -Limit_9 0.00 root offset:1, count:0 -└─IndexLookUp_14 0.00 root - ├─Limit_13 0.00 cop offset:0, count:1 - │ └─IndexScan_11 1.00 cop table:index_prune, index:a, b, range:[1010010404050976781 26467085526790,1010010404050976781 26467085526790], keep order:false - └─TableScan_12 0.00 cop table:index_prune, keep order:false, stats:pseudo +IndexLookUp_14 0.00 root limit embedded(offset:1, count:0) +├─Limit_13 0.00 cop offset:0, count:1 +│ └─IndexScan_11 1.00 cop table:index_prune, index:a, b, range:[1010010404050976781 26467085526790,1010010404050976781 26467085526790], keep order:false +└─TableScan_12 0.00 cop table:index_prune, keep order:false, stats:pseudo explain select * from index_prune WHERE a = 1010010404050976781 AND b = 26467085526790 LIMIT 0, 1; id count task operator info Point_Get_1 1.00 root table:index_prune, index:a b diff --git a/cmd/explaintest/r/generated_columns.result b/cmd/explaintest/r/generated_columns.result index 5c10b13bbf593..b8089dfb3a4f1 100644 --- a/cmd/explaintest/r/generated_columns.result +++ b/cmd/explaintest/r/generated_columns.result @@ -32,9 +32,9 @@ IndexReader_6 3323.33 root index:IndexScan_5 └─IndexScan_5 3323.33 cop table:sgc, index:a, b, range:[-inf,3), keep order:false, stats:pseudo EXPLAIN SELECT a, b from sgc where b < 3; id count task operator info -IndexLookUp_10 3323.33 root -├─IndexScan_8 3323.33 cop table:sgc, index:b, range:[-inf,3), keep order:false, stats:pseudo -└─TableScan_9 3323.33 cop table:sgc, keep order:false, stats:pseudo +IndexLookUp_7 3323.33 root +├─IndexScan_5 3323.33 cop table:sgc, index:b, range:[-inf,3), keep order:false, stats:pseudo +└─TableScan_6 3323.33 cop table:sgc, keep order:false, stats:pseudo EXPLAIN SELECT a, b from sgc where a < 3 and b < 3; id count task operator info IndexReader_11 1104.45 root index:Selection_10 @@ -72,10 +72,10 @@ ANALYZE TABLE sgc1, sgc2; EXPLAIN SELECT /*+ TIDB_INLJ(sgc1, sgc2) */ * from sgc1 join sgc2 on sgc1.a=sgc2.a; id count task operator info IndexJoin_17 5.00 root inner join, inner:IndexLookUp_16, outer key:test.sgc2.a, inner key:test.sgc1.a -├─IndexLookUp_16 0.00 root -│ ├─Selection_15 0.00 cop not(isnull(test.sgc1.a)) +├─IndexLookUp_16 5.00 root +│ ├─Selection_15 5.00 cop not(isnull(test.sgc1.a)) │ │ └─IndexScan_13 5.00 cop table:sgc1, index:a, range: decided by [eq(test.sgc1.a, test.sgc2.a)], keep order:false -│ └─TableScan_14 0.00 cop table:sgc1, keep order:false, stats:pseudo +│ └─TableScan_14 5.00 cop table:sgc1, keep order:false, stats:pseudo └─TableReader_20 1.00 root data:Selection_19 └─Selection_19 1.00 cop not(isnull(test.sgc2.a)) └─TableScan_18 1.00 cop table:sgc2, range:[-inf,+inf], keep order:false @@ -86,10 +86,10 @@ Projection_6 5.00 root test.sgc1.j1, test.sgc1.j2, test.sgc1.a, test.sgc1.b, tes ├─TableReader_39 1.00 root data:Selection_38 │ └─Selection_38 1.00 cop not(isnull(test.sgc2.a)) │ └─TableScan_37 1.00 cop table:sgc2, range:[-inf,+inf], keep order:false - └─IndexLookUp_12 0.00 root - ├─Selection_11 0.00 cop not(isnull(test.sgc1.a)) + └─IndexLookUp_12 5.00 root + ├─Selection_11 5.00 cop not(isnull(test.sgc1.a)) │ └─IndexScan_9 5.00 cop table:sgc1, index:a, range: decided by [eq(test.sgc1.a, test.sgc2.a)], keep order:false - └─TableScan_10 0.00 cop table:sgc1, keep order:false, stats:pseudo + └─TableScan_10 5.00 cop table:sgc1, keep order:false, stats:pseudo DROP TABLE IF EXISTS sgc3; CREATE TABLE sgc3 ( j JSON, diff --git a/cmd/explaintest/r/index_join.result b/cmd/explaintest/r/index_join.result index 6d5555bc8993e..b8cac2cbdafba 100644 --- a/cmd/explaintest/r/index_join.result +++ b/cmd/explaintest/r/index_join.result @@ -7,10 +7,10 @@ analyze table t1, t2; explain select /*+ TIDB_INLJ(t1, t2) */ * from t1 join t2 on t1.a=t2.a; id count task operator info IndexJoin_16 5.00 root inner join, inner:IndexLookUp_15, outer key:test.t2.a, inner key:test.t1.a -├─IndexLookUp_15 0.00 root -│ ├─Selection_14 0.00 cop not(isnull(test.t1.a)) +├─IndexLookUp_15 5.00 root +│ ├─Selection_14 5.00 cop not(isnull(test.t1.a)) │ │ └─IndexScan_12 5.00 cop table:t1, index:a, range: decided by [eq(test.t1.a, test.t2.a)], keep order:false -│ └─TableScan_13 0.00 cop table:t1, keep order:false, stats:pseudo +│ └─TableScan_13 5.00 cop table:t1, keep order:false, stats:pseudo └─TableReader_19 1.00 root data:Selection_18 └─Selection_18 1.00 cop not(isnull(test.t2.a)) └─TableScan_17 1.00 cop table:t2, range:[-inf,+inf], keep order:false @@ -21,7 +21,7 @@ Projection_6 5.00 root test.t1.a, test.t1.b, test.t2.a, test.t2.b ├─TableReader_30 1.00 root data:Selection_29 │ └─Selection_29 1.00 cop not(isnull(test.t2.a)) │ └─TableScan_28 1.00 cop table:t2, range:[-inf,+inf], keep order:false - └─IndexLookUp_11 0.00 root - ├─Selection_10 0.00 cop not(isnull(test.t1.a)) + └─IndexLookUp_11 5.00 root + ├─Selection_10 5.00 cop not(isnull(test.t1.a)) │ └─IndexScan_8 5.00 cop table:t1, index:a, range: decided by [eq(test.t1.a, test.t2.a)], keep order:false - └─TableScan_9 0.00 cop table:t1, keep order:false, stats:pseudo + └─TableScan_9 5.00 cop table:t1, keep order:false, stats:pseudo diff --git a/cmd/explaintest/r/partition_pruning.result b/cmd/explaintest/r/partition_pruning.result index d4430618a0342..247f7037ded6c 100644 --- a/cmd/explaintest/r/partition_pruning.result +++ b/cmd/explaintest/r/partition_pruning.result @@ -1040,490 +1040,490 @@ INSERT INTO t1 VALUES (1, '2009-01-01'), (1, '2009-04-01'), (2, '2009-04-01'), EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-03' AS DATETIME); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-03' AS DATETIME); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-03' AS DATETIME); id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-03' AS DATETIME); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-03' AS DATETIME); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop le(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop le(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_6 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-03' AS DATE); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-03' AS DATE); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-03' AS DATE); id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-03) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-03' AS DATE); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-03' AS DATE); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < '2009-04-03 00:00:00'; id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= '2009-04-03 00:00:00'; id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = '2009-04-03 00:00:00'; id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= '2009-04-03 00:00:00'; id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > '2009-04-03 00:00:00'; id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < '2009-04-02 23:59:59'; id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= '2009-04-02 23:59:59'; id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop le(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop le(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = '2009-04-02 23:59:59'; id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_6 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= '2009-04-02 23:59:59'; id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > '2009-04-02 23:59:59'; id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < '2009-04-03'; id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= '2009-04-03'; id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = '2009-04-03'; id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= '2009-04-03'; id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > '2009-04-03'; id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop le(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop le(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_6 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo DROP TABLE t1; # Test with DATE column NOT NULL CREATE TABLE t1 ( @@ -1543,490 +1543,490 @@ INSERT INTO t1 VALUES (1, '2009-01-01'), (1, '2009-04-01'), (2, '2009-04-01'), EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-03' AS DATETIME); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-03' AS DATETIME); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-03' AS DATETIME); id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-03' AS DATETIME); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-03' AS DATETIME); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop le(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop le(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info -TableReader_8 0.00 root data:Selection_7 +IndexReader_8 0.00 root index:Selection_7 └─Selection_7 0.00 cop eq(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_6 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-02 23:59:59' AS DATETIME); id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-03' AS DATE); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-03' AS DATE); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-03' AS DATE); id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-03) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-03' AS DATE); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-03' AS DATE); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < '2009-04-03 00:00:00'; id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= '2009-04-03 00:00:00'; id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = '2009-04-03 00:00:00'; id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= '2009-04-03 00:00:00'; id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > '2009-04-03 00:00:00'; id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < '2009-04-02 23:59:59'; id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= '2009-04-02 23:59:59'; id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop le(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop le(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = '2009-04-02 23:59:59'; id count task operator info -TableReader_8 0.00 root data:Selection_7 +IndexReader_8 0.00 root index:Selection_7 └─Selection_7 0.00 cop eq(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_6 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= '2009-04-02 23:59:59'; id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > '2009-04-02 23:59:59'; id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:59.000000) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < '2009-04-03'; id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= '2009-04-03'; id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = '2009-04-03'; id count task operator info -TableReader_8 10.00 root data:Selection_7 +IndexReader_8 10.00 root index:Selection_7 └─Selection_7 10.00 cop eq(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= '2009-04-03'; id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > '2009-04-03'; id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:00.000000) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop lt(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info Union_9 9970.00 root -├─TableReader_12 3323.33 root data:Selection_11 +├─IndexReader_12 3323.33 root index:Selection_11 │ └─Selection_11 3323.33 cop le(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_10 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_15 3323.33 root data:Selection_14 +│ └─IndexScan_10 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_15 3323.33 root index:Selection_14 │ └─Selection_14 3323.33 cop le(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_13 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_18 3323.33 root data:Selection_17 +│ └─IndexScan_13 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_18 3323.33 root index:Selection_17 └─Selection_17 3323.33 cop le(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_16 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_16 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info -TableReader_8 0.00 root data:Selection_7 +IndexReader_8 0.00 root index:Selection_7 └─Selection_7 0.00 cop eq(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_6 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop ge(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-03 00:00:01' AS DATETIME); id count task operator info Union_10 13333.33 root -├─TableReader_13 3333.33 root data:Selection_12 +├─IndexReader_13 3333.33 root index:Selection_12 │ └─Selection_12 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_11 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_16 3333.33 root data:Selection_15 +│ └─IndexScan_11 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_16 3333.33 root index:Selection_15 │ └─Selection_15 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_14 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_19 3333.33 root data:Selection_18 +│ └─IndexScan_14 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_19 3333.33 root index:Selection_18 │ └─Selection_18 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:01) -│ └─TableScan_17 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_22 3333.33 root data:Selection_21 +│ └─IndexScan_17 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_22 3333.33 root index:Selection_21 └─Selection_21 3333.33 cop gt(test.t1.b, 2009-04-03 00:00:01) - └─TableScan_20 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_20 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b < CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop lt(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b <= CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info Union_8 6646.67 root -├─TableReader_11 3323.33 root data:Selection_10 +├─IndexReader_11 3323.33 root index:Selection_10 │ └─Selection_10 3323.33 cop le(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_9 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_14 3323.33 root data:Selection_13 +│ └─IndexScan_9 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_14 3323.33 root index:Selection_13 └─Selection_13 3323.33 cop le(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_12 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_12 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b = CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info -TableReader_8 0.00 root data:Selection_7 +IndexReader_8 0.00 root index:Selection_7 └─Selection_7 0.00 cop eq(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_6 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_6 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b >= CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop ge(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo EXPLAIN SELECT * FROM t1 WHERE b > CAST('2009-04-02 23:59:58' AS DATETIME); id count task operator info Union_11 16666.67 root -├─TableReader_14 3333.33 root data:Selection_13 +├─IndexReader_14 3333.33 root index:Selection_13 │ └─Selection_13 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_12 10000.00 cop table:t1, partition:p20090401, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_17 3333.33 root data:Selection_16 +│ └─IndexScan_12 10000.00 cop table:t1, partition:p20090401, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_17 3333.33 root index:Selection_16 │ └─Selection_16 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_15 10000.00 cop table:t1, partition:p20090402, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_20 3333.33 root data:Selection_19 +│ └─IndexScan_15 10000.00 cop table:t1, partition:p20090402, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_20 3333.33 root index:Selection_19 │ └─Selection_19 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_18 10000.00 cop table:t1, partition:p20090403, range:[-inf,+inf], keep order:false, stats:pseudo -├─TableReader_23 3333.33 root data:Selection_22 +│ └─IndexScan_18 10000.00 cop table:t1, partition:p20090403, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +├─IndexReader_23 3333.33 root index:Selection_22 │ └─Selection_22 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) -│ └─TableScan_21 10000.00 cop table:t1, partition:p20090404, range:[-inf,+inf], keep order:false, stats:pseudo -└─TableReader_26 3333.33 root data:Selection_25 +│ └─IndexScan_21 10000.00 cop table:t1, partition:p20090404, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo +└─IndexReader_26 3333.33 root index:Selection_25 └─Selection_25 3333.33 cop gt(test.t1.b, 2009-04-02 23:59:58) - └─TableScan_24 10000.00 cop table:t1, partition:p20090405, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_24 10000.00 cop table:t1, partition:p20090405, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo DROP TABLE t1; # Test with DATETIME column NULL CREATE TABLE t1 ( diff --git a/cmd/explaintest/r/select.result b/cmd/explaintest/r/select.result index 5f56ff0e3c775..f61b0656ced58 100644 --- a/cmd/explaintest/r/select.result +++ b/cmd/explaintest/r/select.result @@ -255,9 +255,9 @@ create table t (a int, b int, c int, key idx(a, b, c)); explain select count(a) from t; id count task operator info StreamAgg_16 1.00 root funcs:count(col_0) -└─TableReader_17 1.00 root data:StreamAgg_8 +└─IndexReader_17 1.00 root index:StreamAgg_8 └─StreamAgg_8 1.00 cop funcs:count(test.t.a) - └─TableScan_15 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexScan_15 10000.00 cop table:t, index:a, b, c, range:[NULL,+inf], keep order:false, stats:pseudo select count(a) from t; count(a) 0 @@ -325,7 +325,7 @@ drop table if exists t; create table t (id int primary key, a int, b int); explain select * from (t t1 left join t t2 on t1.a = t2.a) left join (t t3 left join t t4 on t3.a = t4.a) on t2.b = 1; id count task operator info -HashLeftJoin_10 155937656.25 root left outer join, inner:HashLeftJoin_17, left cond:[eq(test.t2.b, 1)] +HashLeftJoin_10 155937656.25 root CARTESIAN left outer join, inner:HashLeftJoin_17, left cond:[eq(test.t2.b, 1)] ├─HashLeftJoin_11 12487.50 root left outer join, inner:TableReader_16, equal:[eq(test.t1.a, test.t2.a)] │ ├─TableReader_13 10000.00 root data:TableScan_12 │ │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo @@ -381,7 +381,7 @@ create table t(a int, b int); explain select a != any (select a from t t2) from t t1; id count task operator info Projection_9 10000.00 root and(or(or(gt(col_count, 1), ne(test.t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 0)), and(ne(agg_col_cnt, 0), if(isnull(test.t1.a), NULL, 1))) -└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 +└─HashLeftJoin_10 10000.00 root CARTESIAN inner join, inner:StreamAgg_17 ├─TableReader_13 10000.00 root data:TableScan_12 │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1) @@ -391,7 +391,7 @@ Projection_9 10000.00 root and(or(or(gt(col_count, 1), ne(test.t1.a, col_firstro explain select a = all (select a from t t2) from t t1; id count task operator info Projection_9 10000.00 root or(and(and(le(col_count, 1), eq(test.t1.a, col_firstrow)), if(ne(agg_col_sum, 0), NULL, 1)), or(eq(agg_col_cnt, 0), if(isnull(test.t1.a), NULL, 0))) -└─HashLeftJoin_10 10000.00 root inner join, inner:StreamAgg_17 +└─HashLeftJoin_10 10000.00 root CARTESIAN inner join, inner:StreamAgg_17 ├─TableReader_13 10000.00 root data:TableScan_12 │ └─TableScan_12 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo └─StreamAgg_17 1.00 root funcs:firstrow(col_0), count(distinct col_1), sum(col_2), count(1) @@ -430,3 +430,21 @@ Projection_7 10000.00 root 6_aux_0 │ └─TableScan_9 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo └─TableReader_12 10000.00 root data:TableScan_11 └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +explain select 1 from (select sleep(1)) t; +id count task operator info +Projection_4 1.00 root 1 +└─Projection_5 1.00 root sleep(1) + └─TableDual_6 1.00 root rows:1 +drop table t; +CREATE TABLE t (id int(10) unsigned NOT NULL AUTO_INCREMENT, +i int(10) unsigned DEFAULT NULL, +x int(10) unsigned DEFAULT 0, +PRIMARY KEY (`id`) +); +explain select row_number() over( partition by i ) - x as rnk from t; +id count task operator info +Projection_8 10000.00 root minus(4_window_3, test.t.x) +└─Window_9 10000.00 root row_number() over(partition by test.t.i) + └─Sort_12 10000.00 root test.t.i:asc + └─TableReader_11 10000.00 root data:TableScan_10 + └─TableScan_10 10000.00 cop table:t, range:[0,+inf], keep order:false, stats:pseudo diff --git a/cmd/explaintest/r/subquery.result b/cmd/explaintest/r/subquery.result index f0bad21a8d2e0..c3974d52757e3 100644 --- a/cmd/explaintest/r/subquery.result +++ b/cmd/explaintest/r/subquery.result @@ -4,7 +4,7 @@ create table t1(a bigint, b bigint); create table t2(a bigint, b bigint); explain select * from t1 where t1.a in (select t1.b + t2.b from t2); id count task operator info -HashLeftJoin_8 8000.00 root semi join, inner:TableReader_12, other cond:eq(test.t1.a, plus(test.t1.b, test.t2.b)) +HashLeftJoin_8 8000.00 root CARTESIAN semi join, inner:TableReader_12, other cond:eq(test.t1.a, plus(test.t1.b, test.t2.b)) ├─TableReader_10 10000.00 root data:TableScan_9 │ └─TableScan_9 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo └─TableReader_12 10000.00 root data:TableScan_11 @@ -16,9 +16,9 @@ analyze table t; explain select t.c in (select count(*) from t s use index(idx), t t1 where s.b = 1 and s.c = 1 and s.d = t.a and s.a = t1.a) from t; id count task operator info Projection_11 5.00 root 9_aux_0 -└─Apply_13 5.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) - ├─TableReader_15 5.00 root data:TableScan_14 - │ └─TableScan_14 5.00 cop table:t, range:[-inf,+inf], keep order:false +└─Apply_13 5.00 root CARTESIAN left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0) + ├─IndexReader_15 5.00 root index:IndexScan_14 + │ └─IndexScan_14 5.00 cop table:t, index:b, c, d, range:[NULL,+inf], keep order:false └─StreamAgg_20 1.00 root funcs:count(1) └─IndexJoin_23 0.50 root inner join, inner:TableReader_22, outer key:test.s.a, inner key:test.t1.a ├─IndexReader_27 1.00 root index:IndexScan_26 diff --git a/cmd/explaintest/r/topn_push_down.result b/cmd/explaintest/r/topn_push_down.result index b6080a8720f82..9deb403ecceca 100644 --- a/cmd/explaintest/r/topn_push_down.result +++ b/cmd/explaintest/r/topn_push_down.result @@ -177,12 +177,12 @@ Projection_13 0.00 root test.te.expect_time │ │ │ └─IndexScan_70 10.00 cop table:tr, index:shop_identy, trade_status, business_type, trade_pay_status, trade_type, delivery_type, source, biz_date, range:[810094178,810094178], keep order:false, stats:pseudo │ │ └─Selection_73 0.00 cop eq(test.tr.brand_identy, 32314), eq(test.tr.domain_type, 2) │ │ └─TableScan_71 0.00 cop table:tr, keep order:false, stats:pseudo - │ └─IndexLookUp_35 250.00 root + │ └─IndexLookUp_35 0.25 root │ ├─IndexScan_32 10.00 cop table:te, index:trade_id, range: decided by [eq(test.te.trade_id, test.tr.id)], keep order:false, stats:pseudo - │ └─Selection_34 250.00 cop ge(test.te.expect_time, 2018-04-23 00:00:00.000000), le(test.te.expect_time, 2018-04-23 23:59:59.000000) + │ └─Selection_34 0.25 cop ge(test.te.expect_time, 2018-04-23 00:00:00.000000), le(test.te.expect_time, 2018-04-23 23:59:59.000000) │ └─TableScan_33 10.00 cop table:te, keep order:false, stats:pseudo - └─IndexReader_91 0.00 root index:Selection_90 - └─Selection_90 0.00 cop not(isnull(test.p.relate_id)) + └─IndexReader_91 9.99 root index:Selection_90 + └─Selection_90 9.99 cop not(isnull(test.p.relate_id)) └─IndexScan_89 10.00 cop table:p, index:relate_id, range: decided by [eq(test.p.relate_id, test.tr.id)], keep order:false, stats:pseudo desc select 1 as a from dual order by a limit 1; id count task operator info @@ -223,8 +223,8 @@ explain select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 on t1.a = t2.a limit 5 id count task operator info Limit_11 5.00 root offset:0, count:5 └─IndexJoin_15 5.00 root inner join, inner:IndexReader_14, outer key:test.t1.a, inner key:test.t2.a - ├─TableReader_17 4.00 root data:TableScan_16 - │ └─TableScan_16 4.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + ├─IndexReader_17 4.00 root index:IndexScan_16 + │ └─IndexScan_16 4.00 cop table:t1, index:a, range:[NULL,+inf], keep order:false, stats:pseudo └─IndexReader_14 10.00 root index:IndexScan_13 └─IndexScan_13 10.00 cop table:t2, index:a, range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo explain select /*+ TIDB_INLJ(t2) */ * from t t1 left join t t2 on t1.a = t2.a where t2.a is null limit 5; @@ -232,8 +232,8 @@ id count task operator info Limit_12 5.00 root offset:0, count:5 └─Selection_13 5.00 root isnull(test.t2.a) └─IndexJoin_17 5.00 root left outer join, inner:IndexReader_16, outer key:test.t1.a, inner key:test.t2.a - ├─TableReader_19 4.00 root data:TableScan_18 - │ └─TableScan_18 4.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo + ├─IndexReader_19 4.00 root index:IndexScan_18 + │ └─IndexScan_18 4.00 cop table:t1, index:a, range:[NULL,+inf], keep order:false, stats:pseudo └─IndexReader_16 10.00 root index:IndexScan_15 └─IndexScan_15 10.00 cop table:t2, index:a, range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo explain select /*+ TIDB_SMJ(t1, t2) */ * from t t1 join t t2 on t1.a = t2.a limit 5; @@ -256,17 +256,17 @@ Limit_12 5.00 root offset:0, count:5 explain select /*+ TIDB_HJ(t1, t2) */ * from t t1 join t t2 on t1.a = t2.a limit 5; id count task operator info Limit_11 5.00 root offset:0, count:5 -└─HashLeftJoin_19 5.00 root inner join, inner:TableReader_24, equal:[eq(test.t1.a, test.t2.a)] - ├─TableReader_22 4.00 root data:TableScan_21 - │ └─TableScan_21 4.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo - └─TableReader_24 10000.00 root data:TableScan_23 - └─TableScan_23 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo +└─HashLeftJoin_19 5.00 root inner join, inner:IndexReader_24, equal:[eq(test.t1.a, test.t2.a)] + ├─IndexReader_22 4.00 root index:IndexScan_21 + │ └─IndexScan_21 4.00 cop table:t1, index:a, range:[NULL,+inf], keep order:false, stats:pseudo + └─IndexReader_24 10000.00 root index:IndexScan_23 + └─IndexScan_23 10000.00 cop table:t2, index:a, range:[NULL,+inf], keep order:false, stats:pseudo explain select /*+ TIDB_HJ(t1, t2) */ * from t t1 left join t t2 on t1.a = t2.a where t2.a is null limit 5; id count task operator info Limit_12 5.00 root offset:0, count:5 └─Selection_13 5.00 root isnull(test.t2.a) - └─HashLeftJoin_18 5.00 root left outer join, inner:TableReader_22, equal:[eq(test.t1.a, test.t2.a)] - ├─TableReader_20 4.00 root data:TableScan_19 - │ └─TableScan_19 4.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo - └─TableReader_22 10000.00 root data:TableScan_21 - └─TableScan_21 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo + └─HashLeftJoin_18 5.00 root left outer join, inner:IndexReader_22, equal:[eq(test.t1.a, test.t2.a)] + ├─IndexReader_20 4.00 root index:IndexScan_19 + │ └─IndexScan_19 4.00 cop table:t1, index:a, range:[NULL,+inf], keep order:false, stats:pseudo + └─IndexReader_22 10000.00 root index:IndexScan_21 + └─IndexScan_21 10000.00 cop table:t2, index:a, range:[NULL,+inf], keep order:false, stats:pseudo diff --git a/cmd/explaintest/r/tpch.result b/cmd/explaintest/r/tpch.result index 21ce363ba67ac..434610bf692cf 100644 --- a/cmd/explaintest/r/tpch.result +++ b/cmd/explaintest/r/tpch.result @@ -124,7 +124,7 @@ Sort_6 2.94 root tpch.lineitem.l_returnflag:asc, tpch.lineitem.l_linestatus:asc └─HashAgg_14 2.94 root group by:col_13, col_14, funcs:sum(col_0), sum(col_1), sum(col_2), sum(col_3), avg(col_4, col_5), avg(col_6, col_7), avg(col_8, col_9), count(col_10), firstrow(col_11), firstrow(col_12) └─TableReader_15 2.94 root data:HashAgg_9 └─HashAgg_9 2.94 cop group by:tpch.lineitem.l_linestatus, tpch.lineitem.l_returnflag, funcs:sum(tpch.lineitem.l_quantity), sum(tpch.lineitem.l_extendedprice), sum(mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount))), sum(mul(mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), plus(1, tpch.lineitem.l_tax))), avg(tpch.lineitem.l_quantity), avg(tpch.lineitem.l_extendedprice), avg(tpch.lineitem.l_discount), count(1), firstrow(tpch.lineitem.l_returnflag), firstrow(tpch.lineitem.l_linestatus) - └─Selection_13 293683189.00 cop le(tpch.lineitem.l_shipdate, 1998-08-15) + └─Selection_13 293795345.00 cop le(tpch.lineitem.l_shipdate, 1998-08-15) └─TableScan_12 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false /* Q2 Minimum Cost Supplier Query @@ -182,38 +182,39 @@ s_name, p_partkey limit 100; id count task operator info -Projection_36 100.00 root tpch.supplier.s_acctbal, tpch.supplier.s_name, tpch.nation.n_name, tpch.part.p_partkey, tpch.part.p_mfgr, tpch.supplier.s_address, tpch.supplier.s_phone, tpch.supplier.s_comment -└─TopN_39 100.00 root tpch.supplier.s_acctbal:desc, tpch.nation.n_name:asc, tpch.supplier.s_name:asc, tpch.part.p_partkey:asc, offset:0, count:100 - └─HashRightJoin_44 155496.00 root inner join, inner:HashLeftJoin_50, equal:[eq(tpch.part.p_partkey, tpch.partsupp.ps_partkey) eq(tpch.partsupp.ps_supplycost, min(ps_supplycost))] - ├─HashLeftJoin_50 155496.00 root inner join, inner:TableReader_73, equal:[eq(tpch.partsupp.ps_partkey, tpch.part.p_partkey)] - │ ├─HashRightJoin_53 8155010.44 root inner join, inner:HashRightJoin_55, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)] - │ │ ├─HashRightJoin_55 100000.00 root inner join, inner:HashRightJoin_61, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] - │ │ │ ├─HashRightJoin_61 5.00 root inner join, inner:TableReader_66, equal:[eq(tpch.region.r_regionkey, tpch.nation.n_regionkey)] - │ │ │ │ ├─TableReader_66 1.00 root data:Selection_65 - │ │ │ │ │ └─Selection_65 1.00 cop eq(tpch.region.r_name, "ASIA") - │ │ │ │ │ └─TableScan_64 5.00 cop table:region, range:[-inf,+inf], keep order:false - │ │ │ │ └─TableReader_63 25.00 root data:TableScan_62 - │ │ │ │ └─TableScan_62 25.00 cop table:nation, range:[-inf,+inf], keep order:false - │ │ │ └─TableReader_68 500000.00 root data:TableScan_67 - │ │ │ └─TableScan_67 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false - │ │ └─TableReader_70 40000000.00 root data:TableScan_69 - │ │ └─TableScan_69 40000000.00 cop table:partsupp, range:[-inf,+inf], keep order:false - │ └─TableReader_73 155496.00 root data:Selection_72 - │ └─Selection_72 155496.00 cop eq(tpch.part.p_size, 30), like(tpch.part.p_type, "%STEEL", 92) - │ └─TableScan_71 10000000.00 cop table:part, range:[-inf,+inf], keep order:false - └─HashAgg_76 8155010.44 root group by:tpch.partsupp.ps_partkey, funcs:min(tpch.partsupp.ps_supplycost), firstrow(tpch.partsupp.ps_partkey) - └─HashRightJoin_80 8155010.44 root inner join, inner:HashRightJoin_82, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)] - ├─HashRightJoin_82 100000.00 root inner join, inner:HashRightJoin_88, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] - │ ├─HashRightJoin_88 5.00 root inner join, inner:TableReader_93, equal:[eq(tpch.region.r_regionkey, tpch.nation.n_regionkey)] - │ │ ├─TableReader_93 1.00 root data:Selection_92 - │ │ │ └─Selection_92 1.00 cop eq(tpch.region.r_name, "ASIA") - │ │ │ └─TableScan_91 5.00 cop table:region, range:[-inf,+inf], keep order:false - │ │ └─TableReader_90 25.00 root data:TableScan_89 - │ │ └─TableScan_89 25.00 cop table:nation, range:[-inf,+inf], keep order:false - │ └─TableReader_95 500000.00 root data:TableScan_94 - │ └─TableScan_94 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false - └─TableReader_97 40000000.00 root data:TableScan_96 - └─TableScan_96 40000000.00 cop table:partsupp, range:[-inf,+inf], keep order:false +Projection_37 100.00 root tpch.supplier.s_acctbal, tpch.supplier.s_name, tpch.nation.n_name, tpch.part.p_partkey, tpch.part.p_mfgr, tpch.supplier.s_address, tpch.supplier.s_phone, tpch.supplier.s_comment +└─TopN_40 100.00 root tpch.supplier.s_acctbal:desc, tpch.nation.n_name:asc, tpch.supplier.s_name:asc, tpch.part.p_partkey:asc, offset:0, count:100 + └─HashRightJoin_45 155496.00 root inner join, inner:HashLeftJoin_51, equal:[eq(tpch.part.p_partkey, tpch.partsupp.ps_partkey) eq(tpch.partsupp.ps_supplycost, min(ps_supplycost))] + ├─HashLeftJoin_51 155496.00 root inner join, inner:TableReader_74, equal:[eq(tpch.partsupp.ps_partkey, tpch.part.p_partkey)] + │ ├─HashRightJoin_54 8155010.44 root inner join, inner:HashRightJoin_56, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)] + │ │ ├─HashRightJoin_56 100000.00 root inner join, inner:HashRightJoin_62, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] + │ │ │ ├─HashRightJoin_62 5.00 root inner join, inner:TableReader_67, equal:[eq(tpch.region.r_regionkey, tpch.nation.n_regionkey)] + │ │ │ │ ├─TableReader_67 1.00 root data:Selection_66 + │ │ │ │ │ └─Selection_66 1.00 cop eq(tpch.region.r_name, "ASIA") + │ │ │ │ │ └─TableScan_65 5.00 cop table:region, range:[-inf,+inf], keep order:false + │ │ │ │ └─TableReader_64 25.00 root data:TableScan_63 + │ │ │ │ └─TableScan_63 25.00 cop table:nation, range:[-inf,+inf], keep order:false + │ │ │ └─TableReader_69 500000.00 root data:TableScan_68 + │ │ │ └─TableScan_68 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false + │ │ └─TableReader_71 40000000.00 root data:TableScan_70 + │ │ └─TableScan_70 40000000.00 cop table:partsupp, range:[-inf,+inf], keep order:false + │ └─TableReader_74 155496.00 root data:Selection_73 + │ └─Selection_73 155496.00 cop eq(tpch.part.p_size, 30), like(tpch.part.p_type, "%STEEL", 92) + │ └─TableScan_72 10000000.00 cop table:part, range:[-inf,+inf], keep order:false + └─Selection_75 6524008.35 root not(isnull(19_col_0)) + └─HashAgg_78 8155010.44 root group by:tpch.partsupp.ps_partkey, funcs:min(tpch.partsupp.ps_supplycost), firstrow(tpch.partsupp.ps_partkey) + └─HashRightJoin_82 8155010.44 root inner join, inner:HashRightJoin_84, equal:[eq(tpch.supplier.s_suppkey, tpch.partsupp.ps_suppkey)] + ├─HashRightJoin_84 100000.00 root inner join, inner:HashRightJoin_90, equal:[eq(tpch.nation.n_nationkey, tpch.supplier.s_nationkey)] + │ ├─HashRightJoin_90 5.00 root inner join, inner:TableReader_95, equal:[eq(tpch.region.r_regionkey, tpch.nation.n_regionkey)] + │ │ ├─TableReader_95 1.00 root data:Selection_94 + │ │ │ └─Selection_94 1.00 cop eq(tpch.region.r_name, "ASIA") + │ │ │ └─TableScan_93 5.00 cop table:region, range:[-inf,+inf], keep order:false + │ │ └─TableReader_92 25.00 root data:TableScan_91 + │ │ └─TableScan_91 25.00 cop table:nation, range:[-inf,+inf], keep order:false + │ └─TableReader_97 500000.00 root data:TableScan_96 + │ └─TableScan_96 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false + └─TableReader_99 40000000.00 root data:TableScan_98 + └─TableScan_98 40000000.00 cop table:partsupp, range:[-inf,+inf], keep order:false /* Q3 Shipping Priority Query This query retrieves the 10 unshipped orders with the highest value. @@ -250,7 +251,7 @@ limit 10; id count task operator info Projection_14 10.00 root tpch.lineitem.l_orderkey, 7_col_0, tpch.orders.o_orderdate, tpch.orders.o_shippriority └─TopN_17 10.00 root 7_col_0:desc, tpch.orders.o_orderdate:asc, offset:0, count:10 - └─HashAgg_23 40227041.09 root group by:col_4, col_5, col_6, funcs:sum(col_0), firstrow(col_1), firstrow(col_2), firstrow(col_3) + └─HashAgg_23 40252367.98 root group by:col_4, col_5, col_6, funcs:sum(col_0), firstrow(col_1), firstrow(col_2), firstrow(col_3) └─Projection_59 91515927.49 root mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), tpch.orders.o_orderdate, tpch.orders.o_shippriority, tpch.lineitem.l_orderkey, tpch.lineitem.l_orderkey, tpch.orders.o_orderdate, tpch.orders.o_shippriority └─IndexJoin_29 91515927.49 root inner join, inner:IndexLookUp_28, outer key:tpch.orders.o_orderkey, inner key:tpch.lineitem.l_orderkey ├─HashRightJoin_49 22592975.51 root inner join, inner:TableReader_55, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] @@ -260,9 +261,9 @@ Projection_14 10.00 root tpch.lineitem.l_orderkey, 7_col_0, tpch.orders.o_orderd │ └─TableReader_52 36870000.00 root data:Selection_51 │ └─Selection_51 36870000.00 cop lt(tpch.orders.o_orderdate, 1995-03-13 00:00:00.000000) │ └─TableScan_50 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false - └─IndexLookUp_28 162945114.27 root + └─IndexLookUp_28 0.54 root ├─IndexScan_25 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [eq(tpch.lineitem.l_orderkey, tpch.orders.o_orderkey)], keep order:false - └─Selection_27 162945114.27 cop gt(tpch.lineitem.l_shipdate, 1995-03-13 00:00:00.000000) + └─Selection_27 0.54 cop gt(tpch.lineitem.l_shipdate, 1995-03-13 00:00:00.000000) └─TableScan_26 1.00 cop table:lineitem, keep order:false /* Q4 Order Priority Checking Query @@ -301,9 +302,9 @@ Sort_10 1.00 root tpch.orders.o_orderpriority:asc ├─TableReader_33 2925937.50 root data:Selection_32 │ └─Selection_32 2925937.50 cop ge(tpch.orders.o_orderdate, 1995-01-01 00:00:00.000000), lt(tpch.orders.o_orderdate, 1995-04-01) │ └─TableScan_31 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false - └─IndexLookUp_20 240004648.80 root + └─IndexLookUp_20 0.80 root ├─IndexScan_17 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [eq(tpch.lineitem.l_orderkey, tpch.orders.o_orderkey)], keep order:false - └─Selection_19 240004648.80 cop lt(tpch.lineitem.l_commitdate, tpch.lineitem.l_receiptdate) + └─Selection_19 0.80 cop lt(tpch.lineitem.l_commitdate, tpch.lineitem.l_receiptdate) └─TableScan_18 1.00 cop table:lineitem, keep order:false /* Q5 Local Supplier Volume Query @@ -443,9 +444,9 @@ supp_nation, cust_nation, l_year; id count task operator info -Sort_22 768.91 root tpch.shipping.supp_nation:asc, tpch.shipping.cust_nation:asc, shipping.l_year:asc -└─Projection_24 768.91 root tpch.shipping.supp_nation, tpch.shipping.cust_nation, shipping.l_year, 14_col_0 - └─HashAgg_27 768.91 root group by:shipping.l_year, tpch.shipping.cust_nation, tpch.shipping.supp_nation, funcs:sum(shipping.volume), firstrow(tpch.shipping.supp_nation), firstrow(tpch.shipping.cust_nation), firstrow(shipping.l_year) +Sort_22 769.96 root tpch.shipping.supp_nation:asc, tpch.shipping.cust_nation:asc, shipping.l_year:asc +└─Projection_24 769.96 root tpch.shipping.supp_nation, tpch.shipping.cust_nation, shipping.l_year, 14_col_0 + └─HashAgg_27 769.96 root group by:shipping.l_year, tpch.shipping.cust_nation, tpch.shipping.supp_nation, funcs:sum(shipping.volume), firstrow(tpch.shipping.supp_nation), firstrow(tpch.shipping.cust_nation), firstrow(shipping.l_year) └─Projection_28 1957240.42 root tpch.n1.n_name, tpch.n2.n_name, extract("YEAR", tpch.lineitem.l_shipdate), mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)) └─HashLeftJoin_33 1957240.42 root inner join, inner:TableReader_68, equal:[eq(tpch.customer.c_nationkey, tpch.n2.n_nationkey)], other cond:or(and(eq(tpch.n1.n_name, "JAPAN"), eq(tpch.n2.n_name, "INDIA")), and(eq(tpch.n1.n_name, "INDIA"), eq(tpch.n2.n_name, "JAPAN"))) ├─IndexJoin_37 24465505.20 root inner join, inner:TableReader_36, outer key:tpch.orders.o_custkey, inner key:tpch.customer.c_custkey @@ -457,8 +458,8 @@ Sort_22 768.91 root tpch.shipping.supp_nation:asc, tpch.shipping.cust_nation:asc │ │ │ │ │ └─TableScan_56 25.00 cop table:n1, range:[-inf,+inf], keep order:false │ │ │ │ └─TableReader_55 500000.00 root data:TableScan_54 │ │ │ │ └─TableScan_54 500000.00 cop table:supplier, range:[-inf,+inf], keep order:false - │ │ │ └─TableReader_61 91321768.29 root data:Selection_60 - │ │ │ └─Selection_60 91321768.29 cop ge(tpch.lineitem.l_shipdate, 1995-01-01 00:00:00.000000), le(tpch.lineitem.l_shipdate, 1996-12-31 00:00:00.000000) + │ │ │ └─TableReader_61 91446230.29 root data:Selection_60 + │ │ │ └─Selection_60 91446230.29 cop ge(tpch.lineitem.l_shipdate, 1995-01-01 00:00:00.000000), le(tpch.lineitem.l_shipdate, 1996-12-31 00:00:00.000000) │ │ │ └─TableScan_59 300005811.00 cop table:lineitem, range:[-inf,+inf], keep order:false │ │ └─TableReader_42 1.00 root data:TableScan_41 │ │ └─TableScan_41 1.00 cop table:orders, range: decided by [tpch.lineitem.l_orderkey], keep order:false @@ -515,16 +516,16 @@ o_year order by o_year; id count task operator info -Sort_29 718.01 root all_nations.o_year:asc -└─Projection_31 718.01 root all_nations.o_year, div(18_col_0, 18_col_1) - └─HashAgg_34 718.01 root group by:col_3, funcs:sum(col_0), sum(col_1), firstrow(col_2) - └─Projection_89 562348.12 root case(eq(tpch.all_nations.nation, "INDIA"), all_nations.volume, 0), all_nations.volume, all_nations.o_year, all_nations.o_year - └─Projection_35 562348.12 root extract("YEAR", tpch.orders.o_orderdate), mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), tpch.n2.n_name - └─HashLeftJoin_39 562348.12 root inner join, inner:TableReader_87, equal:[eq(tpch.supplier.s_nationkey, tpch.n2.n_nationkey)] - ├─IndexJoin_43 562348.12 root inner join, inner:TableReader_42, outer key:tpch.lineitem.l_suppkey, inner key:tpch.supplier.s_suppkey - │ ├─HashLeftJoin_50 562348.12 root inner join, inner:TableReader_83, equal:[eq(tpch.lineitem.l_partkey, tpch.part.p_partkey)] - │ │ ├─IndexJoin_56 90661378.61 root inner join, inner:IndexLookUp_55, outer key:tpch.orders.o_orderkey, inner key:tpch.lineitem.l_orderkey - │ │ │ ├─HashRightJoin_60 22382008.93 root inner join, inner:HashRightJoin_62, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] +Sort_29 719.02 root all_nations.o_year:asc +└─Projection_31 719.02 root all_nations.o_year, div(18_col_0, 18_col_1) + └─HashAgg_34 719.02 root group by:col_3, funcs:sum(col_0), sum(col_1), firstrow(col_2) + └─Projection_89 563136.02 root case(eq(tpch.all_nations.nation, "INDIA"), all_nations.volume, 0), all_nations.volume, all_nations.o_year, all_nations.o_year + └─Projection_35 563136.02 root extract("YEAR", tpch.orders.o_orderdate), mul(tpch.lineitem.l_extendedprice, minus(1, tpch.lineitem.l_discount)), tpch.n2.n_name + └─HashLeftJoin_39 563136.02 root inner join, inner:TableReader_87, equal:[eq(tpch.supplier.s_nationkey, tpch.n2.n_nationkey)] + ├─IndexJoin_43 563136.02 root inner join, inner:TableReader_42, outer key:tpch.lineitem.l_suppkey, inner key:tpch.supplier.s_suppkey + │ ├─HashLeftJoin_50 563136.02 root inner join, inner:TableReader_83, equal:[eq(tpch.lineitem.l_partkey, tpch.part.p_partkey)] + │ │ ├─IndexJoin_56 90788402.51 root inner join, inner:IndexLookUp_55, outer key:tpch.orders.o_orderkey, inner key:tpch.lineitem.l_orderkey + │ │ │ ├─HashRightJoin_60 22413367.93 root inner join, inner:HashRightJoin_62, equal:[eq(tpch.customer.c_custkey, tpch.orders.o_custkey)] │ │ │ │ ├─HashRightJoin_62 1500000.00 root inner join, inner:HashRightJoin_68, equal:[eq(tpch.n1.n_nationkey, tpch.customer.c_nationkey)] │ │ │ │ │ ├─HashRightJoin_68 5.00 root inner join, inner:TableReader_73, equal:[eq(tpch.region.r_regionkey, tpch.n1.n_regionkey)] │ │ │ │ │ │ ├─TableReader_73 1.00 root data:Selection_72 @@ -534,8 +535,8 @@ Sort_29 718.01 root all_nations.o_year:asc │ │ │ │ │ │ └─TableScan_69 25.00 cop table:n1, range:[-inf,+inf], keep order:false │ │ │ │ │ └─TableReader_75 7500000.00 root data:TableScan_74 │ │ │ │ │ └─TableScan_74 7500000.00 cop table:customer, range:[-inf,+inf], keep order:false - │ │ │ │ └─TableReader_78 22382008.93 root data:Selection_77 - │ │ │ │ └─Selection_77 22382008.93 cop ge(tpch.orders.o_orderdate, 1995-01-01 00:00:00.000000), le(tpch.orders.o_orderdate, 1996-12-31 00:00:00.000000) + │ │ │ │ └─TableReader_78 22413367.93 root data:Selection_77 + │ │ │ │ └─Selection_77 22413367.93 cop ge(tpch.orders.o_orderdate, 1995-01-01 00:00:00.000000), le(tpch.orders.o_orderdate, 1996-12-31 00:00:00.000000) │ │ │ │ └─TableScan_76 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false │ │ │ └─IndexLookUp_55 1.00 root │ │ │ ├─IndexScan_53 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [eq(tpch.lineitem.l_orderkey, tpch.orders.o_orderkey)], keep order:false @@ -672,9 +673,9 @@ Projection_17 20.00 root tpch.customer.c_custkey, tpch.customer.c_name, 9_col_0, │ └─TableReader_48 3017307.69 root data:Selection_47 │ └─Selection_47 3017307.69 cop ge(tpch.orders.o_orderdate, 1993-08-01 00:00:00.000000), lt(tpch.orders.o_orderdate, 1993-11-01) │ └─TableScan_46 75000000.00 cop table:orders, range:[-inf,+inf], keep order:false - └─IndexLookUp_31 73916005.00 root + └─IndexLookUp_31 0.25 root ├─IndexScan_28 1.00 cop table:lineitem, index:L_ORDERKEY, L_LINENUMBER, range: decided by [eq(tpch.lineitem.l_orderkey, tpch.orders.o_orderkey)], keep order:false - └─Selection_30 73916005.00 cop eq(tpch.lineitem.l_returnflag, "R") + └─Selection_30 0.25 cop eq(tpch.lineitem.l_returnflag, "R") └─TableScan_29 1.00 cop table:lineitem, keep order:false /* Q11 Important Stock Identification Query @@ -1222,7 +1223,7 @@ id count task operator info Projection_25 1.00 root tpch.supplier.s_name, 17_col_0 └─TopN_28 1.00 root 17_col_0:desc, tpch.supplier.s_name:asc, offset:0, count:100 └─HashAgg_34 1.00 root group by:tpch.supplier.s_name, funcs:count(1), firstrow(tpch.supplier.s_name) - └─IndexJoin_40 7828961.66 root anti semi join, inner:IndexLookUp_39, outer key:tpch.l1.l_orderkey, inner key:tpch.l3.l_orderkey, other cond:ne(tpch.l3.l_suppkey, tpch.l1.l_suppkey), ne(tpch.l3.l_suppkey, tpch.supplier.s_suppkey) + └─IndexJoin_40 7828961.66 root anti semi join, inner:IndexLookUp_39, outer key:tpch.l1.l_orderkey, inner key:tpch.l3.l_orderkey, other cond:ne(tpch.l3.l_suppkey, tpch.l1.l_suppkey) ├─IndexJoin_56 9786202.08 root semi join, inner:IndexLookUp_55, outer key:tpch.l1.l_orderkey, inner key:tpch.l2.l_orderkey, other cond:ne(tpch.l2.l_suppkey, tpch.l1.l_suppkey), ne(tpch.l2.l_suppkey, tpch.supplier.s_suppkey) │ ├─IndexJoin_62 12232752.60 root inner join, inner:TableReader_61, outer key:tpch.l1.l_orderkey, inner key:tpch.orders.o_orderkey │ │ ├─HashRightJoin_66 12232752.60 root inner join, inner:HashRightJoin_72, equal:[eq(tpch.supplier.s_suppkey, tpch.l1.l_suppkey)] @@ -1241,9 +1242,9 @@ Projection_25 1.00 root tpch.supplier.s_name, 17_col_0 │ └─IndexLookUp_55 1.00 root │ ├─IndexScan_53 1.00 cop table:l2, index:L_ORDERKEY, L_LINENUMBER, range: decided by [eq(tpch.l2.l_orderkey, tpch.l1.l_orderkey)], keep order:false │ └─TableScan_54 1.00 cop table:lineitem, keep order:false - └─IndexLookUp_39 240004648.80 root + └─IndexLookUp_39 0.80 root ├─IndexScan_36 1.00 cop table:l3, index:L_ORDERKEY, L_LINENUMBER, range: decided by [eq(tpch.l3.l_orderkey, tpch.l1.l_orderkey)], keep order:false - └─Selection_38 240004648.80 cop gt(tpch.l3.l_receiptdate, tpch.l3.l_commitdate) + └─Selection_38 0.80 cop gt(tpch.l3.l_receiptdate, tpch.l3.l_commitdate) └─TableScan_37 1.00 cop table:lineitem, keep order:false /* Q22 Global Sales Opportunity Query diff --git a/cmd/explaintest/r/window_function.result b/cmd/explaintest/r/window_function.result index 3f4e6132897fe..6096ce1ab4023 100644 --- a/cmd/explaintest/r/window_function.result +++ b/cmd/explaintest/r/window_function.result @@ -6,8 +6,8 @@ explain select sum(a) over() from t; id count task operator info Projection_7 10000.00 root sum(a) over() └─Window_8 10000.00 root sum(cast(test.t.a)) over() - └─TableReader_10 10000.00 root data:TableScan_9 - └─TableScan_9 10000.00 cop table:t, range:[-inf,+inf], keep order:false, stats:pseudo + └─IndexReader_10 10000.00 root index:IndexScan_9 + └─IndexScan_9 10000.00 cop table:t, index:a, range:[NULL,+inf], keep order:false, stats:pseudo explain select sum(a) over(partition by a) from t; id count task operator info Projection_7 10000.00 root sum(a) over(partition by a) diff --git a/cmd/explaintest/t/black_list.test b/cmd/explaintest/t/black_list.test new file mode 100644 index 0000000000000..6e93cb4943046 --- /dev/null +++ b/cmd/explaintest/t/black_list.test @@ -0,0 +1,41 @@ +use test; +drop table if exists t; +create table t (a int); + +explain select * from t where a < 1; + +insert into mysql.opt_rule_blacklist values('predicate_push_down'); + +admin reload opt_rule_blacklist; + +explain select * from t where a < 1; + +delete from mysql.opt_rule_blacklist where name='predicate_push_down'; + +admin reload opt_rule_blacklist; + +explain select * from t where a < 1; + +insert into mysql.expr_pushdown_blacklist values('<'); + +admin reload expr_pushdown_blacklist; + +explain select * from t where a < 1; + +delete from mysql.expr_pushdown_blacklist where name='<'; + +admin reload expr_pushdown_blacklist; + +explain select * from t where a < 1; + +insert into mysql.expr_pushdown_blacklist values('lt'); + +admin reload expr_pushdown_blacklist; + +explain select * from t where a < 1; + +delete from mysql.expr_pushdown_blacklist where name='lt'; + +admin reload expr_pushdown_blacklist; + +explain select * from t where a < 1; diff --git a/cmd/explaintest/t/explain_complex.test b/cmd/explaintest/t/explain_complex.test index e39412ba4a114..aae669cc9274f 100644 --- a/cmd/explaintest/t/explain_complex.test +++ b/cmd/explaintest/t/explain_complex.test @@ -131,3 +131,46 @@ CREATE TABLE `tbl_008` (`a` int, `b` int); CREATE TABLE `tbl_009` (`a` int, `b` int); explain select sum(a) from (select * from tbl_001 union all select * from tbl_002 union all select * from tbl_003 union all select * from tbl_004 union all select * from tbl_005 union all select * from tbl_006 union all select * from tbl_007 union all select * from tbl_008 union all select * from tbl_009) x group by b; + +CREATE TABLE org_department ( + id int(11) NOT NULL AUTO_INCREMENT, + ctx int(11) DEFAULT '0' COMMENT 'organization id', + name varchar(128) DEFAULT NULL, + left_value int(11) DEFAULT NULL, + right_value int(11) DEFAULT NULL, + depth int(11) DEFAULT NULL, + leader_id bigint(20) DEFAULT NULL, + status int(11) DEFAULT '1000', + created_on datetime DEFAULT NULL, + updated_on datetime DEFAULT NULL, + PRIMARY KEY (id), + UNIQUE KEY org_department_id_uindex (id), + KEY org_department_leader_id_index (leader_id), + KEY org_department_ctx_index (ctx) +); +CREATE TABLE org_position ( + id int(11) NOT NULL AUTO_INCREMENT, + ctx int(11) DEFAULT NULL, + name varchar(128) DEFAULT NULL, + left_value int(11) DEFAULT NULL, + right_value int(11) DEFAULT NULL, + depth int(11) DEFAULT NULL, + department_id int(11) DEFAULT NULL, + status int(2) DEFAULT NULL, + created_on datetime DEFAULT NULL, + updated_on datetime DEFAULT NULL, + PRIMARY KEY (id), + UNIQUE KEY org_position_id_uindex (id), + KEY org_position_department_id_index (department_id) +) ENGINE=InnoDB AUTO_INCREMENT=22 DEFAULT CHARSET=utf8; + CREATE TABLE org_employee_position ( + hotel_id int(11) DEFAULT NULL, + user_id bigint(20) DEFAULT NULL, + position_id int(11) DEFAULT NULL, + status int(11) DEFAULT NULL, + created_on datetime DEFAULT NULL, + updated_on datetime DEFAULT NULL, + UNIQUE KEY org_employee_position_pk (hotel_id,user_id,position_id) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +explain SELECT d.id, d.ctx, d.name, d.left_value, d.right_value, d.depth, d.leader_id, d.status, d.created_on, d.updated_on FROM org_department AS d LEFT JOIN org_position AS p ON p.department_id = d.id AND p.status = 1000 LEFT JOIN org_employee_position AS ep ON ep.position_id = p.id AND ep.status = 1000 WHERE (d.ctx = 1 AND (ep.user_id = 62 OR d.id = 20 OR d.id = 20) AND d.status = 1000) GROUP BY d.id ORDER BY d.left_value; diff --git a/cmd/explaintest/t/explain_easy.test b/cmd/explaintest/t/explain_easy.test index a8cd985164413..3510b4fabb8f6 100644 --- a/cmd/explaintest/t/explain_easy.test +++ b/cmd/explaintest/t/explain_easy.test @@ -35,6 +35,12 @@ explain select if(10, t1.c1, t1.c2) from t1; explain select c1 from t2 union select c1 from t2 union all select c1 from t2; explain select c1 from t2 union all select c1 from t2 union select c1 from t2; +# https://github.com/pingcap/tidb/issues/9125 +explain select count(1) from (select count(1) from (select * from t1 where c3 = 100) k) k2; +explain select 1 from (select count(c2), count(c3) from t1) k; +explain select count(1) from (select max(c2), count(c3) as m from t1) k; +explain select count(1) from (select count(c2) from t1 group by c3) k; + set @@session.tidb_opt_insubq_to_join_and_agg=0; explain select sum(t1.c1 in (select c1 from t2)) from t1; @@ -79,6 +85,7 @@ drop table if exists t; create table t(a bigint primary key); explain select * from t where a = 1 and a = 2; explain select null or a > 1 from t; +explain select * from t where a = 1 for update; drop table if exists ta, tb; create table ta (a varchar(20)); @@ -145,3 +152,10 @@ insert into t values (1); explain select * from t left outer join t t1 on t.a = t1.a where t.a not between 1 and 2; rollback; drop table if exists t; + +create table t(a time, b date); +insert into t values (1, "1000-01-01"), (2, "1000-01-02"), (3, "1000-01-03"); +analyze table t; +explain select * from t where a = 1; +explain select * from t where b = "1000-01-01"; +drop table t; diff --git a/cmd/explaintest/t/select.test b/cmd/explaintest/t/select.test index 758bfb7466970..af708fb20404d 100644 --- a/cmd/explaintest/t/select.test +++ b/cmd/explaintest/t/select.test @@ -201,3 +201,15 @@ explain select a in (select a+b from t t2 where t2.b = t1.b) from t t1; drop table t; create table t(a int not null, b int); explain select a in (select a from t t2 where t2.b = t1.b) from t t1; + +# test sleep in subquery +explain select 1 from (select sleep(1)) t; + +# test fields with windows function +drop table t; +CREATE TABLE t (id int(10) unsigned NOT NULL AUTO_INCREMENT, + i int(10) unsigned DEFAULT NULL, + x int(10) unsigned DEFAULT 0, + PRIMARY KEY (`id`) +); +explain select row_number() over( partition by i ) - x as rnk from t; diff --git a/config/config.go b/config/config.go index 85384bce1f56c..d6babffa84bb1 100644 --- a/config/config.go +++ b/config/config.go @@ -35,7 +35,9 @@ import ( // Config number limitations const ( - MaxLogFileSize = 4096 // MB + MaxLogFileSize = 4096 // MB + MinPessimisticTTL = time.Second * 15 + MaxPessimisticTTL = time.Second * 120 ) // Valid config maps @@ -82,11 +84,13 @@ type Config struct { Binlog Binlog `toml:"binlog" json:"binlog"` CompatibleKillQuery bool `toml:"compatible-kill-query" json:"compatible-kill-query"` Plugin Plugin `toml:"plugin" json:"plugin"` - PessimisticTxn PessimisticTxn `toml:"pessimistic-txn" json:"pessimistic_txn"` + PessimisticTxn PessimisticTxn `toml:"pessimistic-txn" json:"pessimistic-txn"` CheckMb4ValueInUTF8 bool `toml:"check-mb4-value-in-utf8" json:"check-mb4-value-in-utf8"` // TreatOldVersionUTF8AsUTF8MB4 is use to treat old version table/column UTF8 charset as UTF8MB4. This is for compatibility. // Currently not support dynamic modify, because this need to reload all old version schema. - TreatOldVersionUTF8AsUTF8MB4 bool `toml:"treat-old-version-utf8-as-utf8mb4" json:"treat-old-version-utf8-as-utf8mb4"` + TreatOldVersionUTF8AsUTF8MB4 bool `toml:"treat-old-version-utf8-as-utf8mb4" json:"treat-old-version-utf8-as-utf8mb4"` + SplitRegionMaxNum uint64 `toml:"split-region-max-num" json:"split-region-max-num"` + StmtSummary StmtSummary `toml:"stmt-summary" json:"stmt-summary"` } // Log is the log section of config. @@ -100,10 +104,11 @@ type Log struct { // File log config. File logutil.FileLogConfig `toml:"file" json:"file"` - SlowQueryFile string `toml:"slow-query-file" json:"slow-query-file"` - SlowThreshold uint64 `toml:"slow-threshold" json:"slow-threshold"` - ExpensiveThreshold uint `toml:"expensive-threshold" json:"expensive-threshold"` - QueryLogMaxLen uint64 `toml:"query-log-max-len" json:"query-log-max-len"` + SlowQueryFile string `toml:"slow-query-file" json:"slow-query-file"` + SlowThreshold uint64 `toml:"slow-threshold" json:"slow-threshold"` + ExpensiveThreshold uint `toml:"expensive-threshold" json:"expensive-threshold"` + QueryLogMaxLen uint64 `toml:"query-log-max-len" json:"query-log-max-len"` + RecordPlanInSlowLog uint32 `toml:"record-plan-in-slow-log" json:"record-plan-in-slow-log"` } // Security is the security section of the config. @@ -187,6 +192,7 @@ type Performance struct { QueryFeedbackLimit uint `toml:"query-feedback-limit" json:"query-feedback-limit"` PseudoEstimateRatio float64 `toml:"pseudo-estimate-ratio" json:"pseudo-estimate-ratio"` ForcePriority string `toml:"force-priority" json:"force-priority"` + BindInfoLease string `toml:"bind-info-lease" json:"bind-info-lease"` } // PlanCache is the PlanCache section of the config. @@ -271,6 +277,9 @@ type TiKVClient struct { MaxBatchWaitTime time.Duration `toml:"max-batch-wait-time" json:"max-batch-wait-time"` // BatchWaitSize is the max wait size for batch. BatchWaitSize uint `toml:"batch-wait-size" json:"batch-wait-size"` + // If a Region has not been accessed for more than the given duration (in seconds), it + // will be reloaded from the PD. + RegionCacheTTL uint `toml:"region-cache-ttl" json:"region-cache-ttl"` } // Binlog is the config for binlog. @@ -296,14 +305,20 @@ type Plugin struct { type PessimisticTxn struct { // Enable must be true for 'begin lock' or session variable to start a pessimistic transaction. Enable bool `toml:"enable" json:"enable"` - // Starts a pessimistic transaction by default when Enable is true. - Default bool `toml:"default" json:"default"` // The max count of retry for a single statement in a pessimistic transaction. MaxRetryCount uint `toml:"max-retry-count" json:"max-retry-count"` // The pessimistic lock ttl. TTL string `toml:"ttl" json:"ttl"` } +// StmtSummary is the config for statement summary. +type StmtSummary struct { + // The maximum number of statements kept in memory. + MaxStmtCount uint `toml:"max-stmt-count" json:"max-stmt-count"` + // The maximum length of displayed normalized SQL and sample SQL. + MaxSQLLength uint `toml:"max-sql-length" json:"max-sql-length"` +} + var defaultConf = Config{ Host: "0.0.0.0", AdvertiseAddress: "", @@ -320,19 +335,21 @@ var defaultConf = Config{ EnableStreaming: false, CheckMb4ValueInUTF8: true, TreatOldVersionUTF8AsUTF8MB4: true, + SplitRegionMaxNum: 1000, TxnLocalLatches: TxnLocalLatches{ - Enabled: true, + Enabled: false, Capacity: 2048000, }, LowerCaseTableNames: 2, Log: Log{ - Level: "info", - Format: "text", - File: logutil.NewFileLogConfig(true, logutil.DefaultLogMaxSize), - SlowQueryFile: "tidb-slow.log", - SlowThreshold: logutil.DefaultSlowThreshold, - ExpensiveThreshold: 10000, - QueryLogMaxLen: logutil.DefaultQueryLogMaxLen, + Level: "info", + Format: "text", + File: logutil.NewFileLogConfig(true, logutil.DefaultLogMaxSize), + SlowQueryFile: "tidb-slow.log", + SlowThreshold: logutil.DefaultSlowThreshold, + ExpensiveThreshold: 10000, + QueryLogMaxLen: logutil.DefaultQueryLogMaxLen, + RecordPlanInSlowLog: logutil.DefaultRecordPlanInSlowLog, }, Status: Status{ ReportStatus: true, @@ -352,6 +369,7 @@ var defaultConf = Config{ QueryFeedbackLimit: 1024, PseudoEstimateRatio: 0.8, ForcePriority: "NO_PRIORITY", + BindInfoLease: "3s", }, ProxyProtocol: ProxyProtocol{ Networks: "", @@ -382,16 +400,21 @@ var defaultConf = Config{ OverloadThreshold: 200, MaxBatchWaitTime: 0, BatchWaitSize: 8, + + RegionCacheTTL: 600, }, Binlog: Binlog{ WriteTimeout: "15s", Strategy: "range", }, PessimisticTxn: PessimisticTxn{ - Enable: false, - Default: false, + Enable: true, MaxRetryCount: 256, - TTL: "30s", + TTL: "40s", + }, + StmtSummary: StmtSummary{ + MaxStmtCount: 100, + MaxSQLLength: 4096, }, } @@ -428,6 +451,11 @@ func GetGlobalConfig() *Config { return globalConf.Load().(*Config) } +// StoreGlobalConfig stores a new config to the globalConf. It mostly uses in the test to avoid some data races. +func StoreGlobalConfig(config *Config) { + globalConf.Store(config) +} + // ReloadGlobalConfig reloads global configuration for this server. func ReloadGlobalConfig() error { confReloadLock.Lock() @@ -537,6 +565,9 @@ func (c *Config) Valid() error { return fmt.Errorf("invalid max log file size=%v which is larger than max=%v", c.Log.File.MaxSize, MaxLogFileSize) } c.OOMAction = strings.ToLower(c.OOMAction) + if c.OOMAction != OOMActionLog && c.OOMAction != OOMActionCancel { + return fmt.Errorf("unsupported OOMAction %v, TiDB only supports [%v, %v]", c.OOMAction, OOMActionLog, OOMActionCancel) + } // lower_case_table_names is allowed to be 0, 1, 2 if c.LowerCaseTableNames < 0 || c.LowerCaseTableNames > 2 { @@ -559,10 +590,9 @@ func (c *Config) Valid() error { if err != nil { return err } - minDur := time.Second * 15 - maxDur := time.Second * 60 - if dur < minDur || dur > maxDur { - return fmt.Errorf("pessimistic transaction ttl %s out of range [%s, %s]", dur, minDur, maxDur) + if dur < MinPessimisticTTL || dur > MaxPessimisticTTL { + return fmt.Errorf("pessimistic transaction ttl %s out of range [%s, %s]", + dur, MinPessimisticTTL, MaxPessimisticTTL) } } return nil diff --git a/config/config.toml.example b/config/config.toml.example index 08b490b60df52..e0d700f467b40 100644 --- a/config/config.toml.example +++ b/config/config.toml.example @@ -54,6 +54,9 @@ check-mb4-value-in-utf8 = true # treat-old-version-utf8-as-utf8mb4 use for upgrade compatibility. Set to true will treat old version table/column UTF8 charset as UTF8MB4. treat-old-version-utf8-as-utf8mb4 = true +# Maximum number of the splitting region, which is used by the split region statement. +split-region-max-num = 1000 + [log] # Log level: debug, info, warn, error, fatal. level = "info" @@ -70,6 +73,10 @@ slow-query-file = "tidb-slow.log" # Queries with execution time greater than this value will be logged. (Milliseconds) slow-threshold = 300 +# record-plan-in-slow-log is used to enable record query plan in slow log. +# 0 is disable. 1 is enable. +record-plan-in-slow-log = 1 + # Queries with internal result greater than this value will be logged. expensive-threshold = 10000 @@ -119,14 +126,16 @@ report-status = true # TiDB status host. status-host = "0.0.0.0" -# Prometheus pushgateway address, leaves it empty will disable prometheus push. +## status-host is the HTTP address for reporting the internal status of a TiDB server, for example: +## API for prometheus: http://${status-host}:${status_port}/metrics +## API for pprof: http://${status-host}:${status_port}/debug/pprof # TiDB status port. status-port = 10080 -# Prometheus pushgateway address, leaves it empty will disable prometheus push. +# Prometheus pushgateway address, leaves it empty will disable push to pushgateway. metrics-addr = "" -# Prometheus client push interval in second, set \"0\" to disable prometheus push. +# Prometheus client push interval in second, set \"0\" to disable push to pushgateway. metrics-interval = 15 # Record statements qps by database name if it is enabled. @@ -168,6 +177,9 @@ pseudo-estimate-ratio = 0.8 # The value could be "NO_PRIORITY", "LOW_PRIORITY", "HIGH_PRIORITY" or "DELAYED". force-priority = "NO_PRIORITY" +# Bind info lease duration, which influences the duration of loading bind info and handling invalid bind. +bind-info-lease = "3s" + [proxy-protocol] # PROXY protocol acceptable client networks. # Empty string means disable PROXY protocol, * means all networks. @@ -261,10 +273,14 @@ max-batch-wait-time = 0 # Batch wait size, to avoid waiting too long. batch-wait-size = 8 +# If a Region has not been accessed for more than the given duration (in seconds), it +# will be reloaded from the PD. +region-cache-ttl = 600 + [txn-local-latches] # Enable local latches for transactions. Enable it when # there are lots of conflicts between transactions. -enabled = true +enabled = false capacity = 2048000 [binlog] @@ -286,14 +302,18 @@ strategy = "range" [pessimistic-txn] # enable pessimistic transaction. -enable = false - -# start pessimistic transaction by default. -default = false +enable = true # max retry count for a statement in a pessimistic transaction. max-retry-count = 256 # default TTL in milliseconds for pessimistic lock. -# The value must between "15s" and "60s". -ttl = "30s" +# The value must between "15s" and "120s". +ttl = "40s" + +[stmt-summary] +# max number of statements kept in memory. +max-stmt-count = 100 + +# max length of displayed normalized sql and sample sql. +max-sql-length = 4096 diff --git a/config/config_test.go b/config/config_test.go index 173878c62c0ed..8d5d166e936f2 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -39,6 +39,7 @@ func (s *testConfigSuite) TestConfig(c *C) { conf.Binlog.IgnoreError = true conf.Binlog.Strategy = "hash" conf.TiKVClient.CommitTimeout = "10s" + conf.TiKVClient.RegionCacheTTL = 600 configFile := "config.toml" _, localFile, _, _ := runtime.Caller(0) configFile = path.Join(path.Dir(localFile), configFile) @@ -60,10 +61,15 @@ unrecognized-option-test = true _, err = f.WriteString(` token-limit = 0 +split-region-max-num=10000 [performance] [tikv-client] commit-timeout="41s" max-batch-size=128 +region-cache-ttl=6000 +[stmt-summary] +max-stmt-count=1000 +max-sql-length=1024 `) c.Assert(err, IsNil) @@ -77,7 +83,11 @@ max-batch-size=128 c.Assert(conf.TiKVClient.CommitTimeout, Equals, "41s") c.Assert(conf.TiKVClient.MaxBatchSize, Equals, uint(128)) + c.Assert(conf.TiKVClient.RegionCacheTTL, Equals, uint(6000)) c.Assert(conf.TokenLimit, Equals, uint(1000)) + c.Assert(conf.SplitRegionMaxNum, Equals, uint64(10000)) + c.Assert(conf.StmtSummary.MaxStmtCount, Equals, uint(1000)) + c.Assert(conf.StmtSummary.MaxSQLLength, Equals, uint(1024)) c.Assert(f.Close(), IsNil) c.Assert(os.Remove(configFile), IsNil) @@ -203,11 +213,29 @@ func (s *testConfigSuite) TestValid(c *C) { }{ {"14s", false}, {"15s", true}, - {"60s", true}, - {"61s", false}, + {"120s", true}, + {"121s", false}, } for _, tt := range tests { c1.PessimisticTxn.TTL = tt.ttl c.Assert(c1.Valid() == nil, Equals, tt.valid) } } + +func (s *testConfigSuite) TestOOMActionValid(c *C) { + c1 := NewConfig() + tests := []struct { + oomAction string + valid bool + }{ + {"log", true}, + {"Log", true}, + {"Cancel", true}, + {"cANceL", true}, + {"quit", false}, + } + for _, tt := range tests { + c1.OOMAction = tt.oomAction + c.Assert(c1.Valid() == nil, Equals, tt.valid) + } +} diff --git a/ddl/column.go b/ddl/column.go index 3a0df926f62f1..078fa8415b419 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -25,7 +25,6 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/ddl/util" - "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/sessionctx" @@ -324,7 +323,7 @@ func onSetDefaultValue(t *meta.Meta, job *model.Job) (ver int64, _ error) { return ver, errors.Trace(err) } - return updateColumn(t, job, newCol, &newCol.Name) + return updateColumnDefaultValue(t, job, newCol, &newCol.Name) } func (w *worker) onModifyColumn(t *meta.Meta, job *model.Job) (ver int64, _ error) { @@ -481,7 +480,7 @@ func checkForNullValue(ctx sessionctx.Context, isDataTruncated bool, schema, tab return nil } -func updateColumn(t *meta.Meta, job *model.Job, newCol *model.ColumnInfo, oldColName *model.CIStr) (ver int64, _ error) { +func updateColumnDefaultValue(t *meta.Meta, job *model.Job, newCol *model.ColumnInfo, oldColName *model.CIStr) (ver int64, _ error) { tblInfo, err := getTableInfoAndCancelFaultJob(t, job, job.SchemaID) if err != nil { return ver, errors.Trace(err) @@ -491,7 +490,10 @@ func updateColumn(t *meta.Meta, job *model.Job, newCol *model.ColumnInfo, oldCol job.State = model.JobStateCancelled return ver, infoschema.ErrColumnNotExists.GenWithStackByArgs(newCol.Name, tblInfo.Name) } - *oldCol = *newCol + // The newCol's offset may be the value of the old schema version, so we can't use newCol directly. + oldCol.DefaultValue = newCol.DefaultValue + oldCol.DefaultValueBit = newCol.DefaultValueBit + oldCol.Flag = newCol.Flag ver, err = updateVersionAndTableInfo(t, job, tblInfo, true) if err != nil { @@ -519,8 +521,8 @@ func allocateColumnID(tblInfo *model.TableInfo) int64 { return tblInfo.MaxColumnID } -func checkAddColumnTooManyColumns(oldCols int) error { - if uint32(oldCols) > atomic.LoadUint32(&TableColumnCountLimit) { +func checkAddColumnTooManyColumns(colNum int) error { + if uint32(colNum) > atomic.LoadUint32(&TableColumnCountLimit) { return errTooManyFields } return nil @@ -586,9 +588,9 @@ func generateOriginDefaultValue(col *model.ColumnInfo) (interface{}, error) { return odValue, nil } -func findColumnInIndexCols(c *expression.Column, cols []*ast.IndexColName) bool { +func findColumnInIndexCols(c string, cols []*ast.IndexColName) bool { for _, c1 := range cols { - if c.ColName.L == c1.Column.Name.L { + if c == c1.Column.Name.L { return true } } diff --git a/ddl/column_test.go b/ddl/column_test.go index 2ec1e2a5d3584..451d32089d4f0 100644 --- a/ddl/column_test.go +++ b/ddl/column_test.go @@ -954,7 +954,7 @@ func (s *testColumnSuite) colDefStrToFieldType(c *C, str string) *types.FieldTyp stmt, err := parser.New().ParseOneStmt(sqlA, "", "") c.Assert(err, IsNil) colDef := stmt.(*ast.AlterTableStmt).Specs[0].NewColumns[0] - col, _, err := buildColumnAndConstraint(nil, 0, colDef, nil, mysql.DefaultCharset, mysql.DefaultCharset) + col, _, err := buildColumnAndConstraint(nil, 0, colDef, nil, mysql.DefaultCharset, "", mysql.DefaultCharset, "") c.Assert(err, IsNil) return &col.FieldType } diff --git a/ddl/db_change_test.go b/ddl/db_change_test.go index 33539ea5e1c75..3d7449fac4724 100644 --- a/ddl/db_change_test.go +++ b/ddl/db_change_test.go @@ -82,6 +82,7 @@ func (s *testStateChangeSuite) TestShowCreateTable(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("create table t (id int)") + tk.MustExec("create table t2 (a int, b varchar(10)) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci") var checkErr error testCases := []struct { @@ -94,6 +95,10 @@ func (s *testStateChangeSuite) TestShowCreateTable(c *C) { "CREATE TABLE `t` (\n `id` int(11) DEFAULT NULL,\n KEY `idx` (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"}, {"alter table t add column c int", "CREATE TABLE `t` (\n `id` int(11) DEFAULT NULL,\n KEY `idx` (`id`),\n KEY `idx1` (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin"}, + {"alter table t2 add column c varchar(1)", + "CREATE TABLE `t2` (\n `a` int(11) DEFAULT NULL,\n `b` varchar(10) COLLATE utf8mb4_general_ci DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci"}, + {"alter table t2 add column d varchar(1)", + "CREATE TABLE `t2` (\n `a` int(11) DEFAULT NULL,\n `b` varchar(10) COLLATE utf8mb4_general_ci DEFAULT NULL,\n `c` varchar(1) COLLATE utf8mb4_general_ci DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci"}, } prevState := model.StateNone callback := &ddl.TestDDLCallback{} @@ -106,7 +111,13 @@ func (s *testStateChangeSuite) TestShowCreateTable(c *C) { currTestCaseOffset++ } if job.SchemaState != model.StatePublic { - result := tk.MustQuery("show create table t") + var result *testkit.Result + tbl2 := testGetTableByName(c, tk.Se, "test", "t2") + if job.TableID == tbl2.Meta().ID { + result = tk.MustQuery("show create table t2") + } else { + result = tk.MustQuery("show create table t") + } got := result.Rows()[0][1] expected := testCases[currTestCaseOffset].expectedRet if got != expected { @@ -682,6 +693,30 @@ func (s *testStateChangeSuite) TestParallelAlterModifyColumn(c *C) { // s.testControlParallelExecSQL(c, sql1, sql2, f) // } +func (s *testStateChangeSuite) TestParallelAddColumAndSetDefaultValue(c *C) { + _, err := s.se.Execute(context.Background(), "use test_db_state") + c.Assert(err, IsNil) + _, err = s.se.Execute(context.Background(), `create table tx ( + c1 varchar(64), + c2 enum('N','Y') not null default 'N', + primary key idx2 (c2, c1))`) + c.Assert(err, IsNil) + _, err = s.se.Execute(context.Background(), "insert into tx values('a', 'N')") + c.Assert(err, IsNil) + defer s.se.Execute(context.Background(), "drop table tx") + + sql1 := "alter table tx add column cx int after c1" + sql2 := "alter table tx alter c2 set default 'N'" + + f := func(c *C, err1, err2 error) { + c.Assert(err1, IsNil) + c.Assert(err2, IsNil) + _, err := s.se.Execute(context.Background(), "delete from tx where c1='a'") + c.Assert(err, IsNil) + } + s.testControlParallelExecSQL(c, sql1, sql2, f) +} + func (s *testStateChangeSuite) TestParallelChangeColumnName(c *C) { sql1 := "ALTER TABLE t CHANGE a aa int;" sql2 := "ALTER TABLE t CHANGE b aa int;" @@ -730,6 +765,16 @@ func (s *testStateChangeSuite) TestParallelDropColumn(c *C) { s.testControlParallelExecSQL(c, sql, sql, f) } +func (s *testStateChangeSuite) TestParallelDropIndex(c *C) { + sql1 := "alter table t drop index idx1 ;" + sql2 := "alter table t drop index idx2 ;" + f := func(c *C, err1, err2 error) { + c.Assert(err1, IsNil) + c.Assert(err2.Error(), Equals, "[autoid:1075]Incorrect table definition; there can be only one auto column and it must be defined as a key") + } + s.testControlParallelExecSQL(c, sql1, sql2, f) +} + func (s *testStateChangeSuite) TestParallelCreateAndRename(c *C) { sql1 := "create table t_exists(c int);" sql2 := "alter table t rename to t_exists;" @@ -746,7 +791,7 @@ type checkRet func(c *C, err1, err2 error) func (s *testStateChangeSuite) testControlParallelExecSQL(c *C, sql1, sql2 string, f checkRet) { _, err := s.se.Execute(context.Background(), "use test_db_state") c.Assert(err, IsNil) - _, err = s.se.Execute(context.Background(), "create table t(a int, b int, c int)") + _, err = s.se.Execute(context.Background(), "create table t(a int, b int, c int, d int auto_increment,e int, index idx1(d), index idx2(d,e))") c.Assert(err, IsNil) defer s.se.Execute(context.Background(), "drop table t") diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index 8cee4d1a97bc6..7353e57abf6f8 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -42,8 +42,10 @@ import ( "github.com/pingcap/tidb/store/mockstore/mocktikv" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/israce" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testutil" ) var _ = Suite(&testIntegrationSuite1{&testIntegrationSuite{}}) @@ -82,7 +84,7 @@ func setupIntegrationSuite(s *testIntegrationSuite, c *C) { ) c.Assert(err, IsNil) session.SetSchemaLease(s.lease) - session.SetStatsLease(0) + session.DisableStats4Test() s.dom, err = session.BootstrapSession(s.store) c.Assert(err, IsNil) @@ -278,13 +280,38 @@ func (s *testIntegrationSuite2) TestIssue6101(c *C) { tk.MustExec("drop table t1") } +func (s *testIntegrationSuite1) TestIndexLength(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table idx_len(a int(0), b timestamp(0), c datetime(0), d time(0), f float(0), g decimal(0))") + tk.MustExec("create index idx on idx_len(a)") + tk.MustExec("alter table idx_len add index idxa(a)") + tk.MustExec("create index idx1 on idx_len(b)") + tk.MustExec("alter table idx_len add index idxb(b)") + tk.MustExec("create index idx2 on idx_len(c)") + tk.MustExec("alter table idx_len add index idxc(c)") + tk.MustExec("create index idx3 on idx_len(d)") + tk.MustExec("alter table idx_len add index idxd(d)") + tk.MustExec("create index idx4 on idx_len(f)") + tk.MustExec("alter table idx_len add index idxf(f)") + tk.MustExec("create index idx5 on idx_len(g)") + tk.MustExec("alter table idx_len add index idxg(g)") + tk.MustExec("create table idx_len1(a int(0), b timestamp(0), c datetime(0), d time(0), f float(0), g decimal(0), index(a), index(b), index(c), index(d), index(f), index(g))") +} + func (s *testIntegrationSuite4) TestIssue3833(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.MustExec("create table issue3833 (b char(0))") + tk.MustExec("create table issue3833 (b char(0), c binary(0), d varchar(0))") assertErrorCode(c, tk, "create index idx on issue3833 (b)", tmysql.ErrWrongKeyColumn) assertErrorCode(c, tk, "alter table issue3833 add index idx (b)", tmysql.ErrWrongKeyColumn) - assertErrorCode(c, tk, "create table issue3833_2 (b char(0), index (b))", tmysql.ErrWrongKeyColumn) + assertErrorCode(c, tk, "create table issue3833_2 (b char(0), c binary(0), d varchar(0), index(b))", tmysql.ErrWrongKeyColumn) + assertErrorCode(c, tk, "create index idx on issue3833 (c)", tmysql.ErrWrongKeyColumn) + assertErrorCode(c, tk, "alter table issue3833 add index idx (c)", tmysql.ErrWrongKeyColumn) + assertErrorCode(c, tk, "create table issue3833_2 (b char(0), c binary(0), d varchar(0), index(c))", tmysql.ErrWrongKeyColumn) + assertErrorCode(c, tk, "create index idx on issue3833 (d)", tmysql.ErrWrongKeyColumn) + assertErrorCode(c, tk, "alter table issue3833 add index idx (d)", tmysql.ErrWrongKeyColumn) + assertErrorCode(c, tk, "create table issue3833_2 (b char(0), c binary(0), d varchar(0), index(d))", tmysql.ErrWrongKeyColumn) } func (s *testIntegrationSuite10) TestIssue2858And2717(c *C) { @@ -399,6 +426,16 @@ func (s *testIntegrationSuite5) TestMySQLErrorCode(c *C) { assertErrorCode(c, tk, sql, tmysql.ErrPrimaryCantHaveNull) sql = "create table t2 (id int null, age int, primary key(id));" assertErrorCode(c, tk, sql, tmysql.ErrPrimaryCantHaveNull) + sql = "create table t2 (id int auto_increment);" + assertErrorCode(c, tk, sql, tmysql.ErrWrongAutoKey) + sql = "create table t2 (a datetime(2) default current_timestamp(3))" + assertErrorCode(c, tk, sql, tmysql.ErrInvalidDefault) + sql = "create table t2 (a datetime(2) default current_timestamp(2) on update current_timestamp)" + assertErrorCode(c, tk, sql, tmysql.ErrInvalidOnUpdate) + sql = "create table t2 (a datetime default current_timestamp on update current_timestamp(2))" + assertErrorCode(c, tk, sql, tmysql.ErrInvalidOnUpdate) + sql = "create table t2 (a datetime(2) default current_timestamp(2) on update current_timestamp(3))" + assertErrorCode(c, tk, sql, tmysql.ErrInvalidOnUpdate) sql = "create table t2 (id int primary key , age int);" tk.MustExec(sql) @@ -561,6 +598,31 @@ func (s *testIntegrationSuite7) TestNullGeneratedColumn(c *C) { tk.MustExec("drop table t") } +func (s *testIntegrationSuite7) TestDependedGeneratedColumnPrior2GeneratedColumn(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("CREATE TABLE `t` (" + + "`a` int(11) DEFAULT NULL," + + "`b` int(11) GENERATED ALWAYS AS (`a` + 1) VIRTUAL," + + "`c` int(11) GENERATED ALWAYS AS (`b` + 1) VIRTUAL" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin") + // should check unknown column first, then the prior ones. + sql := "alter table t add column d int as (c + f + 1) first" + assertErrorCode(c, tk, sql, mysql.ErrBadField) + + // depended generated column should be prior to generated column self + sql = "alter table t add column d int as (c+1) first" + assertErrorCode(c, tk, sql, mysql.ErrGeneratedColumnNonPrior) + + // correct case + tk.MustExec("alter table t add column d int as (c+1) after c") + + // check position nil case + tk.MustExec("alter table t add column(e int as (c+1))") + tk.MustExec("drop table if exists t") +} + func (s *testIntegrationSuite9) TestChangingCharsetToUtf8(c *C) { tk := testkit.NewTestKit(c, s.store) @@ -666,17 +728,17 @@ func (s *testIntegrationSuite10) TestChangingTableCharset(c *C) { tk.MustExec("drop table t;") tk.MustExec("create table t(a varchar(10)) charset utf8") tk.MustExec("alter table t convert to charset utf8mb4;") - checkCharset := func() { + checkCharset := func(chs, coll string) { tbl := testGetTableByName(c, s.ctx, "test", "t") c.Assert(tbl, NotNil) - c.Assert(tbl.Meta().Charset, Equals, charset.CharsetUTF8MB4) - c.Assert(tbl.Meta().Collate, Equals, charset.CollationUTF8MB4) + c.Assert(tbl.Meta().Charset, Equals, chs) + c.Assert(tbl.Meta().Collate, Equals, coll) for _, col := range tbl.Meta().Columns { - c.Assert(col.Charset, Equals, charset.CharsetUTF8MB4) - c.Assert(col.Collate, Equals, charset.CollationUTF8MB4) + c.Assert(col.Charset, Equals, chs) + c.Assert(col.Collate, Equals, coll) } } - checkCharset() + checkCharset(charset.CharsetUTF8MB4, charset.CollationUTF8MB4) // Test when column charset can not convert to the target charset. tk.MustExec("drop table t;") @@ -685,11 +747,16 @@ func (s *testIntegrationSuite10) TestChangingTableCharset(c *C) { c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "[ddl:210]unsupported modify charset from ascii to utf8mb4") + tk.MustExec("drop table t;") + tk.MustExec("create table t(a varchar(10) character set utf8) charset utf8") + tk.MustExec("alter table t convert to charset utf8 collate utf8_general_ci;") + checkCharset(charset.CharsetUTF8, "utf8_general_ci") + // Test when table charset is equal to target charset but column charset is not equal. tk.MustExec("drop table t;") tk.MustExec("create table t(a varchar(10) character set utf8) charset utf8mb4") - tk.MustExec("alter table t convert to charset utf8mb4;") - checkCharset() + tk.MustExec("alter table t convert to charset utf8mb4 collate utf8mb4_general_ci;") + checkCharset(charset.CharsetUTF8MB4, "utf8mb4_general_ci") // Mock table info with charset is "". Old TiDB maybe create table with charset is "". db, ok := domain.GetDomain(s.ctx).InfoSchema().SchemaByName(model.NewCIStr("test")) @@ -722,7 +789,7 @@ func (s *testIntegrationSuite10) TestChangingTableCharset(c *C) { c.Assert(tbl.Meta().Collate, Equals, "") // Test when table charset is "", this for compatibility. tk.MustExec("alter table t convert to charset utf8mb4;") - checkCharset() + checkCharset(charset.CharsetUTF8MB4, charset.CollationUTF8MB4) // Test when column charset is "". tbl = testGetTableByName(c, s.ctx, "test", "t") @@ -737,10 +804,110 @@ func (s *testIntegrationSuite10) TestChangingTableCharset(c *C) { c.Assert(tbl.Meta().Columns[0].Charset, Equals, "") c.Assert(tbl.Meta().Columns[0].Collate, Equals, "") tk.MustExec("alter table t convert to charset utf8mb4;") - checkCharset() + checkCharset(charset.CharsetUTF8MB4, charset.CollationUTF8MB4) + + tk.MustExec("drop table t") + tk.MustExec("create table t (a blob) character set utf8;") + tk.MustExec("alter table t charset=utf8mb4 collate=utf8mb4_bin;") + tk.MustQuery("show create table t").Check(testutil.RowsWithSep("|", + "t CREATE TABLE `t` (\n"+ + " `a` blob DEFAULT NULL\n"+ + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", + )) + +} + +func (s *testIntegrationSuite5) TestModifyingColumnOption(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("create database if not exists test") + tk.MustExec("use test") + + errMsg := "[ddl:203]" // unsupported modify column with references + assertErrCode := func(sql string, errCodeStr string) { + _, err := tk.Exec(sql) + c.Assert(err, NotNil) + c.Assert(err.Error()[:len(errCodeStr)], Equals, errCodeStr) + } + + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (b char(1) default null) engine=InnoDB default charset=utf8mb4 collate=utf8mb4_general_ci") + tk.MustExec("alter table t1 modify column b char(1) character set utf8mb4 collate utf8mb4_general_ci") + + tk.MustExec("drop table t1") + tk.MustExec("create table t1 (b char(1) collate utf8mb4_general_ci)") + tk.MustExec("alter table t1 modify b char(1) character set utf8mb4 collate utf8mb4_general_ci") + + tk.MustExec("drop table t1") + tk.MustExec("drop table if exists t2") + tk.MustExec("create table t1 (a int)") + tk.MustExec("create table t2 (b int, c int)") + assertErrCode("alter table t2 modify column c int references t1(a)", errMsg) +} + +func (s *testIntegrationSuite1) TestIndexOnMultipleGeneratedColumn(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("create database if not exists test_mul_gen_col") + tk.MustExec("use test_mul_gen_col") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int as (a + 1), c int as (b + 1))") + tk.MustExec("insert into t (a) values (1)") + tk.MustExec("create index idx on t (c)") + tk.MustQuery("select * from t where c > 1").Check(testkit.Rows("1 2 3")) + res := tk.MustQuery("select * from t use index(idx) where c > 1") + tk.MustQuery("select * from t ignore index(idx) where c > 1").Check(res.Rows()) + tk.MustExec("admin check table t") + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int, b int as (a + 1), c int as (b + 1), d int as (c + 1))") + tk.MustExec("insert into t (a) values (1)") + tk.MustExec("create index idx on t (d)") + tk.MustQuery("select * from t where d > 2").Check(testkit.Rows("1 2 3 4")) + res = tk.MustQuery("select * from t use index(idx) where d > 2") + tk.MustQuery("select * from t ignore index(idx) where d > 2").Check(res.Rows()) + tk.MustExec("admin check table t") + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a bigint, b decimal as (a+1), c varchar(20) as (b*2), d float as (a*23+b-1+length(c)))") + tk.MustExec("insert into t (a) values (1)") + tk.MustExec("create index idx on t (d)") + tk.MustQuery("select * from t where d > 2").Check(testkit.Rows("1 2 4 25")) + res = tk.MustQuery("select * from t use index(idx) where d > 2") + tk.MustQuery("select * from t ignore index(idx) where d > 2").Check(res.Rows()) + tk.MustExec("admin check table t") + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a varchar(10), b float as (length(a)+123), c varchar(20) as (right(a, 2)), d float as (b+b-7+1-3+3*ASCII(c)))") + tk.MustExec("insert into t (a) values ('adorable')") + tk.MustExec("create index idx on t (d)") + tk.MustQuery("select * from t where d > 2").Check(testkit.Rows("adorable 131 le 577")) // 131+131-7+1-3+3*108 + res = tk.MustQuery("select * from t use index(idx) where d > 2") + tk.MustQuery("select * from t ignore index(idx) where d > 2").Check(res.Rows()) + tk.MustExec("admin check table t") + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a bigint, b decimal as (a), c int(10) as (a+b), d float as (a+b+c), e decimal as (a+b+c+d))") + tk.MustExec("insert into t (a) values (1)") + tk.MustExec("create index idx on t (d)") + tk.MustQuery("select * from t where d > 2").Check(testkit.Rows("1 1 2 4 8")) + res = tk.MustQuery("select * from t use index(idx) where d > 2") + tk.MustQuery("select * from t ignore index(idx) where d > 2").Check(res.Rows()) + tk.MustExec("admin check table t") + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a bigint, b bigint as (a+1) virtual, c bigint as (b+1) virtual)") + tk.MustExec("alter table t add index idx_b(b)") + tk.MustExec("alter table t add index idx_c(c)") + tk.MustExec("insert into t(a) values(1)") + tk.MustExec("alter table t add column(d bigint as (c+1) virtual)") + tk.MustExec("alter table t add index idx_d(d)") + tk.MustQuery("select * from t where d > 2").Check(testkit.Rows("1 2 3 4")) + res = tk.MustQuery("select * from t use index(idx_d) where d > 2") + tk.MustQuery("select * from t ignore index(idx_d) where d > 2").Check(res.Rows()) + tk.MustExec("admin check table t") } -func (s *testIntegrationSuite7) TestCaseInsensitiveCharsetAndCollate(c *C) { +func (s *testIntegrationSuite2) TestCaseInsensitiveCharsetAndCollate(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("create database if not exists test_charset_collate") @@ -830,6 +997,14 @@ func (s *testIntegrationSuite6) TestBitDefaultValue(c *C) { tk.MustExec("insert into t_bit set c2=1;") tk.MustQuery("select bin(c1),c2 from t_bit").Check(testkit.Rows("11111010 1")) tk.MustExec("drop table t_bit") + + tk.MustExec("create table t_bit (a int)") + tk.MustExec("insert into t_bit value (1)") + tk.MustExec("alter table t_bit add column c bit(16) null default b'1100110111001'") + tk.MustQuery("select c from t_bit").Check(testkit.Rows("\x19\xb9")) + tk.MustExec("update t_bit set c = b'11100000000111'") + tk.MustQuery("select c from t_bit").Check(testkit.Rows("\x38\x07")) + tk.MustExec(`create table testalltypes1 ( field_1 bit default 1, field_2 tinyint null default null @@ -1389,13 +1564,40 @@ func (s *testIntegrationSuite4) TestAlterColumn(c *C) { createSQL = result.Rows()[0][1] expected = "CREATE TABLE `mc` (\n `a` bigint(20) NOT NULL AUTO_INCREMENT,\n `b` int(11) DEFAULT NULL,\n PRIMARY KEY (`a`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin" c.Assert(createSQL, Equals, expected) - s.tk.MustExec("alter table mc modify column a bigint") // Drops auto_increment + _, err = s.tk.Exec("alter table mc modify column a bigint") // Droppping auto_increment is not allow when @@tidb_allow_remove_auto_inc == 'off' + c.Assert(err, NotNil) + s.tk.MustExec("set @@tidb_allow_remove_auto_inc = on") + s.tk.MustExec("alter table mc modify column a bigint") // Dropping auto_increment is ok when @@tidb_allow_remove_auto_inc == 'on' result = s.tk.MustQuery("show create table mc") createSQL = result.Rows()[0][1] expected = "CREATE TABLE `mc` (\n `a` bigint(20) NOT NULL,\n `b` int(11) DEFAULT NULL,\n PRIMARY KEY (`a`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin" c.Assert(createSQL, Equals, expected) + _, err = s.tk.Exec("alter table mc modify column a bigint auto_increment") // Adds auto_increment should throw error c.Assert(err, NotNil) + + s.tk.MustExec("drop table if exists t") + // TODO: fix me, below sql should execute successfully. Currently, the result of calculate key length is wrong. + //s.tk.MustExec("create table t1 (a varchar(10),b varchar(100),c tinyint,d varchar(3071),index(a),index(a,b),index (c,d));") + s.tk.MustExec("create table t1 (a varchar(10),b varchar(100),c tinyint,d varchar(3068),index(a),index(a,b),index (c,d));") + _, err = s.tk.Exec("alter table t1 modify column a varchar(3000);") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:1071]Specified key was too long; max key length is 3072 bytes") + // check modify column with rename column. + _, err = s.tk.Exec("alter table t1 change column a x varchar(3000);") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:1071]Specified key was too long; max key length is 3072 bytes") + _, err = s.tk.Exec("alter table t1 modify column c bigint;") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:1071]Specified key was too long; max key length is 3072 bytes") + + s.tk.MustExec("drop table if exists multi_unique") + s.tk.MustExec("create table multi_unique (a int unique unique)") + s.tk.MustExec("drop table multi_unique") + s.tk.MustExec("create table multi_unique (a int key primary key unique unique)") + s.tk.MustExec("drop table multi_unique") + s.tk.MustExec("create table multi_unique (a int key unique unique key unique)") + s.tk.MustExec("drop table multi_unique") } func (s *testIntegrationSuite) assertWarningExec(c *C, sql string, expectedWarn *terror.Error) { @@ -1456,17 +1658,17 @@ func (s *testIntegrationSuite3) TestAlterAlgorithm(c *C) { s.tk.MustExec("alter table t rename index idx_c to idx_c1, ALGORITHM=DEFAULT") // partition. - s.assertAlterWarnExec(c, "alter table t truncate partition p1, ALGORITHM=COPY") - s.assertAlterErrorExec(c, "alter table t truncate partition p2, ALGORITHM=INPLACE") - s.tk.MustExec("alter table t truncate partition p3, ALGORITHM=INSTANT") + s.assertAlterWarnExec(c, "alter table t ALGORITHM=COPY, truncate partition p1") + s.assertAlterErrorExec(c, "alter table t ALGORITHM=INPLACE, truncate partition p2") + s.tk.MustExec("alter table t ALGORITHM=INSTANT, truncate partition p3") s.assertAlterWarnExec(c, "alter table t add partition (partition p4 values less than (2002)), ALGORITHM=COPY") s.assertAlterErrorExec(c, "alter table t add partition (partition p5 values less than (3002)), ALGORITHM=INPLACE") s.tk.MustExec("alter table t add partition (partition p6 values less than (4002)), ALGORITHM=INSTANT") - s.assertAlterWarnExec(c, "alter table t drop partition p4, ALGORITHM=COPY") - s.assertAlterErrorExec(c, "alter table t drop partition p5, ALGORITHM=INPLACE") - s.tk.MustExec("alter table t drop partition p6, ALGORITHM=INSTANT") + s.assertAlterWarnExec(c, "alter table t ALGORITHM=COPY, drop partition p4") + s.assertAlterErrorExec(c, "alter table t ALGORITHM=INPLACE, drop partition p5") + s.tk.MustExec("alter table t ALGORITHM=INSTANT, drop partition p6") // Table options s.assertAlterWarnExec(c, "alter table t comment = 'test', ALGORITHM=COPY") @@ -1478,6 +1680,34 @@ func (s *testIntegrationSuite3) TestAlterAlgorithm(c *C) { s.tk.MustExec("alter table t default charset = utf8mb4, ALGORITHM=INSTANT") } +func (s *testIntegrationSuite3) TestAlterTableAddUniqueOnPartionRangeColumn(c *C) { + s.tk = testkit.NewTestKit(c, s.store) + s.tk.MustExec("use test") + s.tk.MustExec("drop table if exists t") + defer s.tk.MustExec("drop table if exists t") + + s.tk.MustExec(`create table t( + a int, + b varchar(100), + c int, + INDEX idx_c(c)) + PARTITION BY RANGE COLUMNS( a ) ( + PARTITION p0 VALUES LESS THAN (6), + PARTITION p1 VALUES LESS THAN (11), + PARTITION p2 VALUES LESS THAN (16), + PARTITION p3 VALUES LESS THAN (21) + )`) + s.tk.MustExec("insert into t values (4, 'xxx', 4)") + s.tk.MustExec("insert into t values (4, 'xxx', 9)") // Note the repeated 4 + s.tk.MustExec("insert into t values (17, 'xxx', 12)") + assertErrorCode(c, s.tk, "alter table t add unique index idx_a(a)", mysql.ErrDupEntry) + + s.tk.MustExec("delete from t where a = 4") + s.tk.MustExec("alter table t add unique index idx_a(a)") + s.tk.MustExec("alter table t add unique index idx_ac(a, c)") + assertErrorCode(c, s.tk, "alter table t add unique index idx_b(b)", mysql.ErrUniqueKeyNeedAllFieldsInPf) +} + func (s *testIntegrationSuite5) TestFulltextIndexIgnore(c *C) { s.tk = testkit.NewTestKit(c, s.store) s.tk.MustExec("use test") @@ -1495,6 +1725,9 @@ func (s *testIntegrationSuite5) TestFulltextIndexIgnore(c *C) { } func (s *testIntegrationSuite1) TestTreatOldVersionUTF8AsUTF8MB4(c *C) { + if israce.RaceEnabled { + c.Skip("skip race test") + } s.tk = testkit.NewTestKit(c, s.store) s.tk.MustExec("use test") s.tk.MustExec("drop table if exists t") @@ -1734,10 +1967,39 @@ func (s *testIntegrationSuite11) TestChangingDBCharset(c *C) { for _, fc := range failedCases { c.Assert(tk.ExecToErr(fc.stmt).Error(), Equals, fc.errMsg, Commentf("%v", fc.stmt)) } + tk.MustExec("ALTER SCHEMA CHARACTER SET = 'utf8' COLLATE = 'utf8_unicode_ci'") + verifyDBCharsetAndCollate("alterdb2", "utf8", "utf8_unicode_ci") tk.MustExec("ALTER SCHEMA CHARACTER SET = 'utf8mb4'") verifyDBCharsetAndCollate("alterdb2", "utf8mb4", "utf8mb4_bin") - err := tk.ExecToErr("ALTER SCHEMA CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_general_ci'") - c.Assert(err.Error(), Equals, "[ddl:210]unsupported modify collate from utf8mb4_bin to utf8mb4_general_ci") + tk.MustExec("ALTER SCHEMA CHARACTER SET = 'utf8mb4' COLLATE = 'utf8mb4_general_ci'") + verifyDBCharsetAndCollate("alterdb2", "utf8mb4", "utf8mb4_general_ci") +} + +func (s *testIntegrationSuite4) TestDropAutoIncrementIndex(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("create database if not exists test") + tk.MustExec("use test") + + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (a int auto_increment, unique key (a))") + dropIndexSQL := "alter table t1 drop index a" + assertErrorCode(c, tk, dropIndexSQL, mysql.ErrWrongAutoKey) + + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1 (a int(11) not null auto_increment, b int(11), c bigint, unique key (a, b, c))") + dropIndexSQL = "alter table t1 drop index a" + assertErrorCode(c, tk, dropIndexSQL, mysql.ErrWrongAutoKey) +} + +func (s *testIntegrationSuite3) TestParserIssue284(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table test.t_parser_issue_284(c1 int not null primary key)") + _, err := tk.Exec("create table test.t_parser_issue_284_2(id int not null primary key, c1 int not null, constraint foreign key (c1) references t_parser_issue_284(c1))") + c.Assert(err, IsNil) + + tk.MustExec("drop table test.t_parser_issue_284") + tk.MustExec("drop table test.t_parser_issue_284_2") } diff --git a/ddl/db_partition_test.go b/ddl/db_partition_test.go index 47db0af006b78..1aa5d59e3c81d 100644 --- a/ddl/db_partition_test.go +++ b/ddl/db_partition_test.go @@ -23,6 +23,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" tmysql "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -169,6 +170,19 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) { );` assertErrorCode(c, tk, sql9, tmysql.ErrPartitionFunctionIsNotAllowed) + assertErrorCode(c, tk, `create TABLE t10 (c1 int,c2 int) partition by range(c1 / c2 ) (partition p0 values less than (2));`, tmysql.ErrPartitionFunctionIsNotAllowed) + _, err = tk.Exec(`CREATE TABLE t9 ( + a INT NOT NULL, + b INT NOT NULL, + c INT NOT NULL + ) + partition by range columns(a) ( + partition p0 values less than (10), + partition p2 values less than (20), + partition p3 values less than (20) + );`) + c.Assert(ddl.ErrRangeNotIncreasing.Equal(err), IsTrue) + assertErrorCode(c, tk, `create TABLE t10 (c1 int,c2 int) partition by range(c1 / c2 ) (partition p0 values less than (2));`, tmysql.ErrPartitionFunctionIsNotAllowed) tk.MustExec(`create TABLE t11 (c1 int,c2 int) partition by range(c1 div c2 ) (partition p0 values less than (2));`) @@ -324,14 +338,18 @@ create table log_message_1 ( cases := []testCase{ { "create table t (id int) partition by range columns (id);", - ddl.ErrPartitionsMustBeDefined, + ast.ErrPartitionsMustBeDefined, }, { "create table t (id int) partition by range columns (id) (partition p0 values less than (1, 2));", - ddl.ErrPartitionColumnList, + ast.ErrPartitionColumnList, }, { "create table t (a int) partition by range columns (b) (partition p0 values less than (1, 2));", + ast.ErrPartitionColumnList, + }, + { + "create table t (a int) partition by range columns (b) (partition p0 values less than (1));", ddl.ErrFieldNotFoundPart, }, { @@ -368,10 +386,20 @@ create table log_message_1 ( "partition p1 values less than (1, 'a'))", ddl.ErrRangeNotIncreasing, }, + { + "create table t (col datetime not null default '2000-01-01')" + + "partition by range columns (col) (" + + "PARTITION p0 VALUES LESS THAN (20190905)," + + "PARTITION p1 VALUES LESS THAN (20190906));", + ddl.ErrWrongTypeColumnValue, + }, } for i, t := range cases { _, err := tk.Exec(t.sql) - c.Assert(t.err.Equal(err), IsTrue, Commentf("case %d fail, sql = %s", i, t.sql)) + c.Assert(t.err.Equal(err), IsTrue, Commentf( + "case %d fail, sql = `%s`\nexpected error = `%v`\n actual error = `%v`", + i, t.sql, t.err, err, + )) } tk.MustExec("create table t1 (a int, b char(3)) partition by range columns (a, b) (" + @@ -495,6 +523,45 @@ func (s *testIntegrationSuite5) TestAlterTableAddPartition(c *C) { partition p5 values less than maxvalue );` assertErrorCode(c, tk, sql7, tmysql.ErrSameNamePartition) + + sql8 := "alter table table3 add partition (partition p6);" + assertErrorCode(c, tk, sql8, tmysql.ErrPartitionRequiresValues) + + sql9 := "alter table table3 add partition (partition p7 values in (2018));" + assertErrorCode(c, tk, sql9, tmysql.ErrPartitionWrongValues) + + sql10 := "alter table table3 add partition partitions 4;" + assertErrorCode(c, tk, sql10, tmysql.ErrPartitionsMustBeDefined) + + tk.MustExec("alter table table3 add partition (partition p3 values less than (2001 + 10))") + + // less than value can be negative or expression. + tk.MustExec(`CREATE TABLE tt5 ( + c3 bigint(20) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin + PARTITION BY RANGE ( c3 ) ( + PARTITION p0 VALUES LESS THAN (-3), + PARTITION p1 VALUES LESS THAN (-2) + );`) + tk.MustExec(`ALTER TABLE tt5 add partition ( partition p2 values less than (-1) );`) + tk.MustExec(`ALTER TABLE tt5 add partition ( partition p3 values less than (5-1) );`) + + // Test add partition for the table partition by range columns. + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (a datetime) partition by range columns (a) (partition p1 values less than ('2019-06-01'), partition p2 values less than ('2019-07-01'));") + sql := "alter table t add partition ( partition p3 values less than ('2019-07-01'));" + assertErrorCode(c, tk, sql, tmysql.ErrRangeNotIncreasing) + tk.MustExec("alter table t add partition ( partition p3 values less than ('2019-08-01'));") + + // Add partition value's type should be the same with the column's type. + tk.MustExec("drop table if exists t;") + tk.MustExec(`create table t ( + col date not null default '2000-01-01') + partition by range columns (col) ( + PARTITION p0 VALUES LESS THAN ('20190905'), + PARTITION p1 VALUES LESS THAN ('20190906'));`) + sql = "alter table t add partition (partition p2 values less than (20190907));" + assertErrorCode(c, tk, sql, tmysql.ErrWrongTypeColumnValue) } func (s *testIntegrationSuite6) TestAlterTableDropPartition(c *C) { @@ -632,6 +699,9 @@ func (s *testIntegrationSuite6) TestAlterTableDropPartition(c *C) { tk.MustExec("alter table table4 drop partition PAR5;") sql4 := "alter table table4 drop partition PAR0;" assertErrorCode(c, tk, sql4, tmysql.ErrDropPartitionNonExistent) + + tk.MustExec("CREATE TABLE t1 (a int(11), b varchar(64)) PARTITION BY HASH(a) PARTITIONS 3") + assertErrorCode(c, tk, "alter table t1 drop partition p2", tmysql.ErrOnlyOnRangeListPartition) } func (s *testIntegrationSuite11) TestAddPartitionTooManyPartitions(c *C) { @@ -797,7 +867,7 @@ func (s *testIntegrationSuite6) TestTruncatePartitionAndDropTable(c *C) { tk.MustExec("drop table if exists t5;") tk.MustExec("set @@session.tidb_enable_table_partition=1;") tk.MustExec(`create table t5( - id int, name varchar(50), + id int, name varchar(50), purchased date ) partition by range( year(purchased) ) ( @@ -1047,6 +1117,38 @@ func (s *testIntegrationSuite5) TestPartitionUniqueKeyNeedAllFieldsInPf(c *C) { partition p2 values less than (15) )` assertErrorCode(c, tk, sql9, tmysql.ErrUniqueKeyNeedAllFieldsInPf) + + sql10 := `create table part8 ( + a int not null, + b int not null, + c int default null, + d int default null, + e int default null, + primary key (a, b), + unique key (c, d) + ) + partition by range columns (b) ( + partition p0 values less than (4), + partition p1 values less than (7), + partition p2 values less than (11) + )` + assertErrorCode(c, tk, sql10, tmysql.ErrUniqueKeyNeedAllFieldsInPf) + + sql11 := `create table part9 ( + a int not null, + b int not null, + c int default null, + d int default null, + e int default null, + primary key (a, b), + unique key (b, c, d) + ) + partition by range columns (b, c) ( + partition p0 values less than (4, 5), + partition p1 values less than (7, 9), + partition p2 values less than (11, 22) + )` + assertErrorCode(c, tk, sql11, tmysql.ErrUniqueKeyNeedAllFieldsInPf) } func (s *testIntegrationSuite3) TestPartitionDropIndex(c *C) { @@ -1154,6 +1256,10 @@ func (s *testIntegrationSuite2) TestPartitionCancelAddIndex(c *C) { var checkErr error var c3IdxInfo *model.IndexInfo hook := &ddl.TestDDLCallback{} + originBatchSize := tk.MustQuery("select @@global.tidb_ddl_reorg_batch_size") + // Set batch size to lower try to slow down add-index reorganization, This if for hook to cancel this ddl job. + tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 32") + defer tk.MustExec(fmt.Sprintf("set @@global.tidb_ddl_reorg_batch_size = %v", originBatchSize.Rows()[0][0])) hook.OnJobUpdatedExported, c3IdxInfo, checkErr = backgroundExecOnJobUpdatedExported(c, s.store, s.ctx, hook) originHook := s.dom.DDL().GetHook() defer s.dom.DDL().(ddl.DDLForTest).SetHook(originHook) @@ -1429,6 +1535,9 @@ func (s *testIntegrationSuite4) TestPartitionErrorCode(c *C) { _, err := tk.Exec("alter table employees add partition partitions 8;") c.Assert(ddl.ErrUnsupportedAddPartition.Equal(err), IsTrue) + _, err = tk.Exec("alter table employees add partition (partition p5 values less than (42));") + c.Assert(ddl.ErrUnsupportedAddPartition.Equal(err), IsTrue) + // coalesce partition tk.MustExec(`create table clients ( id int, @@ -1449,3 +1558,34 @@ func (s *testIntegrationSuite4) TestPartitionErrorCode(c *C) { _, err = tk.Exec("alter table t_part coalesce partition 4;") c.Assert(ddl.ErrCoalesceOnlyOnHashPartition.Equal(err), IsTrue) } + +func (s *testIntegrationSuite3) TestUnsupportedPartitionManagementDDLs(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists test_1465;") + tk.MustExec(` + create table test_1465 (a int) + partition by range(a) ( + partition p1 values less than (10), + partition p2 values less than (20), + partition p3 values less than (30) + ); + `) + + _, err := tk.Exec("alter table test_1465 truncate partition p1, p2") + c.Assert(err, ErrorMatches, ".*can't run multi schema change") + _, err = tk.Exec("alter table test_1465 drop partition p1, p2") + c.Assert(err, ErrorMatches, ".*can't run multi schema change") + + _, err = tk.Exec("alter table test_1465 partition by hash(a)") + c.Assert(err, ErrorMatches, ".*alter table partition is unsupported") +} + +func (s *testIntegrationSuite8) TestTruncateTableWithPartition(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists test_4414;") + tk.MustExec("create table test_4414(a int, b int) partition by hash(a) partitions 10;") + tk.MustExec("truncate table test_4414;") + tk.MustQuery("select * from test_4414 partition (p0)").Check(testkit.Rows()) +} diff --git a/ddl/db_test.go b/ddl/db_test.go index 1bcdcca44fff5..a3f800163b0cc 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -27,6 +27,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" tmysql "github.com/pingcap/parser/mysql" @@ -86,7 +87,7 @@ func setUpSuite(s *testDBSuite, c *C) { s.lease = 100 * time.Millisecond session.SetSchemaLease(s.lease) - session.SetStatsLease(0) + session.DisableStats4Test() s.schemaName = "test_db" s.autoIDStep = autoid.GetStep() ddl.WaitTimeWhenErrorOccured = 0 @@ -295,6 +296,10 @@ func (s *testDBSuite3) TestCancelAddIndex(c *C) { var c3IdxInfo *model.IndexInfo hook := &ddl.TestDDLCallback{} + originBatchSize := s.tk.MustQuery("select @@global.tidb_ddl_reorg_batch_size") + // Set batch size to lower try to slow down add-index reorganization, This if for hook to cancel this ddl job. + s.tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 32") + defer s.tk.MustExec(fmt.Sprintf("set @@global.tidb_ddl_reorg_batch_size = %v", originBatchSize.Rows()[0][0])) // let hook.OnJobUpdatedExported has chance to cancel the job. // the hook.OnJobUpdatedExported is called when the job is updated, runReorgJob will wait ddl.ReorgWaitTimeout, then return the ddl.runDDLJob. // After that ddl call d.hook.OnJobUpdated(job), so that we can canceled the job in this test case. @@ -1684,6 +1689,13 @@ func (s *testDBSuite5) TestCreateTableWithLike(c *C) { c.Assert(err, IsNil) c.Assert(tbl1.Meta().ForeignKeys, IsNil) + // for table partition + s.tk.MustExec("use ctwl_db") + s.tk.MustExec("create table pt1 (id int) partition by range columns (id) (partition p0 values less than (10))") + s.tk.MustExec("insert into pt1 values (1),(2),(3),(4);") + s.tk.MustExec("create table ctwl_db1.pt1 like ctwl_db.pt1;") + s.tk.MustQuery("select * from ctwl_db1.pt1").Check(testkit.Rows()) + // for failure cases failSQL := fmt.Sprintf("create table t1 like test_not_exist.t") assertErrorCode(c, s.tk, failSQL, tmysql.ErrNoSuchTable) @@ -1794,7 +1806,21 @@ func (s *testDBSuite1) TestCreateTable(c *C) { _, err = s.tk.Exec("CREATE TABLE `t` (`a` int) DEFAULT CHARSET=abcdefg") c.Assert(err, NotNil) + _, err = s.tk.Exec("CREATE TABLE `collateTest` (`a` int, `b` varchar(10)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_slovak_ci") + c.Assert(err, IsNil) + result := s.tk.MustQuery("show create table collateTest") + got := result.Rows()[0][1] + c.Assert(got, Equals, "CREATE TABLE `collateTest` (\n `a` int(11) DEFAULT NULL,\n `b` varchar(10) COLLATE utf8_slovak_ci DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_slovak_ci") + + s.tk.MustExec("create database test2 default charset utf8 collate utf8_general_ci") + s.tk.MustExec("use test2") + s.tk.MustExec("create table dbCollateTest (a varchar(10))") + result = s.tk.MustQuery("show create table dbCollateTest") + got = result.Rows()[0][1] + c.Assert(got, Equals, "CREATE TABLE `dbCollateTest` (\n `a` varchar(10) COLLATE utf8_general_ci DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_general_ci") + // test for enum column + s.tk.MustExec("use test") failSQL := "create table t_enum (a enum('e','e'));" assertErrorCode(c, s.tk, failSQL, tmysql.ErrDuplicatedValueInType) failSQL = "create table t_enum (a enum('e','E'));" @@ -1812,6 +1838,59 @@ func (s *testDBSuite1) TestCreateTable(c *C) { c.Assert(err.Error(), Equals, "[types:1291]Column 'a' has duplicated value 'B' in ENUM") } +func (s *testDBSuite2) TestCreateTableWithSetCol(c *C) { + s.tk = testkit.NewTestKitWithInit(c, s.store) + s.tk.MustExec("create table t_set (a int, b set('e') default '');") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` set('e') DEFAULT ''\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('a', 'b', 'c', 'd') default 'a,C,c');") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('a','b','c','d') DEFAULT 'a,c'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + // It's for failure cases. + // The type of default value is string. + s.tk.MustExec("drop table t_set") + failedSQL := "create table t_set (a set('1', '4', '10') default '3');" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + failedSQL = "create table t_set (a set('1', '4', '10') default '1,4,11');" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + failedSQL = "create table t_set (a set('1', '4', '10') default '1 ,4');" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + // The type of default value is int. + failedSQL = "create table t_set (a set('1', '4', '10') default 0);" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + failedSQL = "create table t_set (a set('1', '4', '10') default 8);" + assertErrorCode(c, s.tk, failedSQL, tmysql.ErrInvalidDefault) + + // The type of default value is int. + // It's for successful cases + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 1);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '1'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 2);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '4'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 3);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '1,4'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("drop table t_set") + s.tk.MustExec("create table t_set (a set('1', '4', '10', '21') default 15);") + s.tk.MustQuery("show create table t_set").Check(testkit.Rows("t_set CREATE TABLE `t_set` (\n" + + " `a` set('1','4','10','21') DEFAULT '1,4,10,21'\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + s.tk.MustExec("insert into t_set value()") + s.tk.MustQuery("select * from t_set").Check(testkit.Rows("1,4,10,21")) +} + func (s *testDBSuite2) TestTableForeignKey(c *C) { s.tk = testkit.NewTestKit(c, s.store) s.tk.MustExec("use test") @@ -2025,71 +2104,101 @@ func (s *testDBSuite3) TestGeneratedColumnDDL(c *C) { s.tk = testkit.NewTestKit(c, s.store) s.tk.MustExec("use test") - // Check create table with virtual generated column. - s.tk.MustExec(`CREATE TABLE test_gv_ddl(a int, b int as (a+8) virtual)`) + // Check create table with virtual and stored generated columns. + s.tk.MustExec(`CREATE TABLE test_gv_ddl(a int, b int as (a+8) virtual, c int as (b + 2) stored)`) - // Check desc table with virtual generated column. + // Check desc table with virtual and stored generated columns. result := s.tk.MustQuery(`DESC test_gv_ddl`) - result.Check(testkit.Rows(`a int(11) YES `, `b int(11) YES VIRTUAL GENERATED`)) + result.Check(testkit.Rows(`a int(11) YES `, `b int(11) YES VIRTUAL GENERATED`, `c int(11) YES STORED GENERATED`)) - // Check show create table with virtual generated column. + // Check show create table with virtual and stored generated columns. result = s.tk.MustQuery(`show create table test_gv_ddl`) result.Check(testkit.Rows( - "test_gv_ddl CREATE TABLE `test_gv_ddl` (\n `a` int(11) DEFAULT NULL,\n `b` int(11) GENERATED ALWAYS AS (`a` + 8) VIRTUAL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", + "test_gv_ddl CREATE TABLE `test_gv_ddl` (\n `a` int(11) DEFAULT NULL,\n `b` int(11) GENERATED ALWAYS AS (`a` + 8) VIRTUAL,\n `c` int(11) GENERATED ALWAYS AS (`b` + 2) STORED\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", )) - // Check alter table add a stored generated column. - s.tk.MustExec(`alter table test_gv_ddl add column c int as (b+2) stored`) - result = s.tk.MustQuery(`DESC test_gv_ddl`) - result.Check(testkit.Rows(`a int(11) YES `, `b int(11) YES VIRTUAL GENERATED`, `c int(11) YES STORED GENERATED`)) - // Check generated expression with blanks. - s.tk.MustExec("create table table_with_gen_col_blanks (a int, b char(20) as (cast( \r\n\t a \r\n\tas char)))") + s.tk.MustExec("create table table_with_gen_col_blanks (a int, b char(20) as (cast( \r\n\t a \r\n\tas char)), c int as (a+100))") result = s.tk.MustQuery(`show create table table_with_gen_col_blanks`) result.Check(testkit.Rows("table_with_gen_col_blanks CREATE TABLE `table_with_gen_col_blanks` (\n" + " `a` int(11) DEFAULT NULL,\n" + - " `b` char(20) GENERATED ALWAYS AS (CAST(`a` AS CHAR)) VIRTUAL\n" + + " `b` char(20) GENERATED ALWAYS AS (cast(`a` as char)) VIRTUAL,\n" + + " `c` int(11) GENERATED ALWAYS AS (`a` + 100) VIRTUAL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + // Check generated expression with charset latin1 ("latin1" != mysql.DefaultCharset). + s.tk.MustExec("create table table_with_gen_col_latin1 (a int, b char(20) as (cast( \r\n\t a \r\n\tas char charset latin1)), c int as (a+100))") + result = s.tk.MustQuery(`show create table table_with_gen_col_latin1`) + result.Check(testkit.Rows("table_with_gen_col_latin1 CREATE TABLE `table_with_gen_col_latin1` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` char(20) GENERATED ALWAYS AS (cast(`a` as char charset latin1)) VIRTUAL,\n" + + " `c` int(11) GENERATED ALWAYS AS (`a` + 100) VIRTUAL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + // Check generated expression with string (issue 9457). + s.tk.MustExec("create table table_with_gen_col_string (first_name varchar(10), last_name varchar(10), full_name varchar(255) AS (CONCAT(first_name,' ',last_name)))") + result = s.tk.MustQuery(`show create table table_with_gen_col_string`) + result.Check(testkit.Rows("table_with_gen_col_string CREATE TABLE `table_with_gen_col_string` (\n" + + " `first_name` varchar(10) DEFAULT NULL,\n" + + " `last_name` varchar(10) DEFAULT NULL,\n" + + " `full_name` varchar(255) GENERATED ALWAYS AS (concat(`first_name`, ' ', `last_name`)) VIRTUAL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) + + s.tk.MustExec("alter table table_with_gen_col_string modify column full_name varchar(255) GENERATED ALWAYS AS (CONCAT(last_name,' ' ,first_name) ) VIRTUAL") + result = s.tk.MustQuery(`show create table table_with_gen_col_string`) + result.Check(testkit.Rows("table_with_gen_col_string CREATE TABLE `table_with_gen_col_string` (\n" + + " `first_name` varchar(10) DEFAULT NULL,\n" + + " `last_name` varchar(10) DEFAULT NULL,\n" + + " `full_name` varchar(255) GENERATED ALWAYS AS (concat(`last_name`, ' ', `first_name`)) VIRTUAL\n" + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin")) genExprTests := []struct { stmt string err int }{ - // drop/rename columns dependent by other column. + // Drop/rename columns dependent by other column. {`alter table test_gv_ddl drop column a`, mysql.ErrDependentByGeneratedColumn}, {`alter table test_gv_ddl change column a anew int`, mysql.ErrBadField}, - // modify/change stored status of generated columns. + // Modify/change stored status of generated columns. {`alter table test_gv_ddl modify column b bigint`, mysql.ErrUnsupportedOnGeneratedColumn}, {`alter table test_gv_ddl change column c cnew bigint as (a+100)`, mysql.ErrUnsupportedOnGeneratedColumn}, - // modify/change generated columns breaking prior. + // Modify/change generated columns breaking prior. {`alter table test_gv_ddl modify column b int as (c+100)`, mysql.ErrGeneratedColumnNonPrior}, {`alter table test_gv_ddl change column b bnew int as (c+100)`, mysql.ErrGeneratedColumnNonPrior}, - // refer not exist columns in generation expression. + // Refer not exist columns in generation expression. {`create table test_gv_ddl_bad (a int, b int as (c+8))`, mysql.ErrBadField}, - // refer generated columns non prior. + // Refer generated columns non prior. {`create table test_gv_ddl_bad (a int, b int as (c+1), c int as (a+1))`, mysql.ErrGeneratedColumnNonPrior}, - // virtual generated columns cannot be primary key. + // Virtual generated columns cannot be primary key. {`create table test_gv_ddl_bad (a int, b int, c int as (a+b) primary key)`, mysql.ErrUnsupportedOnGeneratedColumn}, {`create table test_gv_ddl_bad (a int, b int, c int as (a+b), primary key(c))`, mysql.ErrUnsupportedOnGeneratedColumn}, {`create table test_gv_ddl_bad (a int, b int, c int as (a+b), primary key(a, c))`, mysql.ErrUnsupportedOnGeneratedColumn}, + + // Add stored generated column through alter table. + {`alter table test_gv_ddl add column d int as (b+2) stored`, mysql.ErrUnsupportedOnGeneratedColumn}, + {`alter table test_gv_ddl modify column b int as (a + 8) stored`, mysql.ErrUnsupportedOnGeneratedColumn}, } for _, tt := range genExprTests { assertErrorCode(c, s.tk, tt.stmt, tt.err) } // Check alter table modify/change generated column. - s.tk.MustExec(`alter table test_gv_ddl modify column c bigint as (b+200) stored`) + modStoredColErrMsg := "[ddl:3106]'modifying a stored column' is not supported for generated columns." + _, err := s.tk.Exec(`alter table test_gv_ddl modify column c bigint as (b+200) stored`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, modStoredColErrMsg) + result = s.tk.MustQuery(`DESC test_gv_ddl`) - result.Check(testkit.Rows(`a int(11) YES `, `b int(11) YES VIRTUAL GENERATED`, `c bigint(20) YES STORED GENERATED`)) + result.Check(testkit.Rows(`a int(11) YES `, `b int(11) YES VIRTUAL GENERATED`, `c int(11) YES STORED GENERATED`)) s.tk.MustExec(`alter table test_gv_ddl change column b b bigint as (a+100) virtual`) result = s.tk.MustQuery(`DESC test_gv_ddl`) - result.Check(testkit.Rows(`a int(11) YES `, `b bigint(20) YES VIRTUAL GENERATED`, `c bigint(20) YES STORED GENERATED`)) + result.Check(testkit.Rows(`a int(11) YES `, `b bigint(20) YES VIRTUAL GENERATED`, `c int(11) YES STORED GENERATED`)) s.tk.MustExec(`alter table test_gv_ddl change column c cnew bigint`) result = s.tk.MustQuery(`DESC test_gv_ddl`) @@ -2126,7 +2235,11 @@ func (s *testDBSuite4) TestComment(c *C) { s.tk.MustExec("drop table if exists ct, ct1") } -func (s *testDBSuite5) TestRebaseAutoID(c *C) { +func (s *testDBSuite4) TestRebaseAutoID(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange"), IsNil) + }() s.tk = testkit.NewTestKit(c, s.store) s.tk.MustExec("use " + s.schemaName) @@ -2246,16 +2359,17 @@ func (s *testDBSuite2) TestAddNotNullColumnWhileInsertOnDupUpdate(c *C) { return default: } - _, tk2Err = tk2.Exec("insert nn (a, b) values (1, 1) on duplicate key update a = 1, b = b + 1") + _, tk2Err = tk2.Exec("insert nn (a, b) values (1, 1) on duplicate key update a = 1, b = values(b) + 1") if tk2Err != nil { return } } }() - tk1.MustExec("alter table nn add column c int not null default 0") + tk1.MustExec("alter table nn add column c int not null default 3 after a") close(closeCh) wg.Wait() c.Assert(tk2Err, IsNil) + tk1.MustQuery("select * from nn").Check(testkit.Rows("1 3 2")) } func (s *testDBSuite3) TestColumnModifyingDefinition(c *C) { @@ -2678,7 +2792,75 @@ func (s *testDBSuite5) TestAddIndexForGeneratedColumn(c *C) { s.tk.MustExec("admin check table gcai_table") } -func (s *testDBSuite6) TestIssue9100(c *C) { +func (s *testDBSuite5) TestModifyGeneratedColumn(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("create database if not exists test;") + tk.MustExec("use test") + modIdxColErrMsg := "[ddl:3106]'modifying an indexed column' is not supported for generated columns." + modStoredColErrMsg := "[ddl:3106]'modifying a stored column' is not supported for generated columns." + + // Modify column with single-col-index. + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1 (a int, b int as (a+1), index idx(b));") + tk.MustExec("insert into t1 set a=1;") + _, err := tk.Exec("alter table t1 modify column b int as (a+2);") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, modIdxColErrMsg) + tk.MustExec("drop index idx on t1;") + tk.MustExec("alter table t1 modify b int as (a+2);") + tk.MustQuery("select * from t1").Check(testkit.Rows("1 3")) + + // Modify column with multi-col-index. + tk.MustExec("drop table t1;") + tk.MustExec("create table t1 (a int, b int as (a+1), index idx(a, b));") + tk.MustExec("insert into t1 set a=1;") + _, err = tk.Exec("alter table t1 modify column b int as (a+2);") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, modIdxColErrMsg) + tk.MustExec("drop index idx on t1;") + tk.MustExec("alter table t1 modify b int as (a+2);") + tk.MustQuery("select * from t1").Check(testkit.Rows("1 3")) + + // Modify column with stored status to a different expression. + tk.MustExec("drop table t1;") + tk.MustExec("create table t1 (a int, b int as (a+1) stored);") + tk.MustExec("insert into t1 set a=1;") + _, err = tk.Exec("alter table t1 modify column b int as (a+2) stored;") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, modStoredColErrMsg) + + // Modify column with stored status to the same expression. + tk.MustExec("drop table t1;") + tk.MustExec("create table t1 (a int, b int as (a+1) stored);") + tk.MustExec("insert into t1 set a=1;") + tk.MustExec("alter table t1 modify column b bigint as (a+1) stored;") + tk.MustExec("alter table t1 modify column b bigint as (a + 1) stored;") + tk.MustQuery("select * from t1").Check(testkit.Rows("1 2")) + + // Modify column with index to the same expression. + tk.MustExec("drop table t1;") + tk.MustExec("create table t1 (a int, b int as (a+1), index idx(b));") + tk.MustExec("insert into t1 set a=1;") + tk.MustExec("alter table t1 modify column b bigint as (a+1);") + tk.MustExec("alter table t1 modify column b bigint as (a + 1);") + tk.MustQuery("select * from t1").Check(testkit.Rows("1 2")) + + // Modify column from non-generated to stored generated. + tk.MustExec("drop table t1;") + tk.MustExec("create table t1 (a int, b int);") + _, err = tk.Exec("alter table t1 modify column b bigint as (a+1) stored;") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, modStoredColErrMsg) + + // Modify column from stored generated to non-generated. + tk.MustExec("drop table t1;") + tk.MustExec("create table t1 (a int, b int as (a+1) stored);") + tk.MustExec("insert into t1 set a=1;") + tk.MustExec("alter table t1 modify column b int;") + tk.MustQuery("select * from t1").Check(testkit.Rows("1 2")) +} + +func (s *testDBSuite4) TestIssue9100(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test_db") tk.MustExec("create table employ (a int, b int) partition by range (b) (partition p0 values less than (1));") @@ -2721,7 +2903,12 @@ func (s *testDBSuite1) TestModifyColumnCharset(c *C) { } -func (s *testDBSuite2) TestAlterShardRowIDBits(c *C) { +func (s *testDBSuite4) TestAlterShardRowIDBits(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange"), IsNil) + }() + s.tk = testkit.NewTestKit(c, s.store) tk := s.tk @@ -2769,10 +2956,10 @@ func (s *testDBSuite2) TestDDLWithInvalidTableInfo(c *C) { _, err := s.tk.Exec(`CREATE TABLE t ( c0 int(11) , c1 int(11), - c2 decimal(16,4) GENERATED ALWAYS AS ((case when (c0 = 0) then 0 when (c0 > 0) then (c1 / c0) end)) + c2 decimal(16,4) GENERATED ALWAYS AS ((case when (c0 = 0) then 0when (c0 > 0) then (c1 / c0) end)) );`) c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[parser:1064]You have an error in your SQL syntax; check the manual that corresponds to your TiDB version for the right syntax to use line 1 column 55 near \"THEN (`c1` / `c0`) END)\" ") + c.Assert(err.Error(), Equals, "[parser:1064]You have an error in your SQL syntax; check the manual that corresponds to your TiDB version for the right syntax to use line 4 column 88 near \"then (c1 / c0) end))\n\t);\" ") tk.MustExec("create table t (a bigint, b int, c int generated always as (b+1)) partition by hash(a) partitions 4;") // Test drop partition column. @@ -2780,11 +2967,11 @@ func (s *testDBSuite2) TestDDLWithInvalidTableInfo(c *C) { c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "[expression:1054]Unknown column 'a' in 'expression'") // Test modify column with invalid expression. - _, err = tk.Exec("alter table t modify column c int GENERATED ALWAYS AS ((case when (a = 0) then 0 when (a > 0) then (b / a) end));") + _, err = tk.Exec("alter table t modify column c int GENERATED ALWAYS AS ((case when (a = 0) then 0when (a > 0) then (b / a) end));") c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[parser:1064]You have an error in your SQL syntax; check the manual that corresponds to your TiDB version for the right syntax to use line 1 column 53 near \"THEN (`b` / `a`) END)\" ") + c.Assert(err.Error(), Equals, "[parser:1064]You have an error in your SQL syntax; check the manual that corresponds to your TiDB version for the right syntax to use line 1 column 97 near \"then (b / a) end));\" ") // Test add column with invalid expression. - _, err = tk.Exec("alter table t add column d int GENERATED ALWAYS AS ((case when (a = 0) then 0 when (a > 0) then (b / a) end));") + _, err = tk.Exec("alter table t add column d int GENERATED ALWAYS AS ((case when (a = 0) then 0when (a > 0) then (b / a) end));") c.Assert(err, NotNil) - c.Assert(err.Error(), Equals, "[parser:1064]You have an error in your SQL syntax; check the manual that corresponds to your TiDB version for the right syntax to use line 1 column 53 near \"THEN (`b` / `a`) END)\" ") + c.Assert(err.Error(), Equals, "[parser:1064]You have an error in your SQL syntax; check the manual that corresponds to your TiDB version for the right syntax to use line 1 column 94 near \"then (b / a) end));\" ") } diff --git a/ddl/ddl.go b/ddl/ddl.go index e23c3168ae089..5bda00d4c9fde 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -24,6 +24,7 @@ import ( "time" "github.com/coreos/etcd/clientv3" + "github.com/google/uuid" "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -44,7 +45,6 @@ import ( "github.com/pingcap/tidb/table" tidbutil "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/logutil" - "github.com/twinj/uuid" "go.uber.org/zap" ) @@ -98,7 +98,7 @@ var ( "unsupported drop integer primary key") errUnsupportedCharset = terror.ClassDDL.New(codeUnsupportedCharset, "unsupported charset %s collate %s") - errUnsupportedShardRowIDBits = terror.ClassDDL.New(codeUnsupportedShardRowIDBits, "unsupported shard_row_id_bits for table with auto_increment column.") + errUnsupportedShardRowIDBits = terror.ClassDDL.New(codeUnsupportedShardRowIDBits, "unsupported shard_row_id_bits for table with primary key as row id.") errBlobKeyWithoutLength = terror.ClassDDL.New(codeBlobKeyWithoutLength, "index for BLOB/TEXT column must specify a key length") errIncorrectPrefixKey = terror.ClassDDL.New(codeIncorrectPrefixKey, "Incorrect prefix key; the used key part isn't a string, the used length is longer than the key part, or the storage engine doesn't support unique prefix keys") errTooLongKey = terror.ClassDDL.New(codeTooLongKey, @@ -109,12 +109,12 @@ var ( errInvalidJobVersion = terror.ClassDDL.New(codeInvalidJobVersion, "DDL job with version %d greater than current %d") errFileNotFound = terror.ClassDDL.New(codeFileNotFound, "Can't find file: './%s/%s.frm'") errErrorOnRename = terror.ClassDDL.New(codeErrorOnRename, "Error on rename of './%s/%s' to './%s/%s'") - errBadField = terror.ClassDDL.New(codeBadField, "Unknown column '%s' in '%s'") errInvalidUseOfNull = terror.ClassDDL.New(codeInvalidUseOfNull, "Invalid use of NULL value") errTooManyFields = terror.ClassDDL.New(codeTooManyFields, "Too many columns") errInvalidSplitRegionRanges = terror.ClassDDL.New(codeInvalidRanges, "Failed to split region ranges") errReorgPanic = terror.ClassDDL.New(codeReorgWorkerPanic, "reorg worker panic.") + errOnlyOnRangeListPartition = terror.ClassDDL.New(codeOnlyOnRangeListPartition, mysql.MySQLErrName[mysql.ErrOnlyOnRangeListPartition]) // errWrongKeyColumn is for table column cannot be indexed. errWrongKeyColumn = terror.ClassDDL.New(codeWrongKeyColumn, mysql.MySQLErrName[mysql.ErrWrongKeyColumn]) // errUnsupportedOnGeneratedColumn is for unsupported actions on generated columns. @@ -161,6 +161,8 @@ var ( // ErrColumnBadNull returns for a bad null value. ErrColumnBadNull = terror.ClassDDL.New(codeBadNull, "column cann't be null") + // ErrBadField forbids to refer to unknown column. + ErrBadField = terror.ClassDDL.New(codeBadField, "Unknown column '%s' in '%s'") // ErrCantRemoveAllFields returns for deleting all columns. ErrCantRemoveAllFields = terror.ClassDDL.New(codeCantRemoveAllFields, "can't delete all columns with ALTER TABLE") // ErrCantDropFieldOrKey returns for dropping a non-existent field or key. @@ -192,8 +194,6 @@ var ( // ErrNotAllowedTypeInPartition returns not allowed type error when creating table partiton with unsupport expression type. ErrNotAllowedTypeInPartition = terror.ClassDDL.New(codeErrFieldTypeNotAllowedAsPartitionField, mysql.MySQLErrName[mysql.ErrFieldTypeNotAllowedAsPartitionField]) - // ErrPartitionsMustBeDefined returns each partition must be defined. - ErrPartitionsMustBeDefined = terror.ClassDDL.New(codePartitionsMustBeDefined, "For RANGE partitions each partition must be defined") // ErrPartitionMgmtOnNonpartitioned returns it's not a partition table. ErrPartitionMgmtOnNonpartitioned = terror.ClassDDL.New(codePartitionMgmtOnNonpartitioned, "Partition management on a not partitioned table is not possible") // ErrDropPartitionNonExistent returns error in list of partition. @@ -204,14 +204,10 @@ var ( ErrRangeNotIncreasing = terror.ClassDDL.New(codeRangeNotIncreasing, "VALUES LESS THAN value must be strictly increasing for each partition") // ErrPartitionMaxvalue returns maxvalue can only be used in last partition definition. ErrPartitionMaxvalue = terror.ClassDDL.New(codePartitionMaxvalue, "MAXVALUE can only be used in last partition definition") - // ErrTooManyValues returns cannot have more than one value for this type of partitioning. - ErrTooManyValues = terror.ClassDDL.New(codeErrTooManyValues, mysql.MySQLErrName[mysql.ErrTooManyValues]) //ErrDropLastPartition returns cannot remove all partitions, use drop table instead. ErrDropLastPartition = terror.ClassDDL.New(codeDropLastPartition, mysql.MySQLErrName[mysql.ErrDropLastPartition]) //ErrTooManyPartitions returns too many partitions were defined. ErrTooManyPartitions = terror.ClassDDL.New(codeTooManyPartitions, mysql.MySQLErrName[mysql.ErrTooManyPartitions]) - //ErrNoParts returns no partition were defined. - ErrNoParts = terror.ClassDDL.New(codeNoParts, mysql.MySQLErrName[mysql.ErrNoParts]) //ErrPartitionFunctionIsNotAllowed returns this partition function is not allowed. ErrPartitionFunctionIsNotAllowed = terror.ClassDDL.New(codePartitionFunctionIsNotAllowed, mysql.MySQLErrName[mysql.ErrPartitionFunctionIsNotAllowed]) // ErrPartitionFuncNotAllowed returns partition function returns the wrong type. @@ -233,8 +229,8 @@ var ( ErrTableCantHandleFt = terror.ClassDDL.New(codeErrTableCantHandleFt, mysql.MySQLErrName[mysql.ErrTableCantHandleFt]) // ErrFieldNotFoundPart returns an error when 'partition by columns' are not found in table columns. ErrFieldNotFoundPart = terror.ClassDDL.New(codeFieldNotFoundPart, mysql.MySQLErrName[mysql.ErrFieldNotFoundPart]) - // ErrPartitionColumnList returns "Inconsistency in usage of column lists for partitioning". - ErrPartitionColumnList = terror.ClassDDL.New(codePartitionColumnList, mysql.MySQLErrName[mysql.ErrPartitionColumnList]) + // ErrWrongTypeColumnValue returns 'Partition column values of incorrect type' + ErrWrongTypeColumnValue = terror.ClassDDL.New(codeWrongTypeColumnValue, mysql.MySQLErrName[mysql.ErrWrongTypeColumnValue]) ) // DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache. @@ -266,7 +262,7 @@ type DDL interface { // RegisterEventCh registers event channel for ddl. RegisterEventCh(chan<- *util.Event) // SchemaSyncer gets the schema syncer. - SchemaSyncer() SchemaSyncer + SchemaSyncer() util.SchemaSyncer // OwnerManager gets the owner manager. OwnerManager() owner.Manager // GetID gets the ddl ID. @@ -295,7 +291,7 @@ type ddlCtx struct { uuid string store kv.Storage ownerManager owner.Manager - schemaSyncer SchemaSyncer + schemaSyncer util.SchemaSyncer ddlJobDoneCh chan struct{} ddlEventCh chan<- *util.Event lease time.Duration // lease is schema lease. @@ -359,10 +355,11 @@ func newDDL(ctx context.Context, etcdCli *clientv3.Client, store kv.Storage, if hook == nil { hook = &BaseCallback{} } - id := uuid.NewV4().String() + + id := uuid.New().String() ctx, cancelFunc := context.WithCancel(ctx) var manager owner.Manager - var syncer SchemaSyncer + var syncer util.SchemaSyncer if etcdCli == nil { // The etcdCli is nil if the store is localstore which is only used for testing. // So we use mockOwnerManager and MockSchemaSyncer. @@ -370,7 +367,7 @@ func newDDL(ctx context.Context, etcdCli *clientv3.Client, store kv.Storage, syncer = NewMockSchemaSyncer() } else { manager = owner.NewOwnerManager(etcdCli, ddlPrompt, id, DDLOwnerKey, cancelFunc) - syncer = NewSchemaSyncer(etcdCli, id) + syncer = util.NewSchemaSyncer(etcdCli, id, manager) } ddlCtx := &ddlCtx{ @@ -453,6 +450,17 @@ func (d *ddl) start(ctx context.Context, ctxPool *pools.ResourcePool) { // checks owner firstly and try to find whether a job exists and run. asyncNotify(worker.ddlJobCh) } + + go tidbutil.WithRecovery( + func() { d.schemaSyncer.StartCleanWork() }, + func(r interface{}) { + if r != nil { + logutil.Logger(ddlLogCtx).Error("[ddl] DDL syncer clean worker meet panic", + zap.String("ID", d.uuid), zap.Reflect("r", r), zap.Stack("stack trace")) + metrics.PanicCounter.WithLabelValues(metrics.LabelDDLSyncer).Inc() + } + }) + metrics.DDLCounter.WithLabelValues(fmt.Sprintf("%s", metrics.StartCleanWork)).Inc() } } @@ -464,6 +472,7 @@ func (d *ddl) close() { startTime := time.Now() close(d.quitCh) d.ownerManager.Cancel() + d.schemaSyncer.CloseCleanWork() err := d.schemaSyncer.RemoveSelfVersionPath() if err != nil { logutil.Logger(ddlLogCtx).Error("[ddl] remove self version path failed", zap.Error(err)) @@ -479,7 +488,7 @@ func (d *ddl) close() { d.delRangeMgr.clear() } - logutil.Logger(ddlLogCtx).Info("[ddl] closing DDL", zap.String("ID", d.uuid), zap.Duration("takeTime", time.Since(startTime))) + logutil.Logger(ddlLogCtx).Info("[ddl] DDL closed", zap.String("ID", d.uuid), zap.Duration("take time", time.Since(startTime))) } // GetLease implements DDL.GetLease interface. @@ -501,26 +510,32 @@ func (d *ddl) GetInfoSchemaWithInterceptor(ctx sessionctx.Context) infoschema.In return d.mu.interceptor.OnGetInfoSchema(ctx, is) } -func (d *ddl) genGlobalID() (int64, error) { - var globalID int64 +func (d *ddl) genGlobalIDs(count int) ([]int64, error) { + ret := make([]int64, count) err := kv.RunInNewTxn(d.store, true, func(txn kv.Transaction) error { var err error failpoint.Inject("mockGenGlobalIDFail", func(val failpoint.Value) { if val.(bool) { - failpoint.Return(errors.New("gofail genGlobalID error")) + failpoint.Return(errors.New("gofail genGlobalIDs error")) } }) - globalID, err = meta.NewMeta(txn).GenGlobalID() - return errors.Trace(err) + m := meta.NewMeta(txn) + for i := 0; i < count; i++ { + ret[i], err = m.GenGlobalID() + if err != nil { + return err + } + } + return nil }) - return globalID, errors.Trace(err) + return ret, err } // SchemaSyncer implements DDL.SchemaSyncer interface. -func (d *ddl) SchemaSyncer() SchemaSyncer { +func (d *ddl) SchemaSyncer() util.SchemaSyncer { return d.schemaSyncer } @@ -729,6 +744,16 @@ const ( codeNotSupportedAlterOperation = terror.ErrCode(mysql.ErrAlterOperationNotSupportedReason) codeFieldNotFoundPart = terror.ErrCode(mysql.ErrFieldNotFoundPart) codePartitionColumnList = terror.ErrCode(mysql.ErrPartitionColumnList) + codeOnlyOnRangeListPartition = terror.ErrCode(mysql.ErrOnlyOnRangeListPartition) + codePartitionRequiresValues = terror.ErrCode(mysql.ErrPartitionRequiresValues) + codePartitionWrongNoPart = terror.ErrCode(mysql.ErrPartitionWrongNoPart) + codePartitionWrongNoSubpart = terror.ErrCode(mysql.ErrPartitionWrongNoSubpart) + codePartitionWrongValues = terror.ErrCode(mysql.ErrPartitionWrongValues) + codeRowSinglePartitionField = terror.ErrCode(mysql.ErrRowSinglePartitionField) + codeSubpartition = terror.ErrCode(mysql.ErrSubpartition) + codeSystemVersioningWrongPartitions = terror.ErrCode(mysql.ErrSystemVersioningWrongPartitions) + codeWrongPartitionTypeExpectedSystemTime = terror.ErrCode(mysql.ErrWrongPartitionTypeExpectedSystemTime) + codeWrongTypeColumnValue = terror.ErrCode(mysql.ErrWrongTypeColumnValue) ) func init() { @@ -791,6 +816,16 @@ func init() { codePartitionColumnList: mysql.ErrPartitionColumnList, codeInvalidDefaultValue: mysql.ErrInvalidDefault, codeErrGeneratedColumnRefAutoInc: mysql.ErrGeneratedColumnRefAutoInc, + codeOnlyOnRangeListPartition: mysql.ErrOnlyOnRangeListPartition, + codePartitionRequiresValues: mysql.ErrPartitionRequiresValues, + codePartitionWrongNoPart: mysql.ErrPartitionWrongNoPart, + codePartitionWrongNoSubpart: mysql.ErrPartitionWrongNoSubpart, + codePartitionWrongValues: mysql.ErrPartitionWrongValues, + codeRowSinglePartitionField: mysql.ErrRowSinglePartitionField, + codeSubpartition: mysql.ErrSubpartition, + codeSystemVersioningWrongPartitions: mysql.ErrSystemVersioningWrongPartitions, + codeWrongPartitionTypeExpectedSystemTime: mysql.ErrWrongPartitionTypeExpectedSystemTime, + codeWrongTypeColumnValue: mysql.ErrWrongTypeColumnValue, } terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes } diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index f94454c93d24b..d2837f4b12f4a 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -30,18 +30,20 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" + "github.com/pingcap/parser/format" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" field_types "github.com/pingcap/parser/types" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/types/parser_driver" + driver "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/mock" @@ -60,10 +62,11 @@ func (d *ddl) CreateSchema(ctx sessionctx.Context, schema model.CIStr, charsetIn return errors.Trace(err) } - schemaID, err := d.genGlobalID() + genIDs, err := d.genGlobalIDs(1) if err != nil { return errors.Trace(err) } + schemaID := genIDs[0] dbInfo := &model.DBInfo{ Name: schema, } @@ -231,7 +234,7 @@ func setColumnFlagWithConstraint(colMap map[string]*table.Column, v *ast.Constra } func buildColumnsAndConstraints(ctx sessionctx.Context, colDefs []*ast.ColumnDef, - constraints []*ast.Constraint, tblCharset, dbCharset string) ([]*table.Column, []*ast.Constraint, error) { + constraints []*ast.Constraint, tblCharset, tblCollate, dbCharset, dbCollate string) ([]*table.Column, []*ast.Constraint, error) { colMap := map[string]*table.Column{} // outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); var outPriKeyConstraint *ast.Constraint @@ -243,7 +246,7 @@ func buildColumnsAndConstraints(ctx sessionctx.Context, colDefs []*ast.ColumnDef } cols := make([]*table.Column, 0, len(colDefs)) for i, colDef := range colDefs { - col, cts, err := buildColumnAndConstraint(ctx, i, colDef, outPriKeyConstraint, tblCharset, dbCharset) + col, cts, err := buildColumnAndConstraint(ctx, i, colDef, outPriKeyConstraint, tblCharset, tblCollate, dbCharset, dbCollate) if err != nil { return nil, nil, errors.Trace(err) } @@ -259,23 +262,32 @@ func buildColumnsAndConstraints(ctx sessionctx.Context, colDefs []*ast.ColumnDef return cols, constraints, nil } -// ResolveCharsetCollation will resolve the charset by the order: table charset > database charset > server default charset. -func ResolveCharsetCollation(tblCharset, dbCharset string) (string, string, error) { +// ResolveCharsetCollation will resolve the charset by the order: table charset > database charset > server default charset, +// and it will also resolve the collate by the order: table collate > database collate > server default collate. +func ResolveCharsetCollation(tblCharset, tblCollate, dbCharset, dbCollate string) (string, string, error) { if len(tblCharset) != 0 { - defCollate, err := charset.GetDefaultCollation(tblCharset) - if err != nil { - // return terror is better. - return "", "", ErrUnknownCharacterSet.GenWithStackByArgs(tblCharset) + // tblCollate is not specified by user. + if len(tblCollate) == 0 { + defCollate, err := charset.GetDefaultCollation(tblCharset) + if err != nil { + // return terror is better. + return "", "", ErrUnknownCharacterSet.GenWithStackByArgs(tblCharset) + } + return tblCharset, defCollate, nil } - return tblCharset, defCollate, nil + return tblCharset, tblCollate, nil } if len(dbCharset) != 0 { - defCollate, err := charset.GetDefaultCollation(dbCharset) - if err != nil { - return "", "", ErrUnknownCharacterSet.GenWithStackByArgs(dbCharset) + // dbCollate is not specified by user. + if len(dbCollate) == 0 { + defCollate, err := charset.GetDefaultCollation(dbCharset) + if err != nil { + return "", "", ErrUnknownCharacterSet.GenWithStackByArgs(dbCharset) + } + return dbCharset, defCollate, nil } - return dbCharset, defCollate, errors.Trace(err) + return dbCharset, dbCollate, nil } charset, collate := charset.GetDefaultCharsetAndCollate() @@ -292,15 +304,34 @@ func typesNeedCharset(tp byte) bool { return false } -func setCharsetCollationFlenDecimal(tp *types.FieldType, tblCharset string, dbCharset string) error { +func setCharsetCollationFlenDecimal(tp *types.FieldType, specifiedCollates []string, tblCharset, tblCollate, dbCharset, dbCollate string) error { tp.Charset = strings.ToLower(tp.Charset) tp.Collate = strings.ToLower(tp.Collate) if len(tp.Charset) == 0 { if typesNeedCharset(tp.Tp) { - var err error - tp.Charset, tp.Collate, err = ResolveCharsetCollation(tblCharset, dbCharset) - if err != nil { - return errors.Trace(err) + if len(specifiedCollates) == 0 { + // Both the charset and collate are not specified. + var err error + tp.Charset, tp.Collate, err = ResolveCharsetCollation(tblCharset, tblCollate, dbCharset, dbCollate) + if err != nil { + return errors.Trace(err) + } + } else { + // The charset is not specified but the collate is. + // We should derive charset from it's collate specified rather than getting from table and db. + // It is handled like mysql's logic, use derived charset to judge conflict with next collate. + for _, spc := range specifiedCollates { + derivedCollation, err := charset.GetCollationByName(spc) + if err != nil { + return errors.Trace(err) + } + if len(tp.Charset) == 0 { + tp.Charset = derivedCollation.CharsetName + } else if tp.Charset != derivedCollation.CharsetName { + return ErrCollationCharsetMismatch.GenWithStackByArgs(derivedCollation.Name, tp.Charset) + } + tp.Collate = derivedCollation.Name + } } } else { tp.Charset = charset.CharsetBin @@ -311,10 +342,25 @@ func setCharsetCollationFlenDecimal(tp *types.FieldType, tblCharset string, dbCh return errUnsupportedCharset.GenWithStackByArgs(tp.Charset, tp.Collate) } if len(tp.Collate) == 0 { - var err error - tp.Collate, err = charset.GetDefaultCollation(tp.Charset) - if err != nil { - return errors.Trace(err) + if len(specifiedCollates) == 0 { + // The charset is specified, but the collate is not. + var err error + tp.Collate, err = charset.GetDefaultCollation(tp.Charset) + if err != nil { + return errors.Trace(err) + } + } else { + // Both the charset and collate are specified. + for _, spc := range specifiedCollates { + derivedCollation, err := charset.GetCollationByName(spc) + if err != nil { + return errors.Trace(err) + } + if tp.Charset != derivedCollation.CharsetName { + return ErrCollationCharsetMismatch.GenWithStackByArgs(derivedCollation.Name, tp.Charset) + } + tp.Collate = derivedCollation.Name + } } } } @@ -337,8 +383,11 @@ func setCharsetCollationFlenDecimal(tp *types.FieldType, tblCharset string, dbCh // outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); func buildColumnAndConstraint(ctx sessionctx.Context, offset int, - colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint, tblCharset, dbCharset string) (*table.Column, []*ast.Constraint, error) { - if err := setCharsetCollationFlenDecimal(colDef.Tp, tblCharset, dbCharset); err != nil { + colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint, tblCharset, tblCollate, dbCharset, dbCollate string) (*table.Column, []*ast.Constraint, error) { + // specifiedCollates refers to collates in colDef.Options, should handle them together. + specifiedCollates := extractCollateFromOption(colDef) + + if err := setCharsetCollationFlenDecimal(colDef.Tp, specifiedCollates, tblCharset, tblCollate, dbCharset, dbCollate); err != nil { return nil, nil, errors.Trace(err) } col, cts, err := columnDefToCol(ctx, offset, colDef, outPriKeyConstraint) @@ -451,6 +500,11 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o }, } + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + for _, v := range colDef.Options { switch v.Tp { case ast.ColumnOptionNotNull: @@ -462,13 +516,19 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o case ast.ColumnOptionAutoIncrement: col.Flag |= mysql.AutoIncrementFlag case ast.ColumnOptionPrimaryKey: - constraint := &ast.Constraint{Tp: ast.ConstraintPrimaryKey, Keys: keys} - constraints = append(constraints, constraint) - col.Flag |= mysql.PriKeyFlag + // Check PriKeyFlag first to avoid extra duplicate constraints. + if col.Flag&mysql.PriKeyFlag == 0 { + constraint := &ast.Constraint{Tp: ast.ConstraintPrimaryKey, Keys: keys} + constraints = append(constraints, constraint) + col.Flag |= mysql.PriKeyFlag + } case ast.ColumnOptionUniqKey: - constraint := &ast.Constraint{Tp: ast.ConstraintUniqKey, Name: colDef.Name.Name.O, Keys: keys} - constraints = append(constraints, constraint) - col.Flag |= mysql.UniqueKeyFlag + // Check UniqueFlag first to avoid extra duplicate constraints. + if col.Flag&mysql.UniqueFlag == 0 { + constraint := &ast.Constraint{Tp: ast.ConstraintUniqKey, Keys: keys} + constraints = append(constraints, constraint) + col.Flag |= mysql.UniqueKeyFlag + } case ast.ColumnOptionDefaultValue: hasDefaultValue, err = setDefaultValue(ctx, col, v) if err != nil { @@ -478,7 +538,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o case ast.ColumnOptionOnUpdate: // TODO: Support other time functions. if col.Tp == mysql.TypeTimestamp || col.Tp == mysql.TypeDatetime { - if !expression.IsCurrentTimestampExpr(v.Expr) { + if !expression.IsValidCurrentTimestampExpr(v.Expr, colDef.Tp) { return nil, nil, ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) } } else { @@ -492,9 +552,12 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o return nil, nil, errors.Trace(err) } case ast.ColumnOptionGenerated: - var buf = bytes.NewBuffer([]byte{}) - v.Expr.Format(buf) - col.GeneratedExprString = buf.String() + sb.Reset() + err = v.Expr.Restore(restoreCtx) + if err != nil { + return nil, nil, errors.Trace(err) + } + col.GeneratedExprString = sb.String() col.GeneratedStored = v.Stored _, dependColNames := findDependedColumnNames(colDef) col.Dependences = dependColNames @@ -526,6 +589,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o col.Flag &= ^mysql.BinaryFlag col.Flag |= mysql.ZerofillFlag } + // If you specify ZEROFILL for a numeric column, MySQL automatically adds the UNSIGNED attribute to the column. // See https://dev.mysql.com/doc/refman/5.7/en/numeric-type-overview.html for more details. // But some types like bit and year, won't show its unsigned flag in `show create table`. @@ -551,8 +615,8 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o return col, constraints, nil } -func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption, t *types.FieldType) (interface{}, error) { - tp, fsp := t.Tp, t.Decimal +func getDefaultValue(ctx sessionctx.Context, col *table.Column, c *ast.ColumnOption) (interface{}, error) { + tp, fsp := col.FieldType.Tp, col.FieldType.Decimal if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime { switch x := c.Expr.(type) { case *ast.FuncCallExpr: @@ -564,14 +628,14 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption } } if defaultFsp != fsp { - return nil, ErrInvalidDefaultValue.GenWithStackByArgs(colName) + return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) } } } vd, err := expression.GetTimeValue(ctx, c.Expr, tp, fsp) value := vd.GetValue() if err != nil { - return nil, ErrInvalidDefaultValue.GenWithStackByArgs(colName) + return nil, ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) } // Value is nil means `default null`. @@ -612,14 +676,14 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption return strconv.FormatUint(value, 10), nil } - if tp == mysql.TypeDuration { - var err error - if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, t); err != nil { + switch tp { + case mysql.TypeSet: + return setSetDefaultValue(v, col) + case mysql.TypeDuration: + if v, err = v.ConvertTo(ctx.GetSessionVars().StmtCtx, &col.FieldType); err != nil { return "", errors.Trace(err) } - } - - if tp == mysql.TypeBit { + case mysql.TypeBit: if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 { // For BIT fields, convert int into BinaryLiteral. return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), nil @@ -629,6 +693,58 @@ func getDefaultValue(ctx sessionctx.Context, colName string, c *ast.ColumnOption return v.ToString() } +// setSetDefaultValue sets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html. +func setSetDefaultValue(v types.Datum, col *table.Column) (string, error) { + if v.Kind() == types.KindInt64 { + setCnt := len(col.Elems) + maxLimit := int64(1< maxLimit { + return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + setVal, err := types.ParseSetValue(col.Elems, uint64(val)) + if err != nil { + return "", errors.Trace(err) + } + v.SetMysqlSet(setVal) + return v.ToString() + } + + str, err := v.ToString() + if err != nil { + return "", errors.Trace(err) + } + if str == "" { + return str, nil + } + + valMap := make(map[string]struct{}, len(col.Elems)) + dVals := strings.Split(strings.ToLower(str), ",") + for _, dv := range dVals { + valMap[dv] = struct{}{} + } + var existCnt int + for dv := range valMap { + for i := range col.Elems { + e := strings.ToLower(col.Elems[i]) + if e == dv { + existCnt++ + break + } + } + } + if existCnt != len(valMap) { + return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + setVal, err := types.ParseSetName(col.Elems, str) + if err != nil { + return "", ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + v.SetMysqlSet(setVal) + + return v.ToString() +} + func removeOnUpdateNowFlag(c *table.Column) { // For timestamp Col, if it is set null or default value, // OnUpdateNowFlag should be removed. @@ -967,10 +1083,11 @@ func buildTableInfo(ctx sessionctx.Context, d *ddl, tableName model.CIStr, cols // When this function is called by MockTableInfo, we should set a particular table id. // So the `ddl` structure may be nil. if d != nil { - tbInfo.ID, err = d.genGlobalID() + genIDs, err := d.genGlobalIDs(1) if err != nil { return nil, errors.Trace(err) } + tbInfo.ID = genIDs[0] } for _, v := range cols { v.ID = allocateColumnID(tbInfo) @@ -1097,10 +1214,22 @@ func (d *ddl) CreateTableWithLike(ctx sessionctx.Context, ident, referIdent ast. } tblInfo := buildTableInfoWithLike(ident, referTbl.Meta()) - tblInfo.ID, err = d.genGlobalID() + count := 1 + if tblInfo.Partition != nil { + count += len(tblInfo.Partition.Definitions) + } + var genIDs []int64 + genIDs, err = d.genGlobalIDs(count) if err != nil { return errors.Trace(err) } + tblInfo.ID = genIDs[0] + if tblInfo.Partition != nil { + for i := 0; i < len(tblInfo.Partition.Definitions); i++ { + tblInfo.Partition.Definitions[i].ID = genIDs[i+1] + } + } + job := &model.Job{ SchemaID: schema.ID, TableID: tblInfo.ID, @@ -1135,6 +1264,12 @@ func buildTableInfoWithLike(ident ast.Ident, referTblInfo *model.TableInfo) mode tblInfo.Name = ident.Name tblInfo.AutoIncID = 0 tblInfo.ForeignKeys = nil + if referTblInfo.Partition != nil { + pi := *referTblInfo.Partition + pi.Definitions = make([]model.PartitionDefinition, len(referTblInfo.Partition.Definitions)) + copy(pi.Definitions, referTblInfo.Partition.Definitions) + tblInfo.Partition = &pi + } return tblInfo } @@ -1142,10 +1277,10 @@ func buildTableInfoWithLike(ident ast.Ident, referTblInfo *model.TableInfo) mode // The SQL string should be a create table statement. // Don't use this function to build a partitioned table. func BuildTableInfoFromAST(s *ast.CreateTableStmt) (*model.TableInfo, error) { - return buildTableInfoWithCheck(mock.NewContext(), nil, s, mysql.DefaultCharset) + return buildTableInfoWithCheck(mock.NewContext(), nil, s, mysql.DefaultCharset, "") } -func buildTableInfoWithCheck(ctx sessionctx.Context, d *ddl, s *ast.CreateTableStmt, dbCharset string) (*model.TableInfo, error) { +func buildTableInfoWithCheck(ctx sessionctx.Context, d *ddl, s *ast.CreateTableStmt, dbCharset, dbCollate string) (*model.TableInfo, error) { ident := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name} colDefs := s.Cols colObjects := make([]interface{}, 0, len(colDefs)) @@ -1172,9 +1307,9 @@ func buildTableInfoWithCheck(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS return nil, errors.Trace(err) } - tableCharset := findTableOptionCharset(s.Options) + tableCharset, tableCollate := findTableOptionCharsetAndCollate(s.Options) // The column charset haven't been resolved here. - cols, newConstraints, err := buildColumnsAndConstraints(ctx, colDefs, s.Constraints, tableCharset, dbCharset) + cols, newConstraints, err := buildColumnsAndConstraints(ctx, colDefs, s.Constraints, tableCharset, tableCollate, dbCharset, dbCollate) if err != nil { return nil, errors.Trace(err) } @@ -1198,11 +1333,7 @@ func buildTableInfoWithCheck(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS if pi != nil { switch pi.Type { case model.PartitionTypeRange: - if len(pi.Columns) == 0 { - err = checkPartitionByRange(ctx, tbInfo, pi, s, cols, newConstraints) - } else { - err = checkPartitionByRangeColumn(ctx, tbInfo, pi, s) - } + err = checkPartitionByRange(ctx, tbInfo, pi, cols, s) case model.PartitionTypeHash: err = checkPartitionByHash(ctx, pi, s, cols, tbInfo) } @@ -1220,7 +1351,7 @@ func buildTableInfoWithCheck(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS return nil, errors.Trace(err) } - if err = resolveDefaultTableCharsetAndCollation(tbInfo, dbCharset); err != nil { + if err = resolveDefaultTableCharsetAndCollation(tbInfo, dbCharset, dbCollate); err != nil { return nil, errors.Trace(err) } @@ -1250,7 +1381,7 @@ func (d *ddl) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) (err e return infoschema.ErrTableExists.GenWithStackByArgs(ident) } - tbInfo, err := buildTableInfoWithCheck(ctx, d, s, schema.Charset) + tbInfo, err := buildTableInfoWithCheck(ctx, d, s, schema.Charset, schema.Collate) if err != nil { return errors.Trace(err) } @@ -1271,23 +1402,29 @@ func (d *ddl) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) (err e err = d.doDDLJob(ctx, job) if err == nil { - var preSplitAndScatter func() // do pre-split and scatter. - if tbInfo.ShardRowIDBits > 0 && tbInfo.PreSplitRegions > 0 { - preSplitAndScatter = func() { preSplitTableRegion(d.store, tbInfo, ctx.GetSessionVars().WaitTableSplitFinish) } - } else if atomic.LoadUint32(&EnableSplitTableRegion) != 0 { + sp, ok := d.store.(kv.SplitableStore) + if ok && atomic.LoadUint32(&EnableSplitTableRegion) != 0 { + var ( + preSplit func() + scatterRegion bool + ) + val, err := variable.GetGlobalSystemVar(ctx.GetSessionVars(), variable.TiDBScatterRegion) + if err != nil { + logutil.Logger(context.Background()).Warn("[ddl] won't scatter region", zap.Error(err)) + } else { + scatterRegion = variable.TiDBOptOn(val) + } pi := tbInfo.GetPartitionInfo() if pi != nil { - preSplitAndScatter = func() { splitPartitionTableRegion(d.store, pi) } + preSplit = func() { splitPartitionTableRegion(sp, pi, scatterRegion) } } else { - preSplitAndScatter = func() { splitTableRegion(d.store, tbInfo.ID) } + preSplit = func() { splitTableRegion(sp, tbInfo, scatterRegion) } } - } - if preSplitAndScatter != nil { - if ctx.GetSessionVars().WaitTableSplitFinish { - preSplitAndScatter() + if scatterRegion { + preSplit() } else { - go preSplitAndScatter() + go preSplit() } } @@ -1356,11 +1493,22 @@ func (d *ddl) CreateView(ctx sessionctx.Context, s *ast.CreateViewStmt) (err err if err = checkTooLongTable(ident.Name); err != nil { return err } - viewInfo, cols := buildViewInfoWithTableColumns(ctx, s) + viewInfo, err := buildViewInfo(ctx, s) + if err != nil { + return err + } - colObjects := make([]interface{}, 0, len(viewInfo.Cols)) - for _, col := range viewInfo.Cols { - colObjects = append(colObjects, col) + cols := make([]*table.Column, len(s.Cols)) + colObjects := make([]interface{}, 0, len(s.Cols)) + + for i, v := range s.Cols { + cols[i] = table.ToColumn(&model.ColumnInfo{ + Name: v, + ID: int64(i), + Offset: i, + State: model.StatePublic, + }) + colObjects = append(colObjects, v) } if err = checkTooLongColumn(colObjects); err != nil { @@ -1398,40 +1546,16 @@ func (d *ddl) CreateView(ctx sessionctx.Context, s *ast.CreateViewStmt) (err err return d.callHookOnChanged(err) } -func buildViewInfoWithTableColumns(ctx sessionctx.Context, s *ast.CreateViewStmt) (*model.ViewInfo, []*table.Column) { - viewInfo := &model.ViewInfo{Definer: s.Definer, Algorithm: s.Algorithm, - Security: s.Security, SelectStmt: s.Select.Text(), CheckOption: s.CheckOption} - - var schemaCols = s.Select.(*ast.SelectStmt).Fields.Fields - viewInfo.Cols = make([]model.CIStr, len(schemaCols)) - for i, v := range schemaCols { - viewInfo.Cols[i] = v.AsName +func buildViewInfo(ctx sessionctx.Context, s *ast.CreateViewStmt) (*model.ViewInfo, error) { + // Always Use `format.RestoreNameBackQuotes` to restore `SELECT` statement despite the `ANSI_QUOTES` SQL Mode is enabled or not. + restoreFlag := format.RestoreStringSingleQuotes | format.RestoreKeyWordUppercase | format.RestoreNameBackQuotes + var sb strings.Builder + if err := s.Select.Restore(format.NewRestoreCtx(restoreFlag, &sb)); err != nil { + return nil, err } - var tableColumns = make([]*table.Column, len(schemaCols)) - if s.Cols == nil { - for i, v := range schemaCols { - tableColumns[i] = table.ToColumn(&model.ColumnInfo{ - Name: v.AsName, - ID: int64(i), - Offset: i, - State: model.StatePublic, - Version: model.CurrLatestColumnInfoVersion, - }) - } - } else { - for i, v := range s.Cols { - tableColumns[i] = table.ToColumn(&model.ColumnInfo{ - Name: v, - ID: int64(i), - Offset: i, - State: model.StatePublic, - Version: model.CurrLatestColumnInfoVersion, - }) - } - } - - return viewInfo, tableColumns + return &model.ViewInfo{Definer: s.Definer, Algorithm: s.Algorithm, + Security: s.Security, SelectStmt: sb.String(), CheckOption: s.CheckOption, Cols: nil}, nil } func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, tbInfo *model.TableInfo) error { @@ -1447,12 +1571,8 @@ func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *as return checkPartitionFuncType(ctx, s, cols, tbInfo) } -func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, newConstraints []*ast.Constraint) error { - if err := checkPartitionNameUnique(tbInfo, pi); err != nil { - return err - } - - if err := checkCreatePartitionValue(ctx, tbInfo, pi, cols); err != nil { +func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, cols []*table.Column, s *ast.CreateTableStmt) error { + if err := checkPartitionNameUnique(pi); err != nil { return err } @@ -1464,31 +1584,37 @@ func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi * return err } - if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil { - return err - } + if len(pi.Columns) == 0 { + if err := checkCreatePartitionValue(ctx, tbInfo, pi, cols); err != nil { + return err + } - return checkPartitionFuncType(ctx, s, cols, tbInfo) -} + // s maybe nil when add partition. + if s == nil { + return nil + } -func checkPartitionByRangeColumn(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt) error { - if err := checkPartitionNameUnique(tbInfo, pi); err != nil { - return err + if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil { + return err + } + return checkPartitionFuncType(ctx, s, cols, tbInfo) } + // Check for range columns partition. if err := checkRangeColumnsPartitionType(tbInfo, pi.Columns); err != nil { return err } - if err := checkRangeColumnsPartitionValue(ctx, tbInfo, pi); err != nil { - return err - } - - if err := checkNoRangePartitions(len(pi.Definitions)); err != nil { - return errors.Trace(err) + if s != nil { + for _, def := range s.Partition.Definitions { + exprs := def.Clause.(*ast.PartitionDefinitionClauseLessThan).Exprs + if err := checkRangeColumnsTypeAndValuesMatch(ctx, tbInfo, pi.Columns, exprs); err != nil { + return err + } + } } - return checkAddPartitionTooManyPartitions(uint64(len(pi.Definitions))) + return checkRangeColumnsPartitionValue(ctx, tbInfo, pi) } func checkRangeColumnsPartitionType(tbInfo *model.TableInfo, columns []model.CIStr) error { @@ -1517,15 +1643,16 @@ func checkRangeColumnsPartitionValue(ctx sessionctx.Context, tbInfo *model.Table // Range columns partition key supports multiple data types with integer、datetime、string. defs := pi.Definitions if len(defs) < 1 { - return errors.Trace(ErrPartitionsMustBeDefined) + return ast.ErrPartitionsMustBeDefined.GenWithStackByArgs("RANGE") } curr := &defs[0] if len(curr.LessThan) != len(pi.Columns) { - return errors.Trace(ErrPartitionColumnList) + return errors.Trace(ast.ErrPartitionColumnList) } + var prev *model.PartitionDefinition for i := 1; i < len(defs); i++ { - prev, curr := curr, &defs[i] + prev, curr = curr, &defs[i] succ, err := checkTwoRangeColumns(ctx, curr, prev, pi, tbInfo) if err != nil { return err @@ -1539,7 +1666,7 @@ func checkRangeColumnsPartitionValue(ctx sessionctx.Context, tbInfo *model.Table func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDefinition, pi *model.PartitionInfo, tbInfo *model.TableInfo) (bool, error) { if len(curr.LessThan) != len(pi.Columns) { - return false, errors.Trace(ErrPartitionColumnList) + return false, errors.Trace(ast.ErrPartitionColumnList) } for i := 0; i < len(pi.Columns); i++ { // Special handling for MAXVALUE. @@ -1610,8 +1737,8 @@ func (d *ddl) handleAutoIncID(tbInfo *model.TableInfo, schemaID int64) error { return nil } -func resolveDefaultTableCharsetAndCollation(tbInfo *model.TableInfo, dbCharset string) (err error) { - chr, collate, err := ResolveCharsetCollation(tbInfo.Charset, dbCharset) +func resolveDefaultTableCharsetAndCollation(tbInfo *model.TableInfo, dbCharset, dbCollate string) (err error) { + chr, collate, err := ResolveCharsetCollation(tbInfo.Charset, tbInfo.Collate, dbCharset, dbCollate) if err != nil { return errors.Trace(err) } @@ -1625,18 +1752,30 @@ func resolveDefaultTableCharsetAndCollation(tbInfo *model.TableInfo, dbCharset s return } -func findTableOptionCharset(options []*ast.TableOption) string { - var tableCharset string +func findTableOptionCharsetAndCollate(options []*ast.TableOption) (tableCharset, tableCollate string) { + var findCnt int for i := len(options) - 1; i >= 0; i-- { op := options[i] - if op.Tp == ast.TableOptionCharset { + if len(tableCharset) == 0 && op.Tp == ast.TableOptionCharset { // find the last one. tableCharset = op.StrValue - break + findCnt++ + if findCnt == 2 { + break + } + continue + } + if len(tableCollate) == 0 && op.Tp == ast.TableOptionCollate { + // find the last one. + tableCollate = op.StrValue + findCnt++ + if findCnt == 2 { + break + } + continue } } - - return tableCharset + return tableCharset, tableCollate } // handleTableOptions updates tableInfo according to table options. @@ -1654,8 +1793,7 @@ func handleTableOptions(options []*ast.TableOption, tbInfo *model.TableInfo) err case ast.TableOptionCompression: tbInfo.Compression = op.StrValue case ast.TableOptionShardRowID: - ok, _ := hasAutoIncrementColumn(tbInfo) - if ok && op.UintValue != 0 { + if op.UintValue > 0 && tbInfo.PKIsHandle { return errUnsupportedShardRowIDBits } tbInfo.ShardRowIDBits = op.UintValue @@ -1674,15 +1812,6 @@ func handleTableOptions(options []*ast.TableOption, tbInfo *model.TableInfo) err return nil } -func hasAutoIncrementColumn(tbInfo *model.TableInfo) (bool, string) { - for _, col := range tbInfo.Columns { - if mysql.HasAutoIncrementFlag(col.Flag) { - return true, col.Name.L - } - } - return false, "" -} - // isIgnorableSpec checks if the spec type is ignorable. // Some specs are parsed by ignored. This is for compatibility. func isIgnorableSpec(tp ast.AlterTableType) bool { @@ -1747,8 +1876,7 @@ func resolveAlterTableSpec(ctx sessionctx.Context, specs []*ast.AlterTableSpec) validSpecs = append(validSpecs, spec) } - if len(validSpecs) != 1 { - // TODO: Hanlde len(validSpecs) == 0. + if len(validSpecs) > 1 { // Now we only allow one schema changing at the same time. return nil, errRunMultiSchemaChanges } @@ -1835,6 +1963,9 @@ func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.A err = ErrUnsupportedModifyPrimaryKey.GenWithStackByArgs("drop") case ast.AlterTableRenameIndex: err = d.RenameIndex(ctx, ident, spec) + case ast.AlterTablePartition: + // Prevent silent succeed if user executes ALTER TABLE x PARTITION BY ... + err = errors.New("alter table partition is unsupported") case ast.AlterTableOption: for i, opt := range spec.Options { switch opt.Tp { @@ -1912,14 +2043,13 @@ func (d *ddl) ShardRowID(ctx sessionctx.Context, tableIdent ast.Ident, uVal uint if err != nil { return errors.Trace(err) } - ok, _ := hasAutoIncrementColumn(t.Meta()) - if ok && uVal != 0 { - return errUnsupportedShardRowIDBits - } if uVal == t.Meta().ShardRowIDBits { // Nothing need to do. return nil } + if uVal > 0 && t.Meta().PKIsHandle { + return errUnsupportedShardRowIDBits + } err = verifyNoOverflowShardBits(d.sessPool, t, uVal) if err != nil { return err @@ -1949,7 +2079,7 @@ func (d *ddl) getSchemaAndTableByIdent(ctx sessionctx.Context, tableIdent ast.Id return schema, t, nil } -func checkColumnConstraint(col *ast.ColumnDef, ti ast.Ident) error { +func checkUnsupportedColumnConstraint(col *ast.ColumnDef, ti ast.Ident) error { for _, constraint := range col.Options { switch constraint.Tp { case ast.ColumnOptionAutoIncrement: @@ -1967,8 +2097,8 @@ func checkColumnConstraint(col *ast.ColumnDef, ti ast.Ident) error { // AddColumn will add a new column to the table. func (d *ddl) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTableSpec) error { specNewColumn := spec.NewColumns[0] - // Check whether the added column constraints are supported. - err := checkColumnConstraint(specNewColumn, ti) + + err := checkUnsupportedColumnConstraint(specNewColumn, ti) if err != nil { return errors.Trace(err) } @@ -1992,23 +2122,33 @@ func (d *ddl) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTab } // If new column is a generated column, do validation. - // NOTE: Because now we can only append columns to table, - // we don't need check whether the column refers other - // generated columns occurring later in table. + // NOTE: we do check whether the column refers other generated + // columns occurring later in a table, but we don't handle the col offset. for _, option := range specNewColumn.Options { if option.Tp == ast.ColumnOptionGenerated { if err := checkIllegalFn4GeneratedColumn(specNewColumn.Name.Name.L, option.Expr); err != nil { return errors.Trace(err) } - referableColNames := make(map[string]struct{}, len(t.Cols())) - for _, col := range t.Cols() { - referableColNames[col.Name.L] = struct{}{} + + if option.Stored { + return errUnsupportedOnGeneratedColumn.GenWithStackByArgs("Adding generated stored column through ALTER TABLE") } + _, dependColNames := findDependedColumnNames(specNewColumn) if err = checkAutoIncrementRef(specNewColumn.Name.Name.L, dependColNames, t.Meta()); err != nil { return errors.Trace(err) } - if err = columnNamesCover(referableColNames, dependColNames); err != nil { + duplicateColNames := make(map[string]struct{}, len(dependColNames)) + for k := range dependColNames { + duplicateColNames[k] = struct{}{} + } + cols := t.Cols() + + if err = checkDependedColExist(dependColNames, cols); err != nil { + return errors.Trace(err) + } + + if err = verifyColumnGenerationSingle(duplicateColNames, cols, spec.Position); err != nil { return errors.Trace(err) } } @@ -2021,7 +2161,7 @@ func (d *ddl) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTab // Ignore table constraints now, maybe return error later. // We use length(t.Cols()) as the default offset firstly, we will change the // column's offset later. - col, _, err = buildColumnAndConstraint(ctx, len(t.Cols()), specNewColumn, nil, t.Meta().Charset, schema.Charset) + col, _, err = buildColumnAndConstraint(ctx, len(t.Cols()), specNewColumn, nil, t.Meta().Charset, t.Meta().Collate, schema.Charset, schema.Collate) if err != nil { return errors.Trace(err) } @@ -2057,30 +2197,21 @@ func (d *ddl) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec * } meta := t.Meta() - if meta.GetPartitionInfo() == nil { + pi := meta.GetPartitionInfo() + if pi == nil { return errors.Trace(ErrPartitionMgmtOnNonpartitioned) } - // We don't support add hash type partition now. - if meta.Partition.Type == model.PartitionTypeHash { - return errors.Trace(ErrUnsupportedAddPartition) - } - - partInfo, err := buildPartitionInfo(meta, d, spec) - if err != nil { - return errors.Trace(err) - } - - err = checkAddPartitionTooManyPartitions(uint64(len(meta.Partition.Definitions) + len(partInfo.Definitions))) - if err != nil { - return errors.Trace(err) - } - err = checkPartitionNameUnique(meta, partInfo) + partInfo, err := buildPartitionInfo(ctx, meta, d, spec) if err != nil { return errors.Trace(err) } - err = checkAddPartitionValue(meta, partInfo) + // partInfo contains only the new added partition, we have to combine it with the + // old partitions to check all partitions is strictly increasing. + tmp := *partInfo + tmp.Definitions = append(pi.Definitions, tmp.Definitions...) + err = checkPartitionByRange(ctx, meta, &tmp, t.Cols(), nil) if err != nil { return errors.Trace(err) } @@ -2115,20 +2246,28 @@ func (d *ddl) CoalescePartitions(ctx sessionctx.Context, ident ast.Ident, spec * return errors.Trace(ErrPartitionMgmtOnNonpartitioned) } - // Coalesce partition can only be used on hash/key partitions. - if meta.Partition.Type == model.PartitionTypeRange { - return errors.Trace(ErrCoalesceOnlyOnHashPartition) - } - + switch meta.Partition.Type { // We don't support coalesce partitions hash type partition now. - if meta.Partition.Type == model.PartitionTypeHash { + case model.PartitionTypeHash: return errors.Trace(ErrUnsupportedCoalescePartition) + + // Key type partition cannot be constructed currently, ignoring it for now. + case model.PartitionTypeKey: + + // Coalesce partition can only be used on hash/key partitions. + default: + return errors.Trace(ErrCoalesceOnlyOnHashPartition) } return errors.Trace(err) } func (d *ddl) TruncateTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + // TODO: Support truncate multiple partitions + if len(spec.PartitionNames) != 1 { + return errRunMultiSchemaChanges + } + is := d.infoHandle.Get() schema, ok := is.SchemaByName(ident.Schema) if !ok { @@ -2144,7 +2283,7 @@ func (d *ddl) TruncateTablePartition(ctx sessionctx.Context, ident ast.Ident, sp } var pid int64 - pid, err = tables.FindPartitionByName(meta, spec.Name) + pid, err = tables.FindPartitionByName(meta, spec.PartitionNames[0].L) if err != nil { return errors.Trace(err) } @@ -2166,6 +2305,11 @@ func (d *ddl) TruncateTablePartition(ctx sessionctx.Context, ident ast.Ident, sp } func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + // TODO: Support drop multiple partitions + if len(spec.PartitionNames) != 1 { + return errRunMultiSchemaChanges + } + is := d.infoHandle.Get() schema, ok := is.SchemaByName(ident.Schema) if !ok { @@ -2179,7 +2323,9 @@ func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec * if meta.GetPartitionInfo() == nil { return errors.Trace(ErrPartitionMgmtOnNonpartitioned) } - err = checkDropTablePartition(meta, spec.Name) + + partName := spec.PartitionNames[0].L + err = checkDropTablePartition(meta, partName) if err != nil { return errors.Trace(err) } @@ -2189,7 +2335,7 @@ func (d *ddl) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec * TableID: meta.ID, Type: model.ActionDropTablePartition, BinlogInfo: &model.HistoryInfo{}, - Args: []interface{}{spec.Name}, + Args: []interface{}{partName}, } err = d.doDDLJob(ctx, job) @@ -2240,8 +2386,10 @@ func modifiableCharsetAndCollation(toCharset, toCollate, origCharset, origCollat if !charset.ValidCharsetAndCollation(toCharset, toCollate) { return ErrUnknownCharacterSet.GenWithStack("Unknown character set: '%s', collation: '%s'", toCharset, toCollate) } - if toCharset == charset.CharsetUTF8MB4 && origCharset == charset.CharsetUTF8 { - // TiDB only allow utf8 to be changed to utf8mb4. + if (origCharset == charset.CharsetUTF8 && toCharset == charset.CharsetUTF8MB4) || + (origCharset == charset.CharsetUTF8 && toCharset == charset.CharsetUTF8) || + (origCharset == charset.CharsetUTF8MB4 && toCharset == charset.CharsetUTF8MB4) { + // TiDB only allow utf8 to be changed to utf8mb4, or changing the collation when the charset is utf8/utf8mb4. return nil } @@ -2324,7 +2472,7 @@ func modifiable(origin *types.FieldType, to *types.FieldType) error { func setDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (bool, error) { hasDefaultValue := false - value, err := getDefaultValue(ctx, col.Name.L, option, &col.FieldType) + value, err := getDefaultValue(ctx, col, option) if err != nil { return hasDefaultValue, errors.Trace(err) } @@ -2354,9 +2502,11 @@ func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.Col // processColumnOptions is only used in getModifiableColumnJob. func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []*ast.ColumnOption) error { - if len(options) == 0 { - return nil - } + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + var hasDefaultValue, setOnUpdateNow bool var err error for _, opt := range options { @@ -2382,7 +2532,7 @@ func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []* case ast.ColumnOptionOnUpdate: // TODO: Support other time functions. if col.Tp == mysql.TypeTimestamp || col.Tp == mysql.TypeDatetime { - if !expression.IsCurrentTimestampExpr(opt.Expr) { + if !expression.IsValidCurrentTimestampExpr(opt.Expr, &col.FieldType) { return ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) } } else { @@ -2391,18 +2541,26 @@ func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []* col.Flag |= mysql.OnUpdateNowFlag setOnUpdateNow = true case ast.ColumnOptionGenerated: - var buf = bytes.NewBuffer([]byte{}) - opt.Expr.Format(buf) - col.GeneratedExprString = buf.String() + sb.Reset() + err = opt.Expr.Restore(restoreCtx) + if err != nil { + return errors.Trace(err) + } + col.GeneratedExprString = sb.String() col.GeneratedStored = opt.Stored col.Dependences = make(map[string]struct{}) col.GeneratedExpr = opt.Expr for _, colName := range findColumnNamesInExpr(opt.Expr) { col.Dependences[colName.Name.L] = struct{}{} } + case ast.ColumnOptionCollate: + col.Collate = opt.StrValue + case ast.ColumnOptionReference: + return errors.Trace(errUnsupportedModifyColumn.GenWithStackByArgs("with references")) + case ast.ColumnOptionFulltext: + return errors.Trace(errUnsupportedModifyColumn.GenWithStackByArgs("with full text")) default: - // TODO: Support other types. - return errors.Trace(errUnsupportedModifyColumn.GenWithStackByArgs(opt.Tp)) + return errors.Trace(errUnsupportedModifyColumn.GenWithStackByArgs(fmt.Sprintf("unknown column option type: %d", opt.Tp))) } } @@ -2412,6 +2570,10 @@ func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []* // it is `not null` and not an `AUTO_INCREMENT` field or `TIMESTAMP` field. setNoDefaultValueFlag(col, hasDefaultValue) + if col.Tp == mysql.TypeBit { + col.Flag |= mysql.UnsignedFlag + } + if hasDefaultValue { return errors.Trace(checkDefaultValue(ctx, col, true)) } @@ -2480,15 +2642,20 @@ func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, or newCol.FieldType.Charset = col.FieldType.Charset newCol.FieldType.Collate = col.FieldType.Collate } - err = setCharsetCollationFlenDecimal(&newCol.FieldType, t.Meta().Charset, schema.Charset) + // specifiedCollates refers to collates in colDef.Option. When setting charset and collate here we + // should take the collate in colDef.Option into consideration rather than handling it separately + specifiedCollates := extractCollateFromOption(specNewColumn) + + err = setCharsetCollationFlenDecimal(&newCol.FieldType, specifiedCollates, t.Meta().Charset, t.Meta().Collate, schema.Charset, schema.Collate) if err != nil { return nil, errors.Trace(err) } - err = modifiable(&col.FieldType, &newCol.FieldType) - if err != nil { + + if err = processColumnOptions(ctx, newCol, specNewColumn.Options); err != nil { return nil, errors.Trace(err) } - if err = processColumnOptions(ctx, newCol, specNewColumn.Options); err != nil { + + if err = modifiable(&col.FieldType, &newCol.FieldType); err != nil { return nil, errors.Trace(err) } @@ -2504,6 +2671,10 @@ func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, or if !mysql.HasAutoIncrementFlag(col.Flag) && mysql.HasAutoIncrementFlag(newCol.Flag) { return nil, errUnsupportedModifyColumn.GenWithStackByArgs("set auto_increment") } + // Disallow modifying column from auto_increment to not auto_increment if the session variable `AllowRemoveAutoInc` is false. + if !ctx.GetSessionVars().AllowRemoveAutoInc && mysql.HasAutoIncrementFlag(col.Flag) && !mysql.HasAutoIncrementFlag(newCol.Flag) { + return nil, errUnsupportedModifyColumn.GenWithStackByArgs("to remove auto_increment without @@tidb_allow_remove_auto_inc enabled") + } // We support modifying the type definitions of 'null' to 'not null' now. var modifyColumnTp byte @@ -2516,11 +2687,15 @@ func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, or } if err = checkColumnFieldLength(newCol); err != nil { - return nil, errors.Trace(err) + return nil, err + } + + if err = checkColumnWithIndexConstraint(t.Meta(), col.ColumnInfo, newCol.ColumnInfo); err != nil { + return nil, err } // As same with MySQL, we don't support modifying the stored status for generated columns. - if err = checkModifyGeneratedColumn(t.Cols(), col, newCol); err != nil { + if err = checkModifyGeneratedColumn(t, col, newCol, specNewColumn); err != nil { return nil, errors.Trace(err) } @@ -2534,6 +2709,43 @@ func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, or return job, nil } +// checkColumnWithIndexConstraint is used to check the related index constraint of the modified column. +// Index has a max-prefix-length constraint. eg: a varchar(100), index idx(a), modifying column a to a varchar(4000) +// will cause index idx to break the max-prefix-length constraint. +func checkColumnWithIndexConstraint(tbInfo *model.TableInfo, originalCol, newCol *model.ColumnInfo) error { + var columns []*model.ColumnInfo + for _, indexInfo := range tbInfo.Indices { + containColumn := false + for _, col := range indexInfo.Columns { + if col.Name.L == originalCol.Name.L { + containColumn = true + break + } + } + if containColumn == false { + continue + } + if columns == nil { + columns = make([]*model.ColumnInfo, 0, len(tbInfo.Columns)) + columns = append(columns, tbInfo.Columns...) + // replace old column with new column. + for i, col := range columns { + if col.Name.L != originalCol.Name.L { + continue + } + columns[i] = newCol.Clone() + columns[i].Name = originalCol.Name + break + } + } + err := checkIndexPrefixLength(columns, indexInfo.Columns) + if err != nil { + return err + } + } + return nil +} + // ChangeColumn renames an existing column and modifies the column's definition, // currently we only support limited kind of changes // that do not need to change or check data on the table. @@ -2573,20 +2785,6 @@ func (d *ddl) ModifyColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Al return ErrWrongTableName.GenWithStackByArgs(specNewColumn.Name.Table.O) } - // If the modified column is generated, check whether it refers to any auto-increment columns. - for _, option := range specNewColumn.Options { - if option.Tp == ast.ColumnOptionGenerated { - _, t, err := d.getSchemaAndTableByIdent(ctx, ident) - if err != nil { - return errors.Trace(err) - } - _, dependColNames := findDependedColumnNames(specNewColumn) - if err := checkAutoIncrementRef(specNewColumn.Name.Name.L, dependColNames, t.Meta()); err != nil { - return errors.Trace(err) - } - } - } - originalColName := specNewColumn.Name.Name job, err := d.getModifiableColumnJob(ctx, ident, originalColName, spec) if err != nil { @@ -2614,7 +2812,7 @@ func (d *ddl) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Alt // Check whether alter column has existed. col := table.FindCol(t.Cols(), colName.L) if col == nil { - return errBadField.GenWithStackByArgs(colName, ident.Name) + return ErrBadField.GenWithStackByArgs(colName, ident.Name) } // Clean the NoDefaultValueFlag value. @@ -2753,7 +2951,7 @@ func checkAlterTableCharset(tblInfo *model.TableInfo, dbInfo *model.DBInfo, toCh if len(origCharset) == 0 { // The table charset may be "", if the table is create in old TiDB version, such as v2.0.8. // This DDL will update the table charset to default charset. - origCharset, origCollate, err = ResolveCharsetCollation("", dbInfo.Charset) + origCharset, origCollate, err = ResolveCharsetCollation("", "", dbInfo.Charset, dbInfo.Collate) if err != nil { return doNothing, err } @@ -2864,10 +3062,11 @@ func (d *ddl) TruncateTable(ctx sessionctx.Context, ti ast.Ident) error { if err != nil { return errors.Trace(err) } - newTableID, err := d.genGlobalID() + genIDs, err := d.genGlobalIDs(1) if err != nil { return errors.Trace(err) } + newTableID := genIDs[0] job := &model.Job{ SchemaID: schema.ID, TableID: tb.Meta().ID, @@ -2962,7 +3161,7 @@ func (d *ddl) CreateIndex(ctx sessionctx.Context, ti ast.Ident, unique bool, ind } tblInfo := t.Meta() - // Check before put the job is put to the queue. + // Check before the job is put to the queue. // This check is redudant, but useful. If DDL check fail before the job is put // to job queue, the fail path logic is super fast. // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. @@ -2972,7 +3171,7 @@ func (d *ddl) CreateIndex(ctx sessionctx.Context, ti ast.Ident, unique bool, ind return errors.Trace(err) } if unique && tblInfo.GetPartitionInfo() != nil { - if err := checkPartitionKeysConstraint(ctx, tblInfo.GetPartitionInfo().Expr, idxColNames, tblInfo); err != nil { + if err := checkPartitionKeysConstraint(tblInfo.GetPartitionInfo(), idxColNames, tblInfo); err != nil { return err } } @@ -3094,10 +3293,17 @@ func (d *ddl) DropIndex(ctx sessionctx.Context, ti ast.Ident, indexName model.CI return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) } - if indexInfo := t.Meta().FindIndexByName(indexName.L); indexInfo == nil { + indexInfo := t.Meta().FindIndexByName(indexName.L) + if indexInfo == nil { return ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", indexName) } + // Check for drop index on auto_increment column. + err = checkDropIndexOnAutoIncrementColumn(t.Meta(), indexInfo) + if err != nil { + return errors.Trace(err) + } + job := &model.Job{ SchemaID: schema.ID, TableID: t.Meta().ID, @@ -3146,45 +3352,48 @@ func validateCommentLength(vars *variable.SessionVars, comment string, maxLen in return comment, nil } -func buildPartitionInfo(meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) { - if meta.Partition.Type == model.PartitionTypeRange && len(spec.PartDefinitions) == 0 { - return nil, errors.Trace(ErrPartitionsMustBeDefined) +func buildPartitionInfo(ctx sessionctx.Context, meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) { + if meta.Partition.Type == model.PartitionTypeRange { + if len(spec.PartDefinitions) == 0 { + return nil, ast.ErrPartitionsMustBeDefined.GenWithStackByArgs(meta.Partition.Type) + } + } else { + // we don't support ADD PARTITION for all other partition types yet. + return nil, errors.Trace(ErrUnsupportedAddPartition) } + part := &model.PartitionInfo{ Type: meta.Partition.Type, Expr: meta.Partition.Expr, Columns: meta.Partition.Columns, Enable: meta.Partition.Enable, } - buf := new(bytes.Buffer) - for _, def := range spec.PartDefinitions { - for _, expr := range def.LessThan { - tp := expr.GetType().Tp - if len(part.Columns) == 0 { - // Partition by range. - if !(tp == mysql.TypeLong || tp == mysql.TypeLonglong) { - expr.Format(buf) - if strings.EqualFold(buf.String(), "MAXVALUE") { - continue - } - buf.Reset() - return nil, infoschema.ErrColumnNotExists.GenWithStackByArgs(buf.String(), "partition function") - } - } - // Partition by range columns if len(part.Columns) != 0. + + genIDs, err := d.genGlobalIDs(len(spec.PartDefinitions)) + if err != nil { + return nil, err + } + for ith, def := range spec.PartDefinitions { + if err := def.Clause.Validate(part.Type, len(part.Columns)); err != nil { + return nil, err } - pid, err1 := d.genGlobalID() - if err1 != nil { - return nil, errors.Trace(err1) + // For RANGE partition only VALUES LESS THAN should be possible. + clause := def.Clause.(*ast.PartitionDefinitionClauseLessThan) + if len(part.Columns) > 0 { + if err := checkRangeColumnsTypeAndValuesMatch(ctx, meta, part.Columns, clause.Exprs); err != nil { + return nil, err + } } + + comment, _ := def.Comment() piDef := model.PartitionDefinition{ Name: def.Name, - ID: pid, - Comment: def.Comment, + ID: genIDs[ith], + Comment: comment, } buf := new(bytes.Buffer) - for _, expr := range def.LessThan { + for _, expr := range clause.Exprs { expr.Format(buf) piDef.LessThan = append(piDef.LessThan, buf.String()) buf.Reset() @@ -3193,3 +3402,54 @@ func buildPartitionInfo(meta *model.TableInfo, d *ddl, spec *ast.AlterTableSpec) } return part, nil } + +func checkRangeColumnsTypeAndValuesMatch(ctx sessionctx.Context, meta *model.TableInfo, colNames []model.CIStr, exprs []ast.ExprNode) error { + // Validate() has already checked len(colNames) = len(exprs) + // create table ... partition by range columns (cols) + // partition p0 values less than (expr) + // check the type of cols[i] and expr is consistent. + for i, colExpr := range exprs { + if _, ok := colExpr.(*ast.MaxValueExpr); ok { + continue + } + + colName := colNames[i] + colInfo := getColumnInfoByName(meta, colName.L) + if colInfo == nil { + return errors.Trace(ErrFieldNotFoundPart) + } + colType := &colInfo.FieldType + + val, err := expression.EvalAstExpr(ctx, colExpr) + if err != nil { + return err + } + + // Check val.ConvertTo(colType) doesn't work, so we need this case by case check. + switch colType.Tp { + case mysql.TypeDate, mysql.TypeDatetime: + switch val.Kind() { + case types.KindString, types.KindBytes: + default: + return ErrWrongTypeColumnValue.GenWithStackByArgs() + } + } + } + return nil +} + +// extractCollateFromOption take collates(may multiple) in option into consideration +// when handle charset and collate of a column, rather than handling it separately. +func extractCollateFromOption(def *ast.ColumnDef) []string { + specifiedCollates := make([]string, 0, 0) + for i := 0; i < len(def.Options); i++ { + op := def.Options[i] + if op.Tp == ast.ColumnOptionCollate { + specifiedCollates = append(specifiedCollates, op.StrValue) + def.Options = append(def.Options[:i], def.Options[i+1:]...) + // maintain the correct index + i-- + } + } + return specifiedCollates +} diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index ff543a280cfdb..1b188e95cec8e 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/model" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/ddl/util" @@ -30,6 +31,7 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/binloginfo" "github.com/pingcap/tidb/sessionctx/variable" + tidbutil "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/admin" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" @@ -40,6 +42,8 @@ var ( RunWorker = true // ddlWorkerID is used for generating the next DDL worker ID. ddlWorkerID = int32(0) + // WaitTimeWhenErrorOccured is waiting interval when processing DDL jobs encounter errors. + WaitTimeWhenErrorOccured = 1 * time.Second ) type workerType byte @@ -103,9 +107,10 @@ func (w *worker) String() string { } func (w *worker) close() { + startTime := time.Now() close(w.quitCh) w.wg.Wait() - logutil.Logger(w.logCtx).Info("[ddl] close DDL worker") + logutil.Logger(w.logCtx).Info("[ddl] DDL worker closed", zap.Duration("take time", time.Since(startTime))) } // start is used for async online schema changing, it will try to become the owner firstly, @@ -396,7 +401,14 @@ func (w *worker) handleDDLJobQueue(d *ddlCtx) error { // If running job meets error, we will save this error in job Error // and retry later if the job is not cancelled. - schemaVer, runJobErr = w.runDDLJob(d, t, job) + tidbutil.WithRecovery(func() { + schemaVer, runJobErr = w.runDDLJob(d, t, job) + }, func(r interface{}) { + if r != nil { + // If run ddl job panic, just cancel the ddl jobs. + job.State = model.JobStateCancelling + } + }) if job.IsCancelled() { txn.Reset() err = w.finishDDLJob(t, job) @@ -468,6 +480,9 @@ func chooseLeaseTime(t, max time.Duration) time.Duration { // runDDLJob runs a DDL job. It returns the current schema version in this transaction and the error. func (w *worker) runDDLJob(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + // Mock for run ddl job panic. + failpoint.Inject("mockPanicInRunDDLJob", func(_ failpoint.Value) {}) + logutil.Logger(w.logCtx).Info("[ddl] run DDL job", zap.String("job", job.String())) timeStart := time.Now() defer func() { @@ -632,12 +647,17 @@ func (w *worker) waitSchemaChanged(ctx context.Context, d *ddlCtx, waitTime time if terror.ErrorEqual(err, context.DeadlineExceeded) { return } + d.schemaSyncer.NotifyCleanExpiredPaths() + // Wait until timeout. select { case <-ctx.Done(): return } } - logutil.Logger(w.logCtx).Info("[ddl] wait latest schema version changed", zap.Int64("ver", latestSchemaVersion), zap.Duration("takeTime", time.Since(timeStart)), zap.String("job", job.String())) + logutil.Logger(w.logCtx).Info("[ddl] wait latest schema version changed", + zap.Int64("ver", latestSchemaVersion), + zap.Duration("take time", time.Since(timeStart)), + zap.String("job", job.String())) } // waitSchemaSynced handles the following situation: diff --git a/ddl/ddl_worker_test.go b/ddl/ddl_worker_test.go index 16663c6199220..9c7e61842b16b 100644 --- a/ddl/ddl_worker_test.go +++ b/ddl/ddl_worker_test.go @@ -564,7 +564,7 @@ func (s *testDDLSuite) TestCancelJob(c *C) { Tp: &types.FieldType{Tp: mysql.TypeLonglong}, Options: []*ast.ColumnOption{}, } - col, _, err := buildColumnAndConstraint(ctx, 2, newColumnDef, nil, mysql.DefaultCharset, mysql.DefaultCharset) + col, _, err := buildColumnAndConstraint(ctx, 2, newColumnDef, nil, mysql.DefaultCharset, "", mysql.DefaultCharset, "") c.Assert(err, IsNil) addColumnArgs := []interface{}{col, &ast.ColumnPosition{Tp: ast.ColumnPositionNone}, 0} diff --git a/ddl/failtest/fail_db_test.go b/ddl/failtest/fail_db_test.go index 23ef4d5d9d676..8488cda6cddf7 100644 --- a/ddl/failtest/fail_db_test.go +++ b/ddl/failtest/fail_db_test.go @@ -110,7 +110,7 @@ func (s *testFailDBSuite) TestHalfwayCancelOperations(c *C) { // Make sure that the table's data has not been deleted. rs, err := s.se.Execute(context.Background(), "select count(*) from t") c.Assert(err, IsNil) - req := rs[0].NewRecordBatch() + req := rs[0].NewChunk() err = rs[0].Next(context.Background(), req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -146,7 +146,7 @@ func (s *testFailDBSuite) TestHalfwayCancelOperations(c *C) { // Make sure that the table's data has not been deleted. rs, err = s.se.Execute(context.Background(), "select count(*) from tx") c.Assert(err, IsNil) - req = rs[0].NewRecordBatch() + req = rs[0].NewChunk() err = rs[0].Next(context.Background(), req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -386,3 +386,17 @@ LOOP: tk.MustExec("admin check table test_add_index") tk.MustExec("drop table test_add_index") } + +// TestRunDDLJobPanic tests recover panic when run ddl job panic. +func (s *testFailDBSuite) TestRunDDLJobPanic(c *C) { + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/ddl/mockPanicInRunDDLJob"), IsNil) + }() + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + c.Assert(failpoint.Enable("github.com/pingcap/tidb/ddl/mockPanicInRunDDLJob", `1*panic("panic test")`), IsNil) + _, err := tk.Exec("create table t(c1 int, c2 int)") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[ddl:12]cancelled DDL job") +} diff --git a/ddl/generated_column.go b/ddl/generated_column.go index efee4f1744af4..dab4e0b3f50a1 100644 --- a/ddl/generated_column.go +++ b/ddl/generated_column.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/table" ) @@ -36,12 +37,12 @@ func verifyColumnGeneration(colName2Generation map[string]columnGenerationInDDL, if attr, ok := colName2Generation[depCol]; ok { if attr.generated && attribute.position <= attr.position { // A generated column definition can refer to other - // generated columns occurring earilier in the table. + // generated columns occurring earlier in the table. err := errGeneratedColumnNonPrior.GenWithStackByArgs() return errors.Trace(err) } } else { - err := errBadField.GenWithStackByArgs(depCol, "generated column function") + err := ErrBadField.GenWithStackByArgs(depCol, "generated column function") return errors.Trace(err) } } @@ -49,18 +50,67 @@ func verifyColumnGeneration(colName2Generation map[string]columnGenerationInDDL, return nil } -// columnNamesCover checks whether dependColNames is covered by normalColNames or not. -// it's only for alter table add column because before alter, we can make sure that all -// columns in table are verified already. -func columnNamesCover(normalColNames map[string]struct{}, dependColNames map[string]struct{}) error { - for name := range dependColNames { - if _, ok := normalColNames[name]; !ok { - return errBadField.GenWithStackByArgs(name, "generated column function") +// verifyColumnGenerationSingle is for ADD GENERATED COLUMN, we just need verify one column itself. +func verifyColumnGenerationSingle(dependColNames map[string]struct{}, cols []*table.Column, position *ast.ColumnPosition) error { + // Since the added column does not exist yet, we should derive it's offset from ColumnPosition. + pos, err := findPositionRelativeColumn(cols, position) + if err != nil { + return errors.Trace(err) + } + // should check unknown column first, then the prior ones. + for _, col := range cols { + if _, ok := dependColNames[col.Name.L]; ok { + if col.IsGenerated() && col.Offset >= pos { + // Generated column can refer only to generated columns defined prior to it. + return errGeneratedColumnNonPrior.GenWithStackByArgs() + } + } + } + return nil +} + +// checkDependedColExist ensure all depended columns exist. +// NOTE: this will MODIFY parameter `dependCols`. +func checkDependedColExist(dependCols map[string]struct{}, cols []*table.Column) error { + for _, col := range cols { + delete(dependCols, col.Name.L) + } + if len(dependCols) != 0 { + for arbitraryCol := range dependCols { + return ErrBadField.GenWithStackByArgs(arbitraryCol, "generated column function") } } return nil } +// findPositionRelativeColumn returns a pos relative to added generated column position. +func findPositionRelativeColumn(cols []*table.Column, pos *ast.ColumnPosition) (int, error) { + position := len(cols) + // Get the column position, default is cols's length means appending. + // For "alter table ... add column(...)", the position will be nil. + // For "alter table ... add column ... ", the position will be default one. + if pos == nil { + return position, nil + } + if pos.Tp == ast.ColumnPositionFirst { + position = 0 + } else if pos.Tp == ast.ColumnPositionAfter { + var col *table.Column + for _, c := range cols { + if c.Name.L == pos.RelativeColumn.Name.L { + col = c + break + } + } + if col == nil { + return -1, ErrBadField.GenWithStackByArgs(pos.RelativeColumn, "generated column function") + } + // Inserted position is after the mentioned column. + position = col.Offset + 1 + } + return position, nil +} + // findDependedColumnNames returns a set of string, which indicates // the names of the columns that are depended by colDef. func findDependedColumnNames(colDef *ast.ColumnDef) (generated bool, colsMap map[string]struct{}) { @@ -106,19 +156,18 @@ func (c *generatedColumnChecker) Leave(inNode ast.Node) (node ast.Node, ok bool) // 1. the modification can't change stored status; // 2. if the new is generated, check its refer rules. // 3. check if the modified expr contains non-deterministic functions -func checkModifyGeneratedColumn(originCols []*table.Column, oldCol, newCol *table.Column) error { +// 4. check whether new column refers to any auto-increment columns. +// 5. check if the new column is indexed or stored +func checkModifyGeneratedColumn(tbl table.Table, oldCol, newCol *table.Column, newColDef *ast.ColumnDef) error { // rule 1. - var stored = [2]bool{false, false} - var cols = [2]*table.Column{oldCol, newCol} - for i, col := range cols { - if !col.IsGenerated() || col.GeneratedStored { - stored[i] = true - } - } - if stored[0] != stored[1] { + oldColIsStored := !oldCol.IsGenerated() || oldCol.GeneratedStored + newColIsStored := !newCol.IsGenerated() || newCol.GeneratedStored + if oldColIsStored != newColIsStored { return errUnsupportedOnGeneratedColumn.GenWithStackByArgs("Changing the STORED status") } + // rule 2. + originCols := tbl.Cols() var colName2Generation = make(map[string]columnGenerationInDDL, len(originCols)) for i, column := range originCols { // We can compare the pointers simply. @@ -155,11 +204,21 @@ func checkModifyGeneratedColumn(originCols []*table.Column, oldCol, newCol *tabl } } - // rule 3 if newCol.IsGenerated() { + // rule 3. if err := checkIllegalFn4GeneratedColumn(newCol.Name.L, newCol.GeneratedExpr); err != nil { return errors.Trace(err) } + + // rule 4. + if err := checkGeneratedWithAutoInc(tbl.Meta(), newColDef); err != nil { + return errors.Trace(err) + } + + // rule 5. + if err := checkIndexOrStored(tbl, oldCol, newCol); err != nil { + return errors.Trace(err) + } } return nil } @@ -195,10 +254,38 @@ func checkIllegalFn4GeneratedColumn(colName string, expr ast.ExprNode) error { return nil } +// Check whether newColumnDef refers to any auto-increment columns. +func checkGeneratedWithAutoInc(tableInfo *model.TableInfo, newColumnDef *ast.ColumnDef) error { + _, dependColNames := findDependedColumnNames(newColumnDef) + if err := checkAutoIncrementRef(newColumnDef.Name.Name.L, dependColNames, tableInfo); err != nil { + return errors.Trace(err) + } + return nil +} + +func checkIndexOrStored(tbl table.Table, oldCol, newCol *table.Column) error { + if oldCol.GeneratedExprString == newCol.GeneratedExprString { + return nil + } + + if newCol.GeneratedStored { + return errUnsupportedOnGeneratedColumn.GenWithStackByArgs("modifying a stored column") + } + + for _, idx := range tbl.Indices() { + for _, col := range idx.Meta().Columns { + if col.Name.L == newCol.Name.L { + return errUnsupportedOnGeneratedColumn.GenWithStackByArgs("modifying an indexed column") + } + } + } + return nil +} + // checkAutoIncrementRef checks if an generated column depends on an auto-increment column and raises an error if so. // See https://dev.mysql.com/doc/refman/5.7/en/create-table-generated-columns.html for details. func checkAutoIncrementRef(name string, dependencies map[string]struct{}, tbInfo *model.TableInfo) error { - exists, autoIncrementColumn := hasAutoIncrementColumn(tbInfo) + exists, autoIncrementColumn := infoschema.HasAutoIncrementColumn(tbInfo) if exists { if _, found := dependencies[autoIncrementColumn]; found { return ErrGeneratedColumnRefAutoInc.GenWithStackByArgs(name) diff --git a/ddl/index.go b/ddl/index.go index 1aaab3e0ce5af..3921499ade8ec 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" + "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -54,73 +55,21 @@ func buildIndexColumns(columns []*model.ColumnInfo, idxColNames []*ast.IndexColN // The sum of length of all index columns. sumLength := 0 - for _, ic := range idxColNames { - col := model.FindColumnInfo(columns, ic.Column.Name.O) + col := model.FindColumnInfo(columns, ic.Column.Name.L) if col == nil { return nil, errKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ic.Column.Name) } - if col.Flen == 0 { - return nil, errors.Trace(errWrongKeyColumn.GenWithStackByArgs(ic.Column.Name)) - } - - // JSON column cannot index. - if col.FieldType.Tp == mysql.TypeJSON { - return nil, errors.Trace(errJSONUsedAsKey.GenWithStackByArgs(col.Name.O)) - } - - // Length must be specified for BLOB and TEXT column indexes. - if types.IsTypeBlob(col.FieldType.Tp) && ic.Length == types.UnspecifiedLength { - return nil, errors.Trace(errBlobKeyWithoutLength) - } - - // Length can only be specified for specifiable types. - if ic.Length != types.UnspecifiedLength && !types.IsTypePrefixable(col.FieldType.Tp) { - return nil, errors.Trace(errIncorrectPrefixKey) - } - - // Key length must be shorter or equal to the column length. - if ic.Length != types.UnspecifiedLength && - types.IsTypeChar(col.FieldType.Tp) && col.Flen < ic.Length { - return nil, errors.Trace(errIncorrectPrefixKey) - } - - // Specified length must be shorter than the max length for prefix. - if ic.Length > maxPrefixLength { - return nil, errors.Trace(errTooLongKey) + if err := checkIndexColumn(col, ic); err != nil { + return nil, err } - // Take care of the sum of length of all index columns. - if ic.Length != types.UnspecifiedLength { - sumLength += ic.Length - } else { - // Specified data types. - if col.Flen != types.UnspecifiedLength { - // Special case for the bit type. - if col.FieldType.Tp == mysql.TypeBit { - sumLength += (col.Flen + 7) >> 3 - } else { - sumLength += col.Flen - } - } else { - if length, ok := mysql.DefaultLengthOfMysqlTypes[col.FieldType.Tp]; ok { - sumLength += length - } else { - return nil, errUnknownTypeLength.GenWithStackByArgs(col.FieldType.Tp) - } - - // Special case for time fraction. - if types.IsTypeFractionable(col.FieldType.Tp) && - col.FieldType.Decimal != types.UnspecifiedLength { - if length, ok := mysql.DefaultLengthOfTimeFraction[col.FieldType.Decimal]; ok { - sumLength += length - } else { - return nil, errUnknownFractionLength.GenWithStackByArgs(col.FieldType.Tp, col.FieldType.Decimal) - } - } - } + indexColumnLength, err := getIndexColumnLength(col, ic.Length) + if err != nil { + return nil, err } + sumLength += indexColumnLength // The sum of all lengths must be shorter than the max length for prefix. if sumLength > maxPrefixLength { @@ -137,6 +86,93 @@ func buildIndexColumns(columns []*model.ColumnInfo, idxColNames []*ast.IndexColN return idxColumns, nil } +func checkIndexPrefixLength(columns []*model.ColumnInfo, idxColumns []*model.IndexColumn) error { + // The sum of length of all index columns. + sumLength := 0 + for _, ic := range idxColumns { + col := model.FindColumnInfo(columns, ic.Name.L) + if col == nil { + return errKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ic.Name) + } + + indexColumnLength, err := getIndexColumnLength(col, ic.Length) + if err != nil { + return err + } + sumLength += indexColumnLength + // The sum of all lengths must be shorter than the max length for prefix. + if sumLength > maxPrefixLength { + return errors.Trace(errTooLongKey) + } + } + return nil +} + +func checkIndexColumn(col *model.ColumnInfo, ic *ast.IndexColName) error { + if col.Flen == 0 && (types.IsTypeChar(col.FieldType.Tp) || types.IsTypeVarchar(col.FieldType.Tp)) { + return errors.Trace(errWrongKeyColumn.GenWithStackByArgs(ic.Column.Name)) + } + + // JSON column cannot index. + if col.FieldType.Tp == mysql.TypeJSON { + return errors.Trace(errJSONUsedAsKey.GenWithStackByArgs(col.Name.O)) + } + + // Length must be specified for BLOB and TEXT column indexes. + if types.IsTypeBlob(col.FieldType.Tp) && ic.Length == types.UnspecifiedLength { + return errors.Trace(errBlobKeyWithoutLength) + } + + // Length can only be specified for specifiable types. + if ic.Length != types.UnspecifiedLength && !types.IsTypePrefixable(col.FieldType.Tp) { + return errors.Trace(errIncorrectPrefixKey) + } + + // Key length must be shorter or equal to the column length. + if ic.Length != types.UnspecifiedLength && + types.IsTypeChar(col.FieldType.Tp) && col.Flen < ic.Length { + return errors.Trace(errIncorrectPrefixKey) + } + + // Specified length must be shorter than the max length for prefix. + if ic.Length > maxPrefixLength { + return errors.Trace(errTooLongKey) + } + return nil +} + +func getIndexColumnLength(col *model.ColumnInfo, colLen int) (int, error) { + // Take care of the sum of length of all index columns. + if colLen != types.UnspecifiedLength { + return colLen, nil + } + // Specified data types. + if col.Flen != types.UnspecifiedLength { + // Special case for the bit type. + if col.FieldType.Tp == mysql.TypeBit { + return (col.Flen + 7) >> 3, nil + } + return col.Flen, nil + + } + + length, ok := mysql.DefaultLengthOfMysqlTypes[col.FieldType.Tp] + if !ok { + return length, errUnknownTypeLength.GenWithStackByArgs(col.FieldType.Tp) + } + + // Special case for time fraction. + if types.IsTypeFractionable(col.FieldType.Tp) && + col.FieldType.Decimal != types.UnspecifiedLength { + decimalLength, ok := mysql.DefaultLengthOfTimeFraction[col.FieldType.Decimal] + if !ok { + return length, errUnknownFractionLength.GenWithStackByArgs(col.FieldType.Tp, col.FieldType.Decimal) + } + length += decimalLength + } + return length, nil +} + func buildIndexInfo(tblInfo *model.TableInfo, indexName model.CIStr, idxColNames []*ast.IndexColName, state model.SchemaState) (*model.IndexInfo, error) { idxColumns, err := buildIndexColumns(tblInfo.Columns, idxColNames) if err != nil { @@ -430,9 +466,40 @@ func checkDropIndex(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.Inde job.State = model.JobStateCancelled return nil, nil, ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", indexName) } + + // Double check for drop index on auto_increment column. + err = checkDropIndexOnAutoIncrementColumn(tblInfo, indexInfo) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, autoid.ErrWrongAutoKey + } + return tblInfo, indexInfo, nil } +func checkDropIndexOnAutoIncrementColumn(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) error { + cols := tblInfo.Columns + for _, idxCol := range indexInfo.Columns { + if !mysql.HasAutoIncrementFlag(cols[idxCol.Offset].Flag) { + continue + } + // check the count of index on auto_increment column. + count := 0 + for _, idx := range tblInfo.Indices { + for _, c := range idx.Columns { + if c.Name.L == idxCol.Name.L { + count++ + break + } + } + } + if count < 2 { + return autoid.ErrWrongAutoKey + } + } + return nil +} + func checkRenameIndex(t *meta.Meta, job *model.Job) (*model.TableInfo, model.CIStr, model.CIStr, error) { var from, to model.CIStr schemaID := job.SchemaID @@ -802,7 +869,7 @@ func (w *addIndexWorker) backfillIndexInTxn(handleRange reorgIndexTask) (taskCtx // Lock the row key to notify us that someone delete or update the row, // then we should not backfill the index of it, otherwise the adding index is redundant. - err := txn.LockKeys(context.Background(), 0, idxRecord.key) + err := txn.LockKeys(context.Background(), nil, 0, idxRecord.key) if err != nil { return errors.Trace(err) } @@ -827,6 +894,8 @@ func (w *addIndexWorker) backfillIndexInTxn(handleRange reorgIndexTask) (taskCtx return } +var addIndexSpeedCounter = metrics.AddIndexTotalCounter.WithLabelValues("speed") + // handleBackfillTask backfills range [task.startHandle, task.endHandle) handle's index to table. func (w *addIndexWorker) handleBackfillTask(d *ddlCtx, task *reorgIndexTask) *addIndexResult { handleRange := *task @@ -852,6 +921,7 @@ func (w *addIndexWorker) handleBackfillTask(d *ddlCtx, task *reorgIndexTask) *ad return result } + addIndexSpeedCounter.Add(float64(taskCtx.addedCount)) mergeAddIndexCtxToResult(&taskCtx, result) w.ddlWorker.reorgCtx.increaseRowCount(int64(taskCtx.addedCount)) @@ -908,27 +978,23 @@ func (w *addIndexWorker) run(d *ddlCtx) { func makeupDecodeColMap(sessCtx sessionctx.Context, t table.Table, indexInfo *model.IndexInfo) (map[int64]decoder.Column, error) { cols := t.Cols() - decodeColMap := make(map[int64]decoder.Column, len(indexInfo.Columns)) - for _, v := range indexInfo.Columns { - col := cols[v.Offset] - tpExpr := decoder.Column{ - Col: col, - } - if col.IsGenerated() && !col.GeneratedStored { - for _, c := range cols { - if _, ok := col.Dependences[c.Name.L]; ok { - decodeColMap[c.ID] = decoder.Column{ - Col: c, - } - } - } - e, err := expression.ParseSimpleExprCastWithTableInfo(sessCtx, col.GeneratedExprString, t.Meta(), &col.FieldType) - if err != nil { - return nil, errors.Trace(err) - } - tpExpr.GenExpr = e - } - decodeColMap[col.ID] = tpExpr + indexedCols := make([]*table.Column, len(indexInfo.Columns)) + for i, v := range indexInfo.Columns { + indexedCols[i] = cols[v.Offset] + } + + var containsVirtualCol bool + decodeColMap, err := decoder.BuildFullDecodeColMap(indexedCols, t, func(genCol *table.Column) (expression.Expression, error) { + containsVirtualCol = true + return expression.ParseSimpleExprCastWithTableInfo(sessCtx, genCol.GeneratedExprString, t.Meta(), &genCol.FieldType) + }) + if err != nil { + return nil, err + } + + if containsVirtualCol { + decoder.SubstituteGenColsInDecodeColMap(decodeColMap) + decoder.RemoveUnusedVirtualCols(decodeColMap, indexedCols) } return decodeColMap, nil } diff --git a/ddl/mock.go b/ddl/mock.go index 1911f8aeec704..fbac094fbcfaa 100644 --- a/ddl/mock.go +++ b/ddl/mock.go @@ -22,10 +22,11 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/sessionctx" ) -var _ SchemaSyncer = &MockSchemaSyncer{} +var _ util.SchemaSyncer = &MockSchemaSyncer{} const mockCheckVersInterval = 2 * time.Millisecond @@ -37,7 +38,7 @@ type MockSchemaSyncer struct { } // NewMockSchemaSyncer creates a new mock SchemaSyncer. -func NewMockSchemaSyncer() SchemaSyncer { +func NewMockSchemaSyncer() util.SchemaSyncer { return &MockSchemaSyncer{} } @@ -113,6 +114,15 @@ func (s *MockSchemaSyncer) OwnerCheckAllVersions(ctx context.Context, latestVer } } +// NotifyCleanExpiredPaths implements SchemaSyncer.NotifyCleanExpiredPaths interface. +func (s *MockSchemaSyncer) NotifyCleanExpiredPaths() bool { return true } + +// StartCleanWork implements SchemaSyncer.StartCleanWork interface. +func (s *MockSchemaSyncer) StartCleanWork() {} + +// CloseCleanWork implements SchemaSyncer.CloseCleanWork interface. +func (s *MockSchemaSyncer) CloseCleanWork() {} + type mockDelRange struct { } @@ -139,7 +149,7 @@ func (dr *mockDelRange) clear() {} // MockTableInfo mocks a table info by create table stmt ast and a specified table id. func MockTableInfo(ctx sessionctx.Context, stmt *ast.CreateTableStmt, tableID int64) (*model.TableInfo, error) { - cols, newConstraints, err := buildColumnsAndConstraints(ctx, stmt.Cols, stmt.Constraints, "", "") + cols, newConstraints, err := buildColumnsAndConstraints(ctx, stmt.Cols, stmt.Constraints, "", "", "", "") if err != nil { return nil, errors.Trace(err) } @@ -154,7 +164,7 @@ func MockTableInfo(ctx sessionctx.Context, stmt *ast.CreateTableStmt, tableID in return nil, errors.Trace(err) } - if err = resolveDefaultTableCharsetAndCollation(tbl, ""); err != nil { + if err = resolveDefaultTableCharsetAndCollation(tbl, "", ""); err != nil { return nil, errors.Trace(err) } diff --git a/ddl/partition.go b/ddl/partition.go index 4826057be4bbb..14354ec825b74 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -20,11 +20,13 @@ import ( "strings" "github.com/pingcap/errors" + "github.com/pingcap/parser" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" @@ -42,6 +44,17 @@ func buildTablePartitionInfo(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS if s.Partition == nil { return nil, nil } + + // force-discard the unsupported types, even when @@tidb_enable_table_partition = 'on' + switch s.Partition.Tp { + case model.PartitionTypeKey: + // can't create a warning for KEY partition, it will fail an integration test :/ + return nil, nil + case model.PartitionTypeList, model.PartitionTypeSystemTime: + ctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedCreatePartition) + return nil, nil + } + var enable bool switch ctx.GetSessionVars().EnableTablePartition { case "on": @@ -91,7 +104,6 @@ func buildTablePartitionInfo(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS } } - // TODO: generate multiple global ID for paritions, reduce the times of obtaining the global ID from the storage. if s.Partition.Tp == model.PartitionTypeRange { if err := buildRangePartitionDefinitions(ctx, d, s, pi); err != nil { return nil, errors.Trace(err) @@ -105,37 +117,41 @@ func buildTablePartitionInfo(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS } func buildHashPartitionDefinitions(ctx sessionctx.Context, d *ddl, s *ast.CreateTableStmt, pi *model.PartitionInfo) error { + genIDs, err := d.genGlobalIDs(int(pi.Num)) + if err != nil { + return errors.Trace(err) + } defs := make([]model.PartitionDefinition, pi.Num) for i := 0; i < len(defs); i++ { - pid, err := d.genGlobalID() - if err != nil { - return errors.Trace(err) + defs[i].ID = genIDs[i] + if len(s.Partition.Definitions) == 0 { + defs[i].Name = model.NewCIStr(fmt.Sprintf("p%v", i)) + } else { + def := s.Partition.Definitions[i] + defs[i].Name = def.Name + defs[i].Comment, _ = def.Comment() } - defs[i].ID = pid - defs[i].Name = model.NewCIStr(fmt.Sprintf("p%v", i)) } pi.Definitions = defs return nil } func buildRangePartitionDefinitions(ctx sessionctx.Context, d *ddl, s *ast.CreateTableStmt, pi *model.PartitionInfo) error { - for _, def := range s.Partition.Definitions { - pid, err := d.genGlobalID() - if err != nil { - return errors.Trace(err) - } + genIDs, err := d.genGlobalIDs(int(pi.Num)) + if err != nil { + return err + } + for ith, def := range s.Partition.Definitions { + comment, _ := def.Comment() piDef := model.PartitionDefinition{ Name: def.Name, - ID: pid, - Comment: def.Comment, + ID: genIDs[ith], + Comment: comment, } - if s.Partition.ColumnNames == nil && len(def.LessThan) != 1 { - return ErrTooManyValues.GenWithStackByArgs(s.Partition.Tp.String()) - } buf := new(bytes.Buffer) // Range columns partitions support multi-column partitions. - for _, expr := range def.LessThan { + for _, expr := range def.Clause.(*ast.PartitionDefinitionClauseLessThan).Exprs { expr.Format(buf) piDef.LessThan = append(piDef.LessThan, buf.String()) buf.Reset() @@ -145,7 +161,19 @@ func buildRangePartitionDefinitions(ctx sessionctx.Context, d *ddl, s *ast.Creat return nil } -func checkPartitionNameUnique(tbInfo *model.TableInfo, pi *model.PartitionInfo) error { +func checkPartitionNameUnique(pi *model.PartitionInfo) error { + partNames := make(map[string]struct{}) + newPars := pi.Definitions + for _, newPar := range newPars { + if _, ok := partNames[newPar.Name.L]; ok { + return ErrSameNamePartition.GenWithStackByArgs(newPar.Name) + } + partNames[newPar.Name.L] = struct{}{} + } + return nil +} + +func checkAddPartitionNameUnique(tbInfo *model.TableInfo, pi *model.PartitionInfo) error { partNames := make(map[string]struct{}) if tbInfo.Partition != nil { oldPars := tbInfo.Partition.Definitions @@ -333,7 +361,11 @@ func validRangePartitionType(col *table.Column) bool { // checkDropTablePartition checks if the partition exists and does not allow deleting the last existing partition in the table. func checkDropTablePartition(meta *model.TableInfo, partName string) error { - oldDefs := meta.Partition.Definitions + pi := meta.Partition + if pi.Type != model.PartitionTypeRange && pi.Type != model.PartitionTypeList { + return errOnlyOnRangeListPartition.GenWithStackByArgs("DROP") + } + oldDefs := pi.Definitions for _, def := range oldDefs { if strings.EqualFold(def.Name.L, strings.ToLower(partName)) { if len(oldDefs) == 1 { @@ -447,14 +479,14 @@ func checkAddPartitionTooManyPartitions(piDefs uint64) error { func checkNoHashPartitions(ctx sessionctx.Context, partitionNum uint64) error { if partitionNum == 0 { - return ErrNoParts.GenWithStackByArgs("partitions") + return ast.ErrNoParts.GenWithStackByArgs("partitions") } return nil } func checkNoRangePartitions(partitionNum int) error { if partitionNum == 0 { - return errors.Trace(ErrPartitionsMustBeDefined) + return ast.ErrPartitionsMustBeDefined.GenWithStackByArgs("RANGE") } return nil } @@ -473,17 +505,22 @@ func getPartitionIDs(table *model.TableInfo) []int64 { // checkRangePartitioningKeysConstraints checks that the range partitioning key is included in the table constraint. func checkRangePartitioningKeysConstraints(sctx sessionctx.Context, s *ast.CreateTableStmt, tblInfo *model.TableInfo, constraints []*ast.Constraint) error { // Returns directly if there is no constraint in the partition table. - // TODO: Remove the test 's.Partition.Expr == nil' when we support 'PARTITION BY RANGE COLUMNS' - if len(constraints) == 0 || s.Partition.Expr == nil { + if len(constraints) == 0 { return nil } - // Parse partitioning key, extract the column names in the partitioning key to slice. - buf := new(bytes.Buffer) - s.Partition.Expr.Format(buf) - partCols, err := extractPartitionColumns(sctx, buf.String(), tblInfo) - if err != nil { - return err + var partCols stringSlice + if s.Partition.Expr != nil { + // Parse partitioning key, extract the column names in the partitioning key to slice. + buf := new(bytes.Buffer) + s.Partition.Expr.Format(buf) + partColumns, err := extractPartitionColumns(buf.String(), tblInfo) + if err != nil { + return err + } + partCols = columnInfoSlice(partColumns) + } else if len(s.Partition.ColumnNames) > 0 { + partCols = columnNameSlice(s.Partition.ColumnNames) } // Checks that the partitioning key is included in the constraint. @@ -503,32 +540,90 @@ func checkRangePartitioningKeysConstraints(sctx sessionctx.Context, s *ast.Creat return nil } -func checkPartitionKeysConstraint(sctx sessionctx.Context, partExpr string, idxColNames []*ast.IndexColName, tblInfo *model.TableInfo) error { - // Parse partitioning key, extract the column names in the partitioning key to slice. - partCols, err := extractPartitionColumns(sctx, partExpr, tblInfo) - if err != nil { - return err +func checkPartitionKeysConstraint(pi *model.PartitionInfo, idxColNames []*ast.IndexColName, tblInfo *model.TableInfo) error { + var ( + partCols []*model.ColumnInfo + err error + ) + // The expr will be an empty string if the partition is defined by: + // CREATE TABLE t (...) PARTITION BY RANGE COLUMNS(...) + if partExpr := pi.Expr; partExpr != "" { + // Parse partitioning key, extract the column names in the partitioning key to slice. + partCols, err = extractPartitionColumns(partExpr, tblInfo) + if err != nil { + return err + } + } else { + partCols = make([]*model.ColumnInfo, 0, len(pi.Columns)) + for _, col := range pi.Columns { + colInfo := getColumnInfoByName(tblInfo, col.L) + if colInfo == nil { + return infoschema.ErrColumnNotExists.GenWithStackByArgs(col, tblInfo.Name) + } + partCols = append(partCols, colInfo) + } } // Every unique key on the table must use every column in the table's partitioning expression. // See https://dev.mysql.com/doc/refman/5.7/en/partitioning-limitations-partitioning-keys-unique-keys.html - if !checkUniqueKeyIncludePartKey(partCols, idxColNames) { + if !checkUniqueKeyIncludePartKey(columnInfoSlice(partCols), idxColNames) { return ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("UNIQUE INDEX") } return nil } -func extractPartitionColumns(sctx sessionctx.Context, partExpr string, tblInfo *model.TableInfo) ([]*expression.Column, error) { - e, err := expression.ParseSimpleExprWithTableInfo(sctx, partExpr, tblInfo) +type columnNameExtractor struct { + extractedColumns []*model.ColumnInfo + tblInfo *model.TableInfo + err error +} + +func (cne *columnNameExtractor) Enter(node ast.Node) (ast.Node, bool) { + return node, false +} + +func (cne *columnNameExtractor) Leave(node ast.Node) (ast.Node, bool) { + if c, ok := node.(*ast.ColumnNameExpr); ok { + for _, info := range cne.tblInfo.Columns { + if info.Name.L == c.Name.Name.L { + cne.extractedColumns = append(cne.extractedColumns, info) + return node, true + } + } + cne.err = ErrBadField.GenWithStackByArgs(c.Name.Name.O, "expression") + return nil, false + } + return node, true +} + +func extractPartitionColumns(partExpr string, tblInfo *model.TableInfo) ([]*model.ColumnInfo, error) { + partExpr = "select " + partExpr + stmts, _, err := parser.New().Parse(partExpr, "", "") if err != nil { - return nil, errors.Trace(err) + return nil, err } - return expression.ExtractColumns(e), nil + extractor := &columnNameExtractor{ + tblInfo: tblInfo, + extractedColumns: make([]*model.ColumnInfo, 0), + } + stmts[0].Accept(extractor) + if extractor.err != nil { + return nil, extractor.err + } + return extractor.extractedColumns, nil +} + +// stringSlice is defined for checkUniqueKeyIncludePartKey. +// if Go supports covariance, the code shouldn't be so complex. +type stringSlice interface { + Len() int + At(i int) string } // checkUniqueKeyIncludePartKey checks that the partitioning key is included in the constraint. -func checkUniqueKeyIncludePartKey(partCols []*expression.Column, idxCols []*ast.IndexColName) bool { - for _, partCol := range partCols { +func checkUniqueKeyIncludePartKey(partCols stringSlice, idxCols []*ast.IndexColName) bool { + for i := 0; i < partCols.Len(); i++ { + partCol := partCols.At(i) if !findColumnInIndexCols(partCol, idxCols) { return false } @@ -536,6 +631,28 @@ func checkUniqueKeyIncludePartKey(partCols []*expression.Column, idxCols []*ast. return true } +// columnInfoSlice implements the stringSlice interface. +type columnInfoSlice []*model.ColumnInfo + +func (cis columnInfoSlice) Len() int { + return len(cis) +} + +func (cis columnInfoSlice) At(i int) string { + return cis[i].Name.L +} + +// columnNameSlice implements the stringSlice interface. +type columnNameSlice []*ast.ColumnName + +func (cns columnNameSlice) Len() int { + return len(cns) +} + +func (cns columnNameSlice) At(i int) string { + return cns[i].Name.L +} + // isRangePartitionColUnsignedBigint returns true if the partitioning key column type is unsigned bigint type. func isRangePartitionColUnsignedBigint(cols []*table.Column, pi *model.PartitionInfo) bool { for _, col := range cols { @@ -555,20 +672,8 @@ func truncateTableByReassignPartitionIDs(t *meta.Meta, tblInfo *model.TableInfo) if err != nil { return errors.Trace(err) } - - var newDef model.PartitionDefinition - if tblInfo.Partition.Type == model.PartitionTypeHash { - newDef = model.PartitionDefinition{ - ID: pid, - } - } else if tblInfo.Partition.Type == model.PartitionTypeRange { - newDef = model.PartitionDefinition{ - ID: pid, - Name: def.Name, - LessThan: def.LessThan, - Comment: def.Comment, - } - } + newDef := def + newDef.ID = pid newDefs = append(newDefs, newDef) } tblInfo.Partition.Definitions = newDefs diff --git a/ddl/schema_test.go b/ddl/schema_test.go index c6eb59bf29eda..bc7e06ace9317 100644 --- a/ddl/schema_test.go +++ b/ddl/schema_test.go @@ -40,13 +40,12 @@ func (s *testSchemaSuite) TearDownSuite(c *C) { } func testSchemaInfo(c *C, d *ddl, name string) *model.DBInfo { - var err error dbInfo := &model.DBInfo{ Name: model.NewCIStr(name), } - - dbInfo.ID, err = d.genGlobalID() + genIDs, err := d.genGlobalIDs(1) c.Assert(err, IsNil) + dbInfo.ID = genIDs[0] return dbInfo } @@ -206,8 +205,9 @@ func (s *testSchemaSuite) TestSchemaWaitJob(c *C) { // d2 must not be owner. c.Assert(d2.ownerManager.IsOwner(), IsFalse) - schemaID, err := d2.genGlobalID() + genIDs, err := d2.genGlobalIDs(1) c.Assert(err, IsNil) + schemaID := genIDs[0] doDDLJobErr(c, schemaID, 0, model.ActionCreateSchema, []interface{}{dbInfo}, ctx, d2) } diff --git a/ddl/serial_test.go b/ddl/serial_test.go index 665337cdd31a5..90aaa7c3069e3 100644 --- a/ddl/serial_test.go +++ b/ddl/serial_test.go @@ -47,7 +47,7 @@ type testSerialSuite struct { func (s *testSerialSuite) SetUpSuite(c *C) { session.SetSchemaLease(200 * time.Millisecond) - session.SetStatsLease(0) + session.DisableStats4Test() ddl.WaitTimeWhenErrorOccured = 1 * time.Microsecond var err error @@ -241,6 +241,10 @@ func (s *testSerialSuite) TestRecoverTableByJobID(c *C) { } func (s *testSerialSuite) TestRecoverTableByTableName(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange"), IsNil) + }() tk := testkit.NewTestKit(c, s.store) tk.MustExec("create database if not exists test_recover") tk.MustExec("use test_recover") diff --git a/ddl/split_region.go b/ddl/split_region.go new file mode 100644 index 0000000000000..18e6cbad2afc1 --- /dev/null +++ b/ddl/split_region.go @@ -0,0 +1,128 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/tablecodec" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" +) + +func splitPartitionTableRegion(store kv.SplitableStore, pi *model.PartitionInfo, scatter bool) { + // Max partition count is 4096, should we sample and just choose some of the partition to split? + regionIDs := make([]uint64, 0, len(pi.Definitions)) + for _, def := range pi.Definitions { + regionIDs = append(regionIDs, splitRecordRegion(store, def.ID, scatter)) + } + if scatter { + waitScatterRegionFinish(store, regionIDs...) + } +} + +func splitTableRegion(store kv.SplitableStore, tbInfo *model.TableInfo, scatter bool) { + if tbInfo.ShardRowIDBits > 0 && tbInfo.PreSplitRegions > 0 { + splitPreSplitedTable(store, tbInfo, scatter) + } else { + regionID := splitRecordRegion(store, tbInfo.ID, scatter) + if scatter { + waitScatterRegionFinish(store, regionID) + } + } +} + +func splitPreSplitedTable(store kv.SplitableStore, tbInfo *model.TableInfo, scatter bool) { + // Example: + // ShardRowIDBits = 4 + // PreSplitRegions = 2 + // + // then will pre-split 2^2 = 4 regions. + // + // in this code: + // max = 1 << tblInfo.ShardRowIDBits = 16 + // step := int64(1 << (tblInfo.ShardRowIDBits - tblInfo.PreSplitRegions)) = 1 << (4-2) = 4; + // + // then split regionID is below: + // 4 << 59 = 2305843009213693952 + // 8 << 59 = 4611686018427387904 + // 12 << 59 = 6917529027641081856 + // + // The 4 pre-split regions range is below: + // 0 ~ 2305843009213693952 + // 2305843009213693952 ~ 4611686018427387904 + // 4611686018427387904 ~ 6917529027641081856 + // 6917529027641081856 ~ 9223372036854775807 ( (1 << 63) - 1 ) + // + // And the max _tidb_rowid is 9223372036854775807, it won't be negative number. + + // Split table region. + step := int64(1 << (tbInfo.ShardRowIDBits - tbInfo.PreSplitRegions)) + max := int64(1 << tbInfo.ShardRowIDBits) + splitTableKeys := make([][]byte, 0, 1<<(tbInfo.PreSplitRegions)) + for p := int64(step); p < max; p += step { + recordID := p << (64 - tbInfo.ShardRowIDBits - 1) + recordPrefix := tablecodec.GenTableRecordPrefix(tbInfo.ID) + key := tablecodec.EncodeRecordKey(recordPrefix, recordID) + splitTableKeys = append(splitTableKeys, key) + } + var err error + regionIDs, err := store.SplitRegions(context.Background(), splitTableKeys, scatter) + if err != nil { + logutil.Logger(context.Background()).Warn("[ddl] pre split table region failed", + zap.Stringer("table", tbInfo.Name), zap.Int("successful region count", len(regionIDs)), zap.Error(err)) + } + regionIDs = append(regionIDs, splitIndexRegion(store, tbInfo, scatter)...) + if scatter { + waitScatterRegionFinish(store, regionIDs...) + } +} + +func splitRecordRegion(store kv.SplitableStore, tableID int64, scatter bool) uint64 { + tableStartKey := tablecodec.GenTablePrefix(tableID) + regionIDs, err := store.SplitRegions(context.Background(), [][]byte{tableStartKey}, scatter) + if err != nil { + // It will be automatically split by TiKV later. + logutil.Logger(context.Background()).Warn("[ddl] split table region failed", zap.Error(err)) + } + if len(regionIDs) == 1 { + return regionIDs[0] + } + return 0 +} + +func splitIndexRegion(store kv.SplitableStore, tblInfo *model.TableInfo, scatter bool) []uint64 { + splitKeys := make([][]byte, 0, len(tblInfo.Indices)) + for _, idx := range tblInfo.Indices { + indexPrefix := tablecodec.EncodeTableIndexPrefix(tblInfo.ID, idx.ID) + splitKeys = append(splitKeys, indexPrefix) + } + regionIDs, err := store.SplitRegions(context.Background(), splitKeys, scatter) + if err != nil { + logutil.Logger(context.Background()).Warn("[ddl] pre split table index region failed", + zap.Stringer("table", tblInfo.Name), zap.Int("successful region count", len(regionIDs)), zap.Error(err)) + } + return regionIDs +} + +func waitScatterRegionFinish(store kv.SplitableStore, regionIDs ...uint64) { + for _, regionID := range regionIDs { + err := store.WaitScatterRegionFinish(regionID, 0) + if err != nil { + logutil.Logger(context.Background()).Warn("[ddl] wait scatter region failed", zap.Uint64("regionID", regionID), zap.Error(err)) + } + } +} diff --git a/ddl/table.go b/ddl/table.go index c81c51a227e07..447cd9fa41316 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/parser/charset" "github.com/pingcap/parser/model" + field_types "github.com/pingcap/parser/types" "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" @@ -32,8 +33,6 @@ import ( "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/gcutil" - "github.com/pingcap/tidb/util/logutil" - "go.uber.org/zap" ) func onCreateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { @@ -339,99 +338,6 @@ func checkSafePoint(w *worker, snapshotTS uint64) error { return gcutil.ValidateSnapshot(ctx, snapshotTS) } -type splitableStore interface { - SplitRegion(splitKey kv.Key) error - SplitRegionAndScatter(splitKey kv.Key) (uint64, error) - WaitScatterRegionFinish(regionID uint64) error -} - -func splitPartitionTableRegion(store kv.Storage, pi *model.PartitionInfo) { - // Max partition count is 4096, should we sample and just choose some of the partition to split? - for _, def := range pi.Definitions { - splitTableRegion(store, def.ID) - } -} - -func splitTableRegion(store kv.Storage, tableID int64) { - s, ok := store.(splitableStore) - if !ok { - return - } - tableStartKey := tablecodec.GenTablePrefix(tableID) - if err := s.SplitRegion(tableStartKey); err != nil { - // It will be automatically split by TiKV later. - logutil.Logger(ddlLogCtx).Warn("[ddl] split table region failed", zap.Error(err)) - } -} - -func preSplitTableRegion(store kv.Storage, tblInfo *model.TableInfo, waitTableSplitFinish bool) { - s, ok := store.(splitableStore) - if !ok { - return - } - regionIDs := make([]uint64, 0, 1<<(tblInfo.PreSplitRegions-1)+len(tblInfo.Indices)) - - // Example: - // ShardRowIDBits = 5 - // PreSplitRegions = 3 - // - // then will pre-split 2^(3-1) = 4 regions. - // - // in this code: - // max = 1 << (tblInfo.ShardRowIDBits - 1) = 1 << (5-1) = 16 - // step := int64(1 << (tblInfo.ShardRowIDBits - tblInfo.PreSplitRegions)) = 1 << (5-3) = 4; - // - // then split regionID is below: - // 4 << 59 = 2305843009213693952 - // 8 << 59 = 4611686018427387904 - // 12 << 59 = 6917529027641081856 - // - // The 4 pre-split regions range is below: - // 0 ~ 2305843009213693952 - // 2305843009213693952 ~ 4611686018427387904 - // 4611686018427387904 ~ 6917529027641081856 - // 6917529027641081856 ~ 9223372036854775807 ( (1 << 63) - 1 ) - // - // And the max _tidb_rowid is 9223372036854775807, it won't be negative number. - - // Split table region. - step := int64(1 << (tblInfo.ShardRowIDBits - tblInfo.PreSplitRegions)) - // The highest bit is the symbol bit,and alloc _tidb_rowid will always be positive number. - // So we only need to split the region for the positive number. - max := int64(1 << (tblInfo.ShardRowIDBits - 1)) - for p := int64(step); p < max; p += step { - recordID := p << (64 - tblInfo.ShardRowIDBits) - recordPrefix := tablecodec.GenTableRecordPrefix(tblInfo.ID) - key := tablecodec.EncodeRecordKey(recordPrefix, recordID) - regionID, err := s.SplitRegionAndScatter(key) - if err != nil { - logutil.Logger(ddlLogCtx).Warn("[ddl] pre split table region failed", zap.Int64("recordID", recordID), zap.Error(err)) - } else { - regionIDs = append(regionIDs, regionID) - } - } - - // Split index region. - for _, idx := range tblInfo.Indices { - indexPrefix := tablecodec.EncodeTableIndexPrefix(tblInfo.ID, idx.ID) - regionID, err := s.SplitRegionAndScatter(indexPrefix) - if err != nil { - logutil.Logger(ddlLogCtx).Warn("[ddl] pre split table index region failed", zap.String("index", idx.Name.L), zap.Error(err)) - } else { - regionIDs = append(regionIDs, regionID) - } - } - if !waitTableSplitFinish { - return - } - for _, regionID := range regionIDs { - err := s.WaitScatterRegionFinish(regionID) - if err != nil { - logutil.Logger(ddlLogCtx).Warn("[ddl] wait scatter region failed", zap.Uint64("regionID", regionID), zap.Error(err)) - } - } -} - func getTable(store kv.Storage, schemaID int64, tblInfo *model.TableInfo) (table.Table, error) { alloc := autoid.NewAllocator(store, tblInfo.GetDBID(schemaID), tblInfo.IsAutoIncColUnsigned()) tbl, err := table.TableFromMeta(alloc, tblInfo) @@ -745,7 +651,7 @@ func onModifyTableCharsetAndCollate(t *meta.Meta, job *model.Job) (ver int64, _ tblInfo.Collate = toCollate // update column charset. for _, col := range tblInfo.Columns { - if typesNeedCharset(col.Tp) { + if field_types.HasCharset(&col.FieldType) { col.Charset = toCharset col.Collate = toCollate } else { @@ -764,7 +670,7 @@ func onModifyTableCharsetAndCollate(t *meta.Meta, job *model.Job) (ver int64, _ func checkTableNotExists(d *ddlCtx, t *meta.Meta, schemaID int64, tableName string) error { // d.infoHandle maybe nil in some test. - if d.infoHandle == nil { + if d.infoHandle == nil || !d.infoHandle.IsValid() { return checkTableNotExistsFromStore(t, schemaID, tableName) } // Try to use memory schema info to check first. @@ -866,7 +772,7 @@ func onAddTablePartition(t *meta.Meta, job *model.Job) (ver int64, _ error) { return ver, errors.Trace(err) } - err = checkPartitionNameUnique(tblInfo, partInfo) + err = checkAddPartitionNameUnique(tblInfo, partInfo) if err != nil { job.State = model.JobStateCancelled return ver, errors.Trace(err) diff --git a/ddl/table_split_test.go b/ddl/table_split_test.go index afb33d5b4044f..da4a1f96c83de 100644 --- a/ddl/table_split_test.go +++ b/ddl/table_split_test.go @@ -38,7 +38,7 @@ func (s *testDDLTableSplitSuite) TestTableSplit(c *C) { c.Assert(err, IsNil) defer store.Close() session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() atomic.StoreUint32(&ddl.EnableSplitTableRegion, 1) dom, err := session.BootstrapSession(store) c.Assert(err, IsNil) diff --git a/ddl/table_test.go b/ddl/table_test.go index f9a2cfbee3b56..ca189139b7d9d 100644 --- a/ddl/table_test.go +++ b/ddl/table_test.go @@ -42,12 +42,12 @@ type testTableSuite struct { // testTableInfo creates a test table with num int columns and with no index. func testTableInfo(c *C, d *ddl, name string, num int) *model.TableInfo { - var err error tblInfo := &model.TableInfo{ Name: model.NewCIStr(name), } - tblInfo.ID, err = d.genGlobalID() + genIDs, err := d.genGlobalIDs(1) c.Assert(err, IsNil) + tblInfo.ID = genIDs[0] cols := make([]*model.ColumnInfo, num) for i := range cols { @@ -71,8 +71,9 @@ func testTableInfo(c *C, d *ddl, name string, num int) *model.TableInfo { // testTableInfo creates a test table with num int columns and with no index. func testTableInfoWithPartition(c *C, d *ddl, name string, num int) *model.TableInfo { tblInfo := testTableInfo(c, d, name, num) - pid, err := d.genGlobalID() + genIDs, err := d.genGlobalIDs(1) c.Assert(err, IsNil) + pid := genIDs[0] tblInfo.Partition = &model.PartitionInfo{ Type: model.PartitionTypeRange, Expr: tblInfo.Columns[0].Name.L, @@ -89,12 +90,12 @@ func testTableInfoWithPartition(c *C, d *ddl, name string, num int) *model.Table // testViewInfo creates a test view with num int columns. func testViewInfo(c *C, d *ddl, name string, num int) *model.TableInfo { - var err error tblInfo := &model.TableInfo{ Name: model.NewCIStr(name), } - tblInfo.ID, err = d.genGlobalID() + genIDs, err := d.genGlobalIDs(1) c.Assert(err, IsNil) + tblInfo.ID = genIDs[0] cols := make([]*model.ColumnInfo, num) viewCols := make([]model.CIStr, num) @@ -196,8 +197,9 @@ func testDropTable(c *C, ctx sessionctx.Context, d *ddl, dbInfo *model.DBInfo, t } func testTruncateTable(c *C, ctx sessionctx.Context, d *ddl, dbInfo *model.DBInfo, tblInfo *model.TableInfo) *model.Job { - newTableID, err := d.genGlobalID() + genIDs, err := d.genGlobalIDs(1) c.Assert(err, IsNil) + newTableID := genIDs[0] job := &model.Job{ SchemaID: dbInfo.ID, TableID: tblInfo.ID, diff --git a/ddl/syncer.go b/ddl/util/syncer.go similarity index 75% rename from ddl/syncer.go rename to ddl/util/syncer.go index ca53899537240..61305e7fb27ab 100644 --- a/ddl/syncer.go +++ b/ddl/util/syncer.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package ddl +package util import ( "context" @@ -25,7 +25,9 @@ import ( "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3/concurrency" + "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes" "github.com/pingcap/errors" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/owner" "github.com/pingcap/tidb/util/logutil" @@ -48,6 +50,8 @@ const ( keyOpDefaultTimeout = 2 * time.Second keyOpRetryInterval = 30 * time.Millisecond checkVersInterval = 20 * time.Millisecond + + ddlPrompt = "ddl-syncer" ) var ( @@ -56,9 +60,9 @@ var ( CheckVersFirstWaitTime = 50 * time.Millisecond // SyncerSessionTTL is the etcd session's TTL in seconds. // and it's an exported variable for testing. - SyncerSessionTTL = 10 * 60 - // WaitTimeWhenErrorOccured is waiting interval when processing DDL jobs encounter errors. - WaitTimeWhenErrorOccured = 1 * time.Second + SyncerSessionTTL = 90 + // ddlLogCtx uses for log. + ddlLogCtx = context.Background() ) // SchemaSyncer is used to synchronize schema version between the DDL worker leader and followers through etcd. @@ -86,6 +90,17 @@ type SchemaSyncer interface { // the latest schema version. If the result is false, wait for a while and check again util the processing time reach 2 * lease. // It returns until all servers' versions are equal to the latest version or the ctx is done. OwnerCheckAllVersions(ctx context.Context, latestVer int64) error + // NotifyCleanExpiredPaths informs to clean up expired paths. + // The returned value is used for testing. + NotifyCleanExpiredPaths() bool + // StartCleanWork starts to clean up tasks. + StartCleanWork() + // CloseCleanWork ends cleanup tasks. + CloseCleanWork() +} + +type ownerChecker interface { + IsOwner() bool } type schemaVersionSyncer struct { @@ -96,13 +111,21 @@ type schemaVersionSyncer struct { sync.RWMutex globalVerCh clientv3.WatchChan } + + // for clean worker + ownerChecker ownerChecker + notifyCleanExpiredPathsCh chan struct{} + quiteCh chan struct{} } // NewSchemaSyncer creates a new SchemaSyncer. -func NewSchemaSyncer(etcdCli *clientv3.Client, id string) SchemaSyncer { +func NewSchemaSyncer(etcdCli *clientv3.Client, id string, oc ownerChecker) SchemaSyncer { return &schemaVersionSyncer{ - etcdCli: etcdCli, - selfSchemaVerPath: fmt.Sprintf("%s/%s", DDLAllSchemaVersions, id), + etcdCli: etcdCli, + selfSchemaVerPath: fmt.Sprintf("%s/%s", DDLAllSchemaVersions, id), + ownerChecker: oc, + notifyCleanExpiredPathsCh: make(chan struct{}, 1), + quiteCh: make(chan struct{}), } } @@ -380,3 +403,106 @@ func (s *schemaVersionSyncer) OwnerCheckAllVersions(ctx context.Context, latestV time.Sleep(checkVersInterval) } } + +const ( + opDefaultRetryCnt = 10 + failedGetTTLLimit = 20 + opDefaultTimeout = 3 * time.Second + opRetryInterval = 500 * time.Millisecond +) + +// NeededCleanTTL is exported for testing. +var NeededCleanTTL = int64(-60) + +func (s *schemaVersionSyncer) StartCleanWork() { + for { + select { + case <-s.notifyCleanExpiredPathsCh: + if !s.ownerChecker.IsOwner() { + continue + } + + for i := 0; i < opDefaultRetryCnt; i++ { + childCtx, cancelFunc := context.WithTimeout(context.Background(), opDefaultTimeout) + resp, err := s.etcdCli.Leases(childCtx) + cancelFunc() + if err != nil { + logutil.Logger(ddlLogCtx).Info("[ddl] syncer clean expired paths, failed to get leases.", zap.Error(err)) + continue + } + + if isFinished := s.doCleanExpirePaths(resp.Leases); isFinished { + break + } + time.Sleep(opRetryInterval) + } + case <-s.quiteCh: + return + } + } +} + +func (s *schemaVersionSyncer) CloseCleanWork() { + close(s.quiteCh) +} + +func (s *schemaVersionSyncer) NotifyCleanExpiredPaths() bool { + var isNotified bool + var err error + startTime := time.Now() + select { + case s.notifyCleanExpiredPathsCh <- struct{}{}: + isNotified = true + default: + err = errors.New("channel is full, failed to notify clean expired paths") + } + metrics.OwnerHandleSyncerHistogram.WithLabelValues(metrics.OwnerNotifyCleanExpirePaths, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + return isNotified +} + +func (s *schemaVersionSyncer) doCleanExpirePaths(leases []clientv3.LeaseStatus) bool { + failedGetIDs := 0 + failedRevokeIDs := 0 + startTime := time.Now() + + defer func() { + metrics.OwnerHandleSyncerHistogram.WithLabelValues(metrics.OwnerCleanExpirePaths, metrics.RetLabel(nil)).Observe(time.Since(startTime).Seconds()) + }() + // TODO: Now LeaseStatus only has lease ID. + for _, lease := range leases { + // The DDL owner key uses '%x', so here print it too. + leaseID := fmt.Sprintf("%x, %d", lease.ID, lease.ID) + childCtx, cancelFunc := context.WithTimeout(context.Background(), opDefaultTimeout) + ttlResp, err := s.etcdCli.TimeToLive(childCtx, lease.ID) + cancelFunc() + if err != nil { + logutil.Logger(ddlLogCtx).Info("[ddl] syncer clean expired paths, failed to get one TTL.", zap.String("leaseID", leaseID), zap.Error(err)) + failedGetIDs++ + continue + } + + if failedGetIDs > failedGetTTLLimit { + return false + } + if ttlResp.TTL >= NeededCleanTTL { + continue + } + + st := time.Now() + childCtx, cancelFunc = context.WithTimeout(context.Background(), opDefaultTimeout) + _, err = s.etcdCli.Revoke(childCtx, lease.ID) + cancelFunc() + if err != nil && terror.ErrorEqual(err, rpctypes.ErrLeaseNotFound) { + logutil.Logger(ddlLogCtx).Warn("[ddl] syncer clean expired paths, failed to revoke lease.", zap.String("leaseID", leaseID), + zap.Int64("TTL", ttlResp.TTL), zap.Error(err)) + failedRevokeIDs++ + } + logutil.Logger(ddlLogCtx).Warn("[ddl] syncer clean expired paths,", zap.String("leaseID", leaseID), zap.Int64("TTL", ttlResp.TTL)) + metrics.OwnerHandleSyncerHistogram.WithLabelValues(metrics.OwnerCleanOneExpirePath, metrics.RetLabel(err)).Observe(time.Since(st).Seconds()) + } + + if failedGetIDs == 0 && failedRevokeIDs == 0 { + return true + } + return false +} diff --git a/ddl/util/syncer_test.go b/ddl/util/syncer_test.go new file mode 100644 index 0000000000000..9199ba2ac2857 --- /dev/null +++ b/ddl/util/syncer_test.go @@ -0,0 +1,249 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package util_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/etcdserver" + "github.com/coreos/etcd/integration" + "github.com/coreos/etcd/mvcc/mvccpb" + "github.com/pingcap/errors" + "github.com/pingcap/parser/terror" + . "github.com/pingcap/tidb/ddl" + . "github.com/pingcap/tidb/ddl/util" + "github.com/pingcap/tidb/owner" + "github.com/pingcap/tidb/store/mockstore" + goctx "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" +) + +const minInterval = 10 * time.Nanosecond // It's used to test timeout. + +func TestSyncerSimple(t *testing.T) { + testLease := 5 * time.Millisecond + origin := CheckVersFirstWaitTime + CheckVersFirstWaitTime = 0 + defer func() { + CheckVersFirstWaitTime = origin + }() + + store, err := mockstore.NewMockTikvStore() + if err != nil { + t.Fatal(err) + } + defer store.Close() + + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1}) + defer clus.Terminate(t) + cli := clus.RandClient() + ctx := goctx.Background() + d := NewDDL(ctx, cli, store, nil, nil, testLease, nil) + defer d.Stop() + + // for init function + if err = d.SchemaSyncer().Init(ctx); err != nil { + t.Fatalf("schema version syncer init failed %v", err) + } + resp, err := cli.Get(ctx, DDLAllSchemaVersions, clientv3.WithPrefix()) + if err != nil { + t.Fatalf("client get version failed %v", err) + } + key := DDLAllSchemaVersions + "/" + d.OwnerManager().ID() + checkRespKV(t, 1, key, InitialVersion, resp.Kvs...) + // for MustGetGlobalVersion function + globalVer, err := d.SchemaSyncer().MustGetGlobalVersion(ctx) + if err != nil { + t.Fatalf("client get global version failed %v", err) + } + if InitialVersion != fmt.Sprintf("%d", globalVer) { + t.Fatalf("client get global version %d isn't equal to init version %s", globalVer, InitialVersion) + } + childCtx, _ := goctx.WithTimeout(ctx, minInterval) + _, err = d.SchemaSyncer().MustGetGlobalVersion(childCtx) + if !isTimeoutError(err) { + t.Fatalf("client get global version result not match, err %v", err) + } + + d1 := NewDDL(ctx, cli, store, nil, nil, testLease, nil) + defer d1.Stop() + if err = d1.SchemaSyncer().Init(ctx); err != nil { + t.Fatalf("schema version syncer init failed %v", err) + } + + // for watchCh + wg := sync.WaitGroup{} + wg.Add(1) + currentVer := int64(123) + go func() { + defer wg.Done() + select { + case resp := <-d.SchemaSyncer().GlobalVersionCh(): + if len(resp.Events) < 1 { + t.Fatalf("get chan events count less than 1") + } + checkRespKV(t, 1, DDLGlobalSchemaVersion, fmt.Sprintf("%v", currentVer), resp.Events[0].Kv) + case <-time.After(100 * time.Millisecond): + t.Fatalf("get udpate version failed") + } + }() + + // for update latestSchemaVersion + err = d.SchemaSyncer().OwnerUpdateGlobalVersion(ctx, currentVer) + if err != nil { + t.Fatalf("update latest schema version failed %v", err) + } + + wg.Wait() + + // for CheckAllVersions + childCtx, cancel := goctx.WithTimeout(ctx, 20*time.Millisecond) + err = d.SchemaSyncer().OwnerCheckAllVersions(childCtx, currentVer) + if err == nil { + t.Fatalf("check result not match") + } + cancel() + + // for UpdateSelfVersion + childCtx, cancel = goctx.WithTimeout(ctx, 100*time.Millisecond) + err = d.SchemaSyncer().UpdateSelfVersion(childCtx, currentVer) + if err != nil { + t.Fatalf("update self version failed %v", errors.ErrorStack(err)) + } + cancel() + childCtx, cancel = goctx.WithTimeout(ctx, 100*time.Millisecond) + err = d1.SchemaSyncer().UpdateSelfVersion(childCtx, currentVer) + if err != nil { + t.Fatalf("update self version failed %v", errors.ErrorStack(err)) + } + cancel() + childCtx, _ = goctx.WithTimeout(ctx, minInterval) + err = d1.SchemaSyncer().UpdateSelfVersion(childCtx, currentVer) + if !isTimeoutError(err) { + t.Fatalf("update self version result not match, err %v", err) + } + + // for CheckAllVersions + childCtx, _ = goctx.WithTimeout(ctx, 100*time.Millisecond) + err = d.SchemaSyncer().OwnerCheckAllVersions(childCtx, currentVer-1) + if err != nil { + t.Fatalf("check all versions failed %v", err) + } + childCtx, _ = goctx.WithTimeout(ctx, 100*time.Millisecond) + err = d.SchemaSyncer().OwnerCheckAllVersions(childCtx, currentVer) + if err != nil { + t.Fatalf("check all versions failed %v", err) + } + childCtx, _ = goctx.WithTimeout(ctx, minInterval) + err = d.SchemaSyncer().OwnerCheckAllVersions(childCtx, currentVer) + if !isTimeoutError(err) { + t.Fatalf("check all versions result not match, err %v", err) + } + + // for StartCleanWork + go d.SchemaSyncer().StartCleanWork() + ttl := 10 + // Make sure NeededCleanTTL > ttl, then we definitely clean the ttl. + NeededCleanTTL = int64(11) + ttlKey := "session_ttl_key" + ttlVal := "session_ttl_val" + session, err := owner.NewSession(ctx, "", cli, owner.NewSessionDefaultRetryCnt, ttl) + if err != nil { + t.Fatalf("new session failed %v", err) + } + childCtx, cancel = context.WithTimeout(ctx, 100*time.Millisecond) + err = PutKVToEtcd(childCtx, cli, 5, ttlKey, ttlVal, clientv3.WithLease(session.Lease())) + if err != nil { + t.Fatalf("put kv to etcd failed %v", err) + } + cancel() + // Make sure the ttlKey is exist in etcd. + resp, err = cli.Get(ctx, ttlKey) + if err != nil { + t.Fatalf("client get version failed %v", err) + } + checkRespKV(t, 1, ttlKey, ttlVal, resp.Kvs...) + d.SchemaSyncer().NotifyCleanExpiredPaths() + // Make sure the clean worker is done. + notifiedCnt := 1 + for i := 0; i < 100; i++ { + isNotified := d.SchemaSyncer().NotifyCleanExpiredPaths() + if isNotified { + notifiedCnt++ + } + // notifyCleanExpiredPathsCh's length is 1, + // so when notifiedCnt is 3, we can make sure the clean worker is done at least once. + if notifiedCnt == 3 { + break + } + time.Sleep(20 * time.Millisecond) + } + if notifiedCnt != 3 { + t.Fatal("clean worker don't finish") + } + // Make sure the ttlKey is removed in etcd. + resp, err = cli.Get(ctx, ttlKey) + if err != nil { + t.Fatalf("client get version failed %v", err) + } + checkRespKV(t, 0, ttlKey, "", resp.Kvs...) + + // for RemoveSelfVersionPath + resp, err = cli.Get(goctx.Background(), key) + if err != nil { + t.Fatalf("get key %s failed %v", key, err) + } + currVer := fmt.Sprintf("%v", currentVer) + checkRespKV(t, 1, key, currVer, resp.Kvs...) + d.SchemaSyncer().RemoveSelfVersionPath() + resp, err = cli.Get(goctx.Background(), key) + if err != nil { + t.Fatalf("get key %s failed %v", key, err) + } + if len(resp.Kvs) != 0 { + t.Fatalf("remove key %s failed %v", key, err) + } +} + +func isTimeoutError(err error) bool { + if terror.ErrorEqual(err, goctx.DeadlineExceeded) || grpc.Code(errors.Cause(err)) == codes.DeadlineExceeded || + terror.ErrorEqual(err, etcdserver.ErrTimeout) { + return true + } + return false +} + +func checkRespKV(t *testing.T, kvCount int, key, val string, + kvs ...*mvccpb.KeyValue) { + if len(kvs) != kvCount { + t.Fatalf("resp key %s kvs %v length is != %d", key, kvs, kvCount) + } + if kvCount == 0 { + return + } + + kv := kvs[0] + if string(kv.Key) != key { + t.Fatalf("key resp %s, exported %s", kv.Key, key) + } + if val != val { + t.Fatalf("val resp %s, exported %s", kv.Value, val) + } +} diff --git a/ddl/util/util.go b/ddl/util/util.go index b47a711fc170b..1161522714bcb 100644 --- a/ddl/util/util.go +++ b/ddl/util/util.go @@ -69,8 +69,8 @@ func loadDeleteRangesFromTable(ctx sessionctx.Context, table string, safePoint u } rs := rss[0] - req := rs.NewRecordBatch() - it := chunk.NewIterator4Chunk(req.Chunk) + req := rs.NewChunk() + it := chunk.NewIterator4Chunk(req) for { err = rs.Next(context.TODO(), req) if err != nil { diff --git a/distsql/distsql_test.go b/distsql/distsql_test.go index 478881e47720f..20cd07619d6b3 100644 --- a/distsql/distsql_test.go +++ b/distsql/distsql_test.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/execdetails" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tipb/go-tipb" ) @@ -42,6 +43,8 @@ func (s *testSuite) createSelectNormal(batch, totalRows int, c *C, planIDs []str SetDesc(false). SetKeepOrder(false). SetFromSessionVars(variable.NewSessionVars()). + SetMemTracker(memory.NewTracker(stringutil.StringerStr("testSuite.createSelectNormal"), + s.sctx.GetSessionVars().MemQuotaDistSQL)). Build() c.Assert(err, IsNil) @@ -106,6 +109,21 @@ func (s *testSuite) TestSelectNormal(c *C) { c.Assert(numAllRows, Equals, 2) err := response.Close() c.Assert(err, IsNil) + c.Assert(response.memTracker.BytesConsumed(), Equals, int64(0)) +} + +func (s *testSuite) TestSelectMemTracker(c *C) { + response, colTypes := s.createSelectNormal(2, 6, c, nil) + response.Fetch(context.TODO()) + + // Test Next. + chk := chunk.New(colTypes, 3, 3) + err := response.Next(context.TODO(), chk) + c.Assert(err, IsNil) + c.Assert(chk.IsFull(), Equals, true) + err = response.Close() + c.Assert(err, IsNil) + c.Assert(response.memTracker.BytesConsumed(), Equals, int64(0)) } func (s *testSuite) TestSelectNormalChunkSize(c *C) { @@ -113,6 +131,7 @@ func (s *testSuite) TestSelectNormalChunkSize(c *C) { response.Fetch(context.TODO()) s.testChunkSize(response, colTypes, c) c.Assert(response.Close(), IsNil) + c.Assert(response.memTracker.BytesConsumed(), Equals, int64(0)) } func (s *testSuite) TestSelectWithRuntimeStats(c *C) { diff --git a/distsql/request_builder.go b/distsql/request_builder.go index 92532de0b7d1d..6818579c5a2a5 100644 --- a/distsql/request_builder.go +++ b/distsql/request_builder.go @@ -14,12 +14,10 @@ package distsql import ( - "fmt" "math" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" @@ -44,10 +42,8 @@ func (builder *RequestBuilder) Build() (*kv.Request, error) { } // SetMemTracker sets a memTracker for this request. -func (builder *RequestBuilder) SetMemTracker(sctx sessionctx.Context, label fmt.Stringer) *RequestBuilder { - t := memory.NewTracker(label, sctx.GetSessionVars().MemQuotaDistSQL) - t.AttachTo(sctx.GetSessionVars().StmtCtx.MemTracker) - builder.Request.MemTracker = t +func (builder *RequestBuilder) SetMemTracker(tracker *memory.Tracker) *RequestBuilder { + builder.Request.MemTracker = tracker return builder } diff --git a/distsql/request_builder_test.go b/distsql/request_builder_test.go index b64dd63218892..a2b472b5ad833 100644 --- a/distsql/request_builder_test.go +++ b/distsql/request_builder_test.go @@ -26,8 +26,10 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/ranger" + "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tipb/go-tipb" ) @@ -49,6 +51,9 @@ type testSuite struct { func (s *testSuite) SetUpSuite(c *C) { ctx := mock.NewContext() + ctx.GetSessionVars().StmtCtx = &stmtctx.StatementContext{ + MemTracker: memory.NewTracker(stringutil.StringerStr("testSuite"), variable.DefTiDBMemQuotaDistSQL), + } ctx.Store = &mock.Store{ Client: &mock.Client{ MockResponse: &mockResponse{ diff --git a/distsql/select_result.go b/distsql/select_result.go index 72a283d1b2f8f..3935df52ad450 100644 --- a/distsql/select_result.go +++ b/distsql/select_result.go @@ -66,8 +66,9 @@ type selectResult struct { fieldTypes []*types.FieldType ctx sessionctx.Context - selectResp *tipb.SelectResponse - respChkIdx int + selectResp *tipb.SelectResponse + selectRespSize int // record the selectResp.Size() when it is initialized. + respChkIdx int feedback *statistics.QueryFeedback partialCount int64 // number of partial results. @@ -103,20 +104,25 @@ func (r *selectResult) fetch(ctx context.Context) { if err != nil { result.err = err } else if resultSubset == nil { + // If the result is drained, the resultSubset would be nil return } else { result.result = resultSubset - if r.memTracker != nil { - r.memTracker.Consume(int64(resultSubset.MemSize())) - } + r.memConsume(int64(resultSubset.MemSize())) } select { case r.results <- result: case <-r.closed: // If selectResult called Close() already, make fetch goroutine exit. + if resultSubset != nil { + r.memConsume(-int64(resultSubset.MemSize())) + } return case <-ctx.Done(): + if resultSubset != nil { + r.memConsume(-int64(resultSubset.MemSize())) + } return } } @@ -161,24 +167,21 @@ func (r *selectResult) getSelectResp() error { if re.err != nil { return errors.Trace(re.err) } - if r.memTracker != nil && r.selectResp != nil { - r.memTracker.Consume(-int64(r.selectResp.Size())) + if r.selectResp != nil { + r.memConsume(-int64(r.selectRespSize)) } if re.result == nil { r.selectResp = nil return nil } - if r.memTracker != nil { - r.memTracker.Consume(-int64(re.result.MemSize())) - } + r.memConsume(-int64(re.result.MemSize())) r.selectResp = new(tipb.SelectResponse) err := r.selectResp.Unmarshal(re.result.GetData()) if err != nil { return errors.Trace(err) } - if r.memTracker != nil && r.selectResp != nil { - r.memTracker.Consume(int64(r.selectResp.Size())) - } + r.selectRespSize = r.selectResp.Size() + r.memConsume(int64(r.selectRespSize)) if err := r.selectResp.Error; err != nil { return terror.ClassTiKV.New(terror.ErrCode(err.Code), err.Msg) } @@ -234,13 +237,27 @@ func (r *selectResult) readRowsData(chk *chunk.Chunk) (err error) { return nil } +func (r *selectResult) memConsume(bytes int64) { + if r.memTracker != nil { + r.memTracker.Consume(bytes) + } +} + // Close closes selectResult. func (r *selectResult) Close() error { - // Close this channel tell fetch goroutine to exit. if r.feedback.Actual() >= 0 { metrics.DistSQLScanKeysHistogram.Observe(float64(r.feedback.Actual())) } metrics.DistSQLPartialCountHistogram.Observe(float64(r.partialCount)) + // Close this channel to tell the fetch goroutine to exit. close(r.closed) + for re := range r.results { + if re.result != nil { + r.memConsume(-int64(re.result.MemSize())) + } + } + if r.selectResp != nil { + r.memConsume(-int64(r.selectRespSize)) + } return r.resp.Close() } diff --git a/docs/design/2019-04-11-indexmerge.md b/docs/design/2019-04-11-indexmerge.md index aa0108d62487a..14a0e7eb5e96a 100644 --- a/docs/design/2019-04-11-indexmerge.md +++ b/docs/design/2019-04-11-indexmerge.md @@ -1,48 +1,48 @@ -# Proposal: scan a table using IndexMerge -- Author(s) : WHU -- Last updated : May 10 -- Discussion at : - - -## Abstract - -The proposal proposes to use multiple indexes to scan a table if possible. In some cases, using multiple indexes will improve performance. - -## Background - -In present TiDB, a SQL statement with conditions involving multiple indexed attributes only uses one of the conditions as the index filter to build access condition, while others are regarded as table filters. Firstly, use index scan (at most one index) to get handles (rowid in TiDB). Then use the handles to get rows and check whether the rows satisfy the conditions of table filters. Some relational databases implement a table access path using multiple indexes. In some cases, this way will improve performance. - -We take an example to explain it. We define the table schema as : - -``` -CREATE TABLE t1 (a int, b int, c int); -CREATE INDEX t1a on t1(a); -CREATE INDEX t1b on t1(b); -CREATE INDEX t1c on t1(c); -``` -And use a test SQL statement `SELECT * FROM t1 where a < 2 or b > 50`. Currently, TiDB does a table scan and puts `a < 2 or b > 50` as a Selection on top of it. If the selectivity of `a < 2 ` and `b > 50` is low, a better approach would be using indexes on columns `a` and `b` to retrieve rows respectively, and applying a union operation on the result sets. - - -## Proposal -In short, we need to consider access paths using multiple indexes. - -### Planner -We propose to add new `IndexMergeReader / PhysicalIndexMergeReader` and `IndexMergeLookUpReader / PhysicalIndexMergeLookUpReader` operators. - -Now we just consider the following two kinds of queries: - -(1)Conditions in CNF, e.g, `select * from t1 where c1 and c2 and c3 and …` - +# Proposal: scan a table using IndexMerge +- Author(s) : WHU +- Last updated : May 10 +- Discussion at : + + +## Abstract + +The proposal proposes to use multiple indexes to scan a table if possible. In some cases, using multiple indexes will improve performance. + +## Background + +In present TiDB, a SQL statement with conditions involving multiple indexed attributes only uses one of the conditions as the index filter to build access condition, while others are regarded as table filters. Firstly, use index scan (at most one index) to get handles (rowid in TiDB). Then use the handles to get rows and check whether the rows satisfy the conditions of table filters. Some relational databases implement a table access path using multiple indexes. In some cases, this way will improve performance. + +We take an example to explain it. We define the table schema as : + +``` +CREATE TABLE t1 (a int, b int, c int); +CREATE INDEX t1a on t1(a); +CREATE INDEX t1b on t1(b); +CREATE INDEX t1c on t1(c); +``` +And use a test SQL statement `SELECT * FROM t1 where a < 2 or b > 50`. Currently, TiDB does a table scan and puts `a < 2 or b > 50` as a Selection on top of it. If the selectivity of `a < 2 ` and `b > 50` is low, a better approach would be using indexes on columns `a` and `b` to retrieve rows respectively, and applying a union operation on the result sets. + + +## Proposal +In short, we need to consider access paths using multiple indexes. + +### Planner +We propose to add new `IndexMergeReader / PhysicalIndexMergeReader` and `IndexMergeLookUpReader / PhysicalIndexMergeLookUpReader` operators. + +Now we just consider the following two kinds of queries: + +(1)Conditions in CNF, e.g, `select * from t1 where c1 and c2 and c3 and …` + In this form, each CNF item can be covered by a single index respectively. For example, if we have single column indexes for `t1.a`, `ta.b` and `t1.c` respectively, for SQL `select * from t1 where (a < 10 or a > 100) and b < 10 and c > 1000`, we can use all the three indexes to read the table handles. The result plan for it is like: - -``` -PhysicalIndexMergeLookUpReader(IndexMergeIntersect) - IndexScan(t1a) - IndexScan(t1b) - IndexScan(t1c) + +``` +PhysicalIndexMergeLookUpReader(IndexMergeIntersect) + IndexScan(t1a) + IndexScan(t1b) + IndexScan(t1c) TableScan ``` - + For the CNF items not covered by any index, we take them as table filters and convert them to selection on top of the scan node. For SQL `select * from t1 where (a < 10 or c >100) and b < 10`, only item `b < 10` can be used as index access condition, so we can only consider single index lookup reader. We set up a experiment for the CNF form to compare our demo implement with the master branch. The schema and test sql form we define are following: @@ -64,17 +64,17 @@ We load two million rows into `T200M` with one to two million sequence for all c CNF 200 -**Note:** `SELECTIVITY`is for the single column. - -(2) Conditions in DNF, e.g, `select * from t1 where c1 or c2 or c3 or …` - -In this form, every DNF item must be covered by a single index. If any DNF item cannot be covered by a single index, we cannot choose IndexMerge scan. For example, SQL `select * from t1 where a > 1 or ( b >1 and b <10)` will generate a possible plan like: - -``` -PhysicalIndexMergeLookUpReader(IndexMergeUnion) - IndexScan(t1a) - IndexScan(t1b) - TableScan +**Note:** `SELECTIVITY`is for the single column. + +(2) Conditions in DNF, e.g, `select * from t1 where c1 or c2 or c3 or …` + +In this form, every DNF item must be covered by a single index. If any DNF item cannot be covered by a single index, we cannot choose IndexMerge scan. For example, SQL `select * from t1 where a > 1 or ( b >1 and b <10)` will generate a possible plan like: + +``` +PhysicalIndexMergeLookUpReader(IndexMergeUnion) + IndexScan(t1a) + IndexScan(t1b) + TableScan ``` We set up a experiment for the DNF form to compare our demo implement with the master branch. The schema and test sql form we define are following: @@ -87,235 +87,235 @@ Table Schema: Test SQL Form: SELECT * FROM T200 WHERE a < $1 OR b > $2; -``` +``` We load two million rows into `T200` with one to two million sequence for all columns. We alter the value of `$1` and `$2` in test sql form to obtain the accurate selectivities. The result can be seen in the following graph: DNF 200 - - -We design PhysicalIndexMergeLookUpReader structure as: - -``` -// PhysicalIndexMergeLookUpReader -type PhysicalIndexMergeLookUpReader struct { - physicalSchemaProducer - - //Follow two plans flat to construct executor pb. - IndexPlans []PhysicalPlan - TablePlans []PhysicalPlan - - indexPlans []PhysicalPlan - tablePlan PhysicalPlan - - IndexMergeType int -} -``` - - -- The field `IndexMergeType` indicates the operations on results of multiple index scans, and has the following possible values: - - 0: not an IndexMerge scan; - - 1: intersection operation on result sets, and with a table scan; - - 2: intersection operation on result sets, without the table scan; - - 3: union operation on result sets, must have a table scan; - + + +We design PhysicalIndexMergeLookUpReader structure as: + +``` +// PhysicalIndexMergeLookUpReader +type PhysicalIndexMergeLookUpReader struct { + physicalSchemaProducer + + //Follow two plans flat to construct executor pb. + IndexPlans []PhysicalPlan + TablePlans []PhysicalPlan + + indexPlans []PhysicalPlan + tablePlan PhysicalPlan + + IndexMergeType int +} +``` + + +- The field `IndexMergeType` indicates the operations on results of multiple index scans, and has the following possible values: + - 0: not an IndexMerge scan; + - 1: intersection operation on result sets, and with a table scan; + - 2: intersection operation on result sets, without the table scan; + - 3: union operation on result sets, must have a table scan; + In first version, we just take `PhysicalIndexMergeLookUpReader` and `PhysicalIndexMergeReader` together. - -### IndexMergePath Generate -Now, we first generate all possible IndexMergeOr paths, then generate possible IndexMergeIntersection path. - -``` -type IndexMergePath struct { - IndexPath[] - tableFilters - IndexMergeType -} -``` - - -``` -GetIndexMergeUnionPaths(IndexInfos, PushdownConditions){ - var results = nil - foreach cond in PushdownConditions { - if !isOrCondition(cond) { - continue - } - args = flatten(cond,'or') - foreach arg in args { - var indexAccessPaths, imPaths - // Common index paths would be merged later in `CreateIndexMergeUnionPath` - if isAndCondition(arg) { - andArgs = flatten(arg,'and') - indexAccessPaths = buildAccessPath(andArgs, IndexInfos) - } else { - tempArgs = []{arg} - indexAccessPaths = buildAccessPath(tempArgs, IndexInfos) - } - if indexAccessPaths == nil { - imPaths = nil - break - } - imPartialPath = GetIndexMergePartialPath(IndexInfos, indexAccessPaths) - imPaths = append(imPaths, imPartialPath) - } - if imPaths != nil { - possiblePath = CreateIndexMergeUnionPath(imPaths,PushdownConditions,con,IndexInfos) - results = append(results, possiblePath) - } - } - return results -} - - -buildAccessPath(Conditions, IndexInfos){ - var results - for index in IndexInfos { - res = detachCNFCondAndBuildRangeForIndex(Conditions, index, considerDNF = true) - if res.accessCondition = nil { - continue - } - indexPath = CreateIndexAccessPath(index, res) - results = append(results, indexPath) - } - return results -} - -// This function will get a best indexPath for a con from some alternative paths. -// now we just take the index which has more columns. -// for exmple: -// (1) -// index1(a,b,c) index2(a,b) index3(a) -// condition: a = 1 will choose index1; a = 1 and b = 2 will also choose index1 -// (2) -// index1(a) index2(b) -// condition: a = 1 and b = 1 -// random choose??? -GetIndexMergePartialPath(IndexInfos, indexAccessPaths) { -} - -// (1)maybe we will merge some indexPaths -// for example: index1(a) index2(b) -// condition : a < 1 or a > 2 or b < 1 or b > 10 -// imPaths will be [a<1,a>2,b<1,b>10] and we can merge it and get [a<1 or a >2 , b < 1 or b > 10] -// (2)IndexMergePath.tableFilters: -// <1> remove cond from PushdownConditions and the remain will be added to tableFitler. -// <2> after merge operation, if any indexPath's tableFilter is not nil, we should add it into tableFilters - -CreateIndexMergeUnionPath(imPaths,PushdownConditions,cond,IndexInfos) { -} - -``` - - -``` -GetIndexMergeIntersectionPaths(pushDownConditions, usedConditionsInOr, indexInfos) { - var partialPaths - - if len(pushDownConditions) - len(usedConditionsInOr) < 2 { - return nil - } - tableFilters := append(tableFilters, usedConditionsInOr...) - newConsiderConditions := remove(pushDownConditions, usedConditionsInOr) - for cond in newConsiderConditions { - indexPaths = buildAccessPath([]{cond}, indexInfos) - if indexPaths == nil { - tableFiltes = append(tableFilters,cond) - continue - } - indexPath := GetIndexMergePartialPath(indexPaths,indexInfos) - partialPaths = append(partialPaths, indexPath) - } - if len(partialPaths) < 2 { - return nil - } - return CreateIndexMergeIntersectionPath(partialPaths, tableFilters) -} - -// Now, we just use all path in partialPaths to generate a IndexMergeIntersection. -// We also need to merge possible paths. -// For example: -// index: ix1(a) -// condition: a > 1 and a < 10 -// we will get two partial paths and they all use index ix1. -// IndexMergePath.tableFilters: -// <1> tableFilters -// <2> after merge operation, if any indexPath's tableFilter is not nil, we -// should add indexPath’s tableFilter into IndexMergePath.tableFilters -CreateIndexMergeIntersectionPath(partialPaths, tableFilters) { -} - -``` - -### Executor -Graph bellow illustrates execution of IndexMerge scan. -Execution Model - -Every index plan in `PhysicalIndexMergeLookUpReader` will start an `IndexWorker` to execute the IndexScan plan and send handles to AndOrWorker. AndOrWorker is responsible for doing set operations (and, or) to get final handles. Then `AndOrWoker` sends final handles to `TableWokers` to get rows from TiKV. - -We can take some tricks to make execution at pipeline mode without considering the order. - - -(1) IndexMergeIntersection + +### IndexMergePath Generate +Now, we first generate all possible IndexMergeOr paths, then generate possible IndexMergeIntersection path. + +``` +type IndexMergePath struct { + IndexPath[] + tableFilters + IndexMergeType +} +``` + + +``` +GetIndexMergeUnionPaths(IndexInfos, PushdownConditions){ + var results = nil + foreach cond in PushdownConditions { + if !isOrCondition(cond) { + continue + } + args = flatten(cond,'or') + foreach arg in args { + var indexAccessPaths, imPaths + // Common index paths would be merged later in `CreateIndexMergeUnionPath` + if isAndCondition(arg) { + andArgs = flatten(arg,'and') + indexAccessPaths = buildAccessPath(andArgs, IndexInfos) + } else { + tempArgs = []{arg} + indexAccessPaths = buildAccessPath(tempArgs, IndexInfos) + } + if indexAccessPaths == nil { + imPaths = nil + break + } + imPartialPath = GetIndexMergePartialPath(IndexInfos, indexAccessPaths) + imPaths = append(imPaths, imPartialPath) + } + if imPaths != nil { + possiblePath = CreateIndexMergeUnionPath(imPaths,PushdownConditions,con,IndexInfos) + results = append(results, possiblePath) + } + } + return results +} + + +buildAccessPath(Conditions, IndexInfos){ + var results + for index in IndexInfos { + res = detachCNFCondAndBuildRangeForIndex(Conditions, index, considerDNF = true) + if res.accessCondition = nil { + continue + } + indexPath = CreateIndexAccessPath(index, res) + results = append(results, indexPath) + } + return results +} + +// This function will get a best indexPath for a con from some alternative paths. +// now we just take the index which has more columns. +// for exmple: +// (1) +// index1(a,b,c) index2(a,b) index3(a) +// condition: a = 1 will choose index1; a = 1 and b = 2 will also choose index1 +// (2) +// index1(a) index2(b) +// condition: a = 1 and b = 1 +// random choose??? +GetIndexMergePartialPath(IndexInfos, indexAccessPaths) { +} + +// (1)maybe we will merge some indexPaths +// for example: index1(a) index2(b) +// condition : a < 1 or a > 2 or b < 1 or b > 10 +// imPaths will be [a<1,a>2,b<1,b>10] and we can merge it and get [a<1 or a >2 , b < 1 or b > 10] +// (2)IndexMergePath.tableFilters: +// <1> remove cond from PushdownConditions and the remain will be added to tableFitler. +// <2> after merge operation, if any indexPath's tableFilter is not nil, we should add it into tableFilters + +CreateIndexMergeUnionPath(imPaths,PushdownConditions,cond,IndexInfos) { +} + +``` + + +``` +GetIndexMergeIntersectionPaths(pushDownConditions, usedConditionsInOr, indexInfos) { + var partialPaths + + if len(pushDownConditions) - len(usedConditionsInOr) < 2 { + return nil + } + tableFilters := append(tableFilters, usedConditionsInOr...) + newConsiderConditions := remove(pushDownConditions, usedConditionsInOr) + for cond in newConsiderConditions { + indexPaths = buildAccessPath([]{cond}, indexInfos) + if indexPaths == nil { + tableFiltes = append(tableFilters,cond) + continue + } + indexPath := GetIndexMergePartialPath(indexPaths,indexInfos) + partialPaths = append(partialPaths, indexPath) + } + if len(partialPaths) < 2 { + return nil + } + return CreateIndexMergeIntersectionPath(partialPaths, tableFilters) +} + +// Now, we just use all path in partialPaths to generate a IndexMergeIntersection. +// We also need to merge possible paths. +// For example: +// index: ix1(a) +// condition: a > 1 and a < 10 +// we will get two partial paths and they all use index ix1. +// IndexMergePath.tableFilters: +// <1> tableFilters +// <2> after merge operation, if any indexPath's tableFilter is not nil, we +// should add indexPath’s tableFilter into IndexMergePath.tableFilters +CreateIndexMergeIntersectionPath(partialPaths, tableFilters) { +} + +``` + +### Executor +Graph bellow illustrates execution of IndexMerge scan. +Execution Model + +Every index plan in `PhysicalIndexMergeLookUpReader` will start an `IndexWorker` to execute the IndexScan plan and send handles to AndOrWorker. AndOrWorker is responsible for doing set operations (and, or) to get final handles. Then `AndOrWoker` sends final handles to `TableWokers` to get rows from TiKV. + +We can take some tricks to make execution at pipeline mode without considering the order. + + +(1) IndexMergeIntersection - IndexMergeIntersection - We take an example to explain it. Use set1 to record rowids which are returned by ix1 but not sent to tableWorker. Use set2 to record the same thing for ix2. -If new rowid comes from ix1, first we check if it is in set2. If so, we delete it from set2 and send it to tableWorker. Otherwise, we add it into set1. For the above figure, we use the following table to show the processing. - -| new rowid | set1 | set2 | sent to TableWorker | -| :------:| :------: | :------: | :------: | -| 2(ix1) | [2] | [ ] | [ ] | -| 1(ix2) | [2] | [1] | [ ] | -| 5(ix1) | [2,5] |[1] | [ ] | -| 5(ix2) | [2] | [1] | [5] | -| 7(ix1) | [2,7] |[1] | [ ] | -| 2(ix2) | [7] | [1] | [2] | - - - -(2) IndexMergeUnion - - - We take a structure(we call it set) to record which rowids are accessed. If a new rowid returned by IndexScan, check if it is in set. If in it, we just skip it. Otherwise, we add it into set and send it to tableWorker. - -### Cost Model -Cost model will consider three factors: IO, CPU, and Network. - -- `IndexMergeType` = 1 - - IO Cost = (totalRowCount + mergedRowCount) * scanFactor - - Network Cost = (totalRowCount + mergedRowCount) * networkFactor - - Cpu Memory Cost = totalRowCount * cpuFactor + totalRowCount * memoryFactor - -- `IndexMergeType` = 2 - - IO Cost = (totalRowCount) * scanFactor - - Network Cost = totalRowCount * networkFactor - - Cpu Memory Cost = totalRowCount * cpuFactor + totalRowCount * memoryFactor - -- `IndexMergeType` = 3 - - IO Cost = (totalRowCount + mergedRowCount) * scanFactor - - Network Cost = (totalRowCount + mergedRowCount) * networkFactor - - Cpu Memory Cost = totalRowCount * cpuFactor + mergedRowCount * memoryFactor - - -**Note**: - -- totalRowCount: sum of handles collected from every index scan. - -- mergedRowCount: number of handles after set operating. - -## Compatibility -This proposal has no effect on the compatibility. - -## Implementation -1. Implement planner operators -1. Enhance `explain` to display the plan -3. Implement executor operators -4. Testing + IndexMergeIntersection + We take an example to explain it. Use set1 to record rowids which are returned by ix1 but not sent to tableWorker. Use set2 to record the same thing for ix2. +If new rowid comes from ix1, first we check if it is in set2. If so, we delete it from set2 and send it to tableWorker. Otherwise, we add it into set1. For the above figure, we use the following table to show the processing. + +| new rowid | set1 | set2 | sent to TableWorker | +| :------:| :------: | :------: | :------: | +| 2(ix1) | [2] | [ ] | [ ] | +| 1(ix2) | [2] | [1] | [ ] | +| 5(ix1) | [2,5] |[1] | [ ] | +| 5(ix2) | [2] | [1] | [5] | +| 7(ix1) | [2,7] |[1] | [ ] | +| 2(ix2) | [7] | [1] | [2] | + + + +(2) IndexMergeUnion + + + We take a structure(we call it set) to record which rowids are accessed. If a new rowid returned by IndexScan, check if it is in set. If in it, we just skip it. Otherwise, we add it into set and send it to tableWorker. + +### Cost Model +Cost model will consider three factors: IO, CPU, and Network. + +- `IndexMergeType` = 1 + + IO Cost = (totalRowCount + mergedRowCount) * scanFactor + + Network Cost = (totalRowCount + mergedRowCount) * networkFactor + + Cpu Memory Cost = totalRowCount * cpuFactor + totalRowCount * memoryFactor + +- `IndexMergeType` = 2 + + IO Cost = (totalRowCount) * scanFactor + + Network Cost = totalRowCount * networkFactor + + Cpu Memory Cost = totalRowCount * cpuFactor + totalRowCount * memoryFactor + +- `IndexMergeType` = 3 + + IO Cost = (totalRowCount + mergedRowCount) * scanFactor + + Network Cost = (totalRowCount + mergedRowCount) * networkFactor + + Cpu Memory Cost = totalRowCount * cpuFactor + mergedRowCount * memoryFactor + + +**Note**: + +- totalRowCount: sum of handles collected from every index scan. + +- mergedRowCount: number of handles after set operating. + +## Compatibility +This proposal has no effect on the compatibility. + +## Implementation +1. Implement planner operators +1. Enhance `explain` to display the plan +3. Implement executor operators +4. Testing diff --git a/domain/domain.go b/domain/domain.go index 8708162e6653d..a06d83aeaa5c6 100644 --- a/domain/domain.go +++ b/domain/domain.go @@ -44,6 +44,7 @@ import ( "github.com/pingcap/tidb/statistics/handle" "github.com/pingcap/tidb/store/tikv" "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/expensivequery" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/zap" @@ -54,23 +55,24 @@ import ( // Domain represents a storage space. Different domains can use the same database name. // Multiple domains can be used in parallel without synchronization. type Domain struct { - store kv.Storage - infoHandle *infoschema.Handle - privHandle *privileges.Handle - bindHandle *bindinfo.BindHandle - statsHandle unsafe.Pointer - statsLease time.Duration - statsUpdating sync2.AtomicInt32 - ddl ddl.DDL - info *InfoSyncer - m sync.Mutex - SchemaValidator SchemaValidator - sysSessionPool *sessionPool - exit chan struct{} - etcdClient *clientv3.Client - wg sync.WaitGroup - gvc GlobalVariableCache - slowQuery *topNSlowQueries + store kv.Storage + infoHandle *infoschema.Handle + privHandle *privileges.Handle + bindHandle *bindinfo.BindHandle + statsHandle unsafe.Pointer + statsLease time.Duration + statsUpdating sync2.AtomicInt32 + ddl ddl.DDL + info *InfoSyncer + m sync.Mutex + SchemaValidator SchemaValidator + sysSessionPool *sessionPool + exit chan struct{} + etcdClient *clientv3.Client + wg sync.WaitGroup + gvc GlobalVariableCache + slowQuery *topNSlowQueries + expensiveQueryHandle *expensivequery.Handle } // loadInfoSchema loads infoschema at startTS into handle, usedSchemaVersion is the currently used @@ -535,6 +537,7 @@ func (do *Domain) mustReload() (exitLoop bool) { // Close closes the Domain and release its resource. func (do *Domain) Close() { + startTime := time.Now() if do.ddl != nil { terror.Log(do.ddl.Stop()) } @@ -548,7 +551,7 @@ func (do *Domain) Close() { do.sysSessionPool.Close() do.slowQuery.Close() do.wg.Wait() - logutil.Logger(context.Background()).Info("domain closed") + logutil.Logger(context.Background()).Info("domain closed", zap.Duration("take time", time.Since(startTime))) } type ddlCallback struct { @@ -793,7 +796,7 @@ func (do *Domain) LoadBindInfoLoop(ctx sessionctx.Context) error { ctx.GetSessionVars().InRestrictedSQL = true do.bindHandle = bindinfo.NewBindHandle(ctx) err := do.bindHandle.Update(true) - if err != nil { + if err != nil || bindinfo.Lease == 0 { return err } @@ -803,7 +806,6 @@ func (do *Domain) LoadBindInfoLoop(ctx sessionctx.Context) error { } func (do *Domain) loadBindInfoLoop() { - duration := 3 * time.Second do.wg.Add(1) go func() { defer do.wg.Done() @@ -812,7 +814,7 @@ func (do *Domain) loadBindInfoLoop() { select { case <-do.exit: return - case <-time.After(duration): + case <-time.After(bindinfo.Lease): } err := do.bindHandle.Update(false) if err != nil { @@ -823,7 +825,6 @@ func (do *Domain) loadBindInfoLoop() { } func (do *Domain) handleInvalidBindTaskLoop() { - handleInvalidTaskDuration := 3 * time.Second do.wg.Add(1) go func() { defer do.wg.Done() @@ -832,7 +833,7 @@ func (do *Domain) handleInvalidBindTaskLoop() { select { case <-do.exit: return - case <-time.After(handleInvalidTaskDuration): + case <-time.After(bindinfo.Lease): } do.bindHandle.DropInvalidBindRecord() } @@ -874,6 +875,11 @@ func (do *Domain) UpdateTableStatsLoop(ctx sessionctx.Context) error { statsHandle := handle.NewHandle(ctx, do.statsLease) atomic.StorePointer(&do.statsHandle, unsafe.Pointer(statsHandle)) do.ddl.RegisterEventCh(statsHandle.DDLEventCh()) + // Negative stats lease indicates that it is in test, it does not need update. + if do.statsLease >= 0 { + do.wg.Add(1) + go do.loadStatsWorker() + } if do.statsLease <= 0 { return nil } @@ -905,22 +911,15 @@ func (do *Domain) newStatsOwner() owner.Manager { return statsOwner } -func (do *Domain) updateStatsWorker(ctx sessionctx.Context, owner owner.Manager) { - defer recoverInDomain("updateStatsWorker", false) +func (do *Domain) loadStatsWorker() { + defer recoverInDomain("loadStatsWorker", false) + defer do.wg.Done() lease := do.statsLease - deltaUpdateDuration := lease * 20 + if lease == 0 { + lease = 3 * time.Second + } loadTicker := time.NewTicker(lease) defer loadTicker.Stop() - deltaUpdateTicker := time.NewTicker(deltaUpdateDuration) - defer deltaUpdateTicker.Stop() - loadHistogramTicker := time.NewTicker(lease) - defer loadHistogramTicker.Stop() - gcStatsTicker := time.NewTicker(100 * lease) - defer gcStatsTicker.Stop() - dumpFeedbackTicker := time.NewTicker(200 * lease) - defer dumpFeedbackTicker.Stop() - loadFeedbackTicker := time.NewTicker(5 * lease) - defer loadFeedbackTicker.Stop() statsHandle := do.StatsHandle() t := time.Now() err := statsHandle.InitStats(do.InfoSchema()) @@ -929,10 +928,6 @@ func (do *Domain) updateStatsWorker(ctx sessionctx.Context, owner owner.Manager) } else { logutil.Logger(context.Background()).Info("init stats info time", zap.Duration("take time", time.Since(t))) } - defer func() { - do.SetStatsUpdating(false) - do.wg.Done() - }() for { select { case <-loadTicker.C: @@ -940,37 +935,60 @@ func (do *Domain) updateStatsWorker(ctx sessionctx.Context, owner owner.Manager) if err != nil { logutil.Logger(context.Background()).Debug("update stats info failed", zap.Error(err)) } + err = statsHandle.LoadNeededHistograms() + if err != nil { + logutil.Logger(context.Background()).Debug("load histograms failed", zap.Error(err)) + } + case <-do.exit: + return + } + } +} + +func (do *Domain) updateStatsWorker(ctx sessionctx.Context, owner owner.Manager) { + defer recoverInDomain("updateStatsWorker", false) + lease := do.statsLease + deltaUpdateTicker := time.NewTicker(20 * lease) + defer deltaUpdateTicker.Stop() + gcStatsTicker := time.NewTicker(100 * lease) + defer gcStatsTicker.Stop() + dumpFeedbackTicker := time.NewTicker(200 * lease) + defer dumpFeedbackTicker.Stop() + loadFeedbackTicker := time.NewTicker(5 * lease) + defer loadFeedbackTicker.Stop() + statsHandle := do.StatsHandle() + defer func() { + do.SetStatsUpdating(false) + do.wg.Done() + }() + for { + select { case <-do.exit: statsHandle.FlushStats() return // This channel is sent only by ddl owner. case t := <-statsHandle.DDLEventCh(): - err = statsHandle.HandleDDLEvent(t) + err := statsHandle.HandleDDLEvent(t) if err != nil { logutil.Logger(context.Background()).Debug("handle ddl event failed", zap.Error(err)) } case <-deltaUpdateTicker.C: - err = statsHandle.DumpStatsDeltaToKV(handle.DumpDelta) + err := statsHandle.DumpStatsDeltaToKV(handle.DumpDelta) if err != nil { logutil.Logger(context.Background()).Debug("dump stats delta failed", zap.Error(err)) } statsHandle.UpdateErrorRate(do.InfoSchema()) - case <-loadHistogramTicker.C: - err = statsHandle.LoadNeededHistograms() - if err != nil { - logutil.Logger(context.Background()).Debug("load histograms failed", zap.Error(err)) - } case <-loadFeedbackTicker.C: statsHandle.UpdateStatsByLocalFeedback(do.InfoSchema()) if !owner.IsOwner() { continue } - err = statsHandle.HandleUpdateStats(do.InfoSchema()) + err := statsHandle.HandleUpdateStats(do.InfoSchema()) if err != nil { logutil.Logger(context.Background()).Debug("update stats using feedback failed", zap.Error(err)) } case <-dumpFeedbackTicker.C: - err = statsHandle.DumpStatsFeedbackToKV() + err := statsHandle.DumpStatsFeedbackToKV() if err != nil { logutil.Logger(context.Background()).Debug("dump stats feedback failed", zap.Error(err)) } @@ -978,7 +996,7 @@ func (do *Domain) updateStatsWorker(ctx sessionctx.Context, owner owner.Manager) if !owner.IsOwner() { continue } - err = statsHandle.GCStats(do.InfoSchema(), do.DDL().GetLease()) + err := statsHandle.GCStats(do.InfoSchema(), do.DDL().GetLease()) if err != nil { logutil.Logger(context.Background()).Debug("GC stats failed", zap.Error(err)) } @@ -1006,6 +1024,16 @@ func (do *Domain) autoAnalyzeWorker(owner owner.Manager) { } } +// ExpensiveQueryHandle returns the expensive query handle. +func (do *Domain) ExpensiveQueryHandle() *expensivequery.Handle { + return do.expensiveQueryHandle +} + +// InitExpensiveQueryHandle init the expensive query handler. +func (do *Domain) InitExpensiveQueryHandle() { + do.expensiveQueryHandle = expensivequery.NewExpensiveQueryHandle(do.exit) +} + const privilegeKey = "/tidb/privilege" // NotifyUpdatePrivilege updates privilege key in etcd, TiDB client that watches diff --git a/domain/global_vars_cache.go b/domain/global_vars_cache.go index 89cc772f8fec7..93218b7546ce9 100644 --- a/domain/global_vars_cache.go +++ b/domain/global_vars_cache.go @@ -18,7 +18,9 @@ import ( "time" "github.com/pingcap/parser/ast" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/stmtsummary" ) // GlobalVariableCache caches global variables. @@ -41,6 +43,8 @@ func (gvc *GlobalVariableCache) Update(rows []chunk.Row, fields []*ast.ResultFie gvc.rows = rows gvc.fields = fields gvc.Unlock() + + checkEnableStmtSummary(rows, fields) } // Get gets the global variables from cache. @@ -63,6 +67,28 @@ func (gvc *GlobalVariableCache) Disable() { return } +// checkEnableStmtSummary looks for TiDBEnableStmtSummary and notifies StmtSummary +func checkEnableStmtSummary(rows []chunk.Row, fields []*ast.ResultField) { + for _, row := range rows { + varName := row.GetString(0) + if varName == variable.TiDBEnableStmtSummary { + varVal := row.GetDatum(1, &fields[1].Column.FieldType) + + sVal := "" + if !varVal.IsNull() { + var err error + sVal, err = varVal.ToString() + if err != nil { + return + } + } + + stmtsummary.StmtSummaryByDigestMap.SetEnabled(sVal, false) + break + } + } +} + // GetGlobalVarsCache gets the global variable cache. func (do *Domain) GetGlobalVarsCache() *GlobalVariableCache { return &do.gvc diff --git a/domain/global_vars_cache_test.go b/domain/global_vars_cache_test.go index 11cb0c95a32f1..f3e3d7d654bae 100644 --- a/domain/global_vars_cache_test.go +++ b/domain/global_vars_cache_test.go @@ -21,9 +21,11 @@ import ( "github.com/pingcap/parser/charset" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/stmtsummary" "github.com/pingcap/tidb/util/testleak" ) @@ -96,3 +98,47 @@ func getResultField(colName string, id, offset int) *ast.ResultField { DBName: model.NewCIStr("test"), } } + +func (gvcSuite *testGVCSuite) TestCheckEnableStmtSummary(c *C) { + defer testleak.AfterTest(c)() + testleak.BeforeTest() + + store, err := mockstore.NewMockTikvStore() + c.Assert(err, IsNil) + defer store.Close() + ddlLease := 50 * time.Millisecond + dom := NewDomain(store, ddlLease, 0, mockFactory) + err = dom.Init(ddlLease, sysMockFactory) + c.Assert(err, IsNil) + defer dom.Close() + + gvc := dom.GetGlobalVarsCache() + + rf := getResultField("c", 1, 0) + rf1 := getResultField("c1", 2, 1) + ft := &types.FieldType{ + Tp: mysql.TypeString, + Charset: charset.CharsetBin, + Collate: charset.CollationBin, + } + ft1 := &types.FieldType{ + Tp: mysql.TypeString, + Charset: charset.CharsetBin, + Collate: charset.CollationBin, + } + + stmtsummary.StmtSummaryByDigestMap.SetEnabled("0", false) + ck := chunk.NewChunkWithCapacity([]*types.FieldType{ft, ft1}, 1024) + ck.AppendString(0, variable.TiDBEnableStmtSummary) + ck.AppendString(1, "1") + row := ck.GetRow(0) + gvc.Update([]chunk.Row{row}, []*ast.ResultField{rf, rf1}) + c.Assert(stmtsummary.StmtSummaryByDigestMap.Enabled(), Equals, true) + + ck = chunk.NewChunkWithCapacity([]*types.FieldType{ft, ft1}, 1024) + ck.AppendString(0, variable.TiDBEnableStmtSummary) + ck.AppendString(1, "0") + row = ck.GetRow(0) + gvc.Update([]chunk.Row{row}, []*ast.ResultField{rf, rf1}) + c.Assert(stmtsummary.StmtSummaryByDigestMap.Enabled(), Equals, false) +} diff --git a/domain/info.go b/domain/info.go index 49d614c295ab3..0831c267f8a7f 100644 --- a/domain/info.go +++ b/domain/info.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/config" - "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/owner" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" @@ -129,7 +129,7 @@ func (is *InfoSyncer) storeServerInfo(ctx context.Context) error { return errors.Trace(err) } str := string(hack.String(infoBuf)) - err = ddl.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, is.serverInfoPath, str, clientv3.WithLease(is.session.Lease())) + err = util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, is.serverInfoPath, str, clientv3.WithLease(is.session.Lease())) return err } @@ -138,7 +138,7 @@ func (is *InfoSyncer) RemoveServerInfo() { if is.etcdCli == nil { return } - err := ddl.DeleteKeyFromEtcd(is.serverInfoPath, is.etcdCli, keyOpDefaultRetryCnt, keyOpDefaultTimeout) + err := util.DeleteKeyFromEtcd(is.serverInfoPath, is.etcdCli, keyOpDefaultRetryCnt, keyOpDefaultTimeout) if err != nil { logutil.Logger(context.Background()).Error("remove server info failed", zap.Error(err)) } diff --git a/domain/schema_validator.go b/domain/schema_validator.go index 53a9bf4977857..17c6763d28f1d 100644 --- a/domain/schema_validator.go +++ b/domain/schema_validator.go @@ -18,6 +18,8 @@ import ( "sync" "time" + "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" @@ -73,7 +75,7 @@ func NewSchemaValidator(lease time.Duration) SchemaValidator { return &schemaValidator{ isStarted: true, lease: lease, - deltaSchemaInfos: make([]deltaSchemaInfo, 0, maxNumberOfDiffsToLoad), + deltaSchemaInfos: make([]deltaSchemaInfo, 0, variable.DefTiDBMaxDeltaSchemaCount), } } @@ -86,26 +88,29 @@ func (s *schemaValidator) IsStarted() bool { func (s *schemaValidator) Stop() { logutil.Logger(context.Background()).Info("the schema validator stops") + metrics.LoadSchemaCounter.WithLabelValues(metrics.SchemaValidatorStop).Inc() s.mux.Lock() defer s.mux.Unlock() s.isStarted = false s.latestSchemaVer = 0 - s.deltaSchemaInfos = make([]deltaSchemaInfo, 0, maxNumberOfDiffsToLoad) + s.deltaSchemaInfos = s.deltaSchemaInfos[:0] } func (s *schemaValidator) Restart() { logutil.Logger(context.Background()).Info("the schema validator restarts") + metrics.LoadSchemaCounter.WithLabelValues(metrics.SchemaValidatorRestart).Inc() s.mux.Lock() defer s.mux.Unlock() s.isStarted = true } func (s *schemaValidator) Reset() { + metrics.LoadSchemaCounter.WithLabelValues(metrics.SchemaValidatorReset).Inc() s.mux.Lock() defer s.mux.Unlock() s.isStarted = true s.latestSchemaVer = 0 - s.deltaSchemaInfos = make([]deltaSchemaInfo, 0, maxNumberOfDiffsToLoad) + s.deltaSchemaInfos = s.deltaSchemaInfos[:0] } func (s *schemaValidator) Update(leaseGrantTS uint64, oldVer, currVer int64, changedTableIDs []int64) { @@ -147,13 +152,17 @@ func hasRelatedTableID(relatedTableIDs, updateTableIDs []int64) bool { // NOTE, this function should be called under lock! func (s *schemaValidator) isRelatedTablesChanged(currVer int64, tableIDs []int64) bool { if len(s.deltaSchemaInfos) == 0 { + metrics.LoadSchemaCounter.WithLabelValues(metrics.SchemaValidatorCacheEmpty).Inc() logutil.Logger(context.Background()).Info("schema change history is empty", zap.Int64("currVer", currVer)) return true } newerDeltas := s.findNewerDeltas(currVer) if len(newerDeltas) == len(s.deltaSchemaInfos) { - logutil.Logger(context.Background()).Info("the schema version is much older than the latest version", zap.Int64("currVer", currVer), - zap.Int64("latestSchemaVer", s.latestSchemaVer)) + metrics.LoadSchemaCounter.WithLabelValues(metrics.SchemaValidatorCacheMiss).Inc() + logutil.Logger(context.Background()).Info("the schema version is much older than the latest version", + zap.Int64("currVer", currVer), + zap.Int64("latestSchemaVer", s.latestSchemaVer), + zap.Reflect("deltas", newerDeltas)) return true } for _, item := range newerDeltas { @@ -209,8 +218,54 @@ func (s *schemaValidator) Check(txnTS uint64, schemaVer int64, relatedTableIDs [ } func (s *schemaValidator) enqueue(schemaVersion int64, relatedTableIDs []int64) { - s.deltaSchemaInfos = append(s.deltaSchemaInfos, deltaSchemaInfo{schemaVersion, relatedTableIDs}) - if len(s.deltaSchemaInfos) > maxNumberOfDiffsToLoad { + maxCnt := int(variable.GetMaxDeltaSchemaCount()) + if maxCnt <= 0 { + logutil.Logger(context.Background()).Info("the schema validator enqueue", zap.Int("delta max count", maxCnt)) + return + } + + delta := deltaSchemaInfo{schemaVersion, relatedTableIDs} + if len(s.deltaSchemaInfos) == 0 { + s.deltaSchemaInfos = append(s.deltaSchemaInfos, delta) + return + } + + lastOffset := len(s.deltaSchemaInfos) - 1 + // The first item we needn't to merge, because we hope to cover more versions. + if lastOffset != 0 && ids(s.deltaSchemaInfos[lastOffset].relatedTableIDs).containIn(delta.relatedTableIDs) { + s.deltaSchemaInfos[lastOffset] = delta + } else { + s.deltaSchemaInfos = append(s.deltaSchemaInfos, delta) + } + + if len(s.deltaSchemaInfos) > maxCnt { + logutil.Logger(context.Background()).Info("the schema validator enqueue, queue is too long", + zap.Int("delta max count", maxCnt), zap.Int64("remove schema version", s.deltaSchemaInfos[0].schemaVersion)) s.deltaSchemaInfos = s.deltaSchemaInfos[1:] } } + +type ids []int64 + +// containIn is checks if a is included in b. +func (a ids) containIn(b []int64) bool { + if len(a) > len(b) { + return false + } + + var isEqual bool + for _, i := range a { + isEqual = false + for _, j := range b { + if i == j { + isEqual = true + break + } + } + if !isEqual { + return false + } + } + + return true +} diff --git a/domain/schema_validator_test.go b/domain/schema_validator_test.go index d5e2193b25b39..46332b1742cb6 100644 --- a/domain/schema_validator_test.go +++ b/domain/schema_validator_test.go @@ -19,6 +19,7 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/util/testleak" ) @@ -141,3 +142,61 @@ func serverFunc(lease time.Duration, requireLease chan leaseGrantItem, oracleCh } } } + +func (*testSuite) TestEnqueue(c *C) { + lease := 10 * time.Millisecond + originalCnt := variable.GetMaxDeltaSchemaCount() + defer variable.SetMaxDeltaSchemaCount(originalCnt) + + validator := NewSchemaValidator(lease).(*schemaValidator) + c.Assert(validator.IsStarted(), IsTrue) + // maxCnt is 0. + variable.SetMaxDeltaSchemaCount(0) + validator.enqueue(1, []int64{11}) + c.Assert(validator.deltaSchemaInfos, HasLen, 0) + + // maxCnt is 10. + variable.SetMaxDeltaSchemaCount(10) + ds := []deltaSchemaInfo{ + {0, []int64{1}}, + {1, []int64{1}}, + {2, []int64{1}}, + {3, []int64{2, 2}}, + {4, []int64{2}}, + {5, []int64{1, 4}}, + {6, []int64{1, 4}}, + {7, []int64{3, 1, 3}}, + {8, []int64{1, 2, 3}}, + {9, []int64{1, 2, 3}}, + } + for _, d := range ds { + validator.enqueue(d.schemaVersion, d.relatedTableIDs) + } + validator.enqueue(10, []int64{1}) + ret := []deltaSchemaInfo{ + {0, []int64{1}}, + {2, []int64{1}}, + {3, []int64{2, 2}}, + {4, []int64{2}}, + {6, []int64{1, 4}}, + {9, []int64{1, 2, 3}}, + {10, []int64{1}}, + } + c.Assert(validator.deltaSchemaInfos, DeepEquals, ret) + // The Items' relatedTableIDs have different order. + validator.enqueue(11, []int64{1, 2, 3, 4}) + validator.enqueue(12, []int64{4, 1, 2, 3, 1}) + validator.enqueue(13, []int64{4, 1, 3, 2, 5}) + ret[len(ret)-1] = deltaSchemaInfo{13, []int64{4, 1, 3, 2, 5}} + c.Assert(validator.deltaSchemaInfos, DeepEquals, ret) + // The length of deltaSchemaInfos is greater then maxCnt. + validator.enqueue(14, []int64{1}) + validator.enqueue(15, []int64{2}) + validator.enqueue(16, []int64{3}) + validator.enqueue(17, []int64{4}) + ret = append(ret, deltaSchemaInfo{14, []int64{1}}) + ret = append(ret, deltaSchemaInfo{15, []int64{2}}) + ret = append(ret, deltaSchemaInfo{16, []int64{3}}) + ret = append(ret, deltaSchemaInfo{17, []int64{4}}) + c.Assert(validator.deltaSchemaInfos, DeepEquals, ret[1:]) +} diff --git a/domain/topn_slow_query.go b/domain/topn_slow_query.go index ce1da7b885ede..bc9e2d84f4030 100644 --- a/domain/topn_slow_query.go +++ b/domain/topn_slow_query.go @@ -213,17 +213,17 @@ func (q *topNSlowQueries) Close() { // SlowQueryInfo is a struct to record slow query info. type SlowQueryInfo struct { - SQL string - Start time.Time - Duration time.Duration - Detail execdetails.ExecDetails - Succ bool - ConnID uint64 - TxnTS uint64 - User string - DB string - TableIDs string - IndexIDs string - Internal bool - Digest string + SQL string + Start time.Time + Duration time.Duration + Detail execdetails.ExecDetails + Succ bool + ConnID uint64 + TxnTS uint64 + User string + DB string + TableIDs string + IndexNames string + Internal bool + Digest string } diff --git a/executor/adapter.go b/executor/adapter.go index e172812bdda11..deb91247c48b2 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -40,16 +40,21 @@ import ( "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/store/tikv" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/stmtsummary" + "github.com/pingcap/tidb/util/stringutil" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) // processinfoSetter is the interface use to set current running process info. type processinfoSetter interface { - SetProcessInfo(string, time.Time, byte) + SetProcessInfo(string, time.Time, byte, uint64) } // recordSet wraps an executor, implements sqlexec.RecordSet interface @@ -99,13 +104,13 @@ func schema2ResultFields(schema *expression.Schema, defaultDB string) (rfs []*as // The reason we need update is that chunk with 0 rows indicating we already finished current query, we need prepare for // next query. // If stmt is not nil and chunk with some rows inside, we simply update last query found rows by the number of row in chunk. -func (a *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (a *recordSet) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("recordSet.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } - err := a.executor.Next(ctx, req) + err := Next(ctx, a.executor, req) if err != nil { a.lastErr = err return err @@ -123,26 +128,35 @@ func (a *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { return nil } -// NewRecordBatch create a recordBatch base on top-level executor's newFirstChunk(). -func (a *recordSet) NewRecordBatch() *chunk.RecordBatch { - return chunk.NewRecordBatch(a.executor.newFirstChunk()) +// NewChunk create a chunk base on top-level executor's newFirstChunk(). +func (a *recordSet) NewChunk() *chunk.Chunk { + return newFirstChunk(a.executor) } func (a *recordSet) Close() error { err := a.executor.Close() - a.stmt.LogSlowQuery(a.txnStartTS, a.lastErr == nil) + a.stmt.LogSlowQuery(a.txnStartTS, a.lastErr == nil, false) + sessVars := a.stmt.Ctx.GetSessionVars() + pps := types.CloneRow(sessVars.PreparedParams) + sessVars.PrevStmt = FormatSQL(a.stmt.OriginText(), pps) a.stmt.logAudit() + a.stmt.SummaryStmt() return err } +// OnFetchReturned implements commandLifeCycle#OnFetchReturned +func (a *recordSet) OnFetchReturned() { + a.stmt.LogSlowQuery(a.txnStartTS, a.lastErr == nil, true) +} + // ExecStmt implements the sqlexec.Statement interface, it builds a planner.Plan to an sqlexec.Statement. type ExecStmt struct { // InfoSchema stores a reference to the schema information. InfoSchema infoschema.InfoSchema // Plan stores a reference to the final physical plan. Plan plannercore.Plan - // Expensive represents whether this query is an expensive one. - Expensive bool + // LowerPriority represents whether to lower the execution priority of a query. + LowerPriority bool // Cacheable represents whether the physical plan can be cached. Cacheable bool // Text represents the origin query text. @@ -150,9 +164,7 @@ type ExecStmt struct { StmtNode ast.StmtNode - Ctx sessionctx.Context - // StartTime stands for the starting time when executing the statement. - StartTime time.Time + Ctx sessionctx.Context isPreparedStmt bool isSelectForUpdate bool retryCount uint @@ -185,13 +197,18 @@ func (a *ExecStmt) IsReadOnly(vars *variable.SessionVars) bool { // RebuildPlan rebuilds current execute statement plan. // It returns the current information schema version that 'a' is using. -func (a *ExecStmt) RebuildPlan() (int64, error) { +func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) { + startTime := time.Now() + defer func() { + a.Ctx.GetSessionVars().DurationCompile = time.Since(startTime) + }() + is := GetInfoSchema(a.Ctx) a.InfoSchema = is if err := plannercore.Preprocess(a.Ctx, a.StmtNode, is, plannercore.InTxnRetry); err != nil { return 0, err } - p, err := planner.Optimize(a.Ctx, a.StmtNode, is) + p, err := planner.Optimize(ctx, a.Ctx, a.StmtNode, is) if err != nil { return 0, err } @@ -202,8 +219,19 @@ func (a *ExecStmt) RebuildPlan() (int64, error) { // Exec builds an Executor from a plan. If the Executor doesn't return result, // like the INSERT, UPDATE statements, it executes in this function, if the Executor returns // result, execution is done after this function returns, in the returned sqlexec.RecordSet Next method. -func (a *ExecStmt) Exec(ctx context.Context) (sqlexec.RecordSet, error) { - a.StartTime = time.Now() +func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) { + defer func() { + r := recover() + if r == nil { + return + } + if str, ok := r.(string); !ok || !strings.HasPrefix(str, memory.PanicMemoryExceed) { + panic(r) + } + err = errors.Errorf("%v", r) + logutil.Logger(ctx).Error("execute sql panic", zap.String("sql", a.Text), zap.Stack("stack")) + }() + sctx := a.Ctx if _, ok := a.Plan.(*plannercore.Analyze); ok && sctx.GetSessionVars().InRestrictedSQL { oriStats, _ := sctx.GetSessionVars().GetSystemVar(variable.TiDBBuildStatsConcurrency) @@ -244,9 +272,12 @@ func (a *ExecStmt) Exec(ctx context.Context) (sqlexec.RecordSet, error) { sql = ss.SecureText() } } + maxExecutionTime := getMaxExecutionTime(sctx, a.StmtNode) // Update processinfo, ShowProcess() will use it. - pi.SetProcessInfo(sql, time.Now(), cmd) - a.Ctx.GetSessionVars().StmtCtx.StmtType = GetStmtLabel(a.StmtNode) + pi.SetProcessInfo(sql, time.Now(), cmd, maxExecutionTime) + if a.Ctx.GetSessionVars().StmtCtx.StmtType == "" { + a.Ctx.GetSessionVars().StmtCtx.StmtType = GetStmtLabel(a.StmtNode) + } } isPessimistic := sctx.GetSessionVars().TxnCtx.IsPessimistic @@ -284,6 +315,20 @@ func (a *ExecStmt) Exec(ctx context.Context) (sqlexec.RecordSet, error) { }, nil } +// getMaxExecutionTime get the max execution timeout value. +func getMaxExecutionTime(sctx sessionctx.Context, stmtNode ast.StmtNode) uint64 { + ret := sctx.GetSessionVars().MaxExecutionTime + if sel, ok := stmtNode.(*ast.SelectStmt); ok { + for _, hint := range sel.TableHints { + if hint.HintName.L == variable.MaxExecutionTime { + ret = hint.MaxExecutionTime + break + } + } + } + return ret +} + type chunkRowRecordSet struct { rows []chunk.Row idx int @@ -295,8 +340,7 @@ func (c *chunkRowRecordSet) Fields() []*ast.ResultField { return c.fields } -func (c *chunkRowRecordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { - chk := req.Chunk +func (c *chunkRowRecordSet) Next(ctx context.Context, chk *chunk.Chunk) error { chk.Reset() for !chk.IsFull() && c.idx < len(c.rows) { chk.AppendRow(c.rows[c.idx]) @@ -305,8 +349,8 @@ func (c *chunkRowRecordSet) Next(ctx context.Context, req *chunk.RecordBatch) er return nil } -func (c *chunkRowRecordSet) NewRecordBatch() *chunk.RecordBatch { - return chunk.NewRecordBatch(c.e.newFirstChunk()) +func (c *chunkRowRecordSet) NewChunk() *chunk.Chunk { + return newFirstChunk(c.e) } func (c *chunkRowRecordSet) Close() error { @@ -338,7 +382,7 @@ func (a *ExecStmt) runPessimisticSelectForUpdate(ctx context.Context, e Executor var rows []chunk.Row var err error fields := rs.Fields() - req := rs.NewRecordBatch() + req := rs.NewChunk() for { err = rs.Next(ctx, req) if err != nil { @@ -348,11 +392,11 @@ func (a *ExecStmt) runPessimisticSelectForUpdate(ctx context.Context, e Executor if req.NumRows() == 0 { return &chunkRowRecordSet{rows: rows, fields: fields, e: e}, nil } - iter := chunk.NewIterator4Chunk(req.Chunk) + iter := chunk.NewIterator4Chunk(req) for r := iter.Begin(); r != iter.End(); r = iter.Next() { rows = append(rows, r) } - req.Chunk = chunk.Renew(req.Chunk, a.Ctx.GetSessionVars().MaxChunkSize) + req = chunk.Renew(req, a.Ctx.GetSessionVars().MaxChunkSize) } return nil, err } @@ -384,7 +428,7 @@ func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, e Executor) (sqlex a.logAudit() }() - err = e.Next(ctx, chunk.NewRecordBatch(e.newFirstChunk())) + err = Next(ctx, e, newFirstChunk(e)) if err != nil { return nil, err } @@ -401,7 +445,12 @@ func (a *ExecStmt) handlePessimisticDML(ctx context.Context, e Executor) error { for { _, err = a.handleNoDelayExecutor(ctx, e) if err != nil { - return err + // It is possible the DML has point get plan that locks the key. + e, err = a.handlePessimisticLockError(ctx, err) + if err != nil { + return err + } + continue } keys, err1 := txn.(pessimisticTxn).KeysNeedToLock() if err1 != nil { @@ -411,56 +460,72 @@ func (a *ExecStmt) handlePessimisticDML(ctx context.Context, e Executor) error { return nil } forUpdateTS := txnCtx.GetForUpdateTS() - err = txn.LockKeys(ctx, forUpdateTS, keys...) + err = txn.LockKeys(ctx, &sctx.GetSessionVars().Killed, forUpdateTS, keys...) + if err == nil { + return nil + } e, err = a.handlePessimisticLockError(ctx, err) if err != nil { return err } - if e == nil { - return nil - } } } // handlePessimisticLockError updates TS and rebuild executor if the err is write conflict. func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, err error) (Executor, error) { - if err == nil { - return nil, nil - } - if !terror.ErrorEqual(kv.ErrWriteConflict, err) { + txnCtx := a.Ctx.GetSessionVars().TxnCtx + var newForUpdateTS uint64 + if deadlock, ok := errors.Cause(err).(*tikv.ErrDeadlock); ok { + if !deadlock.IsRetryable { + return nil, ErrDeadlock + } + logutil.Logger(ctx).Info("single statement deadlock, retry statement", + zap.Uint64("txn", txnCtx.StartTS), + zap.Uint64("lockTS", deadlock.LockTs), + zap.Binary("lockKey", deadlock.LockKey), + zap.Uint64("deadlockKeyHash", deadlock.DeadlockKeyHash)) + } else if terror.ErrorEqual(kv.ErrWriteConflict, err) { + conflictCommitTS := extractConflictCommitTS(err.Error()) + if conflictCommitTS == 0 { + logutil.Logger(ctx).Warn("failed to extract conflictCommitTS from a conflict error") + } + forUpdateTS := txnCtx.GetForUpdateTS() + logutil.Logger(ctx).Info("pessimistic write conflict, retry statement", + zap.Uint64("txn", txnCtx.StartTS), + zap.Uint64("forUpdateTS", forUpdateTS), + zap.Uint64("conflictCommitTS", conflictCommitTS)) + if conflictCommitTS > forUpdateTS { + newForUpdateTS = conflictCommitTS + } + } else { return nil, err } if a.retryCount >= config.GetGlobalConfig().PessimisticTxn.MaxRetryCount { return nil, errors.New("pessimistic lock retry limit reached") } a.retryCount++ - conflictCommitTS := extractConflictCommitTS(err.Error()) - if conflictCommitTS == 0 { - logutil.Logger(ctx).Warn("failed to extract conflictCommitTS from a conflict error") - } - sctx := a.Ctx - txnCtx := sctx.GetSessionVars().TxnCtx - forUpdateTS := txnCtx.GetForUpdateTS() - logutil.Logger(ctx).Info("pessimistic write conflict, retry statement", - zap.Uint64("txn", txnCtx.StartTS), - zap.Uint64("forUpdateTS", forUpdateTS), - zap.Uint64("conflictCommitTS", conflictCommitTS)) - if conflictCommitTS > txnCtx.GetForUpdateTS() { - txnCtx.SetForUpdateTS(conflictCommitTS) - } else { - ts, err1 := sctx.GetStore().GetOracle().GetTimestamp(ctx) - if err1 != nil { - return nil, err1 + if newForUpdateTS == 0 { + newForUpdateTS, err = a.Ctx.GetStore().GetOracle().GetTimestamp(ctx) + if err != nil { + return nil, err } - txnCtx.SetForUpdateTS(ts) } + txnCtx.SetForUpdateTS(newForUpdateTS) + txn, err := a.Ctx.Txn(true) + if err != nil { + return nil, err + } + txn.SetOption(kv.SnapshotTS, newForUpdateTS) e, err := a.buildExecutor() if err != nil { return nil, err } // Rollback the statement change before retry it. - sctx.StmtRollback() - sctx.GetSessionVars().StmtCtx.ResetForRetry() + a.Ctx.StmtRollback() + a.Ctx.GetSessionVars().StmtCtx.ResetForRetry() + a.Ctx.GetSessionVars().StartTime = time.Now() + a.Ctx.GetSessionVars().DurationCompile = time.Duration(0) + a.Ctx.GetSessionVars().DurationParse = time.Duration(0) if err = e.Open(ctx); err != nil { return nil, err @@ -519,7 +584,7 @@ func (a *ExecStmt) buildExecutor() (Executor, error) { switch { case useMaxTS: stmtCtx.Priority = kv.PriorityHigh - case a.Expensive: + case a.LowerPriority: stmtCtx.Priority = kv.PriorityLow } } @@ -542,6 +607,9 @@ func (a *ExecStmt) buildExecutor() (Executor, error) { } a.isPreparedStmt = true a.Plan = executorExec.plan + if executorExec.lowerPriority { + ctx.GetSessionVars().StmtCtx.Priority = kv.PriorityLow + } e = executorExec.stmtExec } a.isSelectForUpdate = b.isSelectForUpdate @@ -560,7 +628,8 @@ func (a *ExecStmt) logAudit() { audit := plugin.DeclareAuditManifest(p.Manifest) if audit.OnGeneralEvent != nil { cmd := mysql.Command2Str[byte(atomic.LoadUint32(&a.Ctx.GetSessionVars().CommandValue))] - audit.OnGeneralEvent(context.Background(), sessVars, plugin.Log, cmd) + ctx := context.WithValue(context.Background(), plugin.ExecStartTimeCtxKey, a.Ctx.GetSessionVars().StartTime) + audit.OnGeneralEvent(ctx, sessVars, plugin.Log, cmd) } return nil }) @@ -569,42 +638,66 @@ func (a *ExecStmt) logAudit() { } } +// FormatSQL is used to format the original SQL, e.g. truncating long SQL, appending prepared arguments. +func FormatSQL(sql string, pps variable.PreparedParams) stringutil.StringerFunc { + return func() string { + cfg := config.GetGlobalConfig() + length := len(sql) + if maxQueryLen := atomic.LoadUint64(&cfg.Log.QueryLogMaxLen); uint64(length) > maxQueryLen { + sql = fmt.Sprintf("%.*q(len:%d)", maxQueryLen, sql, length) + } + return QueryReplacer.Replace(sql) + pps.String() + } +} + // LogSlowQuery is used to print the slow query in the log files. -func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool) { +func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { sessVars := a.Ctx.GetSessionVars() level := log.GetLevel() - if level > zapcore.WarnLevel { - return - } cfg := config.GetGlobalConfig() - costTime := time.Since(a.StartTime) + costTime := time.Since(a.Ctx.GetSessionVars().StartTime) threshold := time.Duration(atomic.LoadUint64(&cfg.Log.SlowThreshold)) * time.Millisecond if costTime < threshold && level > zapcore.DebugLevel { return } - sql := a.Text - if maxQueryLen := atomic.LoadUint64(&cfg.Log.QueryLogMaxLen); uint64(len(sql)) > maxQueryLen { - sql = fmt.Sprintf("%.*q(len:%d)", maxQueryLen, sql, len(a.Text)) - } - sql = QueryReplacer.Replace(sql) + sessVars.GetExecuteArgumentsInfo() + sql := FormatSQL(a.Text, sessVars.PreparedParams) - var tableIDs, indexIDs string + var tableIDs, indexNames string if len(sessVars.StmtCtx.TableIDs) > 0 { tableIDs = strings.Replace(fmt.Sprintf("%v", a.Ctx.GetSessionVars().StmtCtx.TableIDs), " ", ",", -1) } - if len(sessVars.StmtCtx.IndexIDs) > 0 { - indexIDs = strings.Replace(fmt.Sprintf("%v", a.Ctx.GetSessionVars().StmtCtx.IndexIDs), " ", ",", -1) + if len(sessVars.StmtCtx.IndexNames) > 0 { + indexNames = strings.Replace(fmt.Sprintf("%v", a.Ctx.GetSessionVars().StmtCtx.IndexNames), " ", ",", -1) } execDetail := sessVars.StmtCtx.GetExecDetails() copTaskInfo := sessVars.StmtCtx.CopTasksDetails() - statsInfos := a.getStatsInfo() + statsInfos := plannercore.GetStatsInfo(a.Plan) memMax := sessVars.StmtCtx.MemTracker.MaxConsumed() + _, digest := sessVars.StmtCtx.SQLDigest() + slowItems := &variable.SlowQueryLogItems{ + TxnTS: txnTS, + SQL: sql.String(), + Digest: digest, + TimeTotal: costTime, + TimeParse: a.Ctx.GetSessionVars().DurationParse, + TimeCompile: a.Ctx.GetSessionVars().DurationCompile, + IndexNames: indexNames, + StatsInfos: statsInfos, + CopTasks: copTaskInfo, + ExecDetail: execDetail, + MemMax: memMax, + Succ: succ, + Plan: getPlanTree(a.Plan), + Prepared: a.isPreparedStmt, + HasMoreResults: hasMoreResults, + } + if _, ok := a.StmtNode.(*ast.CommitStmt); ok { + slowItems.PrevStmt = sessVars.PrevStmt.String() + } if costTime < threshold { - _, digest := sessVars.StmtCtx.SQLDigest() - logutil.SlowQueryLogger.Debug(sessVars.SlowLogFormat(txnTS, costTime, execDetail, indexIDs, digest, statsInfos, copTaskInfo, memMax, sql)) + logutil.SlowQueryLogger.Debug(sessVars.SlowLogFormat(slowItems)) } else { - _, digest := sessVars.StmtCtx.SQLDigest() - logutil.SlowQueryLogger.Warn(sessVars.SlowLogFormat(txnTS, costTime, execDetail, indexIDs, digest, statsInfos, copTaskInfo, memMax, sql)) + logutil.SlowQueryLogger.Warn(sessVars.SlowLogFormat(slowItems)) metrics.TotalQueryProcHistogram.Observe(costTime.Seconds()) metrics.TotalCopProcHistogram.Observe(execDetail.ProcessTime.Seconds()) metrics.TotalCopWaitHistogram.Observe(execDetail.WaitTime.Seconds()) @@ -613,43 +706,71 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool) { userString = sessVars.User.String() } domain.GetDomain(a.Ctx).LogSlowQuery(&domain.SlowQueryInfo{ - SQL: sql, - Digest: digest, - Start: a.StartTime, - Duration: costTime, - Detail: sessVars.StmtCtx.GetExecDetails(), - Succ: succ, - ConnID: sessVars.ConnectionID, - TxnTS: txnTS, - User: userString, - DB: sessVars.CurrentDB, - TableIDs: tableIDs, - IndexIDs: indexIDs, - Internal: sessVars.InRestrictedSQL, + SQL: sql.String(), + Digest: digest, + Start: a.Ctx.GetSessionVars().StartTime, + Duration: costTime, + Detail: sessVars.StmtCtx.GetExecDetails(), + Succ: succ, + ConnID: sessVars.ConnectionID, + TxnTS: txnTS, + User: userString, + DB: sessVars.CurrentDB, + TableIDs: tableIDs, + IndexNames: indexNames, + Internal: sessVars.InRestrictedSQL, }) } } -func (a *ExecStmt) getStatsInfo() map[string]uint64 { - var physicalPlan plannercore.PhysicalPlan - switch p := a.Plan.(type) { - case *plannercore.Insert: - physicalPlan = p.SelectPlan - case *plannercore.Update: - physicalPlan = p.SelectPlan - case *plannercore.Delete: - physicalPlan = p.SelectPlan - case plannercore.PhysicalPlan: - physicalPlan = p +// getPlanTree will try to get the select plan tree if the plan is select or the select plan of delete/update/insert statement. +func getPlanTree(p plannercore.Plan) string { + cfg := config.GetGlobalConfig() + if atomic.LoadUint32(&cfg.Log.RecordPlanInSlowLog) == 0 { + return "" } - - if physicalPlan == nil { - return nil + var selectPlan plannercore.PhysicalPlan + if physicalPlan, ok := p.(plannercore.PhysicalPlan); ok { + selectPlan = physicalPlan + } else { + switch x := p.(type) { + case *plannercore.Delete: + selectPlan = x.SelectPlan + case *plannercore.Update: + selectPlan = x.SelectPlan + case *plannercore.Insert: + selectPlan = x.SelectPlan + } } + if selectPlan == nil { + return "" + } + planTree := plannercore.EncodePlan(selectPlan) + if len(planTree) == 0 { + return planTree + } + return variable.SlowLogPlanPrefix + planTree + variable.SlowLogPlanSuffix +} - statsInfos := make(map[string]uint64) - statsInfos = plannercore.CollectPlanStatsVersion(physicalPlan, statsInfos) - return statsInfos +// SummaryStmt collects statements for performance_schema.events_statements_summary_by_digest +func (a *ExecStmt) SummaryStmt() { + sessVars := a.Ctx.GetSessionVars() + if sessVars.InRestrictedSQL || !stmtsummary.StmtSummaryByDigestMap.Enabled() { + return + } + stmtCtx := sessVars.StmtCtx + normalizedSQL, digest := stmtCtx.SQLDigest() + costTime := time.Since(sessVars.StartTime) + stmtsummary.StmtSummaryByDigestMap.AddStatement(&stmtsummary.StmtExecInfo{ + SchemaName: sessVars.CurrentDB, + OriginalSQL: a.Text, + NormalizedSQL: normalizedSQL, + Digest: digest, + TotalLatency: uint64(costTime.Nanoseconds()), + AffectedRows: stmtCtx.AffectedRows(), + SentRows: 0, + StartTime: sessVars.StartTime, + }) } // IsPointGetWithPKOrUniqueKeyByAutoCommit returns true when meets following conditions: diff --git a/executor/adapter_test.go b/executor/adapter_test.go new file mode 100644 index 0000000000000..32f288e28d8f4 --- /dev/null +++ b/executor/adapter_test.go @@ -0,0 +1,37 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/testkit" +) + +func (s *testSuite) TestQueryTime(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + costTime := time.Since(tk.Se.GetSessionVars().StartTime) + c.Assert(costTime < 1*time.Second, IsTrue) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t values(1), (1), (1), (1), (1)") + tk.MustExec("select * from t t1 join t t2 on t1.a = t2.a") + + costTime = time.Since(tk.Se.GetSessionVars().StartTime) + c.Assert(costTime < 1*time.Second, IsTrue) +} diff --git a/executor/admin.go b/executor/admin.go index e6721dd976f89..a177c9fff1e5f 100644 --- a/executor/admin.go +++ b/executor/admin.go @@ -60,7 +60,7 @@ type CheckIndexRangeExec struct { } // Next implements the Executor Next interface. -func (e *CheckIndexRangeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *CheckIndexRangeExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() handleIdx := e.schema.Len() - 1 for { @@ -102,7 +102,7 @@ func (e *CheckIndexRangeExec) Open(ctx context.Context) error { FieldType: *colTypeForHandle, }) - e.srcChunk = e.newFirstChunk() + e.srcChunk = newFirstChunk(e) dagPB, err := e.buildDAGPB() if err != nil { return err @@ -431,7 +431,7 @@ func (e *RecoverIndexExec) backfillIndexInTxn(ctx context.Context, txn kv.Transa } recordKey := e.table.RecordKey(row.handle) - err := txn.LockKeys(ctx, 0, recordKey) + err := txn.LockKeys(ctx, nil, 0, recordKey) if err != nil { return result, err } @@ -446,7 +446,7 @@ func (e *RecoverIndexExec) backfillIndexInTxn(ctx context.Context, txn kv.Transa } // Next implements the Executor Next interface. -func (e *RecoverIndexExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *RecoverIndexExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.done { return nil @@ -582,7 +582,7 @@ func (e *CleanupIndexExec) fetchIndex(ctx context.Context, txn kv.Transaction) e } // Next implements the Executor Next interface. -func (e *CleanupIndexExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *CleanupIndexExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.done { return nil diff --git a/executor/admin_plugins.go b/executor/admin_plugins.go new file mode 100644 index 0000000000000..440c1c0852306 --- /dev/null +++ b/executor/admin_plugins.go @@ -0,0 +1,52 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/plugin" + "github.com/pingcap/tidb/util/chunk" +) + +// AdminPluginsExec indicates AdminPlugins executor. +type AdminPluginsExec struct { + baseExecutor + Action core.AdminPluginsAction + Plugins []string +} + +// Next implements the Executor Next interface. +func (e *AdminPluginsExec) Next(ctx context.Context, _ *chunk.Chunk) error { + switch e.Action { + case core.Enable: + return e.changeDisableFlagAndFlush(false) + case core.Disable: + return e.changeDisableFlagAndFlush(true) + } + return nil +} + +func (e *AdminPluginsExec) changeDisableFlagAndFlush(disabled bool) error { + dom := domain.GetDomain(e.ctx) + for _, pluginName := range e.Plugins { + err := plugin.ChangeDisableFlagAndFlush(dom, pluginName, disabled) + if err != nil { + return err + } + } + return nil +} diff --git a/executor/admin_test.go b/executor/admin_test.go index f2317570e23b7..85eed69a04177 100644 --- a/executor/admin_test.go +++ b/executor/admin_test.go @@ -91,10 +91,10 @@ func (s *testSuite2) TestAdminRecoverIndex(c *C) { c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) - _, err = tk.Exec("admin check table admin_test") + err = tk.ExecToErr("admin check table admin_test") c.Assert(err, NotNil) c.Assert(executor.ErrAdminCheckTable.Equal(err), IsTrue) - _, err = tk.Exec("admin check index admin_test c2") + err = tk.ExecToErr("admin check index admin_test c2") c.Assert(err, NotNil) r = tk.MustQuery("SELECT COUNT(*) FROM admin_test USE INDEX(c2)") @@ -115,7 +115,7 @@ func (s *testSuite2) TestAdminRecoverIndex(c *C) { err = txn.Commit(context.Background()) c.Assert(err, IsNil) - _, err = tk.Exec("admin check index admin_test c2") + err = tk.ExecToErr("admin check index admin_test c2") c.Assert(err, NotNil) r = tk.MustQuery("admin recover index admin_test c2") r.Check(testkit.Rows("1 5")) @@ -137,15 +137,15 @@ func (s *testSuite2) TestAdminRecoverIndex(c *C) { err = txn.Commit(context.Background()) c.Assert(err, IsNil) - _, err = tk.Exec("admin check table admin_test") + err = tk.ExecToErr("admin check table admin_test") c.Assert(err, NotNil) - _, err = tk.Exec("admin check index admin_test c2") + err = tk.ExecToErr("admin check index admin_test c2") c.Assert(err, NotNil) r = tk.MustQuery("SELECT COUNT(*) FROM admin_test USE INDEX(c2)") r.Check(testkit.Rows("0")) - r = tk.MustQuery("SELECT COUNT(*) FROM admin_test") + r = tk.MustQuery("SELECT COUNT(*) FROM admin_test USE INDEX()") r.Check(testkit.Rows("5")) r = tk.MustQuery("admin recover index admin_test c2") @@ -261,9 +261,9 @@ func (s *testSuite2) TestAdminCleanupIndex(c *C) { err = txn.Commit(context.Background()) c.Assert(err, IsNil) - _, err = tk.Exec("admin check table admin_test") + err = tk.ExecToErr("admin check table admin_test") c.Assert(err, NotNil) - _, err = tk.Exec("admin check index admin_test c2") + err = tk.ExecToErr("admin check index admin_test c2") c.Assert(err, NotNil) r = tk.MustQuery("SELECT COUNT(*) FROM admin_test USE INDEX(c2)") r.Check(testkit.Rows("11")) @@ -273,9 +273,9 @@ func (s *testSuite2) TestAdminCleanupIndex(c *C) { r.Check(testkit.Rows("6")) tk.MustExec("admin check index admin_test c2") - _, err = tk.Exec("admin check table admin_test") + err = tk.ExecToErr("admin check table admin_test") c.Assert(err, NotNil) - _, err = tk.Exec("admin check index admin_test c3") + err = tk.ExecToErr("admin check index admin_test c3") c.Assert(err, NotNil) r = tk.MustQuery("SELECT COUNT(*) FROM admin_test USE INDEX(c3)") r.Check(testkit.Rows("9")) @@ -322,9 +322,9 @@ func (s *testSuite2) TestAdminCleanupIndexPKNotHandle(c *C) { err = txn.Commit(context.Background()) c.Assert(err, IsNil) - _, err = tk.Exec("admin check table admin_test") + err = tk.ExecToErr("admin check table admin_test") c.Assert(err, NotNil) - _, err = tk.Exec("admin check index admin_test `primary`") + err = tk.ExecToErr("admin check index admin_test `primary`") c.Assert(err, NotNil) r = tk.MustQuery("SELECT COUNT(*) FROM admin_test USE INDEX(`primary`)") r.Check(testkit.Rows("6")) @@ -374,13 +374,13 @@ func (s *testSuite2) TestAdminCleanupIndexMore(c *C) { err = txn.Commit(context.Background()) c.Assert(err, IsNil) - _, err = tk.Exec("admin check table admin_test") + err = tk.ExecToErr("admin check table admin_test") c.Assert(err, NotNil) - _, err = tk.Exec("admin check index admin_test c1") + err = tk.ExecToErr("admin check index admin_test c1") c.Assert(err, NotNil) - _, err = tk.Exec("admin check index admin_test c2") + err = tk.ExecToErr("admin check index admin_test c2") c.Assert(err, NotNil) - r := tk.MustQuery("SELECT COUNT(*) FROM admin_test") + r := tk.MustQuery("SELECT COUNT(*) FROM admin_test USE INDEX()") r.Check(testkit.Rows("3")) r = tk.MustQuery("SELECT COUNT(*) FROM admin_test USE INDEX(c1)") r.Check(testkit.Rows("2003")) @@ -399,6 +399,207 @@ func (s *testSuite2) TestAdminCleanupIndexMore(c *C) { tk.MustExec("admin check table admin_test") } +func (s *testSuite2) TestAdminCheckTableFailed(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists admin_test") + tk.MustExec("create table admin_test (c1 int, c2 int, c3 varchar(255) default '1', primary key(c1), key(c3), unique key(c2), key(c2, c3))") + tk.MustExec("insert admin_test (c1, c2, c3) values (-10, -20, 'y'), (-1, -10, 'z'), (1, 11, 'a'), (2, 12, 'b'), (5, 15, 'c'), (10, 20, 'd'), (20, 30, 'e')") + + // Make some corrupted index. Build the index information. + s.ctx = mock.NewContext() + s.ctx.Store = s.store + is := s.domain.InfoSchema() + dbName := model.NewCIStr("test") + tblName := model.NewCIStr("admin_test") + tbl, err := is.TableByName(dbName, tblName) + c.Assert(err, IsNil) + tblInfo := tbl.Meta() + idxInfo := tblInfo.Indices[1] + indexOpr := tables.NewIndex(tblInfo.ID, tblInfo, idxInfo) + sc := s.ctx.GetSessionVars().StmtCtx + tk.Se.GetSessionVars().IndexLookupSize = 3 + tk.Se.GetSessionVars().MaxChunkSize = 3 + + // Reduce one row of index. + // Table count > index count. + // Index c2 is missing 11. + txn, err := s.store.Begin() + c.Assert(err, IsNil) + err = indexOpr.Delete(sc, txn, types.MakeDatums(-10), -1, nil) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + err = tk.ExecToErr("admin check table admin_test") + c.Assert(err.Error(), Equals, + "[executor:8003]admin_test err:[admin:1]index: != record:&admin.RecordData{Handle:-1, Values:[]types.Datum{types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:-10, b:[]uint8(nil), x:interface {}(nil)}}}") + c.Assert(executor.ErrAdminCheckTable.Equal(err), IsTrue) + r := tk.MustQuery("admin recover index admin_test c2") + r.Check(testkit.Rows("1 7")) + tk.MustExec("admin check table admin_test") + + // Add one row of index. + // Table count < index count. + // Index c2 has one more values ​​than table data: 0, and the handle 0 hasn't correlative record. + txn, err = s.store.Begin() + c.Assert(err, IsNil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(0), 0) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + err = tk.ExecToErr("admin check table admin_test") + c.Assert(err.Error(), Equals, "handle 0, index:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:0, b:[]uint8(nil), x:interface {}(nil)} != record:") + + // Add one row of index. + // Table count < index count. + // Index c2 has two more values ​​than table data: 10, 13, and these handles have correlative record. + txn, err = s.store.Begin() + c.Assert(err, IsNil) + err = indexOpr.Delete(sc, txn, types.MakeDatums(0), 0, nil) + c.Assert(err, IsNil) + // Make sure the index value "19" is smaller "21". Then we scan to "19" before "21". + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(19), 10) + c.Assert(err, IsNil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(13), 2) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + err = tk.ExecToErr("admin check table admin_test") + c.Assert(err.Error(), Equals, "col c2, handle 2, index:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:13, b:[]uint8(nil), x:interface {}(nil)} != record:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:12, b:[]uint8(nil), x:interface {}(nil)}") + + // Table count = index count. + // Two indices have the same handle. + txn, err = s.store.Begin() + c.Assert(err, IsNil) + err = indexOpr.Delete(sc, txn, types.MakeDatums(13), 2, nil) + c.Assert(err, IsNil) + err = indexOpr.Delete(sc, txn, types.MakeDatums(12), 2, nil) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + err = tk.ExecToErr("admin check table admin_test") + c.Assert(err.Error(), Equals, "col c2, handle 10, index:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:19, b:[]uint8(nil), x:interface {}(nil)} != record:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:20, b:[]uint8(nil), x:interface {}(nil)}") + + // Table count = index count. + // Index c2 has one line of data is 19, the corresponding table data is 20. + txn, err = s.store.Begin() + c.Assert(err, IsNil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(12), 2) + c.Assert(err, IsNil) + err = indexOpr.Delete(sc, txn, types.MakeDatums(20), 10, nil) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + err = tk.ExecToErr("admin check table admin_test") + c.Assert(err.Error(), Equals, "col c2, handle 10, index:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:19, b:[]uint8(nil), x:interface {}(nil)} != record:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:20, b:[]uint8(nil), x:interface {}(nil)}") + + // Recover records. + txn, err = s.store.Begin() + c.Assert(err, IsNil) + err = indexOpr.Delete(sc, txn, types.MakeDatums(19), 10, nil) + c.Assert(err, IsNil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(20), 10) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + tk.MustExec("admin check table admin_test") +} + +func (s *testSuite2) TestAdminCheckPartitionTableFailed(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists admin_test_p") + tk.MustExec("create table admin_test_p (c1 int key,c2 int,c3 int,index idx(c2)) partition by hash(c1) partitions 4") + tk.MustExec("insert admin_test_p (c1, c2, c3) values (0,0,0), (1,1,1),(2,2,2),(3,3,3),(4,4,4),(5,5,5)") + tk.MustExec("admin check table admin_test_p") + + // Make some corrupted index. Build the index information. + s.ctx = mock.NewContext() + s.ctx.Store = s.store + is := s.domain.InfoSchema() + dbName := model.NewCIStr("test") + tblName := model.NewCIStr("admin_test_p") + tbl, err := is.TableByName(dbName, tblName) + c.Assert(err, IsNil) + tblInfo := tbl.Meta() + idxInfo := tblInfo.Indices[0] + sc := s.ctx.GetSessionVars().StmtCtx + tk.Se.GetSessionVars().IndexLookupSize = 3 + tk.Se.GetSessionVars().MaxChunkSize = 3 + + // Reduce one row of index on partitions. + // Table count > index count. + for i := 0; i <= 5; i++ { + partitionIdx := i % len(tblInfo.GetPartitionInfo().Definitions) + indexOpr := tables.NewIndex(tblInfo.GetPartitionInfo().Definitions[partitionIdx].ID, tblInfo, idxInfo) + txn, err := s.store.Begin() + c.Assert(err, IsNil) + err = indexOpr.Delete(sc, txn, types.MakeDatums(i), int64(i), nil) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + err = tk.ExecToErr("admin check table admin_test_p") + c.Assert(err.Error(), Equals, fmt.Sprintf("[executor:8003]admin_test_p err:[admin:1]index: != record:&admin.RecordData{Handle:%d, Values:[]types.Datum{types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:%d, b:[]uint8(nil), x:interface {}(nil)}}}", i, i)) + c.Assert(executor.ErrAdminCheckTable.Equal(err), IsTrue) + // TODO: fix admin recover for partition table. + //r := tk.MustQuery("admin recover index admin_test_p idx") + //r.Check(testkit.Rows("0 0")) + //tk.MustExec("admin check table admin_test_p") + // Manual recover index. + txn, err = s.store.Begin() + c.Assert(err, IsNil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(i), int64(i)) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + tk.MustExec("admin check table admin_test_p") + } + + // Add one row of index on partitions. + // Table count < index count. + for i := 0; i <= 5; i++ { + partitionIdx := i % len(tblInfo.GetPartitionInfo().Definitions) + indexOpr := tables.NewIndex(tblInfo.GetPartitionInfo().Definitions[partitionIdx].ID, tblInfo, idxInfo) + txn, err := s.store.Begin() + c.Assert(err, IsNil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(i+8), int64(i+8)) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + err = tk.ExecToErr("admin check table admin_test_p") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, fmt.Sprintf("handle %d, index:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:%d, b:[]uint8(nil), x:interface {}(nil)} != record:", i+8, i+8)) + // TODO: fix admin recover for partition table. + txn, err = s.store.Begin() + c.Assert(err, IsNil) + err = indexOpr.Delete(sc, txn, types.MakeDatums(i+8), int64(i+8), nil) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + tk.MustExec("admin check table admin_test_p") + } + + // Table count = index count, but the index value was wrong. + for i := 0; i <= 5; i++ { + partitionIdx := i % len(tblInfo.GetPartitionInfo().Definitions) + indexOpr := tables.NewIndex(tblInfo.GetPartitionInfo().Definitions[partitionIdx].ID, tblInfo, idxInfo) + txn, err := s.store.Begin() + c.Assert(err, IsNil) + _, err = indexOpr.Create(s.ctx, txn, types.MakeDatums(i+8), int64(i)) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + err = tk.ExecToErr("admin check table admin_test_p") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, fmt.Sprintf("col c2, handle %d, index:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:%d, b:[]uint8(nil), x:interface {}(nil)} != record:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:%d, b:[]uint8(nil), x:interface {}(nil)}", i, i+8, i)) + // TODO: fix admin recover for partition table. + txn, err = s.store.Begin() + c.Assert(err, IsNil) + err = indexOpr.Delete(sc, txn, types.MakeDatums(i+8), int64(i), nil) + c.Assert(err, IsNil) + err = txn.Commit(context.Background()) + tk.MustExec("admin check table admin_test_p") + } +} + func (s *testSuite1) TestAdminCheckTable(c *C) { // test NULL value. tk := testkit.NewTestKit(c, s.store) @@ -419,8 +620,8 @@ func (s *testSuite1) TestAdminCheckTable(c *C) { tk.MustExec(`drop table if exists test`) tk.MustExec(`create table test ( a time, - PRIMARY KEY (a) - );`) + PRIMARY KEY (a) + );`) tk.MustExec(`insert into test set a='12:10:36';`) tk.MustExec(`admin check table test`) @@ -455,22 +656,22 @@ func (s *testSuite1) TestAdminCheckTable(c *C) { // Test index in virtual generated column. tk.MustExec(`drop table if exists test`) - tk.MustExec(`create table test ( b json , c int as (JSON_EXTRACT(b,'$.d')) , index idxc(c));`) + tk.MustExec(`create table test ( b json , c int as (JSON_EXTRACT(b,'$.d')), index idxc(c));`) tk.MustExec(`INSERT INTO test set b='{"d": 100}';`) tk.MustExec(`admin check table test;`) // Test prefix index. tk.MustExec(`drop table if exists t`) tk.MustExec(`CREATE TABLE t ( - ID CHAR(32) NOT NULL, - name CHAR(32) NOT NULL, - value CHAR(255), - INDEX indexIDname (ID(8),name(8)));`) + ID CHAR(32) NOT NULL, + name CHAR(32) NOT NULL, + value CHAR(255), + INDEX indexIDname (ID(8),name(8)));`) tk.MustExec(`INSERT INTO t VALUES ('keyword','urlprefix','text/ /text');`) tk.MustExec(`admin check table t;`) tk.MustExec("use mysql") tk.MustExec(`admin check table test.t;`) - _, err := tk.Exec("admin check table t") + err := tk.ExecToErr("admin check table t") c.Assert(err, NotNil) // test add index on time type column which have default value @@ -510,7 +711,7 @@ func (s *testSuite1) TestAdminCheckTable(c *C) { tk.MustExec(`drop table if exists t1`) tk.MustExec(`create table t1 (a decimal(2,1), index(a))`) tk.MustExec(`insert into t1 set a='1.9'`) - _, err = tk.Exec(`alter table t1 modify column a decimal(3,2);`) + err = tk.ExecToErr(`alter table t1 modify column a decimal(3,2);`) c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "[ddl:203]unsupported modify decimal column precision") tk.MustExec(`delete from t1;`) @@ -553,9 +754,9 @@ func (s *testSuite2) TestAdminCheckWithSnapshot(c *C) { c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) - _, err = tk.Exec("admin check table admin_t_s") + err = tk.ExecToErr("admin check table admin_t_s") c.Assert(err, NotNil) - _, err = tk.Exec("admin check index admin_t_s a") + err = tk.ExecToErr("admin check index admin_t_s a") c.Assert(err, NotNil) // For mocktikv, safe point is not initialized, we manually insert it for snapshot to use. @@ -572,9 +773,9 @@ func (s *testSuite2) TestAdminCheckWithSnapshot(c *C) { tk.MustExec("admin check index admin_t_s a;") tk.MustExec("set @@tidb_snapshot = ''") - _, err = tk.Exec("admin check table admin_t_s") + err = tk.ExecToErr("admin check table admin_t_s") c.Assert(err, NotNil) - _, err = tk.Exec("admin check index admin_t_s a") + err = tk.ExecToErr("admin check index admin_t_s a") c.Assert(err, NotNil) r := tk.MustQuery("admin cleanup index admin_t_s a") diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index c61bd792eebf2..7f9d582d5bf81 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -82,7 +82,8 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { if p.funcName == ast.AggFuncGroupConcat { args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)}) } - desc := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + c.Assert(err, IsNil) partialDesc, finalDesc := desc.Split([]int{0, 1}) // build partial func for partial phase. @@ -183,7 +184,8 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { if p.funcName == ast.AggFuncGroupConcat { args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)}) } - desc := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + c.Assert(err, IsNil) finalFunc := aggfuncs.Build(s.ctx, desc, 0) finalPr := finalFunc.AllocPartialResult() resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) @@ -208,7 +210,8 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { c.Assert(result, Equals, 0) // test the agg func with distinct - desc = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true) + desc, err = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true) + c.Assert(err, IsNil) finalFunc = aggfuncs.Build(s.ctx, desc, 0) finalPr = finalFunc.AllocPartialResult() diff --git a/executor/aggfuncs/func_cume_dist.go b/executor/aggfuncs/func_cume_dist.go index 37e1ffb1636a5..f486c16908812 100644 --- a/executor/aggfuncs/func_cume_dist.go +++ b/executor/aggfuncs/func_cume_dist.go @@ -30,11 +30,11 @@ type partialResult4CumeDist struct { } func (r *cumeDist) AllocPartialResult() PartialResult { - return PartialResult(&partialResult4Rank{}) + return PartialResult(&partialResult4CumeDist{}) } func (r *cumeDist) ResetPartialResult(pr PartialResult) { - p := (*partialResult4Rank)(pr) + p := (*partialResult4CumeDist)(pr) p.curIdx = 0 p.lastRank = 0 p.rows = p.rows[:0] diff --git a/executor/aggfuncs/func_lead_lag_test.go b/executor/aggfuncs/func_lead_lag_test.go new file mode 100644 index 0000000000000..fd4e5aa23dfcb --- /dev/null +++ b/executor/aggfuncs/func_lead_lag_test.go @@ -0,0 +1,114 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/types" +) + +func (s *testSuite) TestLeadLag(c *C) { + zero := expression.Zero + one := expression.One + two := &expression.Constant{ + Value: types.NewDatum(2), + RetType: types.NewFieldType(mysql.TypeTiny), + } + three := &expression.Constant{ + Value: types.NewDatum(3), + RetType: types.NewFieldType(mysql.TypeTiny), + } + million := &expression.Constant{ + Value: types.NewDatum(1000000), + RetType: types.NewFieldType(mysql.TypeLong), + } + defaultArg := &expression.Column{RetType: types.NewFieldType(mysql.TypeLonglong), Index: 0} + + numRows := 3 + tests := []windowTest{ + // lag(field0, N) + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{zero}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{one}, 0, numRows, nil, 0, 1), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{two}, 0, numRows, nil, nil, 0), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{three}, 0, numRows, nil, nil, nil), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{million}, 0, numRows, nil, nil, nil), + // lag(field0, N, 1000000) + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{zero, million}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{one, million}, 0, numRows, 1000000, 0, 1), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{two, million}, 0, numRows, 1000000, 1000000, 0), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{three, million}, 0, numRows, 1000000, 1000000, 1000000), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{million, million}, 0, numRows, 1000000, 1000000, 1000000), + // lag(field0, N, field0) + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{zero, defaultArg}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{one, defaultArg}, 0, numRows, 0, 0, 1), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{two, defaultArg}, 0, numRows, 0, 1, 0), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{three, defaultArg}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLag, mysql.TypeLonglong, + []expression.Expression{million, defaultArg}, 0, numRows, 0, 1, 2), + + // lead(field0, N) + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{zero}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{one}, 0, numRows, 1, 2, nil), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{two}, 0, numRows, 2, nil, nil), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{three}, 0, numRows, nil, nil, nil), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{million}, 0, numRows, nil, nil, nil), + // lead(field0, N, 1000000) + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{zero, million}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{one, million}, 0, numRows, 1, 2, 1000000), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{two, million}, 0, numRows, 2, 1000000, 1000000), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{three, million}, 0, numRows, 1000000, 1000000, 1000000), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{million, million}, 0, numRows, 1000000, 1000000, 1000000), + // lead(field0, N, field0) + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{zero, defaultArg}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{one, defaultArg}, 0, numRows, 1, 2, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{two, defaultArg}, 0, numRows, 2, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{three, defaultArg}, 0, numRows, 0, 1, 2), + buildWindowTesterWithArgs(ast.WindowFuncLead, mysql.TypeLonglong, + []expression.Expression{million, defaultArg}, 0, numRows, 0, 1, 2), + } + for _, test := range tests { + s.testWindowFunc(c, test) + } +} diff --git a/executor/aggfuncs/func_ntile.go b/executor/aggfuncs/func_ntile.go index 1adbb326d7609..eefa0a9cd6b83 100644 --- a/executor/aggfuncs/func_ntile.go +++ b/executor/aggfuncs/func_ntile.go @@ -31,7 +31,7 @@ type partialResult4Ntile struct { curGroupIdx uint64 remainder uint64 quotient uint64 - rows []chunk.Row + numRows uint64 } func (n *ntile) AllocPartialResult() PartialResult { @@ -42,16 +42,16 @@ func (n *ntile) ResetPartialResult(pr PartialResult) { p := (*partialResult4Ntile)(pr) p.curIdx = 0 p.curGroupIdx = 1 - p.rows = p.rows[:0] + p.numRows = 0 } func (n *ntile) UpdatePartialResult(_ sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { p := (*partialResult4Ntile)(pr) - p.rows = append(p.rows, rowsInGroup...) + p.numRows += uint64(len(rowsInGroup)) // Update the quotient and remainder. if n.n != 0 { - p.quotient = uint64(len(p.rows)) / n.n - p.remainder = uint64(len(p.rows)) % n.n + p.quotient = p.numRows / n.n + p.remainder = p.numRows % n.n } return nil } diff --git a/executor/aggfuncs/window_func_test.go b/executor/aggfuncs/window_func_test.go index d6c140d596d40..f6d879596e264 100644 --- a/executor/aggfuncs/window_func_test.go +++ b/executor/aggfuncs/window_func_test.go @@ -44,18 +44,22 @@ func (s *testSuite) testWindowFunc(c *C, p windowTest) { srcChk.AppendDatum(0, &dt) } - desc := aggregation.NewAggFuncDesc(s.ctx, p.funcName, p.args, false) + desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, p.args, false) + c.Assert(err, IsNil) finalFunc := aggfuncs.BuildWindowFunctions(s.ctx, desc, 0, p.orderByCols) finalPr := finalFunc.AllocPartialResult() resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) iter := chunk.NewIterator4Chunk(srcChk) for row := iter.Begin(); row != iter.End(); row = iter.Next() { - finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) + err = finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr) + c.Assert(err, IsNil) } + c.Assert(p.numRows, Equals, len(p.results)) for i := 0; i < p.numRows; i++ { - finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) + err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) + c.Assert(err, IsNil) dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[i]) c.Assert(err, IsNil) @@ -65,6 +69,26 @@ func (s *testSuite) testWindowFunc(c *C, p windowTest) { finalFunc.ResetPartialResult(finalPr) } +func buildWindowTesterWithArgs(funcName string, tp byte, args []expression.Expression, orderByCols int, numRows int, results ...interface{}) windowTest { + pt := windowTest{ + dataType: types.NewFieldType(tp), + numRows: numRows, + funcName: funcName, + } + if funcName != ast.WindowFuncNtile { + pt.args = append(pt.args, &expression.Column{RetType: pt.dataType, Index: 0}) + } + pt.args = append(pt.args, args...) + if orderByCols > 0 { + pt.orderByCols = append(pt.orderByCols, &expression.Column{RetType: pt.dataType, Index: 0}) + } + + for _, result := range results { + pt.results = append(pt.results, types.NewDatum(result)) + } + return pt +} + func buildWindowTester(funcName string, tp byte, constantArg uint64, orderByCols int, numRows int, results ...interface{}) windowTest { pt := windowTest{ dataType: types.NewFieldType(tp), @@ -89,6 +113,7 @@ func buildWindowTester(funcName string, tp byte, constantArg uint64, orderByCols func (s *testSuite) TestWindowFunctions(c *C) { tests := []windowTest{ + buildWindowTester(ast.WindowFuncCumeDist, mysql.TypeLonglong, 0, 1, 1, 1), buildWindowTester(ast.WindowFuncCumeDist, mysql.TypeLonglong, 0, 0, 2, 1, 1), buildWindowTester(ast.WindowFuncCumeDist, mysql.TypeLonglong, 0, 1, 4, 0.25, 0.5, 0.75, 1), @@ -104,23 +129,19 @@ func (s *testSuite) TestWindowFunctions(c *C) { buildWindowTester(ast.WindowFuncFirstValue, mysql.TypeDuration, 0, 1, 2, types.Duration{Duration: time.Duration(0)}, types.Duration{Duration: time.Duration(0)}), buildWindowTester(ast.WindowFuncFirstValue, mysql.TypeJSON, 0, 1, 2, json.CreateBinary(int64(0)), json.CreateBinary(int64(0))), - buildWindowTester(ast.WindowFuncLag, mysql.TypeLonglong, 1, 0, 3, nil, 0, 1), - buildWindowTester(ast.WindowFuncLag, mysql.TypeLonglong, 2, 1, 4, nil, nil, 0, 1), - buildWindowTester(ast.WindowFuncLastValue, mysql.TypeLonglong, 1, 0, 2, 1, 1), - buildWindowTester(ast.WindowFuncLead, mysql.TypeLonglong, 1, 0, 3, 1, 2, nil), - buildWindowTester(ast.WindowFuncLead, mysql.TypeLonglong, 2, 0, 4, 2, 3, nil, nil), - buildWindowTester(ast.WindowFuncNthValue, mysql.TypeLonglong, 2, 0, 3, 1, 1, 1), buildWindowTester(ast.WindowFuncNthValue, mysql.TypeLonglong, 5, 0, 3, nil, nil, nil), buildWindowTester(ast.WindowFuncNtile, mysql.TypeLonglong, 3, 0, 4, 1, 1, 2, 3), buildWindowTester(ast.WindowFuncNtile, mysql.TypeLonglong, 5, 0, 3, 1, 2, 3), + buildWindowTester(ast.WindowFuncPercentRank, mysql.TypeLonglong, 0, 1, 1, 0), buildWindowTester(ast.WindowFuncPercentRank, mysql.TypeLonglong, 0, 0, 3, 0, 0, 0), buildWindowTester(ast.WindowFuncPercentRank, mysql.TypeLonglong, 0, 1, 4, 0, 0.3333333333333333, 0.6666666666666666, 1), + buildWindowTester(ast.WindowFuncRank, mysql.TypeLonglong, 0, 1, 1, 1), buildWindowTester(ast.WindowFuncRank, mysql.TypeLonglong, 0, 0, 3, 1, 1, 1), buildWindowTester(ast.WindowFuncRank, mysql.TypeLonglong, 0, 1, 4, 1, 2, 3, 4), diff --git a/executor/aggregate.go b/executor/aggregate.go index fde12cb6bab8f..1663736c49804 100644 --- a/executor/aggregate.go +++ b/executor/aggregate.go @@ -163,6 +163,7 @@ type HashAggExec struct { isChildReturnEmpty bool childResult *chunk.Chunk + executed bool } // HashAggInput indicates the input of hash agg exec. @@ -207,6 +208,9 @@ func (e *HashAggExec) Close() error { for _, ch := range e.partialOutputChs { close(ch) } + for _, ch := range e.partialInputChs { + close(ch) + } close(e.finalOutputCh) } close(e.finishCh) @@ -214,8 +218,13 @@ func (e *HashAggExec) Close() error { for range ch { } } + for _, ch := range e.partialInputChs { + for range ch { + } + } for range e.finalOutputCh { } + e.executed = false return e.baseExecutor.Close() } @@ -239,7 +248,7 @@ func (e *HashAggExec) initForUnparallelExec() { e.partialResultMap = make(aggPartialResultMapper) e.groupKeyBuffer = make([]byte, 0, 8) e.groupValDatums = make([]types.Datum, 0, len(e.groupKeyBuffer)) - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) } func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { @@ -275,12 +284,12 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { partialResultsMap: make(aggPartialResultMapper), groupByItems: e.GroupByItems, groupValDatums: make([]types.Datum, 0, len(e.GroupByItems)), - chk: e.children[0].newFirstChunk(), + chk: newFirstChunk(e.children[0]), } e.partialWorkers[i] = w e.inputCh <- &HashAggInput{ - chk: e.children[0].newFirstChunk(), + chk: newFirstChunk(e.children[0]), giveBackCh: w.inputCh, } } @@ -295,7 +304,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) { outputCh: e.finalOutputCh, finalResultHolderCh: e.finalInputCh, rowBuffer: make([]types.Datum, 0, e.Schema().Len()), - mutableRow: chunk.MutRowFromTypes(e.retTypes()), + mutableRow: chunk.MutRowFromTypes(retTypes(e)), } } } @@ -514,7 +523,7 @@ func (w *HashAggFinalWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitGro } // Next implements the Executor Next interface. -func (e *HashAggExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *HashAggExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("hashagg.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -525,9 +534,9 @@ func (e *HashAggExec) Next(ctx context.Context, req *chunk.RecordBatch) error { } req.Reset() if e.isUnparallelExec { - return e.unparallelExec(ctx, req.Chunk) + return e.unparallelExec(ctx, req) } - return e.parallelExec(ctx, req.Chunk) + return e.parallelExec(ctx, req) } func (e *HashAggExec) fetchChildData(ctx context.Context) { @@ -555,7 +564,7 @@ func (e *HashAggExec) fetchChildData(ctx context.Context) { } chk = input.chk } - err = e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err = Next(ctx, e.children[0], chk) if err != nil { e.finalOutputCh <- &AfFinalResult{err: err} return @@ -614,10 +623,14 @@ func (e *HashAggExec) parallelExec(ctx context.Context, chk *chunk.Chunk) error } }) + if e.executed { + return nil + } for !chk.IsFull() { e.finalInputCh <- chk result, ok := <-e.finalOutputCh if !ok { // all finalWorkers exited + e.executed = true if chk.NumRows() > 0 { // but there are some data left return nil } @@ -681,7 +694,7 @@ func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) erro func (e *HashAggExec) execute(ctx context.Context) (err error) { inputIter := chunk.NewIterator4Chunk(e.childResult) for { - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], e.childResult) if err != nil { return err } @@ -772,7 +785,7 @@ func (e *StreamAggExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) e.executed = false e.isChildReturnEmpty = true e.inputIter = chunk.NewIterator4Chunk(e.childResult) @@ -789,11 +802,12 @@ func (e *StreamAggExec) Open(ctx context.Context) error { // Close implements the Executor Close interface. func (e *StreamAggExec) Close() error { e.childResult = nil + e.groupChecker.reset() return e.baseExecutor.Close() } // Next implements the Executor Next interface. -func (e *StreamAggExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *StreamAggExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("streamAgg.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -804,7 +818,7 @@ func (e *StreamAggExec) Next(ctx context.Context, req *chunk.RecordBatch) error } req.Reset() for !e.executed && !req.IsFull() { - err := e.consumeOneGroup(ctx, req.Chunk) + err := e.consumeOneGroup(ctx, req) if err != nil { e.executed = true return err @@ -869,7 +883,7 @@ func (e *StreamAggExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Ch return err } - err = e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err = Next(ctx, e.children[0], e.childResult) if err != nil { return err } @@ -954,3 +968,12 @@ func (e *groupChecker) meetNewGroup(row chunk.Row) (bool, error) { } return !firstGroup, nil } + +func (e *groupChecker) reset() { + if e.curGroupKey != nil { + e.curGroupKey = e.curGroupKey[:0] + } + if e.tmpGroupKey != nil { + e.tmpGroupKey = e.tmpGroupKey[:0] + } +} diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 06826616d5564..fca05420430ad 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -15,6 +15,7 @@ package executor_test import ( . "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/parser/terror" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/util/testkit" @@ -334,6 +335,22 @@ func (s *testSuite1) TestAggregation(c *C) { tk.MustExec("insert into t value(0), (-0.9871), (-0.9871)") tk.MustQuery("select 10 from t group by a").Check(testkit.Rows("10", "10")) tk.MustQuery("select sum(a) from (select a from t union all select a from t) tmp").Check(testkit.Rows("-3.9484")) + _, err = tk.Exec("select std(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: std") + _, err = tk.Exec("select stddev(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev") + _, err = tk.Exec("select stddev_pop(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: stddev_pop") + _, err = tk.Exec("select std_samp(a) from t") + // TODO: Fix this error message. + c.Assert(errors.Cause(err).Error(), Equals, "[expression:1305]FUNCTION test.std_samp does not exist") + _, err = tk.Exec("select variance(a) from t") + // TODO: Fix this error message. + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: var_pop") + _, err = tk.Exec("select var_pop(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: var_pop") + _, err = tk.Exec("select var_samp(a) from t") + c.Assert(errors.Cause(err).Error(), Equals, "unsupported agg function: var_samp") } func (s *testSuite1) TestStreamAggPushDown(c *C) { @@ -710,3 +727,14 @@ func (s *testSuite1) TestIssue10098(c *C) { tk.MustExec("insert into t values('1', '222'), ('12', '22')") tk.MustQuery("select group_concat(distinct a, b) from t").Check(testkit.Rows("1222,1222")) } + +func (s *testSuite1) TestIssue10608(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec(`drop table if exists t, s;`) + tk.MustExec("create table t(a int)") + tk.MustExec("create table s(a int, b int)") + tk.MustExec("insert into s values(100292, 508931), (120002, 508932)") + tk.MustExec("insert into t values(508931), (508932)") + tk.MustQuery("select (select group_concat(concat(123,'-')) from t where t.a = s.b group by t.a) as t from s;").Check(testkit.Rows("123-", "123-")) + +} diff --git a/executor/analyze.go b/executor/analyze.go index 637b05d00a511..4cbee01c0df48 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -27,6 +27,7 @@ import ( "github.com/cznic/mathutil" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/debugpb" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" @@ -62,7 +63,7 @@ type AnalyzeExec struct { var ( // MaxSampleSize is the size of samples for once analyze. // It's public for test. - MaxSampleSize = 10000 + MaxSampleSize = int64(10000) // RandSeed is the seed for randing package. // It's public for test. RandSeed = int64(1) @@ -76,7 +77,7 @@ const ( ) // Next implements the Executor Next interface. -func (e *AnalyzeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *AnalyzeExec) Next(ctx context.Context, req *chunk.Chunk) error { concurrency, err := getBuildStatsConcurrency(e.ctx) if err != nil { return err @@ -164,12 +165,8 @@ type analyzeTask struct { var errAnalyzeWorkerPanic = errors.New("analyze worker panic") func (e *AnalyzeExec) analyzeWorker(taskCh <-chan *analyzeTask, resultCh chan<- analyzeResult, isCloseChanThread bool) { + var task *analyzeTask defer func() { - e.wg.Done() - if isCloseChanThread { - e.wg.Wait() - close(resultCh) - } if r := recover(); r != nil { buf := make([]byte, 4096) stackSize := runtime.Stack(buf, false) @@ -178,11 +175,18 @@ func (e *AnalyzeExec) analyzeWorker(taskCh <-chan *analyzeTask, resultCh chan<- metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() resultCh <- analyzeResult{ Err: errAnalyzeWorkerPanic, + job: task.job, } } + e.wg.Done() + if isCloseChanThread { + e.wg.Wait() + close(resultCh) + } }() for { - task, ok := <-taskCh + var ok bool + task, ok = <-taskCh if !ok { break } @@ -295,6 +299,11 @@ func (e *AnalyzeIndexExec) open(ranges []*ranger.Range, considerNull bool) error } func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, needCMS bool) (*statistics.Histogram, *statistics.CMSketch, error) { + failpoint.Inject("buildStatsFromResult", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, nil, errors.New("mock buildStatsFromResult error")) + } + }) hist := &statistics.Histogram{} var cms *statistics.CMSketch if needCMS { @@ -322,7 +331,7 @@ func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, nee if needCMS { if resp.Cms == nil { logutil.Logger(context.TODO()).Warn("nil CMS in response", zap.String("table", e.idxInfo.Table.O), zap.String("index", e.idxInfo.Name.O)) - } else if err := cms.MergeCMSketch(statistics.CMSketchFromProto(resp.Cms)); err != nil { + } else if err := cms.MergeCMSketch(statistics.CMSketchFromProto(resp.Cms), 0); err != nil { return nil, nil, err } } @@ -335,7 +344,10 @@ func (e *AnalyzeIndexExec) buildStats(ranges []*ranger.Range, considerNull bool) return nil, nil, err } defer func() { - err = closeAll(e.result, e.countNullRes) + err1 := closeAll(e.result, e.countNullRes) + if err == nil { + err = err1 + } }() hist, cms, err = e.buildStatsFromResult(e.result, true) if err != nil { @@ -452,7 +464,7 @@ func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range) (hists []*statis collectors[i] = &statistics.SampleCollector{ IsMerger: true, FMSketch: statistics.NewFMSketch(maxSketchSize), - MaxSampleSize: int64(MaxSampleSize), + MaxSampleSize: atomic.LoadInt64(&MaxSampleSize), CMSketch: statistics.NewCMSketch(defaultCMSketchDepth, defaultCMSketchWidth), } } @@ -690,7 +702,6 @@ func (e *AnalyzeFastExec) getNextSampleKey(bo *tikv.Backoffer, startKey kv.Key) func (e *AnalyzeFastExec) buildSampTask() (needRebuild bool, err error) { // Do get regions row count. bo := tikv.NewBackoffer(context.Background(), 500) - e.rowCount = 0 needRebuildForRoutine := make([]bool, e.concurrency) errs := make([]error, e.concurrency) sampTasksForRoutine := make([][]*AnalyzeFastTask, e.concurrency) @@ -722,6 +733,13 @@ func (e *AnalyzeFastExec) buildSampTask() (needRebuild bool, err error) { if err != nil { return false, err } + e.rowCount = 0 + for _, task := range e.sampTasks { + cnt := task.EndOffset - task.BeginOffset + task.BeginOffset = e.rowCount + task.EndOffset = e.rowCount + cnt + e.rowCount += cnt + } for { // Search for the region which contains the targetKey. loc, err := e.cache.LocateKey(bo, targetKey) @@ -795,6 +813,9 @@ func (e *AnalyzeFastExec) updateCollectorSamples(sValue []byte, sKey kv.Key, sam } v = types.NewIntDatum(key) } + if mysql.HasUnsignedFlag(e.pkInfo.Flag) { + v.SetUint64(uint64(v.GetInt64())) + } if e.collectors[0].Samples[samplePos] == nil { e.collectors[0].Samples[samplePos] = &statistics.SampleItem{} } @@ -937,14 +958,16 @@ func (e *AnalyzeFastExec) handleSampTasks(bo *tikv.Backoffer, workID int, err *e keys = append(keys, tablecodec.EncodeRowKeyWithHandle(tableID, randKey)) } - var kvMap map[string][]byte + kvMap := make(map[string][]byte, len(keys)) for _, key := range keys { var iter kv.Iterator iter, *err = snapshot.Iter(key, endKey) if *err != nil { return } - kvMap[string(iter.Key())] = iter.Value() + if iter.Valid() { + kvMap[string(iter.Key())] = iter.Value() + } } *err = e.handleBatchSeekResponse(kvMap) @@ -954,11 +977,7 @@ func (e *AnalyzeFastExec) handleSampTasks(bo *tikv.Backoffer, workID int, err *e } } -func (e *AnalyzeFastExec) buildHist(ID int64, collector *statistics.SampleCollector, tp *types.FieldType) (*statistics.Histogram, error) { - // build collector properties. - collector.Samples = collector.Samples[:e.sampCursor] - sort.Slice(collector.Samples, func(i, j int) bool { return collector.Samples[i].RowID < collector.Samples[j].RowID }) - collector.CalcTotalSize() +func (e *AnalyzeFastExec) buildColumnStats(ID int64, collector *statistics.SampleCollector, tp *types.FieldType, rowCount int64) (*statistics.Histogram, *statistics.CMSketch, error) { data := make([][]byte, 0, len(collector.Samples)) for i, sample := range collector.Samples { sample.Ordinal = i @@ -968,24 +987,49 @@ func (e *AnalyzeFastExec) buildHist(ID int64, collector *statistics.SampleCollec } bytes, err := tablecodec.EncodeValue(e.ctx.GetSessionVars().StmtCtx, sample.Value) if err != nil { - return nil, err + return nil, nil, err } data = append(data, bytes) } - stats := domain.GetDomain(e.ctx).StatsHandle() - rowCount := int64(e.rowCount) - if stats.Lease > 0 { - rowCount = mathutil.MinInt64(stats.GetTableStats(e.tblInfo).Count, rowCount) - } - // build CMSketch - var ndv, scaleRatio uint64 - collector.CMSketch, ndv, scaleRatio = statistics.NewCMSketchWithTopN(defaultCMSketchDepth, defaultCMSketchWidth, data, 20, uint64(rowCount)) - // build Histogram + // Build CMSketch. + cmSketch, ndv, scaleRatio := statistics.NewCMSketchWithTopN(defaultCMSketchDepth, defaultCMSketchWidth, data, 20, uint64(rowCount)) + // Build Histogram. hist, err := statistics.BuildColumnHist(e.ctx, int64(e.maxNumBuckets), ID, collector, tp, rowCount, int64(ndv), collector.NullCount*int64(scaleRatio)) - if err != nil { - return nil, err + return hist, cmSketch, err +} + +func (e *AnalyzeFastExec) buildIndexStats(idxInfo *model.IndexInfo, collector *statistics.SampleCollector, rowCount int64) (*statistics.Histogram, *statistics.CMSketch, error) { + data := make([][][]byte, len(idxInfo.Columns), len(idxInfo.Columns)) + for _, sample := range collector.Samples { + var preLen int + remained := sample.Value.GetBytes() + // We need to insert each prefix values into CM Sketch. + for i := 0; i < len(idxInfo.Columns); i++ { + var err error + var value []byte + value, remained, err = codec.CutOne(remained) + if err != nil { + return nil, nil, err + } + preLen += len(value) + data[i] = append(data[i], sample.Value.GetBytes()[:preLen]) + } + } + numTop := uint32(20) + cmSketch, ndv, scaleRatio := statistics.NewCMSketchWithTopN(defaultCMSketchDepth, defaultCMSketchWidth, data[0], numTop, uint64(rowCount)) + // Build CM Sketch for each prefix and merge them into one. + for i := 1; i < len(idxInfo.Columns); i++ { + var curCMSketch *statistics.CMSketch + // `ndv` should be the ndv of full index, so just rewrite it here. + curCMSketch, ndv, scaleRatio = statistics.NewCMSketchWithTopN(defaultCMSketchDepth, defaultCMSketchWidth, data[i], numTop, uint64(rowCount)) + err := cmSketch.MergeCMSketch(curCMSketch, numTop) + if err != nil { + return nil, nil, err + } } - return hist, nil + // Build Histogram. + hist, err := statistics.BuildColumnHist(e.ctx, int64(e.maxNumBuckets), idxInfo.ID, collector, types.NewFieldType(mysql.TypeBlob), rowCount, int64(ndv), collector.NullCount*int64(scaleRatio)) + return hist, cmSketch, err } func (e *AnalyzeFastExec) runTasks() ([]*statistics.Histogram, []*statistics.CMSketch, error) { @@ -1021,19 +1065,33 @@ func (e *AnalyzeFastExec) runTasks() ([]*statistics.Histogram, []*statistics.CMS return nil, nil, err } + handle := domain.GetDomain(e.ctx).StatsHandle() + tblStats := handle.GetTableStats(e.tblInfo) + rowCount := int64(e.rowCount) + if handle.Lease() > 0 && !tblStats.Pseudo { + rowCount = mathutil.MinInt64(tblStats.Count, rowCount) + } + // Adjust the row count in case the count of `tblStats` is not accurate and too small. + rowCount = mathutil.MaxInt64(rowCount, int64(e.sampCursor)) hists, cms := make([]*statistics.Histogram, length), make([]*statistics.CMSketch, length) for i := 0; i < length; i++ { + // Build collector properties. + collector := e.collectors[i] + collector.Samples = collector.Samples[:e.sampCursor] + sort.Slice(collector.Samples, func(i, j int) bool { return collector.Samples[i].RowID < collector.Samples[j].RowID }) + collector.CalcTotalSize() + // Scale the total column size. + collector.TotalSize *= rowCount / int64(len(collector.Samples)) if i < hasPKInfo { - hists[i], err = e.buildHist(e.pkInfo.ID, e.collectors[i], &e.pkInfo.FieldType) + hists[i], cms[i], err = e.buildColumnStats(e.pkInfo.ID, e.collectors[i], &e.pkInfo.FieldType, rowCount) } else if i < hasPKInfo+len(e.colsInfo) { - hists[i], err = e.buildHist(e.colsInfo[i-hasPKInfo].ID, e.collectors[i], &e.colsInfo[i-hasPKInfo].FieldType) + hists[i], cms[i], err = e.buildColumnStats(e.colsInfo[i-hasPKInfo].ID, e.collectors[i], &e.colsInfo[i-hasPKInfo].FieldType, rowCount) } else { - hists[i], err = e.buildHist(e.idxsInfo[i-hasPKInfo-len(e.colsInfo)].ID, e.collectors[i], types.NewFieldType(mysql.TypeBlob)) + hists[i], cms[i], err = e.buildIndexStats(e.idxsInfo[i-hasPKInfo-len(e.colsInfo)], e.collectors[i], rowCount) } if err != nil { return nil, nil, err } - cms[i] = e.collectors[i].CMSketch } return hists, cms, nil } @@ -1075,7 +1133,7 @@ func (e *AnalyzeFastExec) buildStats() (hists []*statistics.Histogram, cms []*st } randPos := make([]uint64, 0, MaxSampleSize+1) - for i := 0; i < MaxSampleSize; i++ { + for i := 0; i < int(MaxSampleSize); i++ { randPos = append(randPos, uint64(rander.Int63n(int64(e.rowCount)))) } sort.Slice(randPos, func(i, j int) bool { return randPos[i] < randPos[j] }) @@ -1125,7 +1183,7 @@ type analyzeIndexIncrementalExec struct { func analyzeIndexIncremental(idxExec *analyzeIndexIncrementalExec) analyzeResult { startPos := idxExec.oldHist.GetUpper(idxExec.oldHist.Len() - 1) - values, err := codec.DecodeRange(startPos.GetBytes(), len(idxExec.idxInfo.Columns)) + values, _, err := codec.DecodeRange(startPos.GetBytes(), len(idxExec.idxInfo.Columns)) if err != nil { return analyzeResult{Err: err, job: idxExec.job} } diff --git a/executor/analyze_test.go b/executor/analyze_test.go index edbfd808877c3..b53d5545a560b 100644 --- a/executor/analyze_test.go +++ b/executor/analyze_test.go @@ -18,9 +18,11 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "time" . "github.com/pingcap/check" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/domain" @@ -152,13 +154,13 @@ func (s *testSuite1) TestAnalyzeFastSample(c *C) { ) c.Assert(err, IsNil) var dom *domain.Domain - session.SetStatsLease(0) + session.DisableStats4Test() session.SetSchemaLease(0) dom, err = session.BootstrapSession(store) c.Assert(err, IsNil) tk := testkit.NewTestKit(c, store) - executor.MaxSampleSize = 20 - executor.RandSeed = 123 + atomic.StoreInt64(&executor.MaxSampleSize, 20) + atomic.StoreInt64(&executor.RandSeed, 123) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -223,30 +225,30 @@ func (s *testSuite1) TestFastAnalyze(c *C) { ) c.Assert(err, IsNil) var dom *domain.Domain - session.SetStatsLease(0) + session.DisableStats4Test() session.SetSchemaLease(0) dom, err = session.BootstrapSession(store) c.Assert(err, IsNil) tk := testkit.NewTestKit(c, store) - executor.MaxSampleSize = 1000 - executor.RandSeed = 123 + atomic.StoreInt64(&executor.MaxSampleSize, 6) + atomic.StoreInt64(&executor.RandSeed, 123) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int primary key, b int, index index_b(b))") + tk.MustExec("create table t(a int primary key, b int, c char(10), index index_b(b))") tk.MustExec("set @@session.tidb_enable_fast_analyze=1") tk.MustExec("set @@session.tidb_build_stats_concurrency=1") - for i := 0; i < 3000; i++ { - tk.MustExec(fmt.Sprintf("insert into t values (%d, %d)", i, i)) - } tblInfo, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) tid := tblInfo.Meta().ID - // construct 5 regions split by {600, 1200, 1800, 2400} - splitKeys := generateTableSplitKeyForInt(tid, []int{600, 1200, 1800, 2400}) + // construct 6 regions split by {10, 20, 30, 40, 50} + splitKeys := generateTableSplitKeyForInt(tid, []int{10, 20, 30, 40, 50}) manipulateCluster(cluster, splitKeys) + for i := 0; i < 20; i++ { + tk.MustExec(fmt.Sprintf(`insert into t values (%d, %d, "char")`, i*3, i*3)) + } tk.MustExec("analyze table t with 5 buckets") is := executor.GetInfoSchema(tk.Se.(sessionctx.Context)) @@ -254,49 +256,42 @@ func (s *testSuite1) TestFastAnalyze(c *C) { c.Assert(err, IsNil) tableInfo := table.Meta() tbl := dom.StatsHandle().GetTableStats(tableInfo) - sTbl := fmt.Sprintln(tbl) - matched := false - if sTbl == "Table:39 Count:3000\n"+ - "column:1 ndv:3000 totColSize:0\n"+ - "num: 603 lower_bound: 6 upper_bound: 612 repeats: 1\n"+ - "num: 603 lower_bound: 621 upper_bound: 1205 repeats: 1\n"+ - "num: 603 lower_bound: 1207 upper_bound: 1830 repeats: 1\n"+ - "num: 603 lower_bound: 1831 upper_bound: 2387 repeats: 1\n"+ - "num: 588 lower_bound: 2390 upper_bound: 2997 repeats: 1\n"+ - "column:2 ndv:3000 totColSize:0\n"+ - "num: 603 lower_bound: 6 upper_bound: 612 repeats: 1\n"+ - "num: 603 lower_bound: 621 upper_bound: 1205 repeats: 1\n"+ - "num: 603 lower_bound: 1207 upper_bound: 1830 repeats: 1\n"+ - "num: 603 lower_bound: 1831 upper_bound: 2387 repeats: 1\n"+ - "num: 588 lower_bound: 2390 upper_bound: 2997 repeats: 1\n"+ - "index:1 ndv:3000\n"+ - "num: 603 lower_bound: 6 upper_bound: 612 repeats: 1\n"+ - "num: 603 lower_bound: 621 upper_bound: 1205 repeats: 1\n"+ - "num: 603 lower_bound: 1207 upper_bound: 1830 repeats: 1\n"+ - "num: 603 lower_bound: 1831 upper_bound: 2387 repeats: 1\n"+ - "num: 588 lower_bound: 2390 upper_bound: 2997 repeats: 1\n" || - sTbl == "Table:39 Count:3000\n"+ - "column:2 ndv:3000 totColSize:0\n"+ - "num: 603 lower_bound: 6 upper_bound: 612 repeats: 1\n"+ - "num: 603 lower_bound: 621 upper_bound: 1205 repeats: 1\n"+ - "num: 603 lower_bound: 1207 upper_bound: 1830 repeats: 1\n"+ - "num: 603 lower_bound: 1831 upper_bound: 2387 repeats: 1\n"+ - "num: 588 lower_bound: 2390 upper_bound: 2997 repeats: 1\n"+ - "column:1 ndv:3000 totColSize:0\n"+ - "num: 603 lower_bound: 6 upper_bound: 612 repeats: 1\n"+ - "num: 603 lower_bound: 621 upper_bound: 1205 repeats: 1\n"+ - "num: 603 lower_bound: 1207 upper_bound: 1830 repeats: 1\n"+ - "num: 603 lower_bound: 1831 upper_bound: 2387 repeats: 1\n"+ - "num: 588 lower_bound: 2390 upper_bound: 2997 repeats: 1\n"+ - "index:1 ndv:3000\n"+ - "num: 603 lower_bound: 6 upper_bound: 612 repeats: 1\n"+ - "num: 603 lower_bound: 621 upper_bound: 1205 repeats: 1\n"+ - "num: 603 lower_bound: 1207 upper_bound: 1830 repeats: 1\n"+ - "num: 603 lower_bound: 1831 upper_bound: 2387 repeats: 1\n"+ - "num: 588 lower_bound: 2390 upper_bound: 2997 repeats: 1\n" { - matched = true - } - c.Assert(matched, Equals, true) + c.Assert(tbl.String(), Equals, "Table:43 Count:20\n"+ + "column:1 ndv:20 totColSize:0\n"+ + "num: 6 lower_bound: 3 upper_bound: 15 repeats: 1\n"+ + "num: 7 lower_bound: 18 upper_bound: 33 repeats: 1\n"+ + "num: 7 lower_bound: 39 upper_bound: 57 repeats: 1\n"+ + "column:2 ndv:20 totColSize:0\n"+ + "num: 6 lower_bound: 3 upper_bound: 15 repeats: 1\n"+ + "num: 7 lower_bound: 18 upper_bound: 33 repeats: 1\n"+ + "num: 7 lower_bound: 39 upper_bound: 57 repeats: 1\n"+ + "column:3 ndv:1 totColSize:72\n"+ + "num: 20 lower_bound: char upper_bound: char repeats: 18\n"+ + "index:1 ndv:20\n"+ + "num: 6 lower_bound: 3 upper_bound: 15 repeats: 1\n"+ + "num: 7 lower_bound: 18 upper_bound: 33 repeats: 1\n"+ + "num: 7 lower_bound: 39 upper_bound: 57 repeats: 1") + + // Test CM Sketch built from fast analyze. + tk.MustExec("create table t1(a int, b int, index idx(a, b))") + tk.MustExec("insert into t1 values (1,1),(1,1),(1,2),(1,2)") + tk.MustExec("analyze table t1") + tk.MustQuery("explain select a from t1 where a = 1").Check(testkit.Rows( + "IndexReader_6 4.00 root index:IndexScan_5", + "└─IndexScan_5 4.00 cop table:t1, index:a, b, range:[1,1], keep order:false")) + tk.MustQuery("explain select a, b from t1 where a = 1 and b = 1").Check(testkit.Rows( + "IndexReader_6 2.00 root index:IndexScan_5", + "└─IndexScan_5 2.00 cop table:t1, index:a, b, range:[1 1,1 1], keep order:false")) + tk.MustQuery("explain select a, b from t1 where a = 1 and b = 2").Check(testkit.Rows( + "IndexReader_6 2.00 root index:IndexScan_5", + "└─IndexScan_5 2.00 cop table:t1, index:a, b, range:[1 2,1 2], keep order:false")) + + tk.MustExec("create table t2 (a bigint unsigned, primary key(a))") + tk.MustExec("insert into t2 values (0), (18446744073709551615)") + tk.MustExec("analyze table t2") + tk.MustQuery("show stats_buckets where table_name = 't2'").Check(testkit.Rows( + "test t2 a 0 0 1 1 0 0", + "test t2 a 0 1 2 1 18446744073709551615 18446744073709551615")) } func (s *testSuite1) TestAnalyzeIncremental(c *C) { @@ -414,7 +409,8 @@ func (s *testFastAnalyze) TestFastAnalyzeRetryRowCount(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int primary key, b int, index index_b(b))") + tk.MustExec("create table t(a int primary key)") + c.Assert(s.dom.StatsHandle().Update(s.dom.InfoSchema()), IsNil) tk.MustExec("set @@session.tidb_enable_fast_analyze=1") tk.MustExec("set @@session.tidb_build_stats_concurrency=1") tblInfo, err := s.dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) @@ -424,10 +420,23 @@ func (s *testFastAnalyze) TestFastAnalyzeRetryRowCount(c *C) { splitKeys := generateTableSplitKeyForInt(tid, []int{6, 12, 18, 24, 30}) regionIDs := manipulateCluster(s.cluster, splitKeys) for i := 0; i < 30; i++ { - tk.MustExec(fmt.Sprintf("insert into t values (%d, %d)", i, i)) + tk.MustExec(fmt.Sprintf("insert into t values (%d)", i)) } s.cli.setFailRegion(regionIDs[4]) tk.MustExec("analyze table t") // 4 regions will be sampled, and it will retry the last failed region. c.Assert(s.cli.mu.count, Equals, int64(5)) + row := tk.MustQuery(`show stats_meta where db_name = "test" and table_name = "t"`).Rows()[0] + c.Assert(row[5], Equals, "30") +} + +func (s *testSuite1) TestFailedAnalyzeRequest(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int primary key, b int, index index_b(b))") + c.Assert(failpoint.Enable("github.com/pingcap/tidb/executor/buildStatsFromResult", `return(true)`), IsNil) + _, err := tk.Exec("analyze table t") + c.Assert(err.Error(), Equals, "mock buildStatsFromResult error") + c.Assert(failpoint.Disable("github.com/pingcap/tidb/executor/buildStatsFromResult"), IsNil) } diff --git a/executor/batch_checker.go b/executor/batch_checker.go index 12d2b7c63a834..9b15daf15c415 100644 --- a/executor/batch_checker.go +++ b/executor/batch_checker.go @@ -14,6 +14,8 @@ package executor import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/expression" @@ -284,6 +286,94 @@ func (b *batchChecker) deleteDupKeys(ctx sessionctx.Context, t table.Table, rows return nil } +// getOldRowNew gets the table record row from storage for batch check. +// t could be a normal table or a partition, but it must not be a PartitionedTable. +func (b *batchChecker) getOldRowNew(sctx sessionctx.Context, txn kv.Transaction, t table.Table, handle int64, + genExprs []expression.Expression) ([]types.Datum, error) { + oldValue, err := txn.Get(t.RecordKey(handle)) + if err != nil { + return nil, err + } + + cols := t.WritableCols() + oldRow, oldRowMap, err := tables.DecodeRawRowData(sctx, t.Meta(), handle, cols, oldValue) + if err != nil { + return nil, err + } + // Fill write-only and write-reorg columns with originDefaultValue if not found in oldValue. + gIdx := 0 + for _, col := range cols { + if col.State != model.StatePublic && oldRow[col.Offset].IsNull() { + _, found := oldRowMap[col.ID] + if !found { + oldRow[col.Offset], err = table.GetColOriginDefaultValue(sctx, col.ToInfo()) + if err != nil { + return nil, err + } + } + } + if col.IsGenerated() { + // only the virtual column needs fill back. + if !col.GeneratedStored { + val, err := genExprs[gIdx].Eval(chunk.MutRowFromDatums(oldRow).ToRow()) + if err != nil { + return nil, err + } + oldRow[col.Offset], err = table.CastValue(sctx, val, col.ToInfo()) + if err != nil { + return nil, err + } + } + gIdx++ + } + } + return oldRow, nil +} + +// getOldRow gets the table record row from storage for batch check. +// t could be a normal table or a partition, but it must not be a PartitionedTable. +func getOldRow(ctx context.Context, sctx sessionctx.Context, txn kv.Transaction, t table.Table, handle int64, + genExprs []expression.Expression) ([]types.Datum, error) { + oldValue, err := txn.Get(t.RecordKey(handle)) + if err != nil { + return nil, err + } + + cols := t.WritableCols() + oldRow, oldRowMap, err := tables.DecodeRawRowData(sctx, t.Meta(), handle, cols, oldValue) + if err != nil { + return nil, err + } + // Fill write-only and write-reorg columns with originDefaultValue if not found in oldValue. + gIdx := 0 + for _, col := range cols { + if col.State != model.StatePublic && oldRow[col.Offset].IsNull() { + _, found := oldRowMap[col.ID] + if !found { + oldRow[col.Offset], err = table.GetColOriginDefaultValue(sctx, col.ToInfo()) + if err != nil { + return nil, err + } + } + } + if col.IsGenerated() { + // only the virtual column needs fill back. + if !col.GeneratedStored { + val, err := genExprs[gIdx].Eval(chunk.MutRowFromDatums(oldRow).ToRow()) + if err != nil { + return nil, err + } + oldRow[col.Offset], err = table.CastValue(sctx, val, col.ToInfo()) + if err != nil { + return nil, err + } + } + gIdx++ + } + } + return oldRow, nil +} + // getOldRow gets the table record row from storage for batch check. // t could be a normal table or a partition, but it must not be a PartitionedTable. func (b *batchChecker) getOldRow(ctx sessionctx.Context, t table.Table, handle int64, diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index e13e0eaf6f137..aff3cd4e375a3 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -18,6 +18,7 @@ import ( "fmt" "math/rand" "sort" + "strings" "testing" "github.com/pingcap/parser/ast" @@ -25,6 +26,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" @@ -87,6 +89,8 @@ func (mds *mockDataSource) genColDatums(col int) (results []interface{}) { return results[i].(int64) < results[j].(int64) case mysql.TypeDouble: return results[i].(float64) < results[j].(float64) + case mysql.TypeVarString: + return results[i].(string) < results[j].(string) default: panic("not implement") } @@ -102,6 +106,8 @@ func (mds *mockDataSource) randDatum(typ *types.FieldType) interface{} { return int64(rand.Int()) case mysql.TypeDouble: return rand.Float64() + case mysql.TypeVarString: + return rawData default: panic("not implement") } @@ -115,13 +121,13 @@ func (mds *mockDataSource) prepareChunks() { mds.chunkPtr = 0 } -func (mds *mockDataSource) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (mds *mockDataSource) Next(ctx context.Context, req *chunk.Chunk) error { if mds.chunkPtr >= len(mds.chunks) { req.Reset() return nil } dataChk := mds.chunks[mds.chunkPtr] - dataChk.SwapColumns(req.Chunk) + dataChk.SwapColumns(req) mds.chunkPtr++ return nil } @@ -129,7 +135,7 @@ func (mds *mockDataSource) Next(ctx context.Context, req *chunk.RecordBatch) err func buildMockDataSource(opt mockDataSourceParameters) *mockDataSource { baseExec := newBaseExecutor(opt.ctx, opt.schema, nil) m := &mockDataSource{baseExec, opt, nil, nil, 0} - types := m.retTypes() + types := retTypes(m) colData := make([][]interface{}, len(types)) for i := 0; i < len(types); i++ { colData[i] = m.genColDatums(i) @@ -137,18 +143,20 @@ func buildMockDataSource(opt mockDataSourceParameters) *mockDataSource { m.genData = make([]*chunk.Chunk, (m.p.rows+m.initCap-1)/m.initCap) for i := range m.genData { - m.genData[i] = chunk.NewChunkWithCapacity(m.retTypes(), m.ctx.GetSessionVars().MaxChunkSize) + m.genData[i] = chunk.NewChunkWithCapacity(retTypes(m), m.ctx.GetSessionVars().MaxChunkSize) } for i := 0; i < m.p.rows; i++ { idx := i / m.maxChunkSize - retTypes := m.retTypes() + retTypes := retTypes(m) for colIdx := 0; colIdx < len(types); colIdx++ { switch retTypes[colIdx].Tp { case mysql.TypeLong, mysql.TypeLonglong: m.genData[idx].AppendInt64(colIdx, colData[colIdx][i].(int64)) case mysql.TypeDouble: m.genData[idx].AppendFloat64(colIdx, colData[colIdx][i].(float64)) + case mysql.TypeVarString: + m.genData[idx].AppendString(colIdx, colData[colIdx][i].(string)) default: panic("not implement") } @@ -171,11 +179,12 @@ type aggTestCase struct { func (a aggTestCase) columns() []*expression.Column { return []*expression.Column{ {Index: 0, RetType: types.NewFieldType(mysql.TypeDouble)}, - {Index: 1, RetType: types.NewFieldType(mysql.TypeLonglong)}} + {Index: 1, RetType: types.NewFieldType(mysql.TypeLonglong)}, + } } func (a aggTestCase) String() string { - return fmt.Sprintf("(execType:%v, aggFunc:%v, groupByNDV:%v, hasDistinct:%v, rows:%v, concruuency:%v)", + return fmt.Sprintf("(execType:%v, aggFunc:%v, ndv:%v, hasDistinct:%v, rows:%v, concruuency:%v)", a.execType, a.aggFunc, a.groupByNDV, a.hasDistinct, a.rows, a.concurrency) } @@ -228,7 +237,10 @@ func buildAggExecutor(b *testing.B, testCase *aggTestCase, child Executor) Execu childCols := testCase.columns() schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc := aggregation.NewAggFuncDesc(testCase.ctx, testCase.aggFunc, []expression.Expression{childCols[0]}, testCase.hasDistinct) + aggFunc, err := aggregation.NewAggFuncDesc(testCase.ctx, testCase.aggFunc, []expression.Expression{childCols[0]}, testCase.hasDistinct) + if err != nil { + b.Fatal(err) + } aggFuncs := []*aggregation.AggFuncDesc{aggFunc} var aggExec Executor @@ -259,16 +271,15 @@ func benchmarkAggExecWithCase(b *testing.B, casTest *aggTestCase) { b.StopTimer() // prepare a new agg-executor aggExec := buildAggExecutor(b, casTest, dataSource) tmpCtx := context.Background() - chk := aggExec.newFirstChunk() + chk := newFirstChunk(aggExec) dataSource.prepareChunks() b.StartTimer() if err := aggExec.Open(tmpCtx); err != nil { b.Fatal(err) } - batch := chunk.NewRecordBatch(chk) for { - if err := aggExec.Next(tmpCtx, batch); err != nil { + if err := aggExec.Next(tmpCtx, chk); err != nil { b.Fatal(b) } if chk.NumRows() == 0 { @@ -348,3 +359,147 @@ func BenchmarkAggDistinct(b *testing.B) { } } } + +func buildWindowExecutor(ctx sessionctx.Context, windowFunc string, src Executor, schema *expression.Schema, partitionBy []*expression.Column) Executor { + plan := new(core.PhysicalWindow) + + var args []expression.Expression + switch windowFunc { + case ast.WindowFuncNtile: + args = append(args, &expression.Constant{Value: types.NewUintDatum(2)}) + case ast.WindowFuncNthValue: + args = append(args, partitionBy[0], &expression.Constant{Value: types.NewUintDatum(2)}) + default: + args = append(args, partitionBy[0]) + } + desc, _ := aggregation.NewWindowFuncDesc(ctx, windowFunc, args) + plan.WindowFuncDescs = []*aggregation.WindowFuncDesc{desc} + for _, col := range partitionBy { + plan.PartitionBy = append(plan.PartitionBy, property.Item{Col: col}) + } + plan.OrderBy = nil + plan.SetSchema(schema) + plan.Init(ctx, nil) + plan.SetChildren(nil) + b := newExecutorBuilder(ctx, nil) + exec := b.build(plan) + window := exec.(*WindowExec) + window.children[0] = src + return exec +} + +type windowTestCase struct { + // The test table's schema is fixed (col Double, partitionBy LongLong, rawData VarString(5128), col LongLong). + windowFunc string + ndv int // the number of distinct group-by keys + rows int + ctx sessionctx.Context +} + +var rawData = strings.Repeat("x", 5*1024) + +func (a windowTestCase) columns() []*expression.Column { + rawDataTp := new(types.FieldType) + types.DefaultTypeForValue(rawData, rawDataTp) + return []*expression.Column{ + {Index: 0, RetType: types.NewFieldType(mysql.TypeDouble)}, + {Index: 1, RetType: types.NewFieldType(mysql.TypeLonglong)}, + {Index: 2, RetType: rawDataTp}, + {Index: 3, RetType: types.NewFieldType(mysql.TypeLonglong)}, + } +} + +func (a windowTestCase) String() string { + return fmt.Sprintf("(func:%v, ndv:%v, rows:%v)", + a.windowFunc, a.ndv, a.rows) +} + +func defaultWindowTestCase() *windowTestCase { + ctx := mock.NewContext() + ctx.GetSessionVars().InitChunkSize = variable.DefInitChunkSize + ctx.GetSessionVars().MaxChunkSize = variable.DefMaxChunkSize + return &windowTestCase{ast.WindowFuncRowNumber, 1000, 10000000, ctx} +} + +func benchmarkWindowExecWithCase(b *testing.B, casTest *windowTestCase) { + cols := casTest.columns() + dataSource := buildMockDataSource(mockDataSourceParameters{ + schema: expression.NewSchema(cols...), + ndvs: []int{0, casTest.ndv, 0, 0}, + orders: []bool{false, true, false, false}, + rows: casTest.rows, + ctx: casTest.ctx, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() // prepare a new window-executor + childCols := casTest.columns() + schema := expression.NewSchema(childCols...) + windowExec := buildWindowExecutor(casTest.ctx, casTest.windowFunc, dataSource, schema, childCols[1:2]) + tmpCtx := context.Background() + chk := newFirstChunk(windowExec) + dataSource.prepareChunks() + + b.StartTimer() + if err := windowExec.Open(tmpCtx); err != nil { + b.Fatal(err) + } + for { + if err := windowExec.Next(tmpCtx, chk); err != nil { + b.Fatal(b) + } + if chk.NumRows() == 0 { + break + } + } + + if err := windowExec.Close(); err != nil { + b.Fatal(err) + } + b.StopTimer() + } +} + +func BenchmarkWindowRows(b *testing.B) { + b.ReportAllocs() + rows := []int{1000, 100000} + ndvs := []int{10, 1000} + for _, row := range rows { + for _, ndv := range ndvs { + cas := defaultWindowTestCase() + cas.rows = row + cas.ndv = ndv + cas.windowFunc = ast.WindowFuncRowNumber // cheapest + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkWindowExecWithCase(b, cas) + }) + } + } +} + +func BenchmarkWindowFunctions(b *testing.B) { + b.ReportAllocs() + windowFuncs := []string{ + ast.WindowFuncRowNumber, + ast.WindowFuncRank, + ast.WindowFuncDenseRank, + ast.WindowFuncCumeDist, + ast.WindowFuncPercentRank, + ast.WindowFuncNtile, + ast.WindowFuncLead, + ast.WindowFuncLag, + ast.WindowFuncFirstValue, + ast.WindowFuncLastValue, + ast.WindowFuncNthValue, + } + for _, windowFunc := range windowFuncs { + cas := defaultWindowTestCase() + cas.rows = 100000 + cas.ndv = 1000 + cas.windowFunc = windowFunc + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkWindowExecWithCase(b, cas) + }) + } +} diff --git a/executor/bind.go b/executor/bind.go index d2c2034851bde..6849e4eec9b76 100644 --- a/executor/bind.go +++ b/executor/bind.go @@ -39,7 +39,7 @@ type SQLBindExec struct { } // Next implements the Executor Next interface. -func (e *SQLBindExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *SQLBindExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("SQLBindExec.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() diff --git a/executor/builder.go b/executor/builder.go index e04a51245b824..7a78b50d4b548 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -106,6 +106,12 @@ func (b *executorBuilder) build(p plannercore.Plan) Executor { return b.buildCheckIndexRange(v) case *plannercore.ChecksumTable: return b.buildChecksumTable(v) + case *plannercore.ReloadExprPushdownBlacklist: + return b.buildReloadExprPushdownBlacklist(v) + case *plannercore.ReloadOptRuleBlacklist: + return b.buildReloadOptRuleBlacklist(v) + case *plannercore.AdminPlugins: + return b.buildAdminPlugins(v) case *plannercore.DDL: return b.buildDDL(v) case *plannercore.Deallocate: @@ -194,8 +200,8 @@ func (b *executorBuilder) build(p plannercore.Plan) Executor { return b.buildWindow(v) case *plannercore.SQLBindPlan: return b.buildSQLBindExec(v) - case *plannercore.SplitIndexRegion: - return b.buildSplitIndexRegion(v) + case *plannercore.SplitRegion: + return b.buildSplitRegion(v) default: if mp, ok := p.(MockPhysicalPlan); ok { return mp.GetExecutor() @@ -226,7 +232,8 @@ func (b *executorBuilder) buildCancelDDLJobs(v *plannercore.CancelDDLJobs) Execu func (b *executorBuilder) buildChange(v *plannercore.Change) Executor { return &ChangeExec{ - ChangeStmt: v.ChangeStmt, + baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID()), + ChangeStmt: v.ChangeStmt, } } @@ -302,8 +309,8 @@ func (b *executorBuilder) buildCheckIndex(v *plannercore.CheckIndex) Executor { b.err = err return nil } - readerExec.ranges = ranger.FullRange() - readerExec.isCheckOp = true + + buildIndexLookUpChecker(b, v.IndexLookUpReader, readerExec) e := &CheckIndexExec{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID()), @@ -316,12 +323,59 @@ func (b *executorBuilder) buildCheckIndex(v *plannercore.CheckIndex) Executor { return e } +// buildIndexLookUpChecker builds check information to IndexLookUpReader. +func buildIndexLookUpChecker(b *executorBuilder, readerPlan *plannercore.PhysicalIndexLookUpReader, + readerExec *IndexLookUpExecutor) { + is := readerPlan.IndexPlans[0].(*plannercore.PhysicalIndexScan) + readerExec.dagPB.OutputOffsets = make([]uint32, 0, len(is.Index.Columns)) + for i := 0; i <= len(is.Index.Columns); i++ { + readerExec.dagPB.OutputOffsets = append(readerExec.dagPB.OutputOffsets, uint32(i)) + } + readerExec.ranges = ranger.FullRange() + ts := readerPlan.TablePlans[0].(*plannercore.PhysicalTableScan) + readerExec.handleIdx = ts.HandleIdx + + tps := make([]*types.FieldType, 0, len(is.Columns)+1) + for _, col := range is.Columns { + tps = append(tps, &col.FieldType) + } + tps = append(tps, types.NewFieldType(mysql.TypeLonglong)) + readerExec.checkIndexValue = &checkIndexValue{genExprs: is.GenExprs, idxColTps: tps} + + colNames := make([]string, 0, len(is.Columns)) + for _, col := range is.Columns { + colNames = append(colNames, col.Name.O) + } + var err error + readerExec.idxTblCols, err = table.FindCols(readerExec.table.Cols(), colNames, true) + if err != nil { + b.err = errors.Trace(err) + return + } +} + func (b *executorBuilder) buildCheckTable(v *plannercore.CheckTable) Executor { + readerExecs := make([]*IndexLookUpExecutor, 0, len(v.IndexLookUpReaders)) + for _, readerPlan := range v.IndexLookUpReaders { + readerExec, err := buildNoRangeIndexLookUpReader(b, readerPlan) + if err != nil { + b.err = errors.Trace(err) + return nil + } + buildIndexLookUpChecker(b, readerPlan, readerExec) + + readerExecs = append(readerExecs, readerExec) + } + e := &CheckTableExec{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID()), - tables: v.Tables, + dbName: v.DBName, + table: v.Table, + indexInfos: v.IndexInfos, is: b.is, - genExprs: v.GenExprs, + srcs: readerExecs, + exitCh: make(chan struct{}), + retCh: make(chan error, len(readerExecs)), } return e } @@ -461,6 +515,18 @@ func (b *executorBuilder) buildChecksumTable(v *plannercore.ChecksumTable) Execu return e } +func (b *executorBuilder) buildReloadExprPushdownBlacklist(v *plannercore.ReloadExprPushdownBlacklist) Executor { + return &ReloadExprPushdownBlacklistExec{baseExecutor{ctx: b.ctx}} +} + +func (b *executorBuilder) buildReloadOptRuleBlacklist(v *plannercore.ReloadOptRuleBlacklist) Executor { + return &ReloadOptRuleBlacklistExec{baseExecutor{ctx: b.ctx}} +} + +func (b *executorBuilder) buildAdminPlugins(v *plannercore.AdminPlugins) Executor { + return &AdminPluginsExec{baseExecutor: baseExecutor{ctx: b.ctx}, Action: v.Action, Plugins: v.Plugins} +} + func (b *executorBuilder) buildDeallocate(v *plannercore.Deallocate) Executor { base := newBaseExecutor(b.ctx, nil, v.ExplainID()) base.initCap = chunk.ZeroCapacity @@ -541,6 +607,7 @@ func (b *executorBuilder) buildShow(v *plannercore.Show) Executor { DBName: model.NewCIStr(v.DBName), Table: v.Table, Column: v.Column, + IndexName: v.IndexName, User: v.User, Roles: v.Roles, IfNotExists: v.IfNotExists, @@ -550,7 +617,12 @@ func (b *executorBuilder) buildShow(v *plannercore.Show) Executor { is: b.is, } if e.Tp == ast.ShowGrants && e.User == nil { - e.User = e.ctx.GetSessionVars().User + // The input is a "show grants" statement, fulfill the user and roles field. + // Note: "show grants" result are different from "show grants for current_user", + // The former determine privileges with roles, while the later doesn't. + vars := e.ctx.GetSessionVars() + e.User = vars.User + e.Roles = vars.ActiveRoles } if e.Tp == ast.ShowMasterStatus { // show master status need start ts. @@ -558,14 +630,7 @@ func (b *executorBuilder) buildShow(v *plannercore.Show) Executor { b.err = err } } - if len(v.Conditions) == 0 { - return e - } - sel := &SelectionExec{ - baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), e), - filters: v.Conditions, - } - return sel + return e } func (b *executorBuilder) buildSimple(v *plannercore.Simple) Executor { @@ -783,8 +848,9 @@ func (b *executorBuilder) buildUnionScanFromReader(reader Executor, v *plannerco // GetDirtyDB() is safe here. If this table has been modified in the transaction, non-nil DirtyTable // can be found in DirtyDB now, so GetDirtyTable is safe; if this table has not been modified in the // transaction, empty DirtyTable would be inserted into DirtyDB, it does not matter when multiple - // goroutines write empty DirtyTable to DirtyDB for this table concurrently. Thus we don't use lock - // to synchronize here. + // goroutines write empty DirtyTable to DirtyDB for this table concurrently. Although the DirtyDB looks + // safe for data race in all the cases, the map of golang will throw panic when it's accessed in parallel. + // So we lock it when getting dirty table. physicalTableID := getPhysicalTableID(x.table) us.dirty = GetDirtyDB(b.ctx).GetDirtyTable(physicalTableID) us.conditions = v.Conditions @@ -861,8 +927,8 @@ func (b *executorBuilder) buildMergeJoin(v *plannercore.PhysicalMergeJoin) Execu v.JoinType == plannercore.RightOuterJoin, defaultValues, v.OtherConditions, - leftExec.retTypes(), - rightExec.retTypes(), + retTypes(leftExec), + retTypes(rightExec), ), isOuterJoin: v.JoinType.IsOuterJoin(), } @@ -935,7 +1001,7 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo } defaultValues := v.DefaultValues - lhsTypes, rhsTypes := leftExec.retTypes(), rightExec.retTypes() + lhsTypes, rhsTypes := retTypes(leftExec), retTypes(rightExec) if v.InnerChildIdx == 0 { if len(v.LeftConditions) > 0 { b.err = errors.Annotate(ErrBuildExecutor, "join's inner condition should be empty") @@ -969,9 +1035,6 @@ func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) Executo v.OtherConditions, lhsTypes, rhsTypes) } executorCountHashJoinExec.Inc() - if e.ctx.GetSessionVars().EnableRadixJoin { - return &RadixHashJoinExec{HashJoinExec: e} - } return e } @@ -1009,7 +1072,7 @@ func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) Executor if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { e.defaultVal = nil } else { - e.defaultVal = chunk.NewChunkWithCapacity(e.retTypes(), 1) + e.defaultVal = chunk.NewChunkWithCapacity(retTypes(e), 1) } for _, aggDesc := range v.AggFuncs { if aggDesc.HasDistinct { @@ -1068,7 +1131,7 @@ func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) Execu if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { e.defaultVal = nil } else { - e.defaultVal = chunk.NewChunkWithCapacity(e.retTypes(), 1) + e.defaultVal = chunk.NewChunkWithCapacity(retTypes(e), 1) } for i, aggDesc := range v.AggFuncs { aggFunc := aggfuncs.Build(b.ctx, aggDesc, i) @@ -1209,7 +1272,7 @@ func (b *executorBuilder) buildApply(v *plannercore.PhysicalApply) *NestedLoopAp defaultValues = make([]types.Datum, v.Children()[v.InnerChildIdx].Schema().Len()) } tupleJoiner := newJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, - defaultValues, otherConditions, leftChild.retTypes(), rightChild.retTypes()) + defaultValues, otherConditions, retTypes(leftChild), retTypes(rightChild)) outerExec, innerExec := leftChild, rightChild outerFilter, innerFilter := v.LeftConditions, v.RightConditions if v.InnerChildIdx == 0 { @@ -1256,14 +1319,34 @@ func (b *executorBuilder) buildUnionAll(v *plannercore.PhysicalUnionAll) Executo return e } -func (b *executorBuilder) buildSplitIndexRegion(v *plannercore.SplitIndexRegion) Executor { - base := newBaseExecutor(b.ctx, nil, v.ExplainID()) - base.initCap = chunk.ZeroCapacity - return &SplitIndexRegionExec{ +func (b *executorBuilder) buildSplitRegion(v *plannercore.SplitRegion) Executor { + base := newBaseExecutor(b.ctx, v.Schema(), v.ExplainID()) + base.initCap = 1 + base.maxChunkSize = 1 + if v.IndexInfo != nil { + return &SplitIndexRegionExec{ + baseExecutor: base, + tableInfo: v.TableInfo, + indexInfo: v.IndexInfo, + lower: v.Lower, + upper: v.Upper, + num: v.Num, + valueLists: v.ValueLists, + } + } + if len(v.ValueLists) > 0 { + return &SplitTableRegionExec{ + baseExecutor: base, + tableInfo: v.TableInfo, + valueLists: v.ValueLists, + } + } + return &SplitTableRegionExec{ baseExecutor: base, - table: v.Table, - indexInfo: v.IndexInfo, - valueLists: v.ValueLists, + tableInfo: v.TableInfo, + lower: v.Lower[0], + upper: v.Upper[0], + num: v.Num, } } @@ -1673,7 +1756,7 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) if b.err != nil { return nil } - outerTypes := outerExec.retTypes() + outerTypes := retTypes(outerExec) innerPlan := v.Children()[1-v.OuterIndex] innerTypes := make([]*types.FieldType, innerPlan.Schema().Len()) for i, col := range innerPlan.Schema().Columns { @@ -1704,6 +1787,13 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) if defaultValues == nil { defaultValues = make([]types.Datum, len(innerTypes)) } + hasPrefixCol := false + for _, l := range v.IdxColLens { + if l != types.UnspecifiedLength { + hasPrefixCol = true + break + } + } e := &IndexLookUpJoin{ baseExecutor: newBaseExecutor(b.ctx, v.Schema(), v.ExplainID(), outerExec), outerCtx: outerCtx{ @@ -1713,6 +1803,8 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) innerCtx: innerCtx{ readerBuilder: &dataReaderBuilder{Plan: innerPlan, executorBuilder: b}, rowTypes: innerTypes, + colLens: v.IdxColLens, + hasPrefixCol: hasPrefixCol, }, workerWg: new(sync.WaitGroup), joiner: newJoiner(b.ctx, v.JoinType, v.OuterIndex == 1, defaultValues, v.OtherConditions, leftTypes, rightTypes), @@ -1731,7 +1823,7 @@ func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) innerKeyCols[i] = v.InnerJoinKeys[i].Index } e.innerCtx.keyCols = innerKeyCols - e.joinResult = e.newFirstChunk() + e.joinResult = newFirstChunk(e) executorCounterIndexLookUpJoin.Inc() return e } @@ -1754,7 +1846,8 @@ func buildNoRangeTableReader(b *executorBuilder, v *plannercore.PhysicalTableRea } ts := v.TablePlans[0].(*plannercore.PhysicalTableScan) tbl, _ := b.is.TableByID(ts.Table.ID) - if isPartition, physicalTableID := ts.IsPartition(); isPartition { + isPartition, physicalTableID := ts.IsPartition() + if isPartition { pt := tbl.(table.PartitionedTable) tbl = pt.GetPartition(physicalTableID) } @@ -1862,7 +1955,7 @@ func (b *executorBuilder) buildIndexReader(v *plannercore.PhysicalIndexReader) * is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) ret.ranges = is.Ranges sctx := b.ctx.GetSessionVars().StmtCtx - sctx.IndexIDs = append(sctx.IndexIDs, is.Index.ID) + sctx.IndexNames = append(sctx.IndexNames, is.Table.Name.O+":"+is.Index.Name.O) return ret } @@ -1907,6 +2000,7 @@ func buildNoRangeIndexLookUpReader(b *executorBuilder, v *plannercore.PhysicalIn colLens: is.IdxColLens, idxPlans: v.IndexPlans, tblPlans: v.TablePlans, + PushedLimit: v.PushedLimit, } if containsLimit(indexReq.Executors) { @@ -1941,7 +2035,7 @@ func (b *executorBuilder) buildIndexLookUpReader(v *plannercore.PhysicalIndexLoo ret.ranges = is.Ranges executorCounterIndexLookUpExecutor.Inc() sctx := b.ctx.GetSessionVars().StmtCtx - sctx.IndexIDs = append(sctx.IndexIDs, is.Index.ID) + sctx.IndexNames = append(sctx.IndexNames, ts.Table.Name.O+":"+is.Index.Name.O) sctx.TableIDs = append(sctx.TableIDs, ts.Table.ID) return ret } @@ -1985,7 +2079,7 @@ func (builder *dataReaderBuilder) buildUnionScanForIndexJoin(ctx context.Context return nil, err } us := e.(*UnionScanExec) - us.snapshotChunkBuffer = us.newFirstChunk() + us.snapshotChunkBuffer = newFirstChunk(us) return us, nil } @@ -2020,7 +2114,7 @@ func (builder *dataReaderBuilder) buildTableReaderFromHandles(ctx context.Contex return nil, err } e.resultHandler = &tableResultHandler{} - result, err := builder.SelectResult(ctx, builder.ctx, kvReq, e.retTypes(), e.feedback, getPhysicalPlanIDs(e.plans)) + result, err := builder.SelectResult(ctx, builder.ctx, kvReq, retTypes(e), e.feedback, getPhysicalPlanIDs(e.plans)) if err != nil { return nil, err } @@ -2122,7 +2216,11 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) *WindowExec partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs)) resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs) for _, desc := range v.WindowFuncDescs { - aggDesc := aggregation.NewAggFuncDesc(b.ctx, desc.Name, desc.Args, false) + aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, desc.Name, desc.Args, false) + if err != nil { + b.err = err + return nil + } agg := aggfuncs.BuildWindowFunctions(b.ctx, aggDesc, resultColIdx, orderByCols) windowFuncs = append(windowFuncs, agg) partialResults = append(partialResults, agg.AllocPartialResult()) diff --git a/executor/change.go b/executor/change.go index dbf73b0c9b9e0..bb9ec0cf1cee0 100644 --- a/executor/change.go +++ b/executor/change.go @@ -31,7 +31,7 @@ type ChangeExec struct { } // Next implements the Executor Next interface. -func (e *ChangeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ChangeExec) Next(ctx context.Context, req *chunk.Chunk) error { kind := strings.ToLower(e.NodeType) urls := config.GetGlobalConfig().Path registry, err := createRegistry(urls) diff --git a/executor/checksum.go b/executor/checksum.go index e716446577300..c84579fe85ee8 100644 --- a/executor/checksum.go +++ b/executor/checksum.go @@ -83,7 +83,7 @@ func (e *ChecksumTableExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *ChecksumTableExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ChecksumTableExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.done { return nil @@ -187,27 +187,47 @@ func newChecksumContext(db *model.DBInfo, table *model.TableInfo, startTs uint64 } func (c *checksumContext) BuildRequests(ctx sessionctx.Context) ([]*kv.Request, error) { - reqs := make([]*kv.Request, 0, len(c.TableInfo.Indices)+1) - req, err := c.buildTableRequest(ctx) - if err != nil { + var partDefs []model.PartitionDefinition + if part := c.TableInfo.Partition; part != nil { + partDefs = part.Definitions + } + + reqs := make([]*kv.Request, 0, (len(c.TableInfo.Indices)+1)*(len(partDefs)+1)) + if err := c.appendRequest(ctx, c.TableInfo.ID, &reqs); err != nil { return nil, err } - reqs = append(reqs, req) + + for _, partDef := range partDefs { + if err := c.appendRequest(ctx, partDef.ID, &reqs); err != nil { + return nil, err + } + } + + return reqs, nil +} + +func (c *checksumContext) appendRequest(ctx sessionctx.Context, tableID int64, reqs *[]*kv.Request) error { + req, err := c.buildTableRequest(ctx, tableID) + if err != nil { + return err + } + + *reqs = append(*reqs, req) for _, indexInfo := range c.TableInfo.Indices { if indexInfo.State != model.StatePublic { continue } - req, err = c.buildIndexRequest(ctx, indexInfo) + req, err = c.buildIndexRequest(ctx, tableID, indexInfo) if err != nil { - return nil, err + return err } - reqs = append(reqs, req) + *reqs = append(*reqs, req) } - return reqs, nil + return nil } -func (c *checksumContext) buildTableRequest(ctx sessionctx.Context) (*kv.Request, error) { +func (c *checksumContext) buildTableRequest(ctx sessionctx.Context, tableID int64) (*kv.Request, error) { checksum := &tipb.ChecksumRequest{ StartTs: c.StartTs, ScanOn: tipb.ChecksumScanOn_Table, @@ -217,13 +237,13 @@ func (c *checksumContext) buildTableRequest(ctx sessionctx.Context) (*kv.Request ranges := ranger.FullIntRange(false) var builder distsql.RequestBuilder - return builder.SetTableRanges(c.TableInfo.ID, ranges, nil). + return builder.SetTableRanges(tableID, ranges, nil). SetChecksumRequest(checksum). SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency). Build() } -func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, indexInfo *model.IndexInfo) (*kv.Request, error) { +func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, tableID int64, indexInfo *model.IndexInfo) (*kv.Request, error) { checksum := &tipb.ChecksumRequest{ StartTs: c.StartTs, ScanOn: tipb.ChecksumScanOn_Index, @@ -233,7 +253,7 @@ func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, indexInfo *m ranges := ranger.FullRange() var builder distsql.RequestBuilder - return builder.SetIndexRanges(ctx.GetSessionVars().StmtCtx, c.TableInfo.ID, indexInfo.ID, ranges). + return builder.SetIndexRanges(ctx.GetSessionVars().StmtCtx, tableID, indexInfo.ID, ranges). SetChecksumRequest(checksum). SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency). Build() diff --git a/executor/compiler.go b/executor/compiler.go index 7f1c85c691011..455f924975fcd 100644 --- a/executor/compiler.go +++ b/executor/compiler.go @@ -15,7 +15,6 @@ package executor import ( "context" - "fmt" "strings" "github.com/opentracing/opentracing-go" @@ -76,70 +75,59 @@ func (c *Compiler) compile(ctx context.Context, stmtNode ast.StmtNode, skipBind return nil, err } - finalPlan, err := planner.Optimize(c.Ctx, stmtNode, infoSchema) + finalPlan, err := planner.Optimize(ctx, c.Ctx, stmtNode, infoSchema) if err != nil { return nil, err } CountStmtNode(stmtNode, c.Ctx.GetSessionVars().InRestrictedSQL) - isExpensive := logExpensiveQuery(stmtNode, finalPlan) - + lowerPriority := needLowerPriority(finalPlan) return &ExecStmt{ - InfoSchema: infoSchema, - Plan: finalPlan, - Expensive: isExpensive, - Cacheable: plannercore.Cacheable(stmtNode), - Text: stmtNode.Text(), - StmtNode: stmtNode, - Ctx: c.Ctx, + InfoSchema: infoSchema, + Plan: finalPlan, + LowerPriority: lowerPriority, + Cacheable: plannercore.Cacheable(stmtNode), + Text: stmtNode.Text(), + StmtNode: stmtNode, + Ctx: c.Ctx, }, nil } -func logExpensiveQuery(stmtNode ast.StmtNode, finalPlan plannercore.Plan) (expensive bool) { - expensive = isExpensiveQuery(finalPlan) - if !expensive { - return - } - - const logSQLLen = 1024 - sql := stmtNode.Text() - if len(sql) > logSQLLen { - sql = fmt.Sprintf("%s len(%d)", sql[:logSQLLen], len(sql)) - } - logutil.Logger(context.Background()).Warn("EXPENSIVE_QUERY", zap.String("SQL", sql)) - return -} - -func isExpensiveQuery(p plannercore.Plan) bool { +// needLowerPriority checks whether it's needed to lower the execution priority +// of a query. +// If the estimated output row count of any operator in the physical plan tree +// is greater than the specific threshold, we'll set it to lowPriority when +// sending it to the coprocessor. +func needLowerPriority(p plannercore.Plan) bool { switch x := p.(type) { case plannercore.PhysicalPlan: - return isPhysicalPlanExpensive(x) + return isPhysicalPlanNeedLowerPriority(x) case *plannercore.Execute: - return isExpensiveQuery(x.Plan) + return needLowerPriority(x.Plan) case *plannercore.Insert: if x.SelectPlan != nil { - return isPhysicalPlanExpensive(x.SelectPlan) + return isPhysicalPlanNeedLowerPriority(x.SelectPlan) } case *plannercore.Delete: if x.SelectPlan != nil { - return isPhysicalPlanExpensive(x.SelectPlan) + return isPhysicalPlanNeedLowerPriority(x.SelectPlan) } case *plannercore.Update: if x.SelectPlan != nil { - return isPhysicalPlanExpensive(x.SelectPlan) + return isPhysicalPlanNeedLowerPriority(x.SelectPlan) } } return false } -func isPhysicalPlanExpensive(p plannercore.PhysicalPlan) bool { - expensiveRowThreshold := int64(config.GetGlobalConfig().Log.ExpensiveThreshold) - if int64(p.StatsCount()) > expensiveRowThreshold { +func isPhysicalPlanNeedLowerPriority(p plannercore.PhysicalPlan) bool { + expensiveThreshold := int64(config.GetGlobalConfig().Log.ExpensiveThreshold) + if int64(p.StatsCount()) > expensiveThreshold { return true } for _, child := range p.Children() { - if isPhysicalPlanExpensive(child) { + if isPhysicalPlanNeedLowerPriority(child) { return true } } @@ -416,6 +404,7 @@ func addHintForSelect(hash, normdOrigSQL string, ctx sessionctx.Context, stmt as return stmt } if bindRecord.Status == bindinfo.Using { + metrics.BindUsageCounter.WithLabelValues(metrics.ScopeSession).Inc() return bindinfo.BindHint(stmt, bindRecord.Ast) } } @@ -425,6 +414,7 @@ func addHintForSelect(hash, normdOrigSQL string, ctx sessionctx.Context, stmt as bindRecord = globalHandle.GetBindRecord(hash, normdOrigSQL, "") } if bindRecord != nil { + metrics.BindUsageCounter.WithLabelValues(metrics.ScopeGlobal).Inc() return bindinfo.BindHint(stmt, bindRecord.Ast) } return stmt diff --git a/executor/ddl.go b/executor/ddl.go index c7fbc4e79781a..018bdc2bf6358 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -72,7 +72,7 @@ func (e *DDLExec) toErr(err error) error { } // Next implements the Executor Next interface. -func (e *DDLExec) Next(ctx context.Context, req *chunk.RecordBatch) (err error) { +func (e *DDLExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { if e.done { return nil } diff --git a/executor/ddl_test.go b/executor/ddl_test.go index bc43c6bd8dbf4..7a5860dcebb09 100644 --- a/executor/ddl_test.go +++ b/executor/ddl_test.go @@ -17,16 +17,21 @@ import ( "context" "fmt" "math" + "strconv" "strings" "time" . "github.com/pingcap/check" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/ddl" ddlutil "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/meta/autoid" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" @@ -90,8 +95,8 @@ func (s *testSuite3) TestCreateTable(c *C) { rs, err := tk.Exec(`desc issue312_1`) c.Assert(err, IsNil) ctx := context.Background() - req := rs.NewRecordBatch() - it := chunk.NewIterator4Chunk(req.Chunk) + req := rs.NewChunk() + it := chunk.NewIterator4Chunk(req) for { err1 := rs.Next(ctx, req) c.Assert(err1, IsNil) @@ -104,8 +109,8 @@ func (s *testSuite3) TestCreateTable(c *C) { } rs, err = tk.Exec(`desc issue312_2`) c.Assert(err, IsNil) - req = rs.NewRecordBatch() - it = chunk.NewIterator4Chunk(req.Chunk) + req = rs.NewChunk() + it = chunk.NewIterator4Chunk(req) for { err1 := rs.Next(ctx, req) c.Assert(err1, IsNil) @@ -117,6 +122,36 @@ func (s *testSuite3) TestCreateTable(c *C) { } } + // test multiple collate specified in column when create. + tk.MustExec("drop table if exists test_multiple_column_collate;") + tk.MustExec("create table test_multiple_column_collate (a char(1) collate utf8_bin collate utf8_general_ci) charset utf8mb4 collate utf8mb4_bin") + t, err := domain.GetDomain(tk.Se).InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("test_multiple_column_collate")) + c.Assert(err, IsNil) + c.Assert(t.Cols()[0].Charset, Equals, "utf8") + c.Assert(t.Cols()[0].Collate, Equals, "utf8_general_ci") + c.Assert(t.Meta().Charset, Equals, "utf8mb4") + c.Assert(t.Meta().Collate, Equals, "utf8mb4_bin") + + tk.MustExec("drop table if exists test_multiple_column_collate;") + tk.MustExec("create table test_multiple_column_collate (a char(1) charset utf8 collate utf8_bin collate utf8_general_ci) charset utf8mb4 collate utf8mb4_bin") + t, err = domain.GetDomain(tk.Se).InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("test_multiple_column_collate")) + c.Assert(err, IsNil) + c.Assert(t.Cols()[0].Charset, Equals, "utf8") + c.Assert(t.Cols()[0].Collate, Equals, "utf8_general_ci") + c.Assert(t.Meta().Charset, Equals, "utf8mb4") + c.Assert(t.Meta().Collate, Equals, "utf8mb4_bin") + + // test Err case for multiple collate specified in column when create. + tk.MustExec("drop table if exists test_err_multiple_collate;") + _, err = tk.Exec("create table test_err_multiple_collate (a char(1) charset utf8mb4 collate utf8_unicode_ci collate utf8_general_ci) charset utf8mb4 collate utf8mb4_bin") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, ddl.ErrCollationCharsetMismatch.GenWithStackByArgs("utf8_unicode_ci", "utf8mb4").Error()) + + tk.MustExec("drop table if exists test_err_multiple_collate;") + _, err = tk.Exec("create table test_err_multiple_collate (a char(1) collate utf8_unicode_ci collate utf8mb4_general_ci) charset utf8mb4 collate utf8mb4_bin") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, ddl.ErrCollationCharsetMismatch.GenWithStackByArgs("utf8mb4_general_ci", "utf8").Error()) + // table option is auto-increment tk.MustExec("drop table if exists create_auto_increment_test;") tk.MustExec("create table create_auto_increment_test (id int not null auto_increment, name varchar(255), primary key(id)) auto_increment = 999;") @@ -187,6 +222,36 @@ func (s *testSuite3) TestCreateView(c *C) { tk.MustExec("create table if not exists t1 (a int ,b int)") _, err = tk.Exec("create or replace view t1 as select * from t1") c.Assert(err.Error(), Equals, ddl.ErrWrongObject.GenWithStackByArgs("test", "t1", "VIEW").Error()) + // create view using prepare + tk.MustExec(`prepare stmt from "create view v10 (x) as select 1";`) + tk.MustExec("execute stmt") + + // create view on union + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("drop view if exists v") + _, err = tk.Exec("create view v as select * from t1 union select * from t2") + c.Assert(terror.ErrorEqual(err, infoschema.ErrTableNotExists), IsTrue) + tk.MustExec("create table t1(a int, b int)") + tk.MustExec("create table t2(a int, b int)") + tk.MustExec("insert into t1 values(1,2), (1,1), (1,2)") + tk.MustExec("insert into t2 values(1,1),(1,3)") + tk.MustExec("create definer='root'@'localhost' view v as select * from t1 union select * from t2") + tk.MustQuery("select * from v").Sort().Check(testkit.Rows("1 1", "1 2", "1 3")) + tk.MustExec("alter table t1 drop column a") + _, err = tk.Exec("select * from v") + c.Assert(terror.ErrorEqual(err, plannercore.ErrViewInvalid), IsTrue) + tk.MustExec("alter table t1 add column a int") + tk.MustQuery("select * from v").Sort().Check(testkit.Rows("1 1", "1 3", " 1", " 2")) + tk.MustExec("alter table t1 drop column a") + tk.MustExec("alter table t2 drop column b") + _, err = tk.Exec("select * from v") + c.Assert(terror.ErrorEqual(err, plannercore.ErrViewInvalid), IsTrue) + tk.MustExec("drop view v") + + tk.MustExec("create view v as (select * from t1)") + tk.MustExec("drop view v") + tk.MustExec("create view v as (select * from t1 union select * from t2)") + tk.MustExec("drop view v") } func (s *testSuite3) TestCreateDropDatabase(c *C) { @@ -254,7 +319,7 @@ func (s *testSuite3) TestAlterTableAddColumn(c *C) { now := time.Now().Add(-time.Duration(1 * time.Millisecond)).Format(types.TimeFormat) r, err := tk.Exec("select c2 from alter_test") c.Assert(err, IsNil) - req := r.NewRecordBatch() + req := r.NewChunk() err = r.Next(context.Background(), req) c.Assert(err, IsNil) row := req.GetRow(0) @@ -294,7 +359,7 @@ func (s *testSuite3) TestAlterTableModifyColumn(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists mc") - tk.MustExec("create table mc(c1 int, c2 varchar(10))") + tk.MustExec("create table mc(c1 int, c2 varchar(10), c3 bit)") _, err := tk.Exec("alter table mc modify column c1 short") c.Assert(err, NotNil) tk.MustExec("alter table mc modify column c1 bigint") @@ -307,14 +372,52 @@ func (s *testSuite3) TestAlterTableModifyColumn(c *C) { tk.MustExec("alter table mc modify column c2 varchar(11)") tk.MustExec("alter table mc modify column c2 text(13)") tk.MustExec("alter table mc modify column c2 text") + tk.MustExec("alter table mc modify column c3 bit") result := tk.MustQuery("show create table mc") createSQL := result.Rows()[0][1] - expected := "CREATE TABLE `mc` (\n `c1` bigint(20) DEFAULT NULL,\n `c2` text DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin" + expected := "CREATE TABLE `mc` (\n `c1` bigint(20) DEFAULT NULL,\n `c2` text DEFAULT NULL,\n `c3` bit(1) DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin" c.Assert(createSQL, Equals, expected) tk.MustExec("create or replace view alter_view as select c1,c2 from mc") _, err = tk.Exec("alter table alter_view modify column c2 text") c.Assert(err.Error(), Equals, ddl.ErrWrongObject.GenWithStackByArgs("test", "alter_view", "BASE TABLE").Error()) tk.MustExec("drop view alter_view") + + // test multiple collate modification in column. + tk.MustExec("drop table if exists modify_column_multiple_collate") + tk.MustExec("create table modify_column_multiple_collate (a char(1) collate utf8_bin collate utf8_general_ci) charset utf8mb4 collate utf8mb4_bin") + _, err = tk.Exec("alter table modify_column_multiple_collate modify column a char(1) collate utf8mb4_bin;") + c.Assert(err, IsNil) + t, err := domain.GetDomain(tk.Se).InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("modify_column_multiple_collate")) + c.Assert(err, IsNil) + c.Assert(t.Cols()[0].Charset, Equals, "utf8mb4") + c.Assert(t.Cols()[0].Collate, Equals, "utf8mb4_bin") + c.Assert(t.Meta().Charset, Equals, "utf8mb4") + c.Assert(t.Meta().Collate, Equals, "utf8mb4_bin") + + tk.MustExec("drop table if exists modify_column_multiple_collate;") + tk.MustExec("create table modify_column_multiple_collate (a char(1) collate utf8_bin collate utf8_general_ci) charset utf8mb4 collate utf8mb4_bin") + _, err = tk.Exec("alter table modify_column_multiple_collate modify column a char(1) charset utf8mb4 collate utf8mb4_bin;") + c.Assert(err, IsNil) + t, err = domain.GetDomain(tk.Se).InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("modify_column_multiple_collate")) + c.Assert(err, IsNil) + c.Assert(t.Cols()[0].Charset, Equals, "utf8mb4") + c.Assert(t.Cols()[0].Collate, Equals, "utf8mb4_bin") + c.Assert(t.Meta().Charset, Equals, "utf8mb4") + c.Assert(t.Meta().Collate, Equals, "utf8mb4_bin") + + // test Err case for multiple collate modification in column. + tk.MustExec("drop table if exists err_modify_multiple_collate;") + tk.MustExec("create table err_modify_multiple_collate (a char(1) collate utf8_bin collate utf8_general_ci) charset utf8mb4 collate utf8mb4_bin") + _, err = tk.Exec("alter table err_modify_multiple_collate modify column a char(1) charset utf8mb4 collate utf8_bin;") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, ddl.ErrCollationCharsetMismatch.GenWithStackByArgs("utf8_bin", "utf8mb4").Error()) + + tk.MustExec("drop table if exists err_modify_multiple_collate;") + tk.MustExec("create table err_modify_multiple_collate (a char(1) collate utf8_bin collate utf8_general_ci) charset utf8mb4 collate utf8mb4_bin") + _, err = tk.Exec("alter table err_modify_multiple_collate modify column a char(1) collate utf8_bin collate utf8mb4_bin;") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, ddl.ErrCollationCharsetMismatch.GenWithStackByArgs("utf8mb4_bin", "utf8").Error()) + } func (s *testSuite3) TestDefaultDBAfterDropCurDB(c *C) { @@ -337,6 +440,10 @@ func (s *testSuite3) TestDefaultDBAfterDropCurDB(c *C) { } func (s *testSuite3) TestRenameTable(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange"), IsNil) + }() tk := testkit.NewTestKit(c, s.store) tk.MustExec("create database rename1") @@ -519,38 +626,98 @@ func (s *testSuite3) TestShardRowIDBits(c *C) { for i := 0; i < 100; i++ { tk.MustExec(fmt.Sprintf("insert t values (%d)", i)) } - tbl, err := domain.GetDomain(tk.Se).InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + dom := domain.GetDomain(tk.Se) + tbl, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) - var hasShardedID bool - var count int - c.Assert(tk.Se.NewTxn(context.Background()), IsNil) - err = tbl.IterRecords(tk.Se, tbl.FirstKey(), nil, func(h int64, rec []types.Datum, cols []*table.Column) (more bool, err error) { - c.Assert(h, GreaterEqual, int64(0)) - first8bits := h >> 56 - if first8bits > 0 { - hasShardedID = true - } - count++ - return true, nil + + assertCountAndShard := func(t table.Table, expectCount int) { + var hasShardedID bool + var count int + c.Assert(tk.Se.NewTxn(context.Background()), IsNil) + err = t.IterRecords(tk.Se, t.FirstKey(), nil, func(h int64, rec []types.Datum, cols []*table.Column) (more bool, err error) { + c.Assert(h, GreaterEqual, int64(0)) + first8bits := h >> 56 + if first8bits > 0 { + hasShardedID = true + } + count++ + return true, nil + }) + c.Assert(err, IsNil) + c.Assert(count, Equals, expectCount) + c.Assert(hasShardedID, IsTrue) + } + + assertCountAndShard(tbl, 100) + + // After PR 10759, shard_row_id_bits is supported with tables with auto_increment column. + tk.MustExec("create table auto (id int not null auto_increment unique) shard_row_id_bits = 4") + tk.MustExec("alter table auto shard_row_id_bits = 5") + tk.MustExec("drop table auto") + tk.MustExec("create table auto (id int not null auto_increment unique) shard_row_id_bits = 0") + tk.MustExec("alter table auto shard_row_id_bits = 5") + tk.MustExec("drop table auto") + tk.MustExec("create table auto (id int not null auto_increment unique)") + tk.MustExec("alter table auto shard_row_id_bits = 5") + tk.MustExec("drop table auto") + tk.MustExec("create table auto (id int not null auto_increment unique) shard_row_id_bits = 4") + tk.MustExec("alter table auto shard_row_id_bits = 0") + tk.MustExec("drop table auto") + + // After PR 10759, shard_row_id_bits is not supported with pk_is_handle tables. + err = tk.ExecToErr("create table auto (id int not null auto_increment primary key, b int) shard_row_id_bits = 4") + c.Assert(err.Error(), Equals, "[ddl:207]unsupported shard_row_id_bits for table with primary key as row id.") + tk.MustExec("create table auto (id int not null auto_increment primary key, b int) shard_row_id_bits = 0") + err = tk.ExecToErr("alter table auto shard_row_id_bits = 5") + c.Assert(err.Error(), Equals, "[ddl:207]unsupported shard_row_id_bits for table with primary key as row id.") + tk.MustExec("alter table auto shard_row_id_bits = 0") + + // Hack an existing table with shard_row_id_bits and primary key as handle + db, ok := dom.InfoSchema().SchemaByName(model.NewCIStr("test")) + c.Assert(ok, IsTrue) + tbl, err = dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("auto")) + tblInfo := tbl.Meta() + tblInfo.ShardRowIDBits = 5 + tblInfo.MaxShardRowIDBits = 5 + + kv.RunInNewTxn(s.store, false, func(txn kv.Transaction) error { + m := meta.NewMeta(txn) + _, err = m.GenSchemaVersion() + c.Assert(err, IsNil) + c.Assert(m.UpdateTable(db.ID, tblInfo), IsNil) + return nil }) + err = dom.Reload() c.Assert(err, IsNil) - c.Assert(count, Equals, 100) - c.Assert(hasShardedID, IsTrue) - // Test that audo_increment column can not use shard_row_id_bits. - _, err = tk.Exec("create table auto (id int not null auto_increment primary key) shard_row_id_bits = 4") - c.Assert(err, NotNil) - tk.MustExec("create table auto (id int not null auto_increment primary key) shard_row_id_bits = 0") - _, err = tk.Exec("alter table auto shard_row_id_bits = 4") - c.Assert(err, NotNil) + tk.MustExec("insert auto(b) values (1), (3), (5)") + tk.MustQuery("select id from auto order by id").Check(testkit.Rows("1", "2", "3")) + tk.MustExec("alter table auto shard_row_id_bits = 0") + tk.MustExec("drop table auto") + + // Test shard_row_id_bits with auto_increment column + tk.MustExec("create table auto (a int, b int auto_increment unique) shard_row_id_bits = 15") + for i := 0; i < 100; i++ { + tk.MustExec(fmt.Sprintf("insert auto(a) values (%d)", i)) + } + tbl, err = dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("auto")) + assertCountAndShard(tbl, 100) + prevB, err := strconv.Atoi(tk.MustQuery("select b from auto where a=0").Rows()[0][0].(string)) + c.Assert(err, IsNil) + for i := 1; i < 100; i++ { + b, err := strconv.Atoi(tk.MustQuery(fmt.Sprintf("select b from auto where a=%d", i)).Rows()[0][0].(string)) + c.Assert(err, IsNil) + c.Assert(b, Greater, prevB) + prevB = b + } // Test overflow tk.MustExec("drop table if exists t1") tk.MustExec("create table t1 (a int) shard_row_id_bits = 15") defer tk.MustExec("drop table if exists t1") - tbl, err = domain.GetDomain(tk.Se).InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t1")) + tbl, err = dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t1")) c.Assert(err, IsNil) maxID := 1<<(64-15-1) - 1 err = tbl.RebaseAutoID(tk.Se, int64(maxID)-1, false) @@ -712,11 +879,14 @@ func (s *testSuite3) TestGeneratedColumnRelatedDDL(c *C) { _, err = tk.Exec("alter table t1 add column d bigint generated always as (a + 1);") c.Assert(err.Error(), Equals, ddl.ErrGeneratedColumnRefAutoInc.GenWithStackByArgs("d").Error()) - tk.MustExec("alter table t1 add column d bigint generated always as (b + 1); ") + tk.MustExec("alter table t1 add column d bigint generated always as (b + 1);") _, err = tk.Exec("alter table t1 modify column d bigint generated always as (a + 1);") c.Assert(err.Error(), Equals, ddl.ErrGeneratedColumnRefAutoInc.GenWithStackByArgs("d").Error()) + _, err = tk.Exec("alter table t1 add column e bigint as (z + 1);") + c.Assert(err.Error(), Equals, ddl.ErrBadField.GenWithStackByArgs("z", "generated column function").Error()) + tk.MustExec("drop table t1;") } diff --git a/executor/delete.go b/executor/delete.go index c3c52363dff2b..0a10f00e2d74a 100644 --- a/executor/delete.go +++ b/executor/delete.go @@ -43,7 +43,7 @@ type DeleteExec struct { } // Next implements the Executor Next interface. -func (e *DeleteExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *DeleteExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("delete.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -100,12 +100,12 @@ func (e *DeleteExec) deleteSingleTableByChunk(ctx context.Context) error { // If tidb_batch_delete is ON and not in a transaction, we could use BatchDelete mode. batchDelete := e.ctx.GetSessionVars().BatchDelete && !e.ctx.GetSessionVars().InTxn() batchDMLSize := e.ctx.GetSessionVars().DMLBatchSize - fields := e.children[0].retTypes() - chk := e.children[0].newFirstChunk() + fields := retTypes(e.children[0]) + chk := newFirstChunk(e.children[0]) for { iter := chunk.NewIterator4Chunk(chk) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chk) if err != nil { return err } @@ -183,11 +183,11 @@ func (e *DeleteExec) deleteMultiTablesByChunk(ctx context.Context) error { e.initialMultiTableTblMap() colPosInfos := e.getColPosInfos(e.children[0].Schema()) tblRowMap := make(tableRowMapType) - fields := e.children[0].retTypes() - chk := e.children[0].newFirstChunk() + fields := retTypes(e.children[0]) + chk := newFirstChunk(e.children[0]) for { iter := chunk.NewIterator4Chunk(chk) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chk) if err != nil { return err } diff --git a/executor/distsql.go b/executor/distsql.go index 5bd77376292c1..f499070c91e7e 100644 --- a/executor/distsql.go +++ b/executor/distsql.go @@ -19,6 +19,7 @@ import ( "math" "runtime" "sort" + "strconv" "sync" "sync/atomic" "time" @@ -37,6 +38,7 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" @@ -62,6 +64,7 @@ type lookupTableTask struct { handles []int64 rowIdx []int // rowIdx represents the handle index for every row. Only used when keep order. rows []chunk.Row + idxRows *chunk.Chunk cursor int doneCh chan error @@ -71,6 +74,9 @@ type lookupTableTask struct { // The handles fetched from index is originally ordered by index, but we need handles to be ordered by itself // to do table request. indexOrder map[int64]int + // duplicatedIndexOrder map likes indexOrder. But it's used when checkIndexValue isn't nil and + // the same handle of index has multiple values. + duplicatedIndexOrder map[int64]int // memUsage records the memory usage of this task calculated by table worker. // memTracker is used to release memUsage after task is done and unused. @@ -252,6 +258,8 @@ type IndexReaderExecutor struct { colLens []int plans []plannercore.PhysicalPlan + memTracker *memory.Tracker + selectResultHook // for testing } @@ -268,7 +276,7 @@ func (e *IndexReaderExecutor) Close() error { } // Next implements the Executor Next interface. -func (e *IndexReaderExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *IndexReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("tableReader.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -277,7 +285,7 @@ func (e *IndexReaderExecutor) Next(ctx context.Context, req *chunk.RecordBatch) start := time.Now() defer func() { e.runtimeStats.Record(time.Since(start), req.NumRows()) }() } - err := e.result.Next(ctx, req.Chunk) + err := e.result.Next(ctx, req) if err != nil { e.feedback.Invalidate() } @@ -301,8 +309,6 @@ func (e *IndexReaderExecutor) Open(ctx context.Context) error { return e.open(ctx, kvRanges) } -var indexReaderDistSQLTrackerLabel fmt.Stringer = stringutil.StringerStr("IndexReaderDistSQLTracker") - func (e *IndexReaderExecutor) open(ctx context.Context, kvRanges []kv.KeyRange) error { var err error if e.corColInFilter { @@ -317,6 +323,8 @@ func (e *IndexReaderExecutor) open(ctx context.Context, kvRanges []kv.KeyRange) e.dagPB.CollectExecutionSummaries = &collExec } + e.memTracker = memory.NewTracker(e.id, e.ctx.GetSessionVars().MemQuotaDistSQL) + e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) var builder distsql.RequestBuilder kvReq, err := builder.SetKeyRanges(kvRanges). SetDAGRequest(e.dagPB). @@ -324,13 +332,13 @@ func (e *IndexReaderExecutor) open(ctx context.Context, kvRanges []kv.KeyRange) SetKeepOrder(e.keepOrder). SetStreaming(e.streaming). SetFromSessionVars(e.ctx.GetSessionVars()). - SetMemTracker(e.ctx, indexReaderDistSQLTrackerLabel). + SetMemTracker(e.memTracker). Build() if err != nil { e.feedback.Invalidate() return err } - e.result, err = e.SelectResult(ctx, e.ctx, kvReq, e.retTypes(), e.feedback, getPhysicalPlanIDs(e.plans)) + e.result, err = e.SelectResult(ctx, e.ctx, kvReq, retTypes(e), e.feedback, getPhysicalPlanIDs(e.plans)) if err != nil { e.feedback.Invalidate() return err @@ -373,8 +381,8 @@ type IndexLookUpExecutor struct { // memTracker is used to track the memory usage of this executor. memTracker *memory.Tracker - // isCheckOp is used to determine whether we need to check the consistency of the index data. - isCheckOp bool + // checkIndexValue is used to check the consistency of the index data. + *checkIndexValue corColInIdxSide bool idxPlans []plannercore.PhysicalPlan @@ -383,6 +391,14 @@ type IndexLookUpExecutor struct { corColInAccess bool idxCols []*expression.Column colLens []int + // PushedLimit is used to skip the preceding and tailing handles when Limit is sunk into IndexLookUpReader. + PushedLimit *plannercore.PushedDownLimit +} + +type checkIndexValue struct { + idxColTps []*types.FieldType + idxTblCols []*table.Column + genExprs map[model.TableColumnID]expression.Expression } // Open implements the Executor Open interface. @@ -455,6 +471,8 @@ func (e *IndexLookUpExecutor) startIndexWorker(ctx context.Context, kvRanges []k e.dagPB.CollectExecutionSummaries = &collExec } + tracker := memory.NewTracker(stringutil.StringerStr("IndexWorker"), e.ctx.GetSessionVars().MemQuotaIndexLookupReader) + tracker.AttachTo(e.memTracker) var builder distsql.RequestBuilder kvReq, err := builder.SetKeyRanges(kvRanges). SetDAGRequest(e.dagPB). @@ -462,26 +480,32 @@ func (e *IndexLookUpExecutor) startIndexWorker(ctx context.Context, kvRanges []k SetKeepOrder(e.keepOrder). SetStreaming(e.indexStreaming). SetFromSessionVars(e.ctx.GetSessionVars()). - SetMemTracker(e.ctx, indexLookupDistSQLTrackerLabel). + SetMemTracker(tracker). Build() if err != nil { return err } + tps := []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)} + if e.checkIndexValue != nil { + tps = e.idxColTps + } // Since the first read only need handle information. So its returned col is only 1. - result, err := distsql.SelectWithRuntimeStats(ctx, e.ctx, kvReq, []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, e.feedback, getPhysicalPlanIDs(e.idxPlans)) + result, err := distsql.SelectWithRuntimeStats(ctx, e.ctx, kvReq, tps, e.feedback, getPhysicalPlanIDs(e.idxPlans)) if err != nil { return err } result.Fetch(ctx) worker := &indexWorker{ - idxLookup: e, - workCh: workCh, - finished: e.finished, - resultCh: e.resultCh, - keepOrder: e.keepOrder, - batchSize: initBatchSize, - maxBatchSize: e.ctx.GetSessionVars().IndexLookupSize, - maxChunkSize: e.maxChunkSize, + idxLookup: e, + workCh: workCh, + finished: e.finished, + resultCh: e.resultCh, + keepOrder: e.keepOrder, + batchSize: initBatchSize, + checkIndexValue: e.checkIndexValue, + maxBatchSize: e.ctx.GetSessionVars().IndexLookupSize, + maxChunkSize: e.maxChunkSize, + PushedLimit: e.PushedLimit, } if worker.batchSize > worker.maxBatchSize { worker.batchSize = worker.maxBatchSize @@ -499,9 +523,9 @@ func (e *IndexLookUpExecutor) startIndexWorker(ctx context.Context, kvRanges []k } if e.runtimeStats != nil { copStats := e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.GetRootStats(e.idxPlans[len(e.idxPlans)-1].ExplainID().String()) - copStats.SetRowNum(count) + copStats.SetRowNum(int64(count)) copStats = e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.GetRootStats(e.tblPlans[0].ExplainID().String()) - copStats.SetRowNum(count) + copStats.SetRowNum(int64(count)) } e.ctx.StoreQueryFeedback(e.feedback) close(workCh) @@ -511,22 +535,21 @@ func (e *IndexLookUpExecutor) startIndexWorker(ctx context.Context, kvRanges []k return nil } -var tableWorkerLabel fmt.Stringer = stringutil.StringerStr("tableWorker") - // startTableWorker launchs some background goroutines which pick tasks from workCh and execute the task. func (e *IndexLookUpExecutor) startTableWorker(ctx context.Context, workCh <-chan *lookupTableTask) { lookupConcurrencyLimit := e.ctx.GetSessionVars().IndexLookupConcurrency e.tblWorkerWg.Add(lookupConcurrencyLimit) for i := 0; i < lookupConcurrencyLimit; i++ { worker := &tableWorker{ - idxLookup: e, - workCh: workCh, - finished: e.finished, - buildTblReader: e.buildTableReader, - keepOrder: e.keepOrder, - handleIdx: e.handleIdx, - isCheckOp: e.isCheckOp, - memTracker: memory.NewTracker(tableWorkerLabel, -1), + idxLookup: e, + workCh: workCh, + finished: e.finished, + buildTblReader: e.buildTableReader, + keepOrder: e.keepOrder, + handleIdx: e.handleIdx, + checkIndexValue: e.checkIndexValue, + memTracker: memory.NewTracker(stringutil.MemoizeStr(func() string { return "TableWorker_" + strconv.Itoa(i) }), + e.ctx.GetSessionVars().MemQuotaIndexLookupReader), } worker.memTracker.AttachTo(e.memTracker) ctx1, cancel := context.WithCancel(ctx) @@ -571,7 +594,6 @@ func (e *IndexLookUpExecutor) Close() error { e.tblWorkerWg.Wait() e.finished = nil e.workerStarted = false - e.memTracker.Detach() e.memTracker = nil if e.runtimeStats != nil { copStats := e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.GetRootStats(e.idxPlans[0].ExplainID().String()) @@ -581,7 +603,7 @@ func (e *IndexLookUpExecutor) Close() error { } // Next implements Exec Next interface. -func (e *IndexLookUpExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *IndexLookUpExecutor) Next(ctx context.Context, req *chunk.Chunk) error { if e.runtimeStats != nil { start := time.Now() defer func() { e.runtimeStats.Record(time.Since(start), req.NumRows()) }() @@ -642,12 +664,17 @@ type indexWorker struct { batchSize int maxBatchSize int maxChunkSize int + + // checkIndexValue is used to check the consistency of the index data. + *checkIndexValue + // PushedLimit is used to skip the preceding and tailing handles when Limit is sunk into IndexLookUpReader. + PushedLimit *plannercore.PushedDownLimit } // fetchHandles fetches a batch of handles from index data and builds the index lookup tasks. // The tasks are sent to workCh to be further processed by tableWorker, and sent to e.resultCh // at the same time to keep data ordered. -func (w *indexWorker) fetchHandles(ctx context.Context, result distsql.SelectResult) (count int64, err error) { +func (w *indexWorker) fetchHandles(ctx context.Context, result distsql.SelectResult) (count uint64, err error) { defer func() { if r := recover(); r != nil { buf := make([]byte, 4096) @@ -665,9 +692,14 @@ func (w *indexWorker) fetchHandles(ctx context.Context, result distsql.SelectRes } } }() - chk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, w.idxLookup.maxChunkSize) + var chk *chunk.Chunk + if w.checkIndexValue != nil { + chk = chunk.NewChunkWithCapacity(w.idxColTps, w.maxChunkSize) + } else { + chk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, w.idxLookup.maxChunkSize) + } for { - handles, err := w.extractTaskHandles(ctx, chk, result) + handles, retChunk, scannedKeys, err := w.extractTaskHandles(ctx, chk, result, count) if err != nil { doneCh := make(chan error, 1) doneCh <- err @@ -676,11 +708,11 @@ func (w *indexWorker) fetchHandles(ctx context.Context, result distsql.SelectRes } return count, err } + count += scannedKeys if len(handles) == 0 { return count, nil } - count += int64(len(handles)) - task := w.buildTableTask(handles) + task := w.buildTableTask(handles, retChunk) select { case <-ctx.Done(): return count, nil @@ -692,30 +724,63 @@ func (w *indexWorker) fetchHandles(ctx context.Context, result distsql.SelectRes } } -func (w *indexWorker) extractTaskHandles(ctx context.Context, chk *chunk.Chunk, idxResult distsql.SelectResult) (handles []int64, err error) { +func (w *indexWorker) extractTaskHandles(ctx context.Context, chk *chunk.Chunk, idxResult distsql.SelectResult, count uint64) ( + handles []int64, retChk *chunk.Chunk, scannedKeys uint64, err error) { + handleOffset := chk.NumCols() - 1 handles = make([]int64, 0, w.batchSize) + // PushedLimit would always be nil for CheckIndex or CheckTable, we add this check just for insurance. + checkLimit := (w.PushedLimit != nil) && (w.checkIndexValue == nil) for len(handles) < w.batchSize { - chk.SetRequiredRows(w.batchSize-len(handles), w.maxChunkSize) - err = idxResult.Next(ctx, chk) + requiredRows := w.batchSize - len(handles) + if checkLimit { + if w.PushedLimit.Offset+w.PushedLimit.Count <= scannedKeys+count { + return handles, nil, scannedKeys, nil + } + leftCnt := w.PushedLimit.Offset + w.PushedLimit.Count - scannedKeys - count + if uint64(requiredRows) > leftCnt { + requiredRows = int(leftCnt) + } + } + chk.SetRequiredRows(requiredRows, w.maxChunkSize) + err = errors.Trace(idxResult.Next(ctx, chk)) if err != nil { - return handles, err + return handles, nil, scannedKeys, err } if chk.NumRows() == 0 { - return handles, nil + return handles, retChk, scannedKeys, nil } for i := 0; i < chk.NumRows(); i++ { - handles = append(handles, chk.GetRow(i).GetInt64(0)) + scannedKeys++ + if checkLimit { + if (count + scannedKeys) <= w.PushedLimit.Offset { + // Skip the preceding Offset handles. + continue + } + if (count + scannedKeys) > (w.PushedLimit.Offset + w.PushedLimit.Count) { + // Skip the handles after Offset+Count. + return handles, nil, scannedKeys, nil + } + } + h := chk.GetRow(i).GetInt64(handleOffset) + handles = append(handles, h) + } + if w.checkIndexValue != nil { + if retChk == nil { + retChk = chunk.NewChunkWithCapacity(w.idxColTps, w.batchSize) + } + retChk.Append(chk, 0, chk.NumRows()) } } w.batchSize *= 2 if w.batchSize > w.maxBatchSize { w.batchSize = w.maxBatchSize } - return handles, nil + return handles, retChk, scannedKeys, nil } -func (w *indexWorker) buildTableTask(handles []int64) *lookupTableTask { +func (w *indexWorker) buildTableTask(handles []int64, retChk *chunk.Chunk) *lookupTableTask { var indexOrder map[int64]int + var duplicatedIndexOrder map[int64]int if w.keepOrder { // Save the index order. indexOrder = make(map[int64]int, len(handles)) @@ -723,10 +788,27 @@ func (w *indexWorker) buildTableTask(handles []int64) *lookupTableTask { indexOrder[h] = i } } + + if w.checkIndexValue != nil { + // Save the index order. + indexOrder = make(map[int64]int, len(handles)) + duplicatedIndexOrder = make(map[int64]int) + for i, h := range handles { + if _, ok := indexOrder[h]; ok { + duplicatedIndexOrder[h] = i + } else { + indexOrder[h] = i + } + } + } + task := &lookupTableTask{ - handles: handles, - indexOrder: indexOrder, + handles: handles, + indexOrder: indexOrder, + duplicatedIndexOrder: duplicatedIndexOrder, + idxRows: retChk, } + task.doneCh = make(chan error, 1) return task } @@ -743,8 +825,8 @@ type tableWorker struct { // memTracker is used to track the memory usage of this executor. memTracker *memory.Tracker - // isCheckOp is used to determine whether we need to check the consistency of the index data. - isCheckOp bool + // checkIndexValue is used to check the consistency of the index data. + *checkIndexValue } // pickAndExecTask picks tasks from workCh, and execute them. @@ -777,6 +859,66 @@ func (w *tableWorker) pickAndExecTask(ctx context.Context) { } } +func (w *tableWorker) compareData(ctx context.Context, task *lookupTableTask, tableReader Executor) error { + chk := newFirstChunk(tableReader) + tblInfo := w.idxLookup.table.Meta() + vals := make([]types.Datum, 0, len(w.idxTblCols)) + for { + err := tableReader.Next(ctx, chk) + if err != nil { + return errors.Trace(err) + } + if chk.NumRows() == 0 { + for h := range task.indexOrder { + idxRow := task.idxRows.GetRow(task.indexOrder[h]) + return errors.Errorf("handle %#v, index:%#v != record:%#v", h, idxRow.GetDatum(0, w.idxColTps[0]), nil) + } + break + } + + tblReaderExec := tableReader.(*TableReaderExecutor) + iter := chunk.NewIterator4Chunk(chk) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + handle := row.GetInt64(w.handleIdx) + offset, ok := task.indexOrder[handle] + if !ok { + offset = task.duplicatedIndexOrder[handle] + } + delete(task.indexOrder, handle) + idxRow := task.idxRows.GetRow(offset) + vals = vals[:0] + for i, col := range w.idxTblCols { + if col.IsGenerated() && !col.GeneratedStored { + expr := w.genExprs[model.TableColumnID{TableID: tblInfo.ID, ColumnID: col.ID}] + // Eval the column value + val, err := expr.Eval(row) + if err != nil { + return errors.Trace(err) + } + val, err = table.CastValue(tblReaderExec.ctx, val, col.ColumnInfo) + if err != nil { + return errors.Trace(err) + } + vals = append(vals, val) + } else { + vals = append(vals, row.GetDatum(i, &col.FieldType)) + } + } + vals = tables.TruncateIndexValuesIfNeeded(tblInfo, w.idxLookup.index, vals) + for i, val := range vals { + col := w.idxTblCols[i] + tp := &col.FieldType + ret := chunk.Compare(idxRow, i, &val) + if ret != 0 { + return errors.Errorf("col %s, handle %#v, index:%#v != record:%#v", col.Name, handle, idxRow.GetDatum(i, tp), val) + } + } + } + } + + return nil +} + // executeTask executes the table look up tasks. We will construct a table reader and send request by handles. // Then we hold the returning rows and finish this task. func (w *tableWorker) executeTask(ctx context.Context, task *lookupTableTask) error { @@ -787,6 +929,10 @@ func (w *tableWorker) executeTask(ctx context.Context, task *lookupTableTask) er } defer terror.Call(tableReader.Close) + if w.checkIndexValue != nil { + return w.compareData(ctx, task, tableReader) + } + task.memTracker = w.memTracker memUsage := int64(cap(task.handles) * 8) task.memUsage = memUsage @@ -794,8 +940,8 @@ func (w *tableWorker) executeTask(ctx context.Context, task *lookupTableTask) er handleCnt := len(task.handles) task.rows = make([]chunk.Row, 0, handleCnt) for { - chk := tableReader.newFirstChunk() - err = tableReader.Next(ctx, chunk.NewRecordBatch(chk)) + chk := newFirstChunk(tableReader) + err = tableReader.Next(ctx, chk) if err != nil { logutil.Logger(ctx).Error("table reader fetch next chunk failed", zap.Error(err)) return err @@ -828,17 +974,18 @@ func (w *tableWorker) executeTask(ctx context.Context, task *lookupTableTask) er } if handleCnt != len(task.rows) { - if w.isCheckOp { + if len(w.idxLookup.tblPlans) == 1 { obtainedHandlesMap := make(map[int64]struct{}, len(task.rows)) for _, row := range task.rows { handle := row.GetInt64(w.handleIdx) obtainedHandlesMap[handle] = struct{}{} } - return errors.Errorf("inconsistent index %s handle count %d isn't equal to value count %d, missing handles %v in a batch", - w.idxLookup.index.Name.O, handleCnt, len(task.rows), GetLackHandles(task.handles, obtainedHandlesMap)) - } - if len(w.idxLookup.tblPlans) == 1 { + logutil.Logger(ctx).Error("inconsistent index handles", zap.String("index", w.idxLookup.index.Name.O), + zap.Int("index_cnt", handleCnt), zap.Int("table_cnt", len(task.rows)), + zap.Int64s("missing_handles", GetLackHandles(task.handles, obtainedHandlesMap)), + zap.Int64s("total_handles", task.handles)) + // table scan in double read can never has conditions according to convertToIndexScan. // if this table scan has no condition, the number of rows it returns must equal to the length of handles. return errors.Errorf("inconsistent index %s handle count %d isn't equal to value count %d", diff --git a/executor/distsql_test.go b/executor/distsql_test.go index 814e6f47fe1bb..dfbf03221cebf 100644 --- a/executor/distsql_test.go +++ b/executor/distsql_test.go @@ -71,7 +71,7 @@ func (s *testSuite3) TestCopClientSend(c *C) { // Send coprocessor request when the table split. rs, err := tk.Exec("select sum(id) from copclient") c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.GetRow(0).GetMyDecimal(0).String(), Equals, "499500") @@ -86,7 +86,7 @@ func (s *testSuite3) TestCopClientSend(c *C) { // Check again. rs, err = tk.Exec("select sum(id) from copclient") c.Assert(err, IsNil) - req = rs.NewRecordBatch() + req = rs.NewChunk() err = rs.Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.GetRow(0).GetMyDecimal(0).String(), Equals, "499500") @@ -95,7 +95,7 @@ func (s *testSuite3) TestCopClientSend(c *C) { // Check there is no goroutine leak. rs, err = tk.Exec("select * from copclient order by id") c.Assert(err, IsNil) - req = rs.NewRecordBatch() + req = rs.NewChunk() err = rs.Next(ctx, req) c.Assert(err, IsNil) rs.Close() @@ -237,3 +237,18 @@ func (s *testSuite3) TestInconsistentIndex(c *C) { c.Assert(err, IsNil) } } + +func (s *testSuite3) TestPushLimitDownIndexLookUpReader(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists tbl") + tk.MustExec("create table tbl(a int, b int, c int, key idx_b_c(b,c))") + tk.MustExec("insert into tbl values(1,1,1),(2,2,2),(3,3,3),(4,4,4),(5,5,5)") + tk.MustQuery("select * from tbl use index(idx_b_c) where b > 1 limit 2,1").Check(testkit.Rows("4 4 4")) + tk.MustQuery("select * from tbl use index(idx_b_c) where b > 4 limit 2,1").Check(testkit.Rows()) + tk.MustQuery("select * from tbl use index(idx_b_c) where b > 3 limit 2,1").Check(testkit.Rows()) + tk.MustQuery("select * from tbl use index(idx_b_c) where b > 2 limit 2,1").Check(testkit.Rows("5 5 5")) + tk.MustQuery("select * from tbl use index(idx_b_c) where b > 1 limit 1").Check(testkit.Rows("2 2 2")) + tk.MustQuery("select * from tbl use index(idx_b_c) where b > 1 order by b desc limit 2,1").Check(testkit.Rows("3 3 3")) + tk.MustQuery("select * from tbl use index(idx_b_c) where b > 1 and c > 1 limit 2,1").Check(testkit.Rows("4 4 4")) +} diff --git a/executor/errors.go b/executor/errors.go index b7f8ce2ebf19f..a48152f0acdfe 100644 --- a/executor/errors.go +++ b/executor/errors.go @@ -51,6 +51,8 @@ var ( ErrBadDB = terror.ClassExecutor.New(mysql.ErrBadDB, mysql.MySQLErrName[mysql.ErrBadDB]) ErrWrongObject = terror.ClassExecutor.New(mysql.ErrWrongObject, mysql.MySQLErrName[mysql.ErrWrongObject]) ErrRoleNotGranted = terror.ClassPrivilege.New(mysql.ErrRoleNotGranted, mysql.MySQLErrName[mysql.ErrRoleNotGranted]) + ErrDeadlock = terror.ClassExecutor.New(mysql.ErrLockDeadlock, mysql.MySQLErrName[mysql.ErrLockDeadlock]) + ErrQueryInterrupted = terror.ClassExecutor.New(mysql.ErrQueryInterrupted, mysql.MySQLErrName[mysql.ErrQueryInterrupted]) ) func init() { @@ -67,6 +69,8 @@ func init() { mysql.ErrTableaccessDenied: mysql.ErrTableaccessDenied, mysql.ErrBadDB: mysql.ErrBadDB, mysql.ErrWrongObject: mysql.ErrWrongObject, + mysql.ErrLockDeadlock: mysql.ErrLockDeadlock, + mysql.ErrQueryInterrupted: mysql.ErrQueryInterrupted, } terror.ErrClassToMySQLCodes[terror.ClassExecutor] = tableMySQLErrCodes } diff --git a/executor/executor.go b/executor/executor.go index 16bf5746049dd..a9d2367ad24f0 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -39,8 +39,10 @@ import ( "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/admin" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/execdetails" @@ -86,6 +88,11 @@ type baseExecutor struct { runtimeStats *execdetails.RuntimeStats } +// base returns the baseExecutor of an executor, don't override this method! +func (e *baseExecutor) base() *baseExecutor { + return e +} + // Open initializes children recursively and "childrenResults" according to children's schemas. func (e *baseExecutor) Open(ctx context.Context) error { for _, child := range e.children { @@ -99,13 +106,13 @@ func (e *baseExecutor) Open(ctx context.Context) error { // Close closes all executors and release all resources. func (e *baseExecutor) Close() error { - for _, child := range e.children { - err := child.Close() - if err != nil { - return err + var firstErr error + for _, src := range e.children { + if err := src.Close(); err != nil && firstErr == nil { + firstErr = err } } - return nil + return firstErr } // Schema returns the current baseExecutor's schema. If it is nil, then create and return a new one. @@ -117,17 +124,19 @@ func (e *baseExecutor) Schema() *expression.Schema { } // newFirstChunk creates a new chunk to buffer current executor's result. -func (e *baseExecutor) newFirstChunk() *chunk.Chunk { - return chunk.New(e.retTypes(), e.initCap, e.maxChunkSize) +func newFirstChunk(e Executor) *chunk.Chunk { + base := e.base() + return chunk.New(base.retFieldTypes, base.initCap, base.maxChunkSize) } // retTypes returns all output column types. -func (e *baseExecutor) retTypes() []*types.FieldType { - return e.retFieldTypes +func retTypes(e Executor) []*types.FieldType { + base := e.base() + return base.retFieldTypes } // Next fills mutiple rows into a chunk. -func (e *baseExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *baseExecutor) Next(ctx context.Context, req *chunk.Chunk) error { return nil } @@ -166,13 +175,21 @@ func newBaseExecutor(ctx sessionctx.Context, schema *expression.Schema, id fmt.S // return a batch of rows, other than a single row in Volcano. // NOTE: Executors must call "chk.Reset()" before appending their results to it. type Executor interface { + base() *baseExecutor Open(context.Context) error - Next(ctx context.Context, req *chunk.RecordBatch) error + Next(ctx context.Context, req *chunk.Chunk) error Close() error Schema() *expression.Schema +} + +// Next is a wrapper function on e.Next(), it handles some common codes. +func Next(ctx context.Context, e Executor, req *chunk.Chunk) error { + sessVars := e.base().ctx.GetSessionVars() + if atomic.CompareAndSwapUint32(&sessVars.Killed, 1, 0) { + return ErrQueryInterrupted + } - retTypes() []*types.FieldType - newFirstChunk() *chunk.Chunk + return e.Next(ctx, req) } // CancelDDLJobsExec represents a cancel DDL jobs executor. @@ -185,7 +202,7 @@ type CancelDDLJobsExec struct { } // Next implements the Executor Next interface. -func (e *CancelDDLJobsExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *CancelDDLJobsExec) Next(ctx context.Context, req *chunk.Chunk) error { if e.runtimeStats != nil { start := time.Now() defer func() { e.runtimeStats.Record(time.Since(start), req.NumRows()) }() @@ -215,7 +232,7 @@ type ShowNextRowIDExec struct { } // Next implements the Executor Next interface. -func (e *ShowNextRowIDExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ShowNextRowIDExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.done { return nil @@ -255,7 +272,7 @@ type ShowDDLExec struct { } // Next implements the Executor Next interface. -func (e *ShowDDLExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ShowDDLExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.done { return nil @@ -339,7 +356,7 @@ func (e *ShowDDLJobQueriesExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *ShowDDLJobQueriesExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ShowDDLJobQueriesExec) Next(ctx context.Context, req *chunk.Chunk) error { req.GrowAndReset(e.maxChunkSize) if e.cursor >= len(e.jobs) { return nil @@ -386,7 +403,7 @@ func (e *ShowDDLJobsExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *ShowDDLJobsExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ShowDDLJobsExec) Next(ctx context.Context, req *chunk.Chunk) error { req.GrowAndReset(e.maxChunkSize) if e.cursor >= len(e.jobs) { return nil @@ -436,11 +453,14 @@ func getTableName(is infoschema.InfoSchema, id int64) string { type CheckTableExec struct { baseExecutor - tables []*ast.TableName - done bool - is infoschema.InfoSchema - - genExprs map[model.TableColumnID]expression.Expression + dbName string + table table.Table + indexInfos []*model.IndexInfo + srcs []*IndexLookUpExecutor + done bool + is infoschema.InfoSchema + exitCh chan struct{} + retCh chan error } // Open implements the Executor Open interface. @@ -448,63 +468,147 @@ func (e *CheckTableExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } + for _, src := range e.srcs { + if err := src.Open(ctx); err != nil { + return errors.Trace(err) + } + } e.done = false return nil } -// Next implements the Executor Next interface. -func (e *CheckTableExec) Next(ctx context.Context, req *chunk.RecordBatch) error { - if e.done { - return nil +// Close implements the Executor Close interface. +func (e *CheckTableExec) Close() error { + var firstErr error + for _, src := range e.srcs { + if err := src.Close(); err != nil && firstErr == nil { + firstErr = err + } } - defer func() { e.done = true }() - for _, t := range e.tables { - dbName := t.DBInfo.Name - tb, err := e.is.TableByName(dbName, t.Name) + return firstErr +} + +func (e *CheckTableExec) checkTableIndexHandle(ctx context.Context, idxInfo *model.IndexInfo) error { + // For partition table, there will be multi same index indexLookUpReaders on different partitions. + for _, src := range e.srcs { + if src.index.Name.L == idxInfo.Name.L { + err := e.checkIndexHandle(ctx, src) + if err != nil { + return err + } + } + } + return nil +} + +func (e *CheckTableExec) checkIndexHandle(ctx context.Context, src *IndexLookUpExecutor) error { + cols := src.schema.Columns + retFieldTypes := make([]*types.FieldType, len(cols)) + for i := range cols { + retFieldTypes[i] = cols[i].RetType + } + chk := chunk.New(retFieldTypes, e.initCap, e.maxChunkSize) + + var err error + for { + err = src.Next(ctx, chk) if err != nil { - return err + break } - if tb.Meta().GetPartitionInfo() != nil { - err = e.doCheckPartitionedTable(tb.(table.PartitionedTable)) - } else { - err = e.doCheckTable(tb) + if chk.NumRows() == 0 { + break } - if err != nil { - logutil.Logger(ctx).Warn("check table failed", zap.String("tableName", t.Name.O), zap.Error(err)) - if admin.ErrDataInConsistent.Equal(err) { - return ErrAdminCheckTable.GenWithStack("%v err:%v", t.Name, err) - } - return errors.Errorf("%v err:%v", t.Name, err) + select { + case <-e.exitCh: + return nil + default: } } - return nil + e.retCh <- errors.Trace(err) + return errors.Trace(err) } -func (e *CheckTableExec) doCheckPartitionedTable(tbl table.PartitionedTable) error { - info := tbl.Meta().GetPartitionInfo() - for _, def := range info.Definitions { - pid := def.ID - partition := tbl.GetPartition(pid) - if err := e.doCheckTable(partition); err != nil { - return err +func (e *CheckTableExec) handlePanic(r interface{}) { + if r != nil { + e.retCh <- errors.Errorf("%v", r) + } +} + +// Next implements the Executor Next interface. +func (e *CheckTableExec) Next(ctx context.Context, req *chunk.Chunk) error { + if e.done || len(e.srcs) == 0 { + return nil + } + defer func() { e.done = true }() + + idxNames := make([]string, 0, len(e.indexInfos)) + for _, idx := range e.indexInfos { + idxNames = append(idxNames, idx.Name.O) + } + greater, idxOffset, err := admin.CheckIndicesCount(e.ctx, e.dbName, e.table.Meta().Name.O, idxNames) + if err != nil { + if greater == admin.IdxCntGreater { + err = e.checkTableIndexHandle(ctx, e.indexInfos[idxOffset]) + } else if greater == admin.TblCntGreater { + err = e.checkTableRecord(idxOffset) + } + if err != nil && admin.ErrDataInConsistent.Equal(err) { + return ErrAdminCheckTable.GenWithStack("%v err:%v", e.table.Meta().Name, err) + } + return errors.Trace(err) + } + + // The number of table rows is equal to the number of index rows. + // TODO: Make the value of concurrency adjustable. And we can consider the number of records. + concurrency := 3 + wg := sync.WaitGroup{} + for i := range e.srcs { + wg.Add(1) + go func(num int) { + defer wg.Done() + util.WithRecovery(func() { + err1 := e.checkIndexHandle(ctx, e.srcs[num]) + if err1 != nil { + logutil.Logger(ctx).Info("check index handle failed", zap.Error(err)) + } + }, e.handlePanic) + }(i) + + if (i+1)%concurrency == 0 { + wg.Wait() + } + } + + for i := 0; i < len(e.srcs); i++ { + err = <-e.retCh + if err != nil { + return errors.Trace(err) } } return nil } -func (e *CheckTableExec) doCheckTable(tbl table.Table) error { +func (e *CheckTableExec) checkTableRecord(idxOffset int) error { + idxInfo := e.indexInfos[idxOffset] + // TODO: Fix me later, can not use genExprs in indexLookUpReader, because the schema of expression is different. + genExprs := e.srcs[idxOffset].genExprs txn, err := e.ctx.Txn(true) if err != nil { return err } - for _, idx := range tbl.Indices() { - if idx.Meta().State != model.StatePublic { - continue - } - err := admin.CompareIndexData(e.ctx, txn, tbl, idx, e.genExprs) - if err != nil { - return err + if e.table.Meta().GetPartitionInfo() == nil { + idx := tables.NewIndex(e.table.Meta().ID, e.table.Meta(), idxInfo) + return admin.CheckRecordAndIndex(e.ctx, txn, e.table, idx, genExprs) + } + + info := e.table.Meta().GetPartitionInfo() + for _, def := range info.Definitions { + pid := def.ID + partition := e.table.(table.PartitionedTable).GetPartition(pid) + idx := tables.NewIndex(def.ID, e.table.Meta(), idxInfo) + if err := admin.CheckRecordAndIndex(e.ctx, txn, partition, idx, genExprs); err != nil { + return errors.Trace(err) } } return nil @@ -542,19 +646,19 @@ func (e *CheckIndexExec) Close() error { } // Next implements the Executor Next interface. -func (e *CheckIndexExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *CheckIndexExec) Next(ctx context.Context, req *chunk.Chunk) error { if e.done { return nil } defer func() { e.done = true }() - err := admin.CheckIndicesCount(e.ctx, e.dbName, e.tableName, []string{e.idxName}) + _, _, err := admin.CheckIndicesCount(e.ctx, e.dbName, e.tableName, []string{e.idxName}) if err != nil { return err } - chk := e.src.newFirstChunk() + chk := newFirstChunk(e.src) for { - err := e.src.Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.src, chk) if err != nil { return err } @@ -589,7 +693,7 @@ func (e *ShowSlowExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *ShowSlowExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ShowSlowExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.cursor >= len(e.result) { return nil @@ -615,7 +719,7 @@ func (e *ShowSlowExec) Next(ctx context.Context, req *chunk.RecordBatch) error { req.AppendString(7, slow.User) req.AppendString(8, slow.DB) req.AppendString(9, slow.TableIDs) - req.AppendString(10, slow.IndexIDs) + req.AppendString(10, slow.IndexNames) if slow.Internal { req.AppendInt64(11, 1) } else { @@ -637,6 +741,7 @@ type SelectLockExec struct { baseExecutor Lock ast.SelectLockType + keys []kv.Key } // Open implements the Executor Open interface. @@ -646,7 +751,6 @@ func (e *SelectLockExec) Open(ctx context.Context) error { } txnCtx := e.ctx.GetSessionVars().TxnCtx - txnCtx.ForUpdate = true for id := range e.Schema().TblID2Handle { // This operation is only for schema validator check. txnCtx.UpdateDeltaForTable(id, 0, 0, map[int64]int64{}) @@ -655,14 +759,14 @@ func (e *SelectLockExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *SelectLockExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("selectLock.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() } req.GrowAndReset(e.maxChunkSize) - err := e.children[0].Next(ctx, req) + err := Next(ctx, e.children[0], req) if err != nil { return err } @@ -670,29 +774,29 @@ func (e *SelectLockExec) Next(ctx context.Context, req *chunk.RecordBatch) error if len(e.Schema().TblID2Handle) == 0 || e.Lock != ast.SelectLockForUpdate { return nil } - txn, err := e.ctx.Txn(true) - if err != nil { - return err - } - keys := make([]kv.Key, 0, req.NumRows()) - iter := chunk.NewIterator4Chunk(req.Chunk) - forUpdateTS := e.ctx.GetSessionVars().TxnCtx.GetForUpdateTS() - for id, cols := range e.Schema().TblID2Handle { - for _, col := range cols { - keys = keys[:0] - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - keys = append(keys, tablecodec.EncodeRowKeyWithHandle(id, row.GetInt64(col.Index))) - } - if len(keys) == 0 { - continue - } - err = txn.LockKeys(ctx, forUpdateTS, keys...) - if err != nil { - return err + if req.NumRows() != 0 { + iter := chunk.NewIterator4Chunk(req) + for id, cols := range e.Schema().TblID2Handle { + for _, col := range cols { + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + e.keys = append(e.keys, tablecodec.EncodeRowKeyWithHandle(id, row.GetInt64(col.Index))) + } } } + return nil } - return nil + return doLockKeys(ctx, e.ctx, e.keys...) +} + +func doLockKeys(ctx context.Context, se sessionctx.Context, keys ...kv.Key) error { + se.GetSessionVars().TxnCtx.ForUpdate = true + // Lock keys only once when finished fetching all results. + txn, err := se.Txn(true) + if err != nil { + return err + } + forUpdateTS := se.GetSessionVars().TxnCtx.GetForUpdateTS() + return txn.LockKeys(ctx, &se.GetSessionVars().Killed, forUpdateTS, keys...) } // LimitExec represents limit executor @@ -711,7 +815,7 @@ type LimitExec struct { } // Next implements the Executor Next interface. -func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *LimitExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("limit.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -727,7 +831,7 @@ func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error { for !e.meetFirstBatch { // transfer req's requiredRows to childResult and then adjust it in childResult e.childResult = e.childResult.SetRequiredRows(req.RequiredRows(), e.maxChunkSize) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.adjustRequiredRows(e.childResult))) + err := Next(ctx, e.children[0], e.adjustRequiredRows(e.childResult)) if err != nil { return err } @@ -751,8 +855,8 @@ func (e *LimitExec) Next(ctx context.Context, req *chunk.RecordBatch) error { } e.cursor += batchSize } - e.adjustRequiredRows(req.Chunk) - err := e.children[0].Next(ctx, req) + e.adjustRequiredRows(req) + err := Next(ctx, e.children[0], req) if err != nil { return err } @@ -774,7 +878,7 @@ func (e *LimitExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) e.cursor = 0 e.meetFirstBatch = e.begin == 0 return nil @@ -808,21 +912,25 @@ func init() { // While doing optimization in the plan package, we need to execute uncorrelated subquery, // but the plan package cannot import the executor package because of the dependency cycle. // So we assign a function implemented in the executor package to the plan package to avoid the dependency cycle. - plannercore.EvalSubquery = func(p plannercore.PhysicalPlan, is infoschema.InfoSchema, sctx sessionctx.Context) (rows [][]types.Datum, err error) { + plannercore.EvalSubquery = func(ctx context.Context, p plannercore.PhysicalPlan, is infoschema.InfoSchema, sctx sessionctx.Context) (rows [][]types.Datum, err error) { + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("executor.EvalSubQuery", opentracing.ChildOf(span.Context())) + defer span1.Finish() + } + e := &executorBuilder{is: is, ctx: sctx} exec := e.build(p) if e.err != nil { - return rows, err + return rows, e.err } - ctx := context.TODO() err = exec.Open(ctx) defer terror.Call(exec.Close) if err != nil { return rows, err } - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for { - err = exec.Next(ctx, chunk.NewRecordBatch(chk)) + err = Next(ctx, exec, chk) if err != nil { return rows, err } @@ -831,7 +939,7 @@ func init() { } iter := chunk.NewIterator4Chunk(chk) for r := iter.Begin(); r != iter.End(); r = iter.Next() { - row := r.GetDatumRow(exec.retTypes()) + row := r.GetDatumRow(retTypes(exec)) rows = append(rows, row) } chk = chunk.Renew(chk, sctx.GetSessionVars().MaxChunkSize) @@ -855,7 +963,7 @@ func (e *TableDualExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *TableDualExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *TableDualExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("tableDual.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -896,7 +1004,7 @@ func (e *SelectionExec) Open(ctx context.Context) error { if err := e.baseExecutor.Open(ctx); err != nil { return err } - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) e.batched = expression.Vectorizable(e.filters) if e.batched { e.selected = make([]bool, 0, chunk.InitialCapacity) @@ -914,7 +1022,7 @@ func (e *SelectionExec) Close() error { } // Next implements the Executor Next interface. -func (e *SelectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *SelectionExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("selection.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -926,7 +1034,7 @@ func (e *SelectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error req.GrowAndReset(e.maxChunkSize) if !e.batched { - return e.unBatchedNext(ctx, req.Chunk) + return e.unBatchedNext(ctx, req) } for { @@ -939,7 +1047,7 @@ func (e *SelectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error } req.AppendRow(e.inputRow) } - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], e.childResult) if err != nil { return err } @@ -971,7 +1079,7 @@ func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) err return nil } } - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], e.childResult) if err != nil { return err } @@ -997,7 +1105,7 @@ type TableScanExec struct { } // Next implements the Executor Next interface. -func (e *TableScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *TableScanExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("tableScan.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -1008,14 +1116,14 @@ func (e *TableScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error } req.GrowAndReset(e.maxChunkSize) if e.isVirtualTable { - return e.nextChunk4InfoSchema(ctx, req.Chunk) + return e.nextChunk4InfoSchema(ctx, req) } handle, found, err := e.nextHandle() if err != nil || !found { return err } - mutableRow := chunk.MutRowFromTypes(e.retTypes()) + mutableRow := chunk.MutRowFromTypes(retTypes(e)) for req.NumRows() < req.Capacity() { row, err := e.getRow(handle) if err != nil { @@ -1031,12 +1139,12 @@ func (e *TableScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error func (e *TableScanExec) nextChunk4InfoSchema(ctx context.Context, chk *chunk.Chunk) error { chk.GrowAndReset(e.maxChunkSize) if e.virtualTableChunkList == nil { - e.virtualTableChunkList = chunk.NewList(e.retTypes(), e.initCap, e.maxChunkSize) + e.virtualTableChunkList = chunk.NewList(retTypes(e), e.initCap, e.maxChunkSize) columns := make([]*table.Column, e.schema.Len()) for i, colInfo := range e.columns { columns[i] = table.ToColumn(colInfo) } - mutableRow := chunk.MutRowFromTypes(e.retTypes()) + mutableRow := chunk.MutRowFromTypes(retTypes(e)) err := e.t.IterRecords(e.ctx, nil, columns, func(h int64, rec []types.Datum, cols []*table.Column) (bool, error) { mutableRow.SetDatums(rec...) e.virtualTableChunkList.AppendRow(mutableRow.ToRow()) @@ -1105,7 +1213,7 @@ func (e *MaxOneRowExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("maxOneRow.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -1119,7 +1227,7 @@ func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error return nil } e.evaluated = true - err := e.children[0].Next(ctx, req) + err := Next(ctx, e.children[0], req) if err != nil { return err } @@ -1133,8 +1241,8 @@ func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.RecordBatch) error return errors.New("subquery returns more than 1 row") } - childChunk := e.children[0].newFirstChunk() - err = e.children[0].Next(ctx, chunk.NewRecordBatch(childChunk)) + childChunk := newFirstChunk(e.children[0]) + err = Next(ctx, e.children[0], childChunk) if err != nil { return err } @@ -1198,7 +1306,7 @@ func (e *UnionExec) Open(ctx context.Context) error { return err } for _, child := range e.children { - e.childrenResults = append(e.childrenResults, child.newFirstChunk()) + e.childrenResults = append(e.childrenResults, newFirstChunk(child)) } e.stopFetchData.Store(false) e.initialized = false @@ -1245,7 +1353,7 @@ func (e *UnionExec) resultPuller(ctx context.Context, childID int) { return case result.chk = <-e.resourcePools[childID]: } - result.err = e.children[childID].Next(ctx, chunk.NewRecordBatch(result.chk)) + result.err = Next(ctx, e.children[childID], result.chk) if result.err == nil && result.chk.NumRows() == 0 { return } @@ -1258,7 +1366,7 @@ func (e *UnionExec) resultPuller(ctx context.Context, childID int) { } // Next implements the Executor Next interface. -func (e *UnionExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *UnionExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("union.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -1307,19 +1415,29 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { } switch config.GetGlobalConfig().OOMAction { case config.OOMActionCancel: - sc.MemTracker.SetActionOnExceed(&memory.PanicOnExceed{}) + action := &memory.PanicOnExceed{ConnID: ctx.GetSessionVars().ConnectionID} + action.SetLogHook(domain.GetDomain(ctx).ExpensiveQueryHandle().LogOnQueryExceedMemQuota) + sc.MemTracker.SetActionOnExceed(action) case config.OOMActionLog: - sc.MemTracker.SetActionOnExceed(&memory.LogOnExceed{}) + fallthrough default: - sc.MemTracker.SetActionOnExceed(&memory.LogOnExceed{}) + action := &memory.LogOnExceed{ConnID: ctx.GetSessionVars().ConnectionID} + action.SetLogHook(domain.GetDomain(ctx).ExpensiveQueryHandle().LogOnQueryExceedMemQuota) + sc.MemTracker.SetActionOnExceed(action) } - if execStmt, ok := s.(*ast.ExecuteStmt); ok { s, err = getPreparedStmt(execStmt, vars) if err != nil { return } } + // execute missed stmtID uses empty sql + sc.OriginalSQL = s.Text() + if explainStmt, ok := s.(*ast.ExplainStmt); ok { + sc.InExplainStmt = true + sc.CastStrToIntStrict = true + s = explainStmt.Stmt + } // TODO: Many same bool variables here. // We should set only two variables ( // IgnoreErr and StrictSQLMode) to avoid setting the same bool variables and @@ -1380,6 +1498,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.NotFillCache = !opts.SQLCache } sc.PadCharToFullLength = ctx.GetSessionVars().SQLMode.HasPadCharToFullLengthMode() + sc.CastStrToIntStrict = true case *ast.ShowStmt: sc.IgnoreTruncate = true sc.IgnoreZeroInDate = true @@ -1388,7 +1507,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { sc.InShowWarning = true sc.SetWarnings(vars.StmtCtx.GetWarnings()) } - case *ast.SplitIndexRegionStmt: + case *ast.SplitRegionStmt: sc.IgnoreTruncate = false sc.IgnoreZeroInDate = true sc.AllowInvalidDate = vars.SQLMode.HasAllowInvalidDatesMode() @@ -1423,8 +1542,6 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { if err != nil { return err } - // execute missed stmtID uses empty sql - sc.OriginalSQL = s.Text() vars.StmtCtx = sc return } diff --git a/executor/executor_pkg_test.go b/executor/executor_pkg_test.go index b237236e9b31b..cf605132e92bd 100644 --- a/executor/executor_pkg_test.go +++ b/executor/executor_pkg_test.go @@ -36,25 +36,25 @@ type testExecSuite struct { // mockSessionManager is a mocked session manager which is used for test. type mockSessionManager struct { - PS []util.ProcessInfo + PS []*util.ProcessInfo } // ShowProcessList implements the SessionManager.ShowProcessList interface. -func (msm *mockSessionManager) ShowProcessList() map[uint64]util.ProcessInfo { - ret := make(map[uint64]util.ProcessInfo) +func (msm *mockSessionManager) ShowProcessList() map[uint64]*util.ProcessInfo { + ret := make(map[uint64]*util.ProcessInfo) for _, item := range msm.PS { ret[item.ID] = item } return ret } -func (msm *mockSessionManager) GetProcessInfo(id uint64) (util.ProcessInfo, bool) { +func (msm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { for _, item := range msm.PS { if item.ID == id { return item, true } } - return util.ProcessInfo{}, false + return &util.ProcessInfo{}, false } // Kill implements the SessionManager.Kill interface. @@ -70,8 +70,8 @@ func (s *testExecSuite) TestShowProcessList(c *C) { schema := buildSchema(names, ftypes) // Compose a mocked session manager. - ps := make([]util.ProcessInfo, 0, 1) - pi := util.ProcessInfo{ + ps := make([]*util.ProcessInfo, 0, 1) + pi := &util.ProcessInfo{ ID: 0, User: "test", Host: "127.0.0.1", @@ -98,17 +98,17 @@ func (s *testExecSuite) TestShowProcessList(c *C) { err := e.Open(ctx) c.Assert(err, IsNil) - chk := e.newFirstChunk() + chk := newFirstChunk(e) it := chunk.NewIterator4Chunk(chk) // Run test and check results. for _, p := range ps { - err = e.Next(context.Background(), chunk.NewRecordBatch(chk)) + err = e.Next(context.Background(), chk) c.Assert(err, IsNil) for row := it.Begin(); row != it.End(); row = it.Next() { c.Assert(row.GetUint64(0), Equals, p.ID) } } - err = e.Next(context.Background(), chunk.NewRecordBatch(chk)) + err = e.Next(context.Background(), chk) c.Assert(err, IsNil) c.Assert(chk.NumRows(), Equals, 0) err = e.Close() diff --git a/executor/executor_required_rows_test.go b/executor/executor_required_rows_test.go index 5cdd3cfe2898e..5883a904c09e9 100644 --- a/executor/executor_required_rows_test.go +++ b/executor/executor_required_rows_test.go @@ -66,7 +66,7 @@ func newRequiredRowsDataSource(ctx sessionctx.Context, totalRows int, expectedRo return &requiredRowsDataSource{baseExec, totalRows, 0, ctx, expectedRowsRet, 0, defaultGenerator} } -func (r *requiredRowsDataSource) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (r *requiredRowsDataSource) Next(ctx context.Context, req *chunk.Chunk) error { defer func() { if r.expectedRowsRet == nil { r.numNextCalled++ @@ -93,9 +93,9 @@ func (r *requiredRowsDataSource) Next(ctx context.Context, req *chunk.RecordBatc } func (r *requiredRowsDataSource) genOneRow() chunk.Row { - row := chunk.MutRowFromTypes(r.retTypes()) - for i := range r.retTypes() { - row.SetValue(i, r.generator(r.retTypes()[i])) + row := chunk.MutRowFromTypes(retTypes(r)) + for i, tp := range retTypes(r) { + row.SetValue(i, r.generator(tp)) } return row.ToRow() } @@ -177,10 +177,10 @@ func (s *testExecSuite) TestLimitRequiredRows(c *C) { ds := newRequiredRowsDataSource(sctx, testCase.totalRows, testCase.expectedRowsDS) exec := buildLimitExec(sctx, ds, testCase.limitOffset, testCase.limitCount) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], sctx.GetSessionVars().MaxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) } c.Assert(exec.Close(), IsNil) @@ -260,10 +260,10 @@ func (s *testExecSuite) TestSortRequiredRows(c *C) { } exec := buildSortExec(sctx, byItems, ds) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) } c.Assert(exec.Close(), IsNil) @@ -367,10 +367,10 @@ func (s *testExecSuite) TestTopNRequiredRows(c *C) { } exec := buildTopNExec(sctx, testCase.topNOffset, testCase.topNCount, byItems, ds) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) } c.Assert(exec.Close(), IsNil) @@ -460,10 +460,10 @@ func (s *testExecSuite) TestSelectionRequiredRows(c *C) { } exec := buildSelectionExec(sctx, filters, ds) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) } c.Assert(exec.Close(), IsNil) @@ -518,10 +518,10 @@ func (s *testExecSuite) TestProjectionUnparallelRequiredRows(c *C) { } exec := buildProjectionExec(sctx, exprs, ds, 0) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) } c.Assert(exec.Close(), IsNil) @@ -574,10 +574,10 @@ func (s *testExecSuite) TestProjectionParallelRequiredRows(c *C) { } exec := buildProjectionExec(sctx, exprs, ds, testCase.numWorkers) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) // wait projectionInputFetcher blocked on fetching data @@ -659,14 +659,15 @@ func (s *testExecSuite) TestStreamAggRequiredRows(c *C) { childCols := ds.Schema().Columns schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true) + aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true) + c.Assert(err, IsNil) aggFuncs := []*aggregation.AggFuncDesc{aggFunc} exec := buildStreamAggExecutor(sctx, ds, schema, aggFuncs, groupBy) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) } c.Assert(exec.Close(), IsNil) @@ -718,14 +719,15 @@ func (s *testExecSuite) TestHashAggParallelRequiredRows(c *C) { childCols := ds.Schema().Columns schema := expression.NewSchema(childCols...) groupBy := []expression.Expression{childCols[1]} - aggFunc := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct) + aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, hasDistinct) + c.Assert(err, IsNil) aggFuncs := []*aggregation.AggFuncDesc{aggFunc} exec := buildHashAggExecutor(sctx, ds, schema, aggFuncs, groupBy) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) } c.Assert(exec.Close(), IsNil) @@ -758,10 +760,10 @@ func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) { exec := buildMergeJoinExec(ctx, joinType, innerSrc, outerSrc) c.Assert(exec.Open(context.Background()), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range required { chk.SetRequiredRows(required[i], ctx.GetSessionVars().MaxChunkSize) - c.Assert(exec.Next(context.Background(), chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(context.Background(), chk), IsNil) } c.Assert(exec.Close(), IsNil) c.Assert(outerSrc.checkNumNextCalled(), IsNil) diff --git a/executor/executor_test.go b/executor/executor_test.go index d28a364aef4d3..246e58cf714d4 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -42,6 +42,7 @@ import ( "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/planner" plannercore "github.com/pingcap/tidb/planner/core" @@ -59,6 +60,7 @@ import ( "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/admin" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/logutil" @@ -87,10 +89,14 @@ var _ = Suite(&testSuite{}) var _ = Suite(&testSuite1{}) var _ = Suite(&testSuite2{}) var _ = Suite(&testSuite3{}) +var _ = Suite(&testSuite4{}) +var _ = SerialSuites(&testShowStatsSuite{testSuite{}}) var _ = Suite(&testBypassSuite{}) var _ = Suite(&testUpdateSuite{}) var _ = Suite(&testOOMSuite{}) var _ = Suite(&testPointGetSuite{}) +var _ = Suite(&testFlushSuite{}) +var _ = SerialSuites(&testShowStatsSuite{}) type testSuite struct { cluster *mocktikv.Cluster @@ -118,7 +124,7 @@ func (s *testSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() } d, err := session.BootstrapSession(s.store) c.Assert(err, IsNil) @@ -131,6 +137,19 @@ func (s *testSuite) TearDownSuite(c *C) { s.store.Close() } +func (s *testSuite) TestPessimisticSelectForUpdate(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id int primary key, a int)") + tk.MustExec("insert into t values(1, 1)") + tk.MustExec("begin PESSIMISTIC") + tk.MustQuery("select a from t where id=1 for update").Check(testkit.Rows("1")) + tk.MustExec("update t set a=a+1 where id=1") + tk.MustExec("commit") + tk.MustQuery("select a from t where id=1").Check(testkit.Rows("2")) +} + func (s *testSuite) TearDownTest(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -152,7 +171,7 @@ func (s *testSuite) TestAdmin(c *C) { // cancel DDL jobs test r, err := tk.Exec("admin cancel ddl jobs 1") c.Assert(err, IsNil, Commentf("err %v", err)) - req := r.NewRecordBatch() + req := r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) row := req.GetRow(0) @@ -163,7 +182,7 @@ func (s *testSuite) TestAdmin(c *C) { // show ddl test; r, err = tk.Exec("admin show ddl") c.Assert(err, IsNil) - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) row = req.GetRow(0) @@ -183,7 +202,7 @@ func (s *testSuite) TestAdmin(c *C) { c.Assert(row.GetString(2), Equals, serverInfo.IP+":"+ strconv.FormatUint(uint64(serverInfo.Port), 10)) c.Assert(row.GetString(3), Equals, "") - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsTrue) @@ -193,7 +212,7 @@ func (s *testSuite) TestAdmin(c *C) { // show DDL jobs test r, err = tk.Exec("admin show ddl jobs") c.Assert(err, IsNil) - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) row = req.GetRow(0) @@ -209,7 +228,7 @@ func (s *testSuite) TestAdmin(c *C) { r, err = tk.Exec("admin show ddl jobs 20") c.Assert(err, IsNil) - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) row = req.GetRow(0) @@ -225,9 +244,9 @@ func (s *testSuite) TestAdmin(c *C) { result.Check(testkit.Rows()) result = tk.MustQuery(`admin show ddl job queries 1, 2, 3, 4`) result.Check(testkit.Rows()) - historyJob, err := admin.GetHistoryDDLJobs(txn, admin.DefNumHistoryJobs) - result = tk.MustQuery(fmt.Sprintf("admin show ddl job queries %d", historyJob[0].ID)) - result.Check(testkit.Rows(historyJob[0].Query)) + historyJobs, err = admin.GetHistoryDDLJobs(txn, admin.DefNumHistoryJobs) + result = tk.MustQuery(fmt.Sprintf("admin show ddl job queries %d", historyJobs[0].ID)) + result.Check(testkit.Rows(historyJobs[0].Query)) c.Assert(err, IsNil) // check table test @@ -281,6 +300,33 @@ func (s *testSuite) TestAdmin(c *C) { tk.MustExec("ALTER TABLE t1 ADD COLUMN c4 bit(10) default 127;") tk.MustExec("ALTER TABLE t1 ADD INDEX idx3 (c4);") tk.MustExec("admin check table t1;") + + // Test for reverse scan get history ddl jobs when ddl history jobs queue has multiple regions. + txn, err = s.store.Begin() + c.Assert(err, IsNil) + historyJobs, err = admin.GetHistoryDDLJobs(txn, 20) + c.Assert(err, IsNil) + + // Split region for history ddl job queues. + m := meta.NewMeta(txn) + startKey := meta.DDLJobHistoryKey(m, 0) + endKey := meta.DDLJobHistoryKey(m, historyJobs[0].ID) + s.cluster.SplitKeys(s.mvccStore, startKey, endKey, int(historyJobs[0].ID/5)) + + historyJobs2, err := admin.GetHistoryDDLJobs(txn, 20) + c.Assert(err, IsNil) + c.Assert(historyJobs, DeepEquals, historyJobs2) +} + +func (s *testSuite) TestAdminChecksumOfPartitionedTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("USE test;") + tk.MustExec("DROP TABLE IF EXISTS admin_checksum_partition_test;") + tk.MustExec("CREATE TABLE admin_checksum_partition_test (a INT) PARTITION BY HASH(a) PARTITIONS 4;") + tk.MustExec("INSERT INTO admin_checksum_partition_test VALUES (1), (2);") + + r := tk.MustQuery("ADMIN CHECKSUM TABLE admin_checksum_partition_test;") + r.Check(testkit.Rows("test admin_checksum_partition_test 1 5 5")) } func (s *testSuite) fillData(tk *testkit.TestKit, table string) { @@ -311,7 +357,8 @@ func checkCases(tests []testCase, ld *executor.LoadDataInfo, ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true ctx.GetSessionVars().StmtCtx.InLoadDataStmt = true - data, reachLimit, err1 := ld.InsertData(tt.data1, tt.data2) + ctx.GetSessionVars().StmtCtx.InDeleteStmt = false + data, reachLimit, err1 := ld.InsertData(context.Background(), tt.data1, tt.data2) c.Assert(err1, IsNil) c.Assert(reachLimit, IsFalse) if tt.restData == nil { @@ -818,7 +865,7 @@ func (s *testSuite) TestIssue2612(c *C) { tk.MustExec(`insert into t values ('2016-02-13 15:32:24', '2016-02-11 17:23:22');`) rs, err := tk.Exec(`select timediff(finish_at, create_at) from t;`) c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(context.Background(), req) c.Assert(err, IsNil) c.Assert(req.GetRow(0).GetDuration(0, 0).String(), Equals, "-46:09:02") @@ -1812,6 +1859,9 @@ func (s *testSuite) TestTableDual(c *C) { result.Check(testkit.Rows("1")) result = tk.MustQuery("Select 1 from dual where 1") result.Check(testkit.Rows("1")) + + tk.MustExec("create table t(a int primary key)") + tk.MustQuery("select t1.* from t t1, t t2 where t1.a=t2.a and 1=0").Check(testkit.Rows()) } func (s *testSuite) TestTableScan(c *C) { @@ -1868,7 +1918,7 @@ func (s *testSuite) TestIsPointGet(c *C) { c.Check(err, IsNil) err = plannercore.Preprocess(ctx, stmtNode, infoSchema) c.Check(err, IsNil) - p, err := planner.Optimize(ctx, stmtNode, infoSchema) + p, err := planner.Optimize(context.TODO(), ctx, stmtNode, infoSchema) c.Check(err, IsNil) ret, err := executor.IsPointGetWithPKOrUniqueKeyByAutoCommit(ctx, p) c.Assert(err, IsNil) @@ -1912,6 +1962,24 @@ func (s *testSuite) TestPointGetRepeatableRead(c *C) { c.Assert(failpoint.Disable(step2), IsNil) } +func (s *testSuite) TestSplitRegionTimeout(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/MockSplitRegionTimeout", `return(true)`), IsNil) + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a varchar(100),b int, index idx1(b,a))") + tk.MustExec(`split table t index idx1 by (10000,"abcd"),(10000000);`) + tk.MustExec(`set @@tidb_wait_split_region_timeout=1`) + // result 0 0 means split 0 region and 0 region finish scatter regions before timeout. + tk.MustQuery(`split table t between (0) and (10000) regions 10`).Check(testkit.Rows("0 0")) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/MockSplitRegionTimeout"), IsNil) + + // Test scatter regions timeout. + c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/MockScatterRegionTimeout", `return(true)`), IsNil) + tk.MustQuery(`split table t between (0) and (10000) regions 10`).Check(testkit.Rows("10 1")) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/MockScatterRegionTimeout"), IsNil) +} + func (s *testSuite) TestRow(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -2040,6 +2108,13 @@ func (s *testSuite) TestColumnName(c *C) { c.Assert(fields[1].ColumnAsName.L, Equals, "num") tk.MustExec("set @@tidb_enable_window_function = 0") rs.Close() + + rs, err = tk.Exec("select if(1,c,c) from t;") + c.Check(err, IsNil) + fields = rs.Fields() + c.Assert(fields[0].Column.Name.L, Equals, "if(1,c,c)") + // It's a compatibility issue. Should be empty instead. + c.Assert(fields[0].ColumnAsName.L, Equals, "if(1,c,c)") } func (s *testSuite) TestSelectVar(c *C) { @@ -2751,7 +2826,7 @@ func (s *testSuite) TestBit(c *C) { c.Assert(err, NotNil) r, err := tk.Exec("select * from t where c1 = 2") c.Assert(err, IsNil) - req := r.NewRecordBatch() + req := r.NewChunk() err = r.Next(context.Background(), req) c.Assert(err, IsNil) c.Assert(types.BinaryLiteral(req.GetRow(0).GetBytes(0)), DeepEquals, types.NewBinaryLiteralFromUint(2, -1)) @@ -2985,7 +3060,7 @@ func (s *testSuite) TestCheckIndex(c *C) { c.Assert(err, IsNil) _, err = se.Execute(context.Background(), "admin check index t c") c.Assert(err, NotNil) - c.Assert(strings.Contains(err.Error(), "isn't equal to value count"), IsTrue) + c.Assert(err.Error(), Equals, "handle 3, index:types.Datum{k:0x1, collation:0x0, decimal:0x0, length:0x0, i:30, b:[]uint8(nil), x:interface {}(nil)} != record:") // set data to: // index data (handle, data): (1, 10), (2, 20), (3, 30), (4, 40) @@ -3380,7 +3455,7 @@ func (s *testSuite3) TestMaxOneRow(c *C) { rs, err := tk.Exec(`select (select t1.a from t1 where t1.a > t2.a) as a from t2;`) c.Assert(err, IsNil) - err = rs.Next(context.TODO(), rs.NewRecordBatch()) + err = rs.Next(context.TODO(), rs.NewChunk()) c.Assert(err.Error(), Equals, "subquery returns more than 1 row") err = rs.Close() @@ -3511,6 +3586,13 @@ func (s *testSuite3) TestSelectPartition(c *C) { c.Assert(err.Error(), Equals, "[table:1735]Unknown partition 'p4' in table 'th'") err = tk.ExecToErr("select b from tr partition (r1,r4)") c.Assert(err.Error(), Equals, "[table:1735]Unknown partition 'r4' in table 'tr'") + + // test select partition table in transaction. + tk.MustExec("begin") + tk.MustExec("insert into th values (10,10),(11,11)") + tk.MustQuery("select a, b from th where b>10").Check(testkit.Rows("11 11")) + tk.MustExec("commit") + tk.MustQuery("select a, b from th where b>10").Check(testkit.Rows("11 11")) } func (s *testSuite) TestSelectView(c *C) { @@ -3527,11 +3609,11 @@ func (s *testSuite) TestSelectView(c *C) { tk.MustExec("drop table view_t;") tk.MustExec("create table view_t(c int,d int)") err := tk.ExecToErr("select * from view1") - c.Assert(err.Error(), Equals, plannercore.ErrViewInvalid.GenWithStackByArgs("test", "view1").Error()) + c.Assert(err.Error(), Equals, "[planner:1356]View 'test.view1' references invalid table(s) or column(s) or function(s) or definer/invoker of view lack rights to use them") err = tk.ExecToErr("select * from view2") - c.Assert(err.Error(), Equals, plannercore.ErrViewInvalid.GenWithStackByArgs("test", "view2").Error()) + c.Assert(err.Error(), Equals, "[planner:1356]View 'test.view2' references invalid table(s) or column(s) or function(s) or definer/invoker of view lack rights to use them") err = tk.ExecToErr("select * from view3") - c.Assert(err.Error(), Equals, "[planner:1054]Unknown column 'a' in 'field list'") + c.Assert(err.Error(), Equals, plannercore.ErrViewInvalid.GenWithStackByArgs("test", "view3").Error()) tk.MustExec("drop table view_t;") tk.MustExec("create table view_t(a int,b int,c int)") tk.MustExec("insert into view_t values(1,2,3)") @@ -3583,7 +3665,7 @@ func (s *testSuite2) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() } d, err := session.BootstrapSession(s.store) c.Assert(err, IsNil) @@ -3634,7 +3716,7 @@ func (s *testSuite3) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() } d, err := session.BootstrapSession(s.store) c.Assert(err, IsNil) @@ -3685,7 +3767,7 @@ func (s *testSuite4) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() } d, err := session.BootstrapSession(s.store) c.Assert(err, IsNil) @@ -3714,6 +3796,9 @@ func (s *testSuite4) TearDownTest(c *C) { func (s *testSuite) TestStrToDateBuiltin(c *C) { tk := testkit.NewTestKit(c, s.store) + tk.MustQuery(`select str_to_date('20190101','%Y%m%d%!') from dual`).Check(testkit.Rows("2019-01-01")) + tk.MustQuery(`select str_to_date('20190101','%Y%m%d%f') from dual`).Check(testkit.Rows("2019-01-01 00:00:00.000000")) + tk.MustQuery(`select str_to_date('20190101','%Y%m%d%H%i%s') from dual`).Check(testkit.Rows("2019-01-01 00:00:00")) tk.MustQuery(`select str_to_date('18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22")) tk.MustQuery(`select str_to_date('a18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("")) tk.MustQuery(`select str_to_date('69/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2069-10-22")) @@ -3769,16 +3854,247 @@ func (s *testSuite) TestReadPartitionedTable(c *C) { tk.MustQuery("select a from pt where b = 3").Check(testkit.Rows("3")) } -func (s *testSuite) TestSplitIndexRegion(c *C) { +func (s *testSuite) TestSplitRegion(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t, t1") tk.MustExec("create table t(a varchar(100),b int, index idx1(b,a))") tk.MustExec(`split table t index idx1 by (10000,"abcd"),(10000000);`) _, err := tk.Exec(`split table t index idx1 by ("abcd");`) c.Assert(err, NotNil) terr := errors.Cause(err).(*terror.Error) c.Assert(terr.Code(), Equals, terror.ErrCode(mysql.WarnDataTruncated)) + + // Test for split index region. + // Check min value is more than max value. + tk.MustExec(`split table t index idx1 between (0) and (1000000000) regions 10`) + _, err = tk.Exec(`split table t index idx1 between (2,'a') and (1,'c') regions 10`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split index `idx1` region lower value (2,a) should less than the upper value (1,c)") + + // Check min value is invalid. + _, err = tk.Exec(`split table t index idx1 between () and (1) regions 10`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split index `idx1` region lower value count should more than 0") + + // Check max value is invalid. + _, err = tk.Exec(`split table t index idx1 between (1) and () regions 10`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split index `idx1` region upper value count should more than 0") + + // Check pre-split region num is too large. + _, err = tk.Exec(`split table t index idx1 between (0) and (1000000000) regions 10000`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split index region num exceeded the limit 1000") + + // Check pre-split region num 0 is invalid. + _, err = tk.Exec(`split table t index idx1 between (0) and (1000000000) regions 0`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split index region num should more than 0") + + // Test truncate error msg. + _, err = tk.Exec(`split table t index idx1 between ("aa") and (1000000000) regions 0`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[types:1265]Incorrect value: 'aa' for column 'b'") + + // Test for split table region. + tk.MustExec(`split table t between (0) and (1000000000) regions 10`) + // Check the lower value is more than the upper value. + _, err = tk.Exec(`split table t between (2) and (1) regions 10`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split table `t` region lower value 2 should less than the upper value 1") + + // Check the lower value is invalid. + _, err = tk.Exec(`split table t between () and (1) regions 10`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split table region lower value count should be 1") + + // Check upper value is invalid. + _, err = tk.Exec(`split table t between (1) and () regions 10`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split table region upper value count should be 1") + + // Check pre-split region num is too large. + _, err = tk.Exec(`split table t between (0) and (1000000000) regions 10000`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split table region num exceeded the limit 1000") + + // Check pre-split region num 0 is invalid. + _, err = tk.Exec(`split table t between (0) and (1000000000) regions 0`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split table region num should more than 0") + + // Test truncate error msg. + _, err = tk.Exec(`split table t between ("aa") and (1000000000) regions 10`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "[types:1265]Incorrect value: 'aa' for column '_tidb_rowid'") + + // Test split table region step is too small. + _, err = tk.Exec(`split table t between (0) and (100) regions 10`) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "Split table `t` region step value should more than 1000, step 10 is invalid") + + // Test split region by syntax. + tk.MustExec(`split table t by (0),(1000),(1000000)`) + + // Test split region twice to test for multiple batch split region requests. + tk.MustExec("create table t1(a int, b int)") + tk.MustQuery("split table t1 between(0) and (10000) regions 10;").Check(testkit.Rows("9 1")) + tk.MustQuery("split table t1 between(10) and (10010) regions 5;").Check(testkit.Rows("4 1")) +} + +func (s *testSuite) TestShowTableRegion(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t_regions") + tk.MustExec("create table t_regions (a int key, b int, c int, index idx(b), index idx2(c))") + + // Test show table regions. + tk.MustQuery(`split table t_regions between (-10000) and (10000) regions 4;`).Check(testkit.Rows("4 1")) + re := tk.MustQuery("show table t_regions regions") + rows := re.Rows() + // Table t_regions should have 5 regions now. + // 4 regions to store record data. + // 1 region to store index data. + c.Assert(len(rows), Equals, 5) + c.Assert(len(rows[0]), Equals, 11) + tbl := testGetTableByName(c, tk.Se, "test", "t_regions") + // Check the region start key. + c.Assert(rows[0][1], Equals, fmt.Sprintf("t_%d_r", tbl.Meta().ID)) + c.Assert(rows[1][1], Equals, fmt.Sprintf("t_%d_r_-5000", tbl.Meta().ID)) + c.Assert(rows[2][1], Equals, fmt.Sprintf("t_%d_r_0", tbl.Meta().ID)) + c.Assert(rows[3][1], Equals, fmt.Sprintf("t_%d_r_5000", tbl.Meta().ID)) + c.Assert(rows[4][2], Equals, fmt.Sprintf("t_%d_r", tbl.Meta().ID)) + + // Test show table index regions. + tk.MustQuery(`split table t_regions index idx between (-1000) and (1000) regions 4;`).Check(testkit.Rows("5 1")) + re = tk.MustQuery("show table t_regions index idx regions") + rows = re.Rows() + // The index `idx` of table t_regions should have 4 regions now. + c.Assert(len(rows), Equals, 4) + // Check the region start key. + c.Assert(rows[0][1], Equals, fmt.Sprintf("t_%d_i_1_", tbl.Meta().ID)) + c.Assert(rows[1][1], Matches, fmt.Sprintf("t_%d_i_1_.*", tbl.Meta().ID)) + c.Assert(rows[2][1], Matches, fmt.Sprintf("t_%d_i_1_.*", tbl.Meta().ID)) + c.Assert(rows[3][1], Matches, fmt.Sprintf("t_%d_i_1_.*", tbl.Meta().ID)) + + re = tk.MustQuery("show table t_regions regions") + rows = re.Rows() + // The index `idx` of table t_regions should have 9 regions now. + // 4 regions to store record data. + // 4 region to store index idx data. + // 1 region to store index idx2 data. + c.Assert(len(rows), Equals, 9) + // Check the region start key. + c.Assert(rows[0][1], Equals, fmt.Sprintf("t_%d_r", tbl.Meta().ID)) + c.Assert(rows[1][1], Equals, fmt.Sprintf("t_%d_r_-5000", tbl.Meta().ID)) + c.Assert(rows[2][1], Equals, fmt.Sprintf("t_%d_r_0", tbl.Meta().ID)) + c.Assert(rows[3][1], Equals, fmt.Sprintf("t_%d_r_5000", tbl.Meta().ID)) + c.Assert(rows[4][1], Matches, fmt.Sprintf("t_%d_i_1_.*", tbl.Meta().ID)) + c.Assert(rows[5][1], Matches, fmt.Sprintf("t_%d_i_1_.*", tbl.Meta().ID)) + c.Assert(rows[6][1], Matches, fmt.Sprintf("t_%d_i_1_.*", tbl.Meta().ID)) + c.Assert(rows[7][2], Equals, fmt.Sprintf("t_%d_i_2_", tbl.Meta().ID)) + c.Assert(rows[8][2], Equals, fmt.Sprintf("t_%d_r", tbl.Meta().ID)) + + // Test unsigned primary key and wait scatter finish. + tk.MustExec("drop table if exists t_regions") + tk.MustExec("create table t_regions (a int unsigned key, b int, index idx(b))") + + // Test show table regions. + tk.MustExec(`set @@session.tidb_wait_split_region_finish=1;`) + tk.MustQuery(`split table t_regions by (2500),(5000),(7500);`).Check(testkit.Rows("3 1")) + re = tk.MustQuery("show table t_regions regions") + rows = re.Rows() + // Table t_regions should have 4 regions now. + c.Assert(len(rows), Equals, 4) + tbl = testGetTableByName(c, tk.Se, "test", "t_regions") + // Check the region start key. + c.Assert(rows[0][1], Matches, "t_.*") + c.Assert(rows[1][1], Equals, fmt.Sprintf("t_%d_r_2500", tbl.Meta().ID)) + c.Assert(rows[2][1], Equals, fmt.Sprintf("t_%d_r_5000", tbl.Meta().ID)) + c.Assert(rows[3][1], Equals, fmt.Sprintf("t_%d_r_7500", tbl.Meta().ID)) + + // Test show table index regions. + tk.MustQuery(`split table t_regions index idx by (250),(500),(750);`).Check(testkit.Rows("4 1")) + re = tk.MustQuery("show table t_regions index idx regions") + rows = re.Rows() + // The index `idx` of table t_regions should have 4 regions now. + c.Assert(len(rows), Equals, 4) + // Check the region start key. + c.Assert(rows[0][1], Equals, fmt.Sprintf("t_%d_i_1_", tbl.Meta().ID)) + c.Assert(rows[1][1], Matches, fmt.Sprintf("t_%d_i_1_.*", tbl.Meta().ID)) + c.Assert(rows[2][1], Matches, fmt.Sprintf("t_%d_i_1_.*", tbl.Meta().ID)) + c.Assert(rows[3][1], Matches, fmt.Sprintf("t_%d_i_1_.*", tbl.Meta().ID)) + + // Test show table regions for partition table when disable split region when create table. + atomic.StoreUint32(&ddl.EnableSplitTableRegion, 0) + tk.MustExec("drop table if exists partition_t;") + tk.MustExec("set @@session.tidb_enable_table_partition = '1';") + tk.MustExec("create table partition_t (a int, b int,index(a)) partition by hash (a) partitions 3") + re = tk.MustQuery("show table partition_t regions") + rows = re.Rows() + c.Assert(len(rows), Equals, 1) + c.Assert(rows[0][1], Matches, "t_.*") + + // Test show table regions for partition table when enable split region when create table. + atomic.StoreUint32(&ddl.EnableSplitTableRegion, 1) + tk.MustExec("set @@global.tidb_scatter_region=1;") + tk.MustExec("drop table if exists partition_t;") + tk.MustExec("create table partition_t (a int, b int,index(a)) partition by hash (a) partitions 3") + re = tk.MustQuery("show table partition_t regions") + rows = re.Rows() + c.Assert(len(rows), Equals, 3) + tbl = testGetTableByName(c, tk.Se, "test", "partition_t") + partitionDef := tbl.Meta().GetPartitionInfo().Definitions + c.Assert(rows[0][1], Matches, fmt.Sprintf("t_%d_.*", partitionDef[0].ID)) + c.Assert(rows[1][1], Matches, fmt.Sprintf("t_%d_.*", partitionDef[1].ID)) + c.Assert(rows[2][1], Matches, fmt.Sprintf("t_%d_.*", partitionDef[2].ID)) + + // Test pre-split table region when create table. + tk.MustExec("drop table if exists t_pre") + tk.MustExec("create table t_pre (a int, b int) shard_row_id_bits = 2 pre_split_regions=2;") + re = tk.MustQuery("show table t_pre regions") + rows = re.Rows() + // Table t_regions should have 4 regions now. + c.Assert(len(rows), Equals, 4) + tbl = testGetTableByName(c, tk.Se, "test", "t_pre") + c.Assert(rows[1][1], Equals, fmt.Sprintf("t_%d_r_2305843009213693952", tbl.Meta().ID)) + c.Assert(rows[2][1], Equals, fmt.Sprintf("t_%d_r_4611686018427387904", tbl.Meta().ID)) + c.Assert(rows[3][1], Equals, fmt.Sprintf("t_%d_r_6917529027641081856", tbl.Meta().ID)) + atomic.StoreUint32(&ddl.EnableSplitTableRegion, 0) +} + +func (s *testSuite) TestChangePumpAndDrainer(c *C) { + tk := testkit.NewTestKit(c, s.store) + // change pump or drainer's state need connect to etcd + // so will meet error "URL scheme must be http, https, unix, or unixs: /tmp/tidb" + err := tk.ExecToErr("change pump to node_state ='paused' for node_id 'pump1'") + c.Assert(err, ErrorMatches, "URL scheme must be http, https, unix, or unixs.*") + err = tk.ExecToErr("change drainer to node_state ='paused' for node_id 'drainer1'") + c.Assert(err, ErrorMatches, "URL scheme must be http, https, unix, or unixs.*") +} + +func testGetTableByName(c *C, ctx sessionctx.Context, db, table string) table.Table { + dom := domain.GetDomain(ctx) + // Make sure the table schema is the new schema. + err := dom.Reload() + c.Assert(err, IsNil) + tbl, err := dom.InfoSchema().TableByName(model.NewCIStr(db), model.NewCIStr(table)) + c.Assert(err, IsNil) + return tbl +} + +func (s *testSuite) TestIssue10435(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t1(i int, j int, k int)") + tk.MustExec("insert into t1 VALUES (1,1,1),(2,2,2),(3,3,3),(4,4,4)") + tk.MustExec("INSERT INTO t1 SELECT 10*i,j,5*j FROM t1 UNION SELECT 20*i,j,5*j FROM t1 UNION SELECT 30*i,j,5*j FROM t1") + + tk.MustExec("set @@session.tidb_enable_window_function=1") + tk.MustQuery("SELECT SUM(i) OVER W FROM t1 WINDOW w AS (PARTITION BY j ORDER BY i) ORDER BY 1+SUM(i) OVER w").Check( + testkit.Rows("1", "2", "3", "4", "11", "22", "31", "33", "44", "61", "62", "93", "122", "124", "183", "244"), + ) } func (s *testSuite) TestUnsignedFeedback(c *C) { @@ -3855,6 +4171,44 @@ func (s *testOOMSuite) TestDistSQLMemoryControl(c *C) { tk.Se.GetSessionVars().MemQuotaDistSQL = -1 } +func setOOMAction(action string) { + newConf := config.NewConfig() + newConf.OOMAction = action + config.StoreGlobalConfig(newConf) +} + +func (s *testSuite) TestOOMPanicAction(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (a int primary key, b double);") + tk.MustExec("insert into t values (1,1)") + sm := &mockSessionManager1{ + PS: make([]*util.ProcessInfo, 0), + } + tk.Se.SetSessionManager(sm) + s.domain.ExpensiveQueryHandle().SetSessionManager(sm) + orgAction := config.GetGlobalConfig().OOMAction + setOOMAction(config.OOMActionCancel) + defer func() { + setOOMAction(orgAction) + }() + tk.MustExec("set @@tidb_mem_quota_query=1;") + err := tk.QueryToErr("select sum(b) from t group by a;") + c.Assert(err, NotNil) + c.Assert(err.Error(), Matches, "Out Of Memory Quota!.*") + + // Test insert from select oom panic. + tk.MustExec("drop table if exists t,t1") + tk.MustExec("create table t (a bigint);") + tk.MustExec("create table t1 (a bigint);") + tk.MustExec("insert into t1 values (1),(2),(3),(4),(5);") + tk.MustExec("set @@tidb_mem_quota_query=200;") + _, err = tk.Exec("insert into t select a from t1 order by a desc;") + c.Assert(err, NotNil) + c.Assert(err.Error(), Matches, "Out Of Memory Quota!.*") +} + type oomCapturer struct { zapcore.Core tracker string diff --git a/executor/explain.go b/executor/explain.go index 61ced6d564b62..8dca1e894ba41 100644 --- a/executor/explain.go +++ b/executor/explain.go @@ -46,7 +46,7 @@ func (e *ExplainExec) Close() error { } // Next implements the Executor Next interface. -func (e *ExplainExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ExplainExec) Next(ctx context.Context, req *chunk.Chunk) error { if e.rows == nil { var err error e.rows, err = e.generateExplainInfo(ctx) @@ -72,9 +72,9 @@ func (e *ExplainExec) Next(ctx context.Context, req *chunk.RecordBatch) error { func (e *ExplainExec) generateExplainInfo(ctx context.Context) ([][]string, error) { if e.analyzeExec != nil { - chk := e.analyzeExec.newFirstChunk() + chk := newFirstChunk(e.analyzeExec) for { - err := e.analyzeExec.Next(ctx, chunk.NewRecordBatch(chk)) + err := e.analyzeExec.Next(ctx, chk) if err != nil { return nil, err } diff --git a/executor/explain_test.go b/executor/explain_test.go new file mode 100644 index 0000000000000..426864d652c13 --- /dev/null +++ b/executor/explain_test.go @@ -0,0 +1,140 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + "fmt" + "strings" + + . "github.com/pingcap/check" + "github.com/pingcap/parser/auth" + plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/util/testkit" +) + +func (s *testSuite1) TestExplainPriviliges(c *C) { + se, err := session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + c.Assert(se.Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil), IsTrue) + tk := testkit.NewTestKit(c, s.store) + tk.Se = se + + tk.MustExec("create database explaindatabase") + tk.MustExec("use explaindatabase") + tk.MustExec("create table t (id int)") + tk.MustExec("create view v as select * from t") + tk.MustExec(`create user 'explain'@'%'`) + tk.MustExec(`flush privileges`) + + tk1 := testkit.NewTestKit(c, s.store) + se, err = session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + c.Assert(se.Auth(&auth.UserIdentity{Username: "explain", Hostname: "%"}, nil, nil), IsTrue) + tk1.Se = se + + tk.MustExec(`grant select on explaindatabase.v to 'explain'@'%'`) + tk.MustExec(`flush privileges`) + tk1.MustQuery("show databases").Check(testkit.Rows("INFORMATION_SCHEMA", "explaindatabase")) + + tk1.MustExec("use explaindatabase") + tk1.MustQuery("select * from v") + err = tk1.ExecToErr("explain select * from v") + c.Assert(err.Error(), Equals, plannercore.ErrViewNoExplain.Error()) + + tk.MustExec(`grant show view on explaindatabase.v to 'explain'@'%'`) + tk.MustExec(`flush privileges`) + tk1.MustQuery("explain select * from v") + + tk.MustExec(`revoke select on explaindatabase.v from 'explain'@'%'`) + tk.MustExec(`flush privileges`) + + err = tk1.ExecToErr("explain select * from v") + c.Assert(err.Error(), Equals, plannercore.ErrTableaccessDenied.GenWithStackByArgs("SELECT", "explain", "%", "v").Error()) +} + +func (s *testSuite1) TestExplainCartesianJoin(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (v int)") + + cases := []struct { + sql string + isCartesianJoin bool + }{ + {"explain select * from t t1, t t2", true}, + {"explain select * from t t1 where exists (select 1 from t t2 where t2.v > t1.v)", true}, + {"explain select * from t t1 where exists (select 1 from t t2 where t2.v in (t1.v+1, t1.v+2))", true}, + {"explain select * from t t1, t t2 where t1.v = t2.v", false}, + } + for _, ca := range cases { + rows := tk.MustQuery(ca.sql).Rows() + ok := false + for _, row := range rows { + str := fmt.Sprintf("%v", row) + if strings.Contains(str, "CARTESIAN") { + ok = true + } + } + + c.Assert(ok, Equals, ca.isCartesianJoin) + } +} + +func (s *testSuite1) TestExplainAnalyzeMemory(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (v int, k int, key(k))") + tk.MustExec("insert into t values (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)") + + s.checkMemoryInfo(c, tk, "explain analyze select * from t order by v") + s.checkMemoryInfo(c, tk, "explain analyze select * from t order by v limit 5") + s.checkMemoryInfo(c, tk, "explain analyze select /*+ TIDB_HJ(t1, t2) */ t1.k from t t1, t t2 where t1.v = t2.v+1") + s.checkMemoryInfo(c, tk, "explain analyze select /*+ TIDB_SMJ(t1, t2) */ t1.k from t t1, t t2 where t1.k = t2.k+1") + s.checkMemoryInfo(c, tk, "explain analyze select /*+ TIDB_INLJ(t1, t2) */ t1.k from t t1, t t2 where t1.k = t2.k and t1.v=1") + s.checkMemoryInfo(c, tk, "explain analyze select sum(k) from t group by v") + s.checkMemoryInfo(c, tk, "explain analyze select sum(v) from t group by k") + s.checkMemoryInfo(c, tk, "explain analyze select * from t") + s.checkMemoryInfo(c, tk, "explain analyze select k from t use index(k)") + s.checkMemoryInfo(c, tk, "explain analyze select * from t use index(k)") +} + +func (s *testSuite1) checkMemoryInfo(c *C, tk *testkit.TestKit, sql string) { + memCol := 5 + ops := []string{"Join", "Reader", "Top", "Sort", "LookUp"} + rows := tk.MustQuery(sql).Rows() + for _, row := range rows { + strs := make([]string, len(row)) + for i, c := range row { + strs[i] = c.(string) + } + if strings.Contains(strs[2], "cop") { + continue + } + + shouldHasMem := false + for _, op := range ops { + if strings.Contains(strs[0], op) { + shouldHasMem = true + break + } + } + + if shouldHasMem { + c.Assert(strs[memCol], Not(Equals), "N/A") + } else { + c.Assert(strs[memCol], Equals, "N/A") + } + } +} diff --git a/executor/explainfor_test.go b/executor/explainfor_test.go index 04a64ce3ffc2a..632874180495b 100644 --- a/executor/explainfor_test.go +++ b/executor/explainfor_test.go @@ -25,25 +25,25 @@ import ( // mockSessionManager is a mocked session manager which is used for test. type mockSessionManager1 struct { - PS []util.ProcessInfo + PS []*util.ProcessInfo } // ShowProcessList implements the SessionManager.ShowProcessList interface. -func (msm *mockSessionManager1) ShowProcessList() map[uint64]util.ProcessInfo { - ret := make(map[uint64]util.ProcessInfo) +func (msm *mockSessionManager1) ShowProcessList() map[uint64]*util.ProcessInfo { + ret := make(map[uint64]*util.ProcessInfo) for _, item := range msm.PS { ret[item.ID] = item } return ret } -func (msm *mockSessionManager1) GetProcessInfo(id uint64) (util.ProcessInfo, bool) { +func (msm *mockSessionManager1) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { for _, item := range msm.PS { if item.ID == id { return item, true } } - return util.ProcessInfo{}, false + return &util.ProcessInfo{}, false } // Kill implements the SessionManager.Kill interface. @@ -62,7 +62,7 @@ func (s *testSuite) TestExplainFor(c *C) { tkRoot.MustQuery("select * from t1;") tkRootProcess := tkRoot.Se.ShowProcess() - ps := []util.ProcessInfo{tkRootProcess} + ps := []*util.ProcessInfo{tkRootProcess} tkRoot.Se.SetSessionManager(&mockSessionManager1{PS: ps}) tkUser.Se.SetSessionManager(&mockSessionManager1{PS: ps}) tkRoot.MustQuery(fmt.Sprintf("explain for connection %d", tkRootProcess.ID)).Check(testkit.Rows( @@ -75,7 +75,27 @@ func (s *testSuite) TestExplainFor(c *C) { c.Check(core.ErrNoSuchThread.Equal(err), IsTrue) tkRootProcess.Plan = nil - ps = []util.ProcessInfo{tkRootProcess} + ps = []*util.ProcessInfo{tkRootProcess} tkRoot.Se.SetSessionManager(&mockSessionManager1{PS: ps}) tkRoot.MustExec(fmt.Sprintf("explain for connection %d", tkRootProcess.ID)) } + +func (s *testSuite) TestIssue11124(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("create table kankan1(id int, name text);") + tk.MustExec("create table kankan2(id int, h1 text);") + tk.MustExec("insert into kankan1 values(1, 'a'), (2, 'a');") + tk.MustExec("insert into kankan2 values(2, 'z');") + tk.MustQuery("select t1.id from kankan1 t1 left join kankan2 t2 on t1.id = t2.id where (case when t1.name='b' then 'case2' when t1.name='a' then 'case1' else NULL end) = 'case1'") + tkRootProcess := tk.Se.ShowProcess() + ps := []*util.ProcessInfo{tkRootProcess} + tk.Se.SetSessionManager(&mockSessionManager1{PS: ps}) + tk2.Se.SetSessionManager(&mockSessionManager1{PS: ps}) + + rs := tk.MustQuery("explain select t1.id from kankan1 t1 left join kankan2 t2 on t1.id = t2.id where (case when t1.name='b' then 'case2' when t1.name='a' then 'case1' else NULL end) = 'case1'").Rows() + rs2 := tk2.MustQuery(fmt.Sprintf("explain for connection %d", tkRootProcess.ID)).Rows() + for i := range rs { + c.Assert(rs[i], DeepEquals, rs2[i]) + } +} diff --git a/executor/grant.go b/executor/grant.go index cd3d8065b6591..8b99bb8ffdcbb 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -53,7 +53,7 @@ type GrantExec struct { } // Next implements the Executor Next interface. -func (e *GrantExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { if e.done { return nil } @@ -64,7 +64,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.RecordBatch) error { dbName = e.ctx.GetSessionVars().CurrentDB } // Grant for each user - for _, user := range e.Users { + for idx, user := range e.Users { // Check if user exists. exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname) if err != nil { @@ -79,7 +79,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.RecordBatch) error { } user := fmt.Sprintf(`('%s', '%s', '%s')`, user.User.Hostname, user.User.Username, pwd) sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password) VALUES %s;`, mysql.SystemDB, mysql.UserTable, user) - _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.TODO(), sql) + _, err := e.ctx.(sqlexec.SQLExecutor).Execute(ctx, sql) if err != nil { return err } @@ -105,6 +105,15 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.RecordBatch) error { if e.WithGrant { privs = append(privs, &ast.PrivElem{Priv: mysql.GrantPriv}) } + + if idx == 0 { + // Commit the old transaction, like DDL. + if err := e.ctx.NewTxn(ctx); err != nil { + return err + } + defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() + } + // Grant each priv to the user. for _, priv := range privs { if len(priv.Cols) > 0 { diff --git a/executor/index_lookup_join.go b/executor/index_lookup_join.go index 3d9371d14d396..556bf0d84b12a 100644 --- a/executor/index_lookup_join.go +++ b/executor/index_lookup_join.go @@ -90,6 +90,8 @@ type innerCtx struct { readerBuilder *dataReaderBuilder rowTypes []*types.FieldType keyCols []int + colLens []int + hasPrefixCol bool } type lookUpJoinTask struct { @@ -227,7 +229,7 @@ func (e *IndexLookUpJoin) newInnerWorker(taskCh chan *lookUpJoinTask) *innerWork } // Next implements the Executor interface. -func (e *IndexLookUpJoin) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *IndexLookUpJoin) Next(ctx context.Context, req *chunk.Chunk) error { if e.runtimeStats != nil { start := time.Now() defer func() { e.runtimeStats.Record(time.Since(start), req.NumRows()) }() @@ -253,7 +255,7 @@ func (e *IndexLookUpJoin) Next(ctx context.Context, req *chunk.RecordBatch) erro outerRow := task.outerResult.GetRow(task.cursor) if e.innerIter.Current() != e.innerIter.End() { - matched, isNull, err := e.joiner.tryToMatch(outerRow, e.innerIter, req.Chunk) + matched, isNull, err := e.joiner.tryToMatch(outerRow, e.innerIter, req) if err != nil { return err } @@ -262,7 +264,7 @@ func (e *IndexLookUpJoin) Next(ctx context.Context, req *chunk.RecordBatch) erro } if e.innerIter.Current() == e.innerIter.End() { if !task.hasMatch { - e.joiner.onMissMatch(task.hasNull, outerRow, req.Chunk) + e.joiner.onMissMatch(task.hasNull, outerRow, req) } task.cursor++ task.hasMatch = false @@ -298,9 +300,6 @@ func (e *IndexLookUpJoin) getFinishedTask(ctx context.Context) (*lookUpJoinTask, return nil, nil } - if e.task != nil { - e.task.memTracker.Detach() - } e.task = task return task, nil } @@ -365,11 +364,11 @@ func (ow *outerWorker) pushToChan(ctx context.Context, task *lookUpJoinTask, dst // buildTask builds a lookUpJoinTask and read outer rows. // When err is not nil, task must not be nil to send the error to the main thread via task. func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { - ow.executor.newFirstChunk() + newFirstChunk(ow.executor) task := &lookUpJoinTask{ doneCh: make(chan error, 1), - outerResult: ow.executor.newFirstChunk(), + outerResult: newFirstChunk(ow.executor), encodedLookUpKeys: chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeBlob)}, ow.ctx.GetSessionVars().MaxChunkSize), lookupMap: mvmap.NewMVMap(), } @@ -386,7 +385,7 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { task.memTracker.Consume(task.outerResult.MemoryUsage()) for !task.outerResult.IsFull() { - err := ow.executor.Next(ctx, chunk.NewRecordBatch(ow.executorChk)) + err := Next(ctx, ow.executor, ow.executorChk) if err != nil { return task, err } @@ -495,6 +494,16 @@ func (iw *innerWorker) constructLookupContent(task *lookUpJoinTask) ([]*indexJoi } // Store the encoded lookup key in chunk, so we can use it to lookup the matched inners directly. task.encodedLookUpKeys.AppendBytes(0, keyBuf) + if iw.hasPrefixCol { + for i := range iw.outerCtx.keyCols { + // If it's a prefix column. Try to fix it. + if iw.colLens[i] != types.UnspecifiedLength { + ranger.CutDatumByPrefixLen(&dLookUpKey[i], iw.colLens[i], iw.rowTypes[iw.keyCols[i]]) + } + } + // dLookUpKey is sorted and deduplicated at sortAndDedupLookUpContents. + // So we don't need to do it here. + } lookUpContents = append(lookUpContents, &indexJoinLookUpContent{keys: dLookUpKey, row: task.outerResult.GetRow(i)}) } @@ -582,11 +591,11 @@ func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTa return err } defer terror.Call(innerExec.Close) - innerResult := chunk.NewList(innerExec.retTypes(), iw.ctx.GetSessionVars().MaxChunkSize, iw.ctx.GetSessionVars().MaxChunkSize) + innerResult := chunk.NewList(retTypes(innerExec), iw.ctx.GetSessionVars().MaxChunkSize, iw.ctx.GetSessionVars().MaxChunkSize) innerResult.GetMemTracker().SetLabel(innerResultLabel) innerResult.GetMemTracker().AttachTo(task.memTracker) for { - err := innerExec.Next(ctx, chunk.NewRecordBatch(iw.executorChk)) + err := Next(ctx, innerExec, iw.executorChk) if err != nil { return err } @@ -594,7 +603,7 @@ func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTa break } innerResult.Add(iw.executorChk) - iw.executorChk = innerExec.newFirstChunk() + iw.executorChk = newFirstChunk(innerExec) } task.innerResult = innerResult return nil @@ -643,7 +652,6 @@ func (e *IndexLookUpJoin) Close() error { e.cancelFunc() } e.workerWg.Wait() - e.memTracker.Detach() e.memTracker = nil return e.children[0].Close() } diff --git a/executor/index_lookup_join_test.go b/executor/index_lookup_join_test.go index 60ebb5e1e811b..8f281c8497961 100644 --- a/executor/index_lookup_join_test.go +++ b/executor/index_lookup_join_test.go @@ -31,7 +31,7 @@ func (s *testSuite1) TestIndexLookupJoinHang(c *C) { rs, err := tk.Exec("select /*+ TIDB_INLJ(i)*/ * from idxJoinOuter o left join idxJoinInner i on o.a = i.a where o.a in (1, 2) and (i.a - 3) > 0") c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() for i := 0; i < 5; i++ { rs.Next(context.Background(), req) } @@ -67,11 +67,11 @@ func (s *testSuite1) TestIndexJoinUnionScan(c *C) { "│ └─TableReader_17 9990.00 root data:Selection_16", "│ └─Selection_16 9990.00 cop not(isnull(test.t1.a))", "│ └─TableScan_15 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", - "└─UnionScan_12 0.00 root not(isnull(test.t2.a))", - " └─IndexLookUp_11 0.00 root ", - " ├─Selection_10 0.00 cop not(isnull(test.t2.a))", + "└─UnionScan_12 9.99 root not(isnull(test.t2.a))", + " └─IndexLookUp_11 9.99 root ", + " ├─Selection_10 9.99 cop not(isnull(test.t2.a))", " │ └─IndexScan_8 10.00 cop table:t2, index:a, range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo", - " └─TableScan_9 0.00 cop table:t2, keep order:false, stats:pseudo", + " └─TableScan_9 9.99 cop table:t2, keep order:false, stats:pseudo", )) tk.MustQuery("select /*+ TIDB_INLJ(t1, t2)*/ * from t1 join t2 on t1.a = t2.a").Check(testkit.Rows( "2 2 2 2 2", @@ -85,9 +85,9 @@ func (s *testSuite1) TestIndexJoinUnionScan(c *C) { " │ └─TableReader_16 9990.00 root data:Selection_15", " │ └─Selection_15 9990.00 cop not(isnull(test.t1.a))", " │ └─TableScan_14 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", - " └─UnionScan_11 0.00 root not(isnull(test.t2.a))", - " └─IndexReader_10 0.00 root index:Selection_9", - " └─Selection_9 0.00 cop not(isnull(test.t2.a))", + " └─UnionScan_11 9.99 root not(isnull(test.t2.a))", + " └─IndexReader_10 9.99 root index:Selection_9", + " └─Selection_9 9.99 cop not(isnull(test.t2.a))", " └─IndexScan_8 10.00 cop table:t2, index:a, range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo", )) tk.MustQuery("select /*+ TIDB_INLJ(t1, t2)*/ t1.a, t2.a from t1 join t2 on t1.a = t2.a").Check(testkit.Rows( @@ -114,9 +114,9 @@ func (s *testSuite1) TestBatchIndexJoinUnionScan(c *C) { " │ └─TableReader_22 9990.00 root data:Selection_21", " │ └─Selection_21 9990.00 cop not(isnull(test.t1.a))", " │ └─TableScan_20 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", - " └─UnionScan_26 0.00 root not(isnull(test.t2.a))", - " └─IndexReader_25 0.00 root index:Selection_24", - " └─Selection_24 0.00 cop not(isnull(test.t2.a))", + " └─UnionScan_26 9.99 root not(isnull(test.t2.a))", + " └─IndexReader_25 9.99 root index:Selection_24", + " └─Selection_24 9.99 cop not(isnull(test.t2.a))", " └─IndexScan_23 10.00 cop table:t2, index:a, range: decided by [eq(test.t2.a, test.t1.a)], keep order:false, stats:pseudo", )) tk.MustQuery("select /*+ TIDB_INLJ(t1, t2)*/ count(*) from t1 join t2 on t1.a = t2.id").Check(testkit.Rows( @@ -152,3 +152,19 @@ func (s *testSuite) TestIndexJoinOverflow(c *C) { tk.MustExec(`create table t2(a int unsigned, index idx(a));`) tk.MustQuery(`select /*+ TIDB_INLJ(t2) */ * from t1 join t2 on t1.a = t2.a;`).Check(testkit.Rows()) } + +func (s *testSuite2) TestIssue11061(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1(c varchar(30), index ix_c(c(10)))") + tk.MustExec("insert into t1 (c) values('7_chars'), ('13_characters')") + tk.MustQuery("SELECT /*+ TIDB_INLJ(t1) */ SUM(LENGTH(c)) FROM t1 WHERE c IN (SELECT t1.c FROM t1)").Check(testkit.Rows("20")) +} + +func (s *testSuite2) TestIndexJoinPartitionTable(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int not null, c int, key idx(c)) partition by hash(b) partitions 30") + tk.MustExec("insert into t values(1, 27, 2)") + tk.MustQuery("SELECT /*+ TIDB_INLJ(t1) */ count(1) FROM t t1 INNER JOIN (SELECT a, max(c) AS c FROM t WHERE b = 27 AND a = 1 GROUP BY a) t2 ON t1.a = t2.a AND t1.c = t2.c WHERE t1.b = 27").Check(testkit.Rows("1")) +} diff --git a/executor/insert.go b/executor/insert.go index e65681157d2ec..810b9af85240b 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -15,12 +15,14 @@ package executor import ( "context" + "encoding/hex" "fmt" "github.com/opentracing/opentracing-go" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" @@ -32,8 +34,12 @@ import ( // InsertExec represents an insert executor. type InsertExec struct { *InsertValues - OnDuplicate []*expression.Assignment - Priority mysql.PriorityEnum + OnDuplicate []*expression.Assignment + evalBuffer4Dup chunk.MutRow + curInsertVals chunk.MutRow + row4Update []types.Datum + + Priority mysql.PriorityEnum } func (e *InsertExec) exec(ctx context.Context, rows [][]types.Datum) error { @@ -59,7 +65,7 @@ func (e *InsertExec) exec(ctx context.Context, rows [][]types.Datum) error { // If `ON DUPLICATE KEY UPDATE` is specified, and no `IGNORE` keyword, // the to-be-insert rows will be check on duplicate keys and update to the new rows. if len(e.OnDuplicate) > 0 { - err := e.batchUpdateDupRows(rows) + err := e.batchUpdateDupRows(ctx, rows) if err != nil { return err } @@ -78,64 +84,152 @@ func (e *InsertExec) exec(ctx context.Context, rows [][]types.Datum) error { return nil } +func prefetchUniqueIndices(txn kv.Transaction, rows []toBeCheckedRow) (map[string][]byte, error) { + nKeys := 0 + for _, r := range rows { + if r.handleKey != nil { + nKeys++ + } + nKeys += len(r.uniqueKeys) + } + batchKeys := make([]kv.Key, 0, nKeys) + for _, r := range rows { + if r.handleKey != nil { + batchKeys = append(batchKeys, r.handleKey.newKV.key) + } + for _, k := range r.uniqueKeys { + batchKeys = append(batchKeys, k.newKV.key) + } + } + return txn.BatchGet(batchKeys) +} + +func prefetchConflictedOldRows(ctx context.Context, txn kv.Transaction, rows []toBeCheckedRow, values map[string][]byte) error { + batchKeys := make([]kv.Key, 0, len(rows)) + for _, r := range rows { + for _, uk := range r.uniqueKeys { + if val, found := values[string(uk.newKV.key)]; found { + handle, err := tables.DecodeHandle(val) + if err != nil { + return err + } + batchKeys = append(batchKeys, r.t.RecordKey(handle)) + } + } + } + _, err := txn.BatchGet(batchKeys) + return err +} + +func prefetchDataCache(ctx context.Context, txn kv.Transaction, rows []toBeCheckedRow) error { + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("prefetchDataCache", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + values, err := prefetchUniqueIndices(txn, rows) + if err != nil { + return err + } + return prefetchConflictedOldRows(ctx, txn, rows, values) +} + +// updateDupRow updates a duplicate row to a new row. +func (e *InsertExec) updateDupRow(ctx context.Context, txn kv.Transaction, row toBeCheckedRow, handle int64, onDuplicate []*expression.Assignment) error { + oldRow, err := getOldRow(ctx, e.ctx, txn, row.t, handle, e.GenExprs) + if err != nil { + return err + } + + _, _, _, err = e.doDupRowUpdate(handle, oldRow, row.row, e.OnDuplicate) + if e.ctx.GetSessionVars().StmtCtx.DupKeyAsWarning && kv.ErrKeyExists.Equal(err) { + e.ctx.GetSessionVars().StmtCtx.AppendWarning(err) + return nil + } + return err +} + // batchUpdateDupRows updates multi-rows in batch if they are duplicate with rows in table. -func (e *InsertExec) batchUpdateDupRows(newRows [][]types.Datum) error { - err := e.batchGetInsertKeys(e.ctx, e.Table, newRows) +func (e *InsertExec) batchUpdateDupRows(ctx context.Context, newRows [][]types.Datum) error { + // Get keys need to be checked. + toBeCheckedRows, err := e.getKeysNeedCheck(e.ctx, e.Table, newRows) if err != nil { return err } - // Batch get the to-be-updated rows in storage. - err = e.initDupOldRowValue(e.ctx, e.Table, newRows) + txn, err := e.ctx.Txn(true) if err != nil { return err } - for i, r := range e.toBeCheckedRows { + // Use BatchGet to fill cache. + // It's an optimization and could be removed without affecting correctness. + if err = prefetchDataCache(ctx, txn, toBeCheckedRows); err != nil { + return err + } + + for i, r := range toBeCheckedRows { if r.handleKey != nil { - if _, found := e.dupKVs[string(r.handleKey.newKV.key)]; found { - handle, err := tablecodec.DecodeRowKey(r.handleKey.newKV.key) - if err != nil { - return err - } - err = e.updateDupRow(r, handle, e.OnDuplicate) - if err != nil { - return err - } + handle, err := tablecodec.DecodeRowKey(r.handleKey.newKV.key) + if err != nil { + return err + } + + err = e.updateDupRow(ctx, txn, r, handle, e.OnDuplicate) + if err == nil { continue } + if !kv.IsErrNotFound(err) { + return err + } } + for _, uk := range r.uniqueKeys { - if val, found := e.dupKVs[string(uk.newKV.key)]; found { - handle, err := tables.DecodeHandle(val) - if err != nil { - return err + val, err := txn.Get(uk.newKV.key) + if err != nil { + if kv.IsErrNotFound(err) { + continue } - err = e.updateDupRow(r, handle, e.OnDuplicate) - if err != nil { - return err + return err + } + handle, err := tables.DecodeHandle(val) + if err != nil { + return err + } + + err = e.updateDupRow(ctx, txn, r, handle, e.OnDuplicate) + if err != nil { + if kv.IsErrNotFound(err) { + // Data index inconsistent? A unique key provide the handle information, but the + // handle points to nothing. + logutil.Logger(ctx).Error("get old row failed when insert on dup", + zap.String("uniqueKey", hex.EncodeToString(uk.newKV.key)), + zap.Int64("handle", handle), + zap.String("toBeInsertedRow", types.DatumsToStrNoErr(r.row))) } - newRows[i] = nil - break + return err } + + newRows[i] = nil + break } + // If row was checked with no duplicate keys, // we should do insert the row, // and key-values should be filled back to dupOldRowValues for the further row check, // due to there may be duplicate keys inside the insert statement. if newRows[i] != nil { - newHandle, err := e.addRecord(newRows[i]) + _, err := e.addRecord(newRows[i]) if err != nil { return err } - e.fillBackKeys(e.Table, r, newHandle) } } return nil } // Next implements the Executor Next interface. -func (e *InsertExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *InsertExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("insert.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -160,6 +254,9 @@ func (e *InsertExec) Close() error { // Open implements the Executor Open interface. func (e *InsertExec) Open(ctx context.Context) error { + if e.OnDuplicate != nil { + e.initEvalBuffer4Dup() + } if e.SelectExec != nil { return e.SelectExec.Open(ctx) } @@ -167,23 +264,27 @@ func (e *InsertExec) Open(ctx context.Context) error { return nil } -// updateDupRow updates a duplicate row to a new row. -func (e *InsertExec) updateDupRow(row toBeCheckedRow, handle int64, onDuplicate []*expression.Assignment) error { - oldRow, err := e.getOldRow(e.ctx, e.Table, handle, e.GenExprs) - if err != nil { - logutil.Logger(context.Background()).Error("get old row failed when insert on dup", zap.Int64("handle", handle), zap.String("toBeInsertedRow", types.DatumsToStrNoErr(row.row))) - return err +func (e *InsertExec) initEvalBuffer4Dup() { + // Use public columns for new row. + numCols := len(e.Table.Cols()) + // Use writable columns for old row for update. + numWritableCols := len(e.Table.WritableCols()) + + evalBufferTypes := make([]*types.FieldType, 0, numCols+numWritableCols) + + // Append the old row before the new row, to be consistent with "Schema4OnDuplicate" in the "Insert" PhysicalPlan. + for _, col := range e.Table.WritableCols() { + evalBufferTypes = append(evalBufferTypes, &col.FieldType) } - // Do update row. - updatedRow, handleChanged, newHandle, err := e.doDupRowUpdate(handle, oldRow, row.row, onDuplicate) - if e.ctx.GetSessionVars().StmtCtx.DupKeyAsWarning && kv.ErrKeyExists.Equal(err) { - e.ctx.GetSessionVars().StmtCtx.AppendWarning(err) - return nil + for _, col := range e.Table.Cols() { + evalBufferTypes = append(evalBufferTypes, &col.FieldType) } - if err != nil { - return err + if e.hasExtraHandle { + evalBufferTypes = append(evalBufferTypes, types.NewFieldType(mysql.TypeLonglong)) } - return e.updateDupKeyValues(handle, newHandle, handleChanged, oldRow, updatedRow) + e.evalBuffer4Dup = chunk.MutRowFromTypes(evalBufferTypes) + e.curInsertVals = chunk.MutRowFromTypes(evalBufferTypes[numWritableCols:]) + e.row4Update = make([]types.Datum, 0, len(evalBufferTypes)) } // doDupRowUpdate updates the duplicate row. @@ -191,26 +292,32 @@ func (e *InsertExec) doDupRowUpdate(handle int64, oldRow []types.Datum, newRow [ cols []*expression.Assignment) ([]types.Datum, bool, int64, error) { assignFlag := make([]bool, len(e.Table.WritableCols())) // See http://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values - e.ctx.GetSessionVars().CurrInsertValues = chunk.MutRowFromDatums(newRow).ToRow() + e.curInsertVals.SetDatums(newRow...) + e.ctx.GetSessionVars().CurrInsertValues = e.curInsertVals.ToRow() // NOTE: In order to execute the expression inside the column assignment, // we have to put the value of "oldRow" before "newRow" in "row4Update" to // be consistent with "Schema4OnDuplicate" in the "Insert" PhysicalPlan. - row4Update := make([]types.Datum, 0, len(oldRow)+len(newRow)) - row4Update = append(row4Update, oldRow...) - row4Update = append(row4Update, newRow...) + e.row4Update = e.row4Update[:0] + e.row4Update = append(e.row4Update, oldRow...) + e.row4Update = append(e.row4Update, newRow...) // Update old row when the key is duplicated. + e.evalBuffer4Dup.SetDatums(e.row4Update...) for _, col := range cols { - val, err1 := col.Expr.Eval(chunk.MutRowFromDatums(row4Update).ToRow()) + val, err1 := col.Expr.Eval(e.evalBuffer4Dup.ToRow()) + if err1 != nil { + return nil, false, 0, err1 + } + e.row4Update[col.Col.Index], err1 = table.CastValue(e.ctx, val, col.Col.ToInfo()) if err1 != nil { return nil, false, 0, err1 } - row4Update[col.Col.Index] = val + e.evalBuffer4Dup.SetDatum(col.Col.Index, e.row4Update[col.Col.Index]) assignFlag[col.Col.Index] = true } - newData := row4Update[:len(oldRow)] + newData := e.row4Update[:len(oldRow)] _, handleChanged, newHandle, err := updateRecord(e.ctx, handle, oldRow, newData, assignFlag, e.Table, true) if err != nil { return nil, false, 0, err @@ -218,29 +325,6 @@ func (e *InsertExec) doDupRowUpdate(handle int64, oldRow []types.Datum, newRow [ return newData, handleChanged, newHandle, nil } -// updateDupKeyValues updates the dupKeyValues for further duplicate key check. -func (e *InsertExec) updateDupKeyValues(oldHandle int64, newHandle int64, - handleChanged bool, oldRow []types.Datum, updatedRow []types.Datum) error { - // There is only one row per update. - fillBackKeysInRows, err := e.getKeysNeedCheck(e.ctx, e.Table, [][]types.Datum{updatedRow}) - if err != nil { - return err - } - // Delete old keys and fill back new key-values of the updated row. - err = e.deleteDupKeys(e.ctx, e.Table, [][]types.Datum{oldRow}) - if err != nil { - return err - } - - if handleChanged { - delete(e.dupOldRowValues, string(e.Table.RecordKey(oldHandle))) - e.fillBackKeys(e.Table, fillBackKeysInRows[0], newHandle) - } else { - e.fillBackKeys(e.Table, fillBackKeysInRows[0], oldHandle) - } - return nil -} - // setMessage sets info message(ERR_INSERT_INFO) generated by INSERT statement func (e *InsertExec) setMessage() { stmtCtx := e.ctx.GetSessionVars().StmtCtx diff --git a/executor/insert_common.go b/executor/insert_common.go index a9bbded3b6664..3f79b9a687c5f 100644 --- a/executor/insert_common.go +++ b/executor/insert_common.go @@ -15,6 +15,7 @@ package executor import ( "context" + "math" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" @@ -57,6 +58,12 @@ type InsertValues struct { colDefaultVals []defaultVal evalBuffer chunk.MutRow evalBufferTypes []*types.FieldType + + // Fill the autoID lazily to datum. This is used for being compatible with JDBC using getGeneratedKeys(). + // `insert|replace values` can guarantee consecutive autoID in a batch. + // Other statements like `insert select from` don't guarantee consecutive autoID. + // https://dev.mysql.com/doc/refman/8.0/en/innodb-auto-increment-handling.html + lazyFillAutoID bool } type defaultVal struct { @@ -183,15 +190,22 @@ func (e *InsertValues) insertRows(ctx context.Context, exec func(ctx context.Con batchInsert := sessVars.BatchInsert && !sessVars.InTxn() batchSize := sessVars.DMLBatchSize + e.lazyFillAutoID = true + rows := make([][]types.Datum, 0, len(e.Lists)) for i, list := range e.Lists { e.rowCount++ - row, err := e.evalRow(list, i) + row, err := e.evalRow(ctx, list, i) if err != nil { return err } rows = append(rows, row) if batchInsert && e.rowCount%uint64(batchSize) == 0 { + // Before batch insert, fill the batch allocated autoIDs. + rows, err = e.lazyAdjustAutoIncrementDatum(ctx, rows) + if err != nil { + return err + } if err = exec(ctx, rows); err != nil { return err } @@ -201,6 +215,11 @@ func (e *InsertValues) insertRows(ctx context.Context, exec func(ctx context.Con } } } + // Fill the batch allocated autoIDs. + rows, err = e.lazyAdjustAutoIncrementDatum(ctx, rows) + if err != nil { + return err + } return exec(ctx, rows) } @@ -228,7 +247,7 @@ func (e *InsertValues) handleErr(col *table.Column, val *types.Datum, rowIdx int // evalRow evaluates a to-be-inserted row. The value of the column may base on another column, // so we use setValueForRefColumn to fill the empty row some default values when needFillDefaultValues is true. -func (e *InsertValues) evalRow(list []expression.Expression, rowIdx int) ([]types.Datum, error) { +func (e *InsertValues) evalRow(ctx context.Context, list []expression.Expression, rowIdx int) ([]types.Datum, error) { rowLen := len(e.Table.Cols()) if e.hasExtraHandle { rowLen++ @@ -258,8 +277,8 @@ func (e *InsertValues) evalRow(list []expression.Expression, rowIdx int) ([]type row[offset], hasValue[offset] = *val1.Copy(), true e.evalBuffer.SetDatum(offset, val1) } - - return e.fillRow(row, hasValue) + // Row may lack of generated column, autoIncrement column, empty column here. + return e.fillRow(ctx, row, hasValue) } // setValueForRefColumn set some default values for the row to eval the row value with other columns, @@ -295,8 +314,8 @@ func (e *InsertValues) setValueForRefColumn(row []types.Datum, hasValue []bool) func (e *InsertValues) insertRowsFromSelect(ctx context.Context, exec func(ctx context.Context, rows [][]types.Datum) error) error { // process `insert|replace into ... select ... from ...` selectExec := e.children[0] - fields := selectExec.retTypes() - chk := selectExec.newFirstChunk() + fields := retTypes(selectExec) + chk := newFirstChunk(selectExec) iter := chunk.NewIterator4Chunk(chk) rows := make([][]types.Datum, 0, chk.Capacity()) @@ -309,7 +328,7 @@ func (e *InsertValues) insertRowsFromSelect(ctx context.Context, exec func(ctx c batchSize := sessVars.DMLBatchSize for { - err := selectExec.Next(ctx, chunk.NewRecordBatch(chk)) + err := selectExec.Next(ctx, chk) if err != nil { return err } @@ -320,7 +339,7 @@ func (e *InsertValues) insertRowsFromSelect(ctx context.Context, exec func(ctx c for innerChunkRow := iter.Begin(); innerChunkRow != iter.End(); innerChunkRow = iter.Next() { innerRow := types.CloneRow(innerChunkRow.GetDatumRow(fields)) e.rowCount++ - row, err := e.getRow(innerRow) + row, err := e.getRow(ctx, innerRow) if err != nil { return err } @@ -361,7 +380,7 @@ func (e *InsertValues) doBatchInsert(ctx context.Context) error { // getRow gets the row which from `insert into select from` or `load data`. // The input values from these two statements are datums instead of // expressions which are used in `insert into set x=y`. -func (e *InsertValues) getRow(vals []types.Datum) ([]types.Datum, error) { +func (e *InsertValues) getRow(ctx context.Context, vals []types.Datum) ([]types.Datum, error) { row := make([]types.Datum, len(e.Table.Cols())) hasValue := make([]bool, len(e.Table.Cols())) for i, v := range vals { @@ -375,7 +394,7 @@ func (e *InsertValues) getRow(vals []types.Datum) ([]types.Datum, error) { hasValue[offset] = true } - return e.fillRow(row, hasValue) + return e.fillRow(ctx, row, hasValue) } func (e *InsertValues) filterErr(err error) error { @@ -409,10 +428,18 @@ func (e *InsertValues) getColDefaultValue(idx int, col *table.Column) (d types.D } // fillColValue fills the column value if it is not set in the insert statement. -func (e *InsertValues) fillColValue(datum types.Datum, idx int, column *table.Column, hasValue bool) (types.Datum, +func (e *InsertValues) fillColValue(ctx context.Context, datum types.Datum, idx int, column *table.Column, hasValue bool) (types.Datum, error) { if mysql.HasAutoIncrementFlag(column.Flag) { - d, err := e.adjustAutoIncrementDatum(datum, hasValue, column) + if e.lazyFillAutoID { + // Handle hasValue info in autoIncrement column previously for lazy handle. + if !hasValue { + datum.SetNull() + } + // Store the plain datum of autoIncrement column directly for lazy handle. + return datum, nil + } + d, err := e.adjustAutoIncrementDatum(ctx, datum, hasValue, column) if err != nil { return types.Datum{}, err } @@ -430,12 +457,16 @@ func (e *InsertValues) fillColValue(datum types.Datum, idx int, column *table.Co // fillRow fills generated columns, auto_increment column and empty column. // For NOT NULL column, it will return error or use zero value based on sql_mode. -func (e *InsertValues) fillRow(row []types.Datum, hasValue []bool) ([]types.Datum, error) { +// When lazyFillAutoID is true, fill row will lazily handle auto increment datum for lazy batch allocation. +// `insert|replace values` can guarantee consecutive autoID in a batch. +// Other statements like `insert select from` don't guarantee consecutive autoID. +// https://dev.mysql.com/doc/refman/8.0/en/innodb-auto-increment-handling.html +func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue []bool) ([]types.Datum, error) { gIdx := 0 for i, c := range e.Table.Cols() { var err error // Get the default value for all no value columns, the auto increment column is different from the others. - row[i], err = e.fillColValue(row[i], i, c, hasValue[i]) + row[i], err = e.fillColValue(ctx, row[i], i, c, hasValue[i]) if err != nil { return nil, err } @@ -453,16 +484,175 @@ func (e *InsertValues) fillRow(row []types.Datum, hasValue []bool) ([]types.Datu return nil, err } } + // Handle the bad null error. Cause generated column with `not null` flag will get default value datum in fillColValue + // which should be override by generated expr first, then handle the bad null logic here. + if !e.lazyFillAutoID || (e.lazyFillAutoID && !mysql.HasAutoIncrementFlag(c.Flag)) { + if row[i], err = c.HandleBadNull(row[i], e.ctx.GetSessionVars().StmtCtx); err != nil { + return nil, err + } + } + } + return row, nil +} + +// isAutoNull can help judge whether a datum is AutoIncrement Null quickly. +// This used to help lazyFillAutoIncrement to find consecutive N datum backwards for batch autoID alloc. +func (e *InsertValues) isAutoNull(ctx context.Context, d types.Datum, col *table.Column) bool { + var err error + var recordID int64 + if !d.IsNull() { + recordID, err = getAutoRecordID(d, &col.FieldType, true) + if err != nil { + return false + } + } + // Use the value if it's not null and not 0. + if recordID != 0 { + return false + } + // Change NULL to auto id. + // Change value 0 to auto id, if NoAutoValueOnZero SQL mode is not set. + if d.IsNull() || e.ctx.GetSessionVars().SQLMode&mysql.ModeNoAutoValueOnZero == 0 { + return true + } + return false +} + +func (e *InsertValues) hasAutoIncrementColumn() (int, bool) { + colIdx := -1 + for i, c := range e.Table.Cols() { + if mysql.HasAutoIncrementFlag(c.Flag) { + colIdx = i + break + } + } + return colIdx, colIdx != -1 +} + +func (e *InsertValues) lazyAdjustAutoIncrementDatumInRetry(ctx context.Context, rows [][]types.Datum, colIdx int) ([][]types.Datum, error) { + // Get the autoIncrement column. + col := e.Table.Cols()[colIdx] + // Consider the colIdx of autoIncrement in row are the same. + length := len(rows) + for i := 0; i < length; i++ { + autoDatum := rows[i][colIdx] + + // autoID can be found in RetryInfo. + retryInfo := e.ctx.GetSessionVars().RetryInfo + if retryInfo.Retrying { + id, err := retryInfo.GetCurrAutoIncrementID() + if err != nil { + return nil, err + } + autoDatum.SetAutoID(id, col.Flag) + + if autoDatum, err = col.HandleBadNull(autoDatum, e.ctx.GetSessionVars().StmtCtx); err != nil { + return nil, err + } + rows[i][colIdx] = autoDatum + } + } + return rows, nil +} - // Handle the bad null error. - if row[i], err = c.HandleBadNull(row[i], e.ctx.GetSessionVars().StmtCtx); err != nil { +// lazyAdjustAutoIncrementDatum is quite similar to adjustAutoIncrementDatum +// except it will cache auto increment datum previously for lazy batch allocation of autoID. +func (e *InsertValues) lazyAdjustAutoIncrementDatum(ctx context.Context, rows [][]types.Datum) ([][]types.Datum, error) { + // Not in lazyFillAutoID mode means no need to fill. + if !e.lazyFillAutoID { + return rows, nil + } + // No autoIncrement column means no need to fill. + colIdx, ok := e.hasAutoIncrementColumn() + if !ok { + return rows, nil + } + // autoID can be found in RetryInfo. + retryInfo := e.ctx.GetSessionVars().RetryInfo + if retryInfo.Retrying { + return e.lazyAdjustAutoIncrementDatumInRetry(ctx, rows, colIdx) + } + // Get the autoIncrement column. + col := e.Table.Cols()[colIdx] + // Consider the colIdx of autoIncrement in row are the same. + length := len(rows) + for i := 0; i < length; i++ { + autoDatum := rows[i][colIdx] + + var err error + var recordID int64 + if !autoDatum.IsNull() { + recordID, err = getAutoRecordID(autoDatum, &col.FieldType, true) + if err != nil { + return nil, err + } + } + // Use the value if it's not null and not 0. + if recordID != 0 { + err = e.Table.RebaseAutoID(e.ctx, recordID, true) + if err != nil { + return nil, err + } + e.ctx.GetSessionVars().StmtCtx.InsertID = uint64(recordID) + retryInfo.AddAutoIncrementID(recordID) + rows[i][colIdx] = autoDatum + continue + } + + // Change NULL to auto id. + // Change value 0 to auto id, if NoAutoValueOnZero SQL mode is not set. + if autoDatum.IsNull() || e.ctx.GetSessionVars().SQLMode&mysql.ModeNoAutoValueOnZero == 0 { + // Find consecutive num. + start := i + cnt := 1 + for i+1 < length && e.isAutoNull(ctx, rows[i+1][colIdx], col) { + i++ + cnt++ + } + // Alloc batch N consecutive (min, max] autoIDs. + // max value can be derived from adding one for cnt times. + min, _, err := table.AllocBatchAutoIncrementValue(ctx, e.Table, e.ctx, cnt) + if e.filterErr(err) != nil { + return nil, err + } + // It's compatible with mysql setting the first allocated autoID to lastInsertID. + // Cause autoID may be specified by user, judge only the first row is not suitable. + if e.lastInsertID == 0 { + e.lastInsertID = uint64(min) + 1 + } + // Assign autoIDs to rows. + for j := 0; j < cnt; j++ { + offset := j + start + d := rows[offset][colIdx] + + id := int64(uint64(min) + uint64(j) + 1) + d.SetAutoID(id, col.Flag) + retryInfo.AddAutoIncrementID(id) + + // The value of d is adjusted by auto ID, so we need to cast it again. + d, err := table.CastValue(e.ctx, d, col.ToInfo()) + if err != nil { + return nil, err + } + rows[offset][colIdx] = d + } + continue + } + + autoDatum.SetAutoID(recordID, col.Flag) + retryInfo.AddAutoIncrementID(recordID) + + // the value of d is adjusted by auto ID, so we need to cast it again. + autoDatum, err = table.CastValue(e.ctx, autoDatum, col.ToInfo()) + if err != nil { return nil, err } + rows[i][colIdx] = autoDatum } - return row, nil + return rows, nil } -func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c *table.Column) (types.Datum, error) { +func (e *InsertValues) adjustAutoIncrementDatum(ctx context.Context, d types.Datum, hasValue bool, c *table.Column) (types.Datum, error) { retryInfo := e.ctx.GetSessionVars().RetryInfo if retryInfo.Retrying { id, err := retryInfo.GetCurrAutoIncrementID() @@ -479,12 +669,10 @@ func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c d.SetNull() } if !d.IsNull() { - sc := e.ctx.GetSessionVars().StmtCtx - datum, err1 := d.ConvertTo(sc, &c.FieldType) - if e.filterErr(err1) != nil { - return types.Datum{}, err1 + recordID, err = getAutoRecordID(d, &c.FieldType, true) + if err != nil { + return types.Datum{}, err } - recordID = datum.GetInt64() } // Use the value if it's not null and not 0. if recordID != 0 { @@ -494,19 +682,19 @@ func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c } e.ctx.GetSessionVars().StmtCtx.InsertID = uint64(recordID) retryInfo.AddAutoIncrementID(recordID) - d.SetAutoID(recordID, c.Flag) return d, nil } // Change NULL to auto id. // Change value 0 to auto id, if NoAutoValueOnZero SQL mode is not set. if d.IsNull() || e.ctx.GetSessionVars().SQLMode&mysql.ModeNoAutoValueOnZero == 0 { - recordID, err = e.Table.AllocAutoID(e.ctx) + recordID, err = table.AllocAutoIncrementValue(ctx, e.Table, e.ctx) if e.filterErr(err) != nil { return types.Datum{}, err } - // It's compatible with mysql. So it sets last insert id to the first row. - if e.rowCount == 1 { + // It's compatible with mysql setting the first allocated autoID to lastInsertID. + // Cause autoID may be specified by user, judge only the first row is not suitable. + if e.lastInsertID == 0 { e.lastInsertID = uint64(recordID) } } @@ -522,6 +710,26 @@ func (e *InsertValues) adjustAutoIncrementDatum(d types.Datum, hasValue bool, c return casted, nil } +func getAutoRecordID(d types.Datum, target *types.FieldType, isInsert bool) (int64, error) { + var recordID int64 + + switch target.Tp { + case mysql.TypeFloat, mysql.TypeDouble: + f := d.GetFloat64() + if isInsert { + recordID = int64(math.Round(f)) + } else { + recordID = int64(f) + } + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + recordID = d.GetInt64() + default: + return 0, errors.Errorf("unexpected field type [%v]", target.Tp) + } + + return recordID, nil +} + func (e *InsertValues) handleWarning(err error) { sc := e.ctx.GetSessionVars().StmtCtx sc.AppendWarning(err) @@ -532,26 +740,46 @@ func (e *InsertValues) handleWarning(err error) { func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, addRecord func(row []types.Datum) (int64, error)) error { // all the rows will be checked, so it is safe to set BatchCheck = true e.ctx.GetSessionVars().StmtCtx.BatchCheck = true - err := e.batchGetInsertKeys(e.ctx, e.Table, rows) + + // Get keys need to be checked. + toBeCheckedRows, err := e.getKeysNeedCheck(e.ctx, e.Table, rows) if err != nil { return err } - // append warnings and get no duplicated error rows - for i, r := range e.toBeCheckedRows { + + txn, err := e.ctx.Txn(true) + if err != nil { + return err + } + + // Fill cache using BatchGet, the following Get requests don't need to visit TiKV. + if _, err = prefetchUniqueIndices(txn, toBeCheckedRows); err != nil { + return err + } + + for i, r := range toBeCheckedRows { + // skip := false if r.handleKey != nil { - if _, found := e.dupKVs[string(r.handleKey.newKV.key)]; found { - rows[i] = nil + _, err := txn.Get(r.handleKey.newKV.key) + if err == nil { e.ctx.GetSessionVars().StmtCtx.AppendWarning(r.handleKey.dupErr) continue } + if !kv.IsErrNotFound(err) { + return err + } } for _, uk := range r.uniqueKeys { - if _, found := e.dupKVs[string(uk.newKV.key)]; found { + _, err := txn.Get(uk.newKV.key) + if err == nil { // If duplicate keys were found in BatchGet, mark row = nil. rows[i] = nil e.ctx.GetSessionVars().StmtCtx.AppendWarning(uk.dupErr) break } + if !kv.IsErrNotFound(err) { + return err + } } // If row was checked with no duplicate keys, // it should be add to values map for the further row check. @@ -562,12 +790,6 @@ func (e *InsertValues) batchCheckAndInsert(rows [][]types.Datum, addRecord func( if err != nil { return err } - if r.handleKey != nil { - e.dupKVs[string(r.handleKey.newKV.key)] = r.handleKey.newKV.value - } - for _, uk := range r.uniqueKeys { - e.dupKVs[string(uk.newKV.key)] = []byte{} - } } } return nil diff --git a/executor/insert_test.go b/executor/insert_test.go index 8c3510b044126..1a8e9c7de99eb 100644 --- a/executor/insert_test.go +++ b/executor/insert_test.go @@ -15,10 +15,13 @@ package executor_test import ( "fmt" + "strings" . "github.com/pingcap/check" "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/testkit" ) @@ -286,3 +289,476 @@ func (s *testSuite3) TestAllowInvalidDates(c *C) { runWithMode("STRICT_TRANS_TABLES,ALLOW_INVALID_DATES") runWithMode("ALLOW_INVALID_DATES") } + +func (s *testSuite3) TestInsertWithAutoidSchema(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1(id int primary key auto_increment, n int);`) + tk.MustExec(`create table t2(id int unsigned primary key auto_increment, n int);`) + tk.MustExec(`create table t3(id tinyint primary key auto_increment, n int);`) + tk.MustExec(`create table t4(id int primary key, n float auto_increment, key I_n(n));`) + tk.MustExec(`create table t5(id int primary key, n float unsigned auto_increment, key I_n(n));`) + tk.MustExec(`create table t6(id int primary key, n double auto_increment, key I_n(n));`) + tk.MustExec(`create table t7(id int primary key, n double unsigned auto_increment, key I_n(n));`) + // test for inserting multiple values + tk.MustExec(`create table t8(id int primary key auto_increment, n int);`) + + tests := []struct { + insert string + query string + result [][]interface{} + }{ + { + `insert into t1(id, n) values(1, 1)`, + `select * from t1 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t1(n) values(2)`, + `select * from t1 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t1(n) values(3)`, + `select * from t1 where id = 3`, + testkit.Rows(`3 3`), + }, + { + `insert into t1(id, n) values(-1, 4)`, + `select * from t1 where id = -1`, + testkit.Rows(`-1 4`), + }, + { + `insert into t1(n) values(5)`, + `select * from t1 where id = 4`, + testkit.Rows(`4 5`), + }, + { + `insert into t1(id, n) values('5', 6)`, + `select * from t1 where id = 5`, + testkit.Rows(`5 6`), + }, + { + `insert into t1(n) values(7)`, + `select * from t1 where id = 6`, + testkit.Rows(`6 7`), + }, + { + `insert into t1(id, n) values(7.4, 8)`, + `select * from t1 where id = 7`, + testkit.Rows(`7 8`), + }, + { + `insert into t1(id, n) values(7.5, 9)`, + `select * from t1 where id = 8`, + testkit.Rows(`8 9`), + }, + { + `insert into t1(n) values(9)`, + `select * from t1 where id = 9`, + testkit.Rows(`9 9`), + }, + // test last insert id + { + `insert into t1 values(3000, -1), (null, -2)`, + `select * from t1 where id = 3000`, + testkit.Rows(`3000 -1`), + }, + { + `;`, + `select * from t1 where id = 3001`, + testkit.Rows(`3001 -2`), + }, + { + `;`, + `select last_insert_id()`, + testkit.Rows(`3001`), + }, + { + `insert into t2(id, n) values(1, 1)`, + `select * from t2 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t2(n) values(2)`, + `select * from t2 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t2(n) values(3)`, + `select * from t2 where id = 3`, + testkit.Rows(`3 3`), + }, + { + `insert into t3(id, n) values(1, 1)`, + `select * from t3 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t3(n) values(2)`, + `select * from t3 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t3(n) values(3)`, + `select * from t3 where id = 3`, + testkit.Rows(`3 3`), + }, + { + `insert into t3(id, n) values(-1, 4)`, + `select * from t3 where id = -1`, + testkit.Rows(`-1 4`), + }, + { + `insert into t3(n) values(5)`, + `select * from t3 where id = 4`, + testkit.Rows(`4 5`), + }, + { + `insert into t4(id, n) values(1, 1)`, + `select * from t4 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t4(id) values(2)`, + `select * from t4 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t4(id, n) values(3, -1)`, + `select * from t4 where id = 3`, + testkit.Rows(`3 -1`), + }, + { + `insert into t4(id) values(4)`, + `select * from t4 where id = 4`, + testkit.Rows(`4 3`), + }, + { + `insert into t4(id, n) values(5, 5.5)`, + `select * from t4 where id = 5`, + testkit.Rows(`5 5.5`), + }, + { + `insert into t4(id) values(6)`, + `select * from t4 where id = 6`, + testkit.Rows(`6 7`), + }, + { + `insert into t4(id, n) values(7, '7.7')`, + `select * from t4 where id = 7`, + testkit.Rows(`7 7.7`), + }, + { + `insert into t4(id) values(8)`, + `select * from t4 where id = 8`, + testkit.Rows(`8 9`), + }, + { + `insert into t4(id, n) values(9, 10.4)`, + `select * from t4 where id = 9`, + testkit.Rows(`9 10.4`), + }, + { + `insert into t4(id) values(10)`, + `select * from t4 where id = 10`, + testkit.Rows(`10 11`), + }, + { + `insert into t5(id, n) values(1, 1)`, + `select * from t5 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t5(id) values(2)`, + `select * from t5 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t5(id) values(3)`, + `select * from t5 where id = 3`, + testkit.Rows(`3 3`), + }, + { + `insert into t6(id, n) values(1, 1)`, + `select * from t6 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t6(id) values(2)`, + `select * from t6 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t6(id, n) values(3, -1)`, + `select * from t6 where id = 3`, + testkit.Rows(`3 -1`), + }, + { + `insert into t6(id) values(4)`, + `select * from t6 where id = 4`, + testkit.Rows(`4 3`), + }, + { + `insert into t6(id, n) values(5, 5.5)`, + `select * from t6 where id = 5`, + testkit.Rows(`5 5.5`), + }, + { + `insert into t6(id) values(6)`, + `select * from t6 where id = 6`, + testkit.Rows(`6 7`), + }, + { + `insert into t6(id, n) values(7, '7.7')`, + `select * from t4 where id = 7`, + testkit.Rows(`7 7.7`), + }, + { + `insert into t6(id) values(8)`, + `select * from t4 where id = 8`, + testkit.Rows(`8 9`), + }, + { + `insert into t6(id, n) values(9, 10.4)`, + `select * from t6 where id = 9`, + testkit.Rows(`9 10.4`), + }, + { + `insert into t6(id) values(10)`, + `select * from t6 where id = 10`, + testkit.Rows(`10 11`), + }, + { + `insert into t7(id, n) values(1, 1)`, + `select * from t7 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `insert into t7(id) values(2)`, + `select * from t7 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `insert into t7(id) values(3)`, + `select * from t7 where id = 3`, + testkit.Rows(`3 3`), + }, + + // the following is test for insert multiple values. + { + `insert into t8(n) values(1),(2)`, + `select * from t8 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `;`, + `select * from t8 where id = 2`, + testkit.Rows(`2 2`), + }, + { + `;`, + `select last_insert_id();`, + testkit.Rows(`1`), + }, + // test user rebase and auto alloc mixture. + { + `insert into t8 values(null, 3),(-1, -1),(null,4),(null, 5)`, + `select * from t8 where id = 3`, + testkit.Rows(`3 3`), + }, + // -1 won't rebase allocator here cause -1 < base. + { + `;`, + `select * from t8 where id = -1`, + testkit.Rows(`-1 -1`), + }, + { + `;`, + `select * from t8 where id = 4`, + testkit.Rows(`4 4`), + }, + { + `;`, + `select * from t8 where id = 5`, + testkit.Rows(`5 5`), + }, + { + `;`, + `select last_insert_id();`, + testkit.Rows(`3`), + }, + { + `insert into t8 values(null, 6),(10, 7),(null, 8)`, + `select * from t8 where id = 6`, + testkit.Rows(`6 6`), + }, + // 10 will rebase allocator here. + { + `;`, + `select * from t8 where id = 10`, + testkit.Rows(`10 7`), + }, + { + `;`, + `select * from t8 where id = 11`, + testkit.Rows(`11 8`), + }, + { + `;`, + `select last_insert_id()`, + testkit.Rows(`6`), + }, + // fix bug for last_insert_id should be first allocated id in insert rows (skip the rebase id). + { + `insert into t8 values(100, 9),(null,10),(null,11)`, + `select * from t8 where id = 100`, + testkit.Rows(`100 9`), + }, + { + `;`, + `select * from t8 where id = 101`, + testkit.Rows(`101 10`), + }, + { + `;`, + `select * from t8 where id = 102`, + testkit.Rows(`102 11`), + }, + { + `;`, + `select last_insert_id()`, + testkit.Rows(`101`), + }, + // test with sql_mode: NO_AUTO_VALUE_ON_ZERO. + { + `;`, + `select @@sql_mode`, + testkit.Rows(`ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION`), + }, + { + `;`, + "set session sql_mode = `ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION,NO_AUTO_VALUE_ON_ZERO`", + nil, + }, + { + `insert into t8 values (0, 12), (null, 13)`, + `select * from t8 where id = 0`, + testkit.Rows(`0 12`), + }, + { + `;`, + `select * from t8 where id = 103`, + testkit.Rows(`103 13`), + }, + { + `;`, + `select last_insert_id()`, + testkit.Rows(`103`), + }, + // test without sql_mode: NO_AUTO_VALUE_ON_ZERO. + { + `;`, + "set session sql_mode = `ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION`", + nil, + }, + // value 0 will be substitute by autoid. + { + `insert into t8 values (0, 14), (null, 15)`, + `select * from t8 where id = 104`, + testkit.Rows(`104 14`), + }, + { + `;`, + `select * from t8 where id = 105`, + testkit.Rows(`105 15`), + }, + { + `;`, + `select last_insert_id()`, + testkit.Rows(`104`), + }, + // last test : auto increment allocation can find in retryInfo. + { + `retry : insert into t8 values (null, 16), (null, 17)`, + `select * from t8 where id = 1000`, + testkit.Rows(`1000 16`), + }, + { + `;`, + `select * from t8 where id = 1001`, + testkit.Rows(`1001 17`), + }, + { + `;`, + `select last_insert_id()`, + // this insert doesn't has the last_insert_id, should be same as the last insert case. + testkit.Rows(`104`), + }, + } + + for _, tt := range tests { + if strings.HasPrefix(tt.insert, "retry : ") { + // it's the last retry insert case, change the sessionVars. + retryInfo := &variable.RetryInfo{Retrying: true} + retryInfo.AddAutoIncrementID(1000) + retryInfo.AddAutoIncrementID(1001) + tk.Se.GetSessionVars().RetryInfo = retryInfo + tk.MustExec(tt.insert[8:]) + tk.Se.GetSessionVars().RetryInfo = &variable.RetryInfo{} + } else { + tk.MustExec(tt.insert) + } + if tt.query == "set session sql_mode = `ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION,NO_AUTO_VALUE_ON_ZERO`" || + tt.query == "set session sql_mode = `ONLY_FULL_GROUP_BY,STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION`" { + tk.MustExec(tt.query) + } else { + tk.MustQuery(tt.query).Check(tt.result) + } + } + +} + +func (s *testSuite3) TestPartitionInsertOnDuplicate(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a int,b int,primary key(a,b)) partition by range(a) (partition p0 values less than (100),partition p1 values less than (1000))`) + tk.MustExec(`insert into t1 set a=1, b=1`) + tk.MustExec(`insert into t1 set a=1,b=1 on duplicate key update a=1,b=1`) + tk.MustQuery(`select * from t1`).Check(testkit.Rows("1 1")) + + tk.MustExec(`create table t2 (a int,b int,primary key(a,b)) partition by hash(a) partitions 4`) + tk.MustExec(`insert into t2 set a=1,b=1;`) + tk.MustExec(`insert into t2 set a=1,b=1 on duplicate key update a=1,b=1`) + tk.MustQuery(`select * from t2`).Check(testkit.Rows("1 1")) + + tk.MustExec(`CREATE TABLE t3 (a int, b int, c int, d int, e int, + PRIMARY KEY (a,b), + UNIQUE KEY (b,c,d) +) PARTITION BY RANGE ( b ) ( + PARTITION p0 VALUES LESS THAN (4), + PARTITION p1 VALUES LESS THAN (7), + PARTITION p2 VALUES LESS THAN (11) +)`) + tk.MustExec("insert into t3 values (1,2,3,4,5)") + tk.MustExec("insert into t3 values (1,2,3,4,5),(6,2,3,4,6) on duplicate key update e = e + values(e)") + tk.MustQuery("select * from t3").Check(testkit.Rows("1 2 3 4 16")) +} + +func (s *testSuite3) TestBit(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1 (a bit(3))`) + _, err := tk.Exec("insert into t1 values(-1)") + c.Assert(types.ErrDataTooLong.Equal(err), IsTrue) + c.Assert(err.Error(), Matches, ".*Data too long for column 'a' at.*") + _, err = tk.Exec("insert into t1 values(9)") + c.Assert(err.Error(), Matches, ".*Data too long for column 'a' at.*") + + tk.MustExec(`create table t64 (a bit(64))`) + tk.MustExec("insert into t64 values(-1)") + tk.MustExec("insert into t64 values(18446744073709551615)") // 2^64 - 1 + _, err = tk.Exec("insert into t64 values(18446744073709551616)") // z^64 + c.Assert(err.Error(), Matches, ".*Out of range value for column 'a' at.*") + +} diff --git a/executor/join.go b/executor/join.go index 3352a3307161a..83c2619932c1b 100644 --- a/executor/join.go +++ b/executor/join.go @@ -52,7 +52,7 @@ type HashJoinExec struct { prepared bool // concurrency is the number of partition, build and join workers. concurrency uint - globalHashTable *mvmap.MVMap + hashTable *mvmap.MVMap innerFinished chan error hashJoinBuffers []*hashJoinBuffer // joinWorkerWaitGroup is for sync multiple join workers. @@ -135,7 +135,6 @@ func (e *HashJoinExec) Close() error { e.outerChkResourceCh = nil e.joinChkResourceCh = nil } - e.memTracker.Detach() e.memTracker = nil err := e.baseExecutor.Close() @@ -162,11 +161,6 @@ func (e *HashJoinExec) Open(ctx context.Context) error { e.hashJoinBuffers = append(e.hashJoinBuffers, buffer) } - e.innerKeyColIdx = make([]int, len(e.innerKeys)) - for i := range e.innerKeys { - e.innerKeyColIdx[i] = e.innerKeys[i].Index - } - e.closeCh = make(chan struct{}) e.finished.Store(false) e.joinWorkerWaitGroup = sync.WaitGroup{} @@ -178,10 +172,10 @@ func (e *HashJoinExec) getJoinKeyFromChkRow(isOuterKey bool, row chunk.Row, keyB var allTypes []*types.FieldType if isOuterKey { keyColIdx = e.outerKeyColIdx - allTypes = e.outerExec.retTypes() + allTypes = retTypes(e.outerExec) } else { keyColIdx = e.innerKeyColIdx - allTypes = e.innerExec.retTypes() + allTypes = retTypes(e.innerExec) } for _, i := range keyColIdx { @@ -202,6 +196,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { if e.finished.Load().(bool) { return } + var outerResource *outerChkResource var ok bool select { @@ -217,7 +212,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { required := int(atomic.LoadInt64(&e.requiredRows)) outerResult.SetRequiredRows(required, e.maxChunkSize) } - err := e.outerExec.Next(ctx, chunk.NewRecordBatch(outerResult)) + err := Next(ctx, e.outerExec, outerResult) if err != nil { e.joinResultCh <- &hashjoinWorkerResult{ err: err, @@ -244,6 +239,7 @@ func (e *HashJoinExec) fetchOuterChunks(ctx context.Context) { if outerResult.NumRows() == 0 { return } + outerResource.dest <- outerResult } } @@ -257,7 +253,7 @@ func (e *HashJoinExec) wait4Inner() (finished bool, err error) { return false, err } } - if e.innerResult.Len() == 0 && (e.joinType == plannercore.InnerJoin || e.joinType == plannercore.SemiJoin) { + if e.hashTable.Len() == 0 && (e.joinType == plannercore.InnerJoin || e.joinType == plannercore.SemiJoin) { return true, nil } return false, nil @@ -265,23 +261,41 @@ func (e *HashJoinExec) wait4Inner() (finished bool, err error) { var innerResultLabel fmt.Stringer = stringutil.StringerStr("innerResult") -// fetchInnerRows fetches all rows from inner executor, and append them to -// e.innerResult. -func (e *HashJoinExec) fetchInnerRows(ctx context.Context) error { - e.innerResult = chunk.NewList(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) +// fetchInnerRows fetches all rows from inner executor, +// and append them to e.innerResult. +func (e *HashJoinExec) fetchInnerRows(ctx context.Context, chkCh chan<- *chunk.Chunk, doneCh chan struct{}) { + defer close(chkCh) + e.innerResult = chunk.NewList(e.innerExec.base().retFieldTypes, e.initCap, e.maxChunkSize) e.innerResult.GetMemTracker().AttachTo(e.memTracker) e.innerResult.GetMemTracker().SetLabel(innerResultLabel) var err error for { - if e.finished.Load().(bool) { - return nil - } - chk := e.children[e.innerIdx].newFirstChunk() - err = e.innerExec.Next(ctx, chunk.NewRecordBatch(chk)) - if err != nil || chk.NumRows() == 0 { - return err + select { + case <-doneCh: + return + case <-e.closeCh: + return + default: + if e.finished.Load().(bool) { + return + } + chk := chunk.NewChunkWithCapacity(e.children[e.innerIdx].base().retFieldTypes, e.ctx.GetSessionVars().MaxChunkSize) + err = e.innerExec.Next(ctx, chk) + if err != nil { + e.innerFinished <- errors.Trace(err) + return + } + if chk.NumRows() == 0 { + return + } + select { + case chkCh <- chk: + break + case <-e.closeCh: + return + } + e.innerResult.Add(chk) } - e.innerResult.Add(chk) } } @@ -299,7 +313,7 @@ func (e *HashJoinExec) initializeForProbe() { e.outerChkResourceCh = make(chan *outerChkResource, e.concurrency) for i := uint(0); i < e.concurrency; i++ { e.outerChkResourceCh <- &outerChkResource{ - chk: e.outerExec.newFirstChunk(), + chk: newFirstChunk(e.outerExec), dest: e.outerResultChs[i], } } @@ -309,7 +323,7 @@ func (e *HashJoinExec) initializeForProbe() { e.joinChkResourceCh = make([]chan *chunk.Chunk, e.concurrency) for i := uint(0); i < e.concurrency; i++ { e.joinChkResourceCh[i] = make(chan *chunk.Chunk, 1) - e.joinChkResourceCh[i] <- e.newFirstChunk() + e.joinChkResourceCh[i] <- newFirstChunk(e) } // e.joinResultCh is for transmitting the join result chunks to the main @@ -412,7 +426,7 @@ func (e *HashJoinExec) joinMatchedOuterRow2Chunk(workerID uint, outerRow chunk.R e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) return true, joinResult } - e.hashTableValBufs[workerID] = e.globalHashTable.Get(joinKey, e.hashTableValBufs[workerID][:0]) + e.hashTableValBufs[workerID] = e.hashTable.Get(joinKey, e.hashTableValBufs[workerID][:0]) innerPtrs := e.hashTableValBufs[workerID] if len(innerPtrs) == 0 { e.joiners[workerID].onMissMatch(false, outerRow, joinResult.chk) @@ -494,7 +508,7 @@ func (e *HashJoinExec) join2Chunk(workerID uint, outerChk *chunk.Chunk, joinResu // hash join constructs the result following these steps: // step 1. fetch data from inner child and build a hash table; // step 2. fetch data from outer child in a background goroutine and probe the hash table in multiple join workers. -func (e *HashJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) (err error) { +func (e *HashJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { if e.runtimeStats != nil { start := time.Now() defer func() { e.runtimeStats.Record(time.Since(start), req.NumRows()) }() @@ -512,6 +526,7 @@ func (e *HashJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) (err er if e.joinResultCh == nil { return nil } + result, ok := <-e.joinResultCh if !ok { return nil @@ -533,21 +548,31 @@ func (e *HashJoinExec) handleFetchInnerAndBuildHashTablePanic(r interface{}) { } func (e *HashJoinExec) fetchInnerAndBuildHashTable(ctx context.Context) { - if err := e.fetchInnerRows(ctx); err != nil { - e.innerFinished <- err - return - } + // innerResultCh transfers inner chunk from inner fetch to build hash table. + innerResultCh := make(chan *chunk.Chunk, 1) + doneCh := make(chan struct{}) + go util.WithRecovery(func() { e.fetchInnerRows(ctx, innerResultCh, doneCh) }, nil) - if err := e.buildGlobalHashTable(); err != nil { - e.innerFinished <- err + // TODO: Parallel build hash table. Currently not support because `mvmap` is not thread-safe. + err := e.buildHashTableForList(innerResultCh) + if err != nil { + e.innerFinished <- errors.Trace(err) + close(doneCh) + } + // wait fetchInnerRows be finished. + for range innerResultCh { } } -// buildGlobalHashTable builds a global hash table for the inner relation. +// buildHashTableForList builds hash table from `list`. // key of hash table: hash value of key columns // value of hash table: RowPtr of the corresponded row -func (e *HashJoinExec) buildGlobalHashTable() error { - e.globalHashTable = mvmap.NewMVMap() +func (e *HashJoinExec) buildHashTableForList(innerResultCh chan *chunk.Chunk) error { + e.hashTable = mvmap.NewMVMap() + e.innerKeyColIdx = make([]int, len(e.innerKeys)) + for i := range e.innerKeys { + e.innerKeyColIdx[i] = e.innerKeys[i].Index + } var ( hasNull bool err error @@ -555,23 +580,25 @@ func (e *HashJoinExec) buildGlobalHashTable() error { valBuf = make([]byte, 8) ) - for chkIdx := 0; chkIdx < e.innerResult.NumChunks(); chkIdx++ { + chkIdx := uint32(0) + for chk := range innerResultCh { if e.finished.Load().(bool) { return nil } - chk := e.innerResult.GetChunk(chkIdx) - for j, numRows := 0, chk.NumRows(); j < numRows; j++ { + numRows := chk.NumRows() + for j := 0; j < numRows; j++ { hasNull, keyBuf, err = e.getJoinKeyFromChkRow(false, chk.GetRow(j), keyBuf) if err != nil { - return err + return errors.Trace(err) } if hasNull { continue } - rowPtr := chunk.RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(j)} + rowPtr := chunk.RowPtr{ChkIdx: chkIdx, RowIdx: uint32(j)} *(*chunk.RowPtr)(unsafe.Pointer(&valBuf[0])) = rowPtr - e.globalHashTable.Put(keyBuf, valBuf) + e.hashTable.Put(keyBuf, valBuf) } + chkIdx++ } return nil } @@ -610,7 +637,6 @@ type NestedLoopApplyExec struct { func (e *NestedLoopApplyExec) Close() error { e.innerRows = nil - e.memTracker.Detach() e.memTracker = nil return e.outerExec.Close() } @@ -625,9 +651,9 @@ func (e *NestedLoopApplyExec) Open(ctx context.Context) error { } e.cursor = 0 e.innerRows = e.innerRows[:0] - e.outerChunk = e.outerExec.newFirstChunk() - e.innerChunk = e.innerExec.newFirstChunk() - e.innerList = chunk.NewList(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) + e.outerChunk = newFirstChunk(e.outerExec) + e.innerChunk = newFirstChunk(e.innerExec) + e.innerList = chunk.NewList(retTypes(e.innerExec), e.initCap, e.maxChunkSize) e.memTracker = memory.NewTracker(e.id, e.ctx.GetSessionVars().MemQuotaNestedLoopApply) e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) @@ -642,7 +668,7 @@ func (e *NestedLoopApplyExec) fetchSelectedOuterRow(ctx context.Context, chk *ch outerIter := chunk.NewIterator4Chunk(e.outerChunk) for { if e.outerChunkCursor >= e.outerChunk.NumRows() { - err := e.outerExec.Next(ctx, chunk.NewRecordBatch(e.outerChunk)) + err := Next(ctx, e.outerExec, e.outerChunk) if err != nil { return nil, err } @@ -679,7 +705,7 @@ func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error { e.innerList.Reset() innerIter := chunk.NewIterator4Chunk(e.innerChunk) for { - err := e.innerExec.Next(ctx, chunk.NewRecordBatch(e.innerChunk)) + err := Next(ctx, e.innerExec, e.innerChunk) if err != nil { return err } @@ -700,7 +726,7 @@ func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error { } // Next implements the Executor interface. -func (e *NestedLoopApplyExec) Next(ctx context.Context, req *chunk.RecordBatch) (err error) { +func (e *NestedLoopApplyExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { if e.runtimeStats != nil { start := time.Now() defer func() { e.runtimeStats.Record(time.Since(start), req.NumRows()) }() @@ -709,9 +735,9 @@ func (e *NestedLoopApplyExec) Next(ctx context.Context, req *chunk.RecordBatch) for { if e.innerIter == nil || e.innerIter.Current() == e.innerIter.End() { if e.outerRow != nil && !e.hasMatch { - e.joiner.onMissMatch(e.hasNull, *e.outerRow, req.Chunk) + e.joiner.onMissMatch(e.hasNull, *e.outerRow, req) } - e.outerRow, err = e.fetchSelectedOuterRow(ctx, req.Chunk) + e.outerRow, err = e.fetchSelectedOuterRow(ctx, req) if e.outerRow == nil || err != nil { return err } @@ -729,7 +755,7 @@ func (e *NestedLoopApplyExec) Next(ctx context.Context, req *chunk.RecordBatch) e.innerIter.Begin() } - matched, isNull, err := e.joiner.tryToMatch(*e.outerRow, e.innerIter, req.Chunk) + matched, isNull, err := e.joiner.tryToMatch(*e.outerRow, e.innerIter, req) e.hasMatch = e.hasMatch || matched e.hasNull = e.hasNull || isNull diff --git a/executor/join_test.go b/executor/join_test.go index 255f1befb90da..6c2211538be29 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -760,7 +760,7 @@ func (s *testSuite2) TestJoinLeak(c *C) { tk.MustExec("commit") result, err := tk.Exec("select * from t t1 left join (select 1) t2 on 1") c.Assert(err, IsNil) - req := result.NewRecordBatch() + req := result.NewChunk() err = result.Next(context.Background(), req) c.Assert(err, IsNil) time.Sleep(time.Millisecond) @@ -832,8 +832,8 @@ func (s *testSuite2) TestIndexLookupJoin(c *C) { tk.MustExec("CREATE INDEX idx_s_a ON s(`a`)") tk.MustExec("INSERT INTO s VALUES (-277544960, 'fpnndsjo') , (2, 'kfpnndsjof') , (2, 'vtdiockfpn'), (-277544960, 'fpnndsjo') , (2, 'kfpnndsjof') , (6, 'ckfp')") tk.MustQuery("select /*+ TIDB_INLJ(t, s) */ t.a from t join s on t.a = s.a").Check(testkit.Rows("-277544960", "-277544960")) - tk.MustQuery("select /*+ TIDB_INLJ(t, s) */ t.a from t left join s on t.a = s.a").Check(testkit.Rows("148307968", "-1327693824", "-277544960", "-277544960")) - tk.MustQuery("select /*+ TIDB_INLJ(t, s) */ t.a from t right join s on t.a = s.a").Check(testkit.Rows("-277544960", "", "", "-277544960", "", "")) + tk.MustQuery("select /*+ TIDB_INLJ(t, s) */ t.a from t left join s on t.a = s.a").Check(testkit.Rows("-1327693824", "-277544960", "-277544960", "148307968")) + tk.MustQuery("select /*+ TIDB_INLJ(t, s) */ t.a from t right join s on t.a = s.a").Check(testkit.Rows("-277544960", "-277544960", "", "", "", "")) tk.MustExec("DROP TABLE IF EXISTS t;") tk.MustExec("CREATE TABLE t(a BIGINT PRIMARY KEY, b BIGINT);") tk.MustExec("INSERT INTO t VALUES(1, 2);") @@ -1381,3 +1381,22 @@ func (s *testSuite2) TestInjectProjOnTopN(c *C) { "2", )) } + +func (s *testSuite2) TestIssue11544(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table 11544t(a int)") + tk.MustExec("create table 11544tt(a int, b varchar(10), index idx(a, b(3)))") + tk.MustExec("insert into 11544t values(1)") + tk.MustExec("insert into 11544tt values(1, 'aaaaaaa'), (1, 'aaaabbb'), (1, 'aaaacccc')") + tk.MustQuery("select /*+ TIDB_INLJ(tt) */ * from 11544t t, 11544tt tt where t.a=tt.a and (tt.b = 'aaaaaaa' or tt.b = 'aaaabbb')").Check(testkit.Rows("1 1 aaaaaaa", "1 1 aaaabbb")) + tk.MustQuery("select /*+ TIDB_INLJ(tt) */ * from 11544t t, 11544tt tt where t.a=tt.a and tt.b in ('aaaaaaa', 'aaaabbb', 'aaaacccc')").Check(testkit.Rows("1 1 aaaaaaa", "1 1 aaaabbb", "1 1 aaaacccc")) +} + +func (s *testSuite2) TestIssue11390(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table 11390t (k1 int unsigned, k2 int unsigned, key(k1, k2))") + tk.MustExec("insert into 11390t values(1, 1)") + tk.MustQuery("select /*+ TIDB_INLJ(t1, t2) */ * from 11390t t1, 11390t t2 where t1.k2 > 0 and t1.k2 = t2.k2 and t2.k1=1;").Check(testkit.Rows("1 1 1 1")) +} diff --git a/executor/load_data.go b/executor/load_data.go index c5c7afe61df3a..8bc586a7825f2 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -53,7 +53,7 @@ func NewLoadDataInfo(ctx sessionctx.Context, row []types.Datum, tbl table.Table, } // Next implements the Executor Next interface. -func (e *LoadDataExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *LoadDataExec) Next(ctx context.Context, req *chunk.Chunk) error { req.GrowAndReset(e.maxChunkSize) // TODO: support load data without local field. if !e.IsLocal { @@ -213,7 +213,7 @@ func (e *LoadDataInfo) getLine(prevData, curData []byte) ([]byte, []byte, bool) // If it has the rest of data isn't completed the processing, then it returns without completed data. // If the number of inserted rows reaches the batchRows, then the second return value is true. // If prevData isn't nil and curData is nil, there are no other data to deal with and the isEOF is true. -func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error) { +func (e *LoadDataInfo) InsertData(ctx context.Context, prevData, curData []byte) ([]byte, bool, error) { if len(prevData) == 0 && len(curData) == 0 { return nil, false, nil } @@ -252,8 +252,10 @@ func (e *LoadDataInfo) InsertData(prevData, curData []byte) ([]byte, bool, error if err != nil { return nil, false, err } - rows = append(rows, e.colsToRow(cols)) + // rowCount will be used in fillRow(), last insert ID will be assigned according to the rowCount = 1. + // So should add first here. e.rowCount++ + rows = append(rows, e.colsToRow(ctx, cols)) if e.maxRowsInBatch != 0 && e.rowCount%e.maxRowsInBatch == 0 { reachLimit = true logutil.Logger(context.Background()).Info("batch limit hit when inserting rows", zap.Int("maxBatchRows", e.maxChunkSize), @@ -281,9 +283,15 @@ func (e *LoadDataInfo) SetMessage() { e.ctx.GetSessionVars().StmtCtx.SetMessage(msg) } -func (e *LoadDataInfo) colsToRow(cols []field) []types.Datum { +func (e *LoadDataInfo) colsToRow(ctx context.Context, cols []field) []types.Datum { + totalCols := e.Table.Cols() for i := 0; i < len(e.row); i++ { if i >= len(cols) { + // If some columns is missing and their type is time and has not null flag, they should be set as current time. + if types.IsTypeTime(totalCols[i].Tp) && mysql.HasNotNullFlag(totalCols[i].Flag) { + e.row[i].SetMysqlTime(types.CurrentTime(totalCols[i].Tp)) + continue + } e.row[i].SetNull() continue } @@ -295,7 +303,7 @@ func (e *LoadDataInfo) colsToRow(cols []field) []types.Datum { e.row[i].SetString(string(cols[i].str)) } } - row, err := e.getRow(e.row) + row, err := e.getRow(ctx, e.row) if err != nil { e.handleWarning(err) return nil @@ -324,15 +332,15 @@ type fieldWriter struct { pos int enclosedChar byte fieldTermChar byte - term *string + term string isEnclosed bool isLineStart bool isFieldStart bool - ReadBuf *[]byte + ReadBuf []byte OutputBuf []byte } -func (w *fieldWriter) Init(enclosedChar byte, fieldTermChar byte, readBuf *[]byte, term *string) { +func (w *fieldWriter) Init(enclosedChar byte, fieldTermChar byte, readBuf []byte, term string) { w.isEnclosed = false w.isLineStart = true w.isFieldStart = true @@ -347,8 +355,8 @@ func (w *fieldWriter) putback() { } func (w *fieldWriter) getChar() (bool, byte) { - if w.pos < len(*w.ReadBuf) { - ret := (*w.ReadBuf)[w.pos] + if w.pos < len(w.ReadBuf) { + ret := w.ReadBuf[w.pos] w.pos++ return true, ret } @@ -357,9 +365,9 @@ func (w *fieldWriter) getChar() (bool, byte) { func (w *fieldWriter) isTerminator() bool { chkpt, isterm := w.pos, true - for i := 1; i < len(*w.term); i++ { + for i := 1; i < len(w.term); i++ { flag, ch := w.getChar() - if !flag || ch != (*w.term)[i] { + if !flag || ch != w.term[i] { isterm = false break } @@ -473,7 +481,7 @@ func (e *LoadDataInfo) getFieldsFromLine(line []byte) ([]field, error) { return fields, nil } - reader.Init(e.FieldsInfo.Enclosed, e.FieldsInfo.Terminated[0], &line, &e.FieldsInfo.Terminated) + reader.Init(e.FieldsInfo.Enclosed, e.FieldsInfo.Terminated[0], line, e.FieldsInfo.Terminated) for { eol, f := reader.GetField() f = f.escape() diff --git a/executor/load_stats.go b/executor/load_stats.go index e55ada1e84b60..58a764748341e 100644 --- a/executor/load_stats.go +++ b/executor/load_stats.go @@ -50,7 +50,7 @@ func (k loadStatsVarKeyType) String() string { const LoadStatsVarKey loadStatsVarKeyType = 0 // Next implements the Executor Next interface. -func (e *LoadStatsExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *LoadStatsExec) Next(ctx context.Context, req *chunk.Chunk) error { req.GrowAndReset(e.maxChunkSize) if len(e.info.Path) == 0 { return errors.New("Load Stats: file path is empty") diff --git a/executor/memory_test.go b/executor/memory_test.go new file mode 100644 index 0000000000000..56dd62fb3ead2 --- /dev/null +++ b/executor/memory_test.go @@ -0,0 +1,111 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + "context" + "fmt" + "runtime" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/store/mockstore" +) + +var _ = SerialSuites(&testMemoryLeak{}) + +type testMemoryLeak struct { + store kv.Storage + domain *domain.Domain +} + +func (s *testMemoryLeak) SetUpSuite(c *C) { + var err error + s.store, err = mockstore.NewMockTikvStore() + c.Assert(err, IsNil) + s.domain, err = session.BootstrapSession(s.store) + c.Assert(err, IsNil) +} + +func (s *testMemoryLeak) TestPBMemoryLeak(c *C) { + c.Skip("too slow") + + se, err := session.CreateSession4Test(s.store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "create database test_mem") + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "use test_mem") + c.Assert(err, IsNil) + + // prepare data + totalSize := uint64(256 << 20) // 256MB + blockSize := uint64(8 << 10) // 8KB + delta := totalSize / 5 + numRows := totalSize / blockSize + _, err = se.Execute(context.Background(), fmt.Sprintf("create table t (c varchar(%v))", blockSize)) + c.Assert(err, IsNil) + defer func() { + _, err = se.Execute(context.Background(), "drop table t") + c.Assert(err, IsNil) + }() + sql := fmt.Sprintf("insert into t values (space(%v))", blockSize) + for i := uint64(0); i < numRows; i++ { + _, err = se.Execute(context.Background(), sql) + c.Assert(err, IsNil) + } + + // read data + runtime.GC() + allocatedBegin, inUseBegin := s.readMem() + records, err := se.Execute(context.Background(), "select * from t") + c.Assert(err, IsNil) + record := records[0] + rowCnt := 0 + chk := record.NewChunk() + for { + c.Assert(record.Next(context.Background(), chk), IsNil) + rowCnt += chk.NumRows() + if chk.NumRows() == 0 { + break + } + } + c.Assert(rowCnt, Equals, int(numRows)) + + // check memory before close + runtime.GC() + allocatedAfter, inUseAfter := s.readMem() + c.Assert(allocatedAfter-allocatedBegin, GreaterEqual, totalSize) + c.Assert(s.memDiff(inUseAfter, inUseBegin), Less, delta) + + se.Close() + runtime.GC() + allocatedFinal, inUseFinal := s.readMem() + c.Assert(allocatedFinal-allocatedAfter, Less, delta) + c.Assert(s.memDiff(inUseFinal, inUseAfter), Less, delta) +} + +func (s *testMemoryLeak) readMem() (allocated, heapInUse uint64) { + var stat runtime.MemStats + runtime.ReadMemStats(&stat) + return stat.TotalAlloc, stat.HeapInuse +} + +func (s *testMemoryLeak) memDiff(m1, m2 uint64) uint64 { + if m1 > m2 { + return m1 - m2 + } + return m2 - m1 +} diff --git a/executor/merge_join.go b/executor/merge_join.go index 2eca140bed902..4a0521740bc7c 100644 --- a/executor/merge_join.go +++ b/executor/merge_join.go @@ -142,7 +142,7 @@ func (t *mergeJoinInnerTable) nextRow() (chunk.Row, error) { if t.curRow == t.curIter.End() { t.reallocReaderResult() oldMemUsage := t.curResult.MemoryUsage() - err := t.reader.Next(t.ctx, chunk.NewRecordBatch(t.curResult)) + err := Next(t.ctx, t.reader, t.curResult) // error happens or no more data. if err != nil || t.curResult.NumRows() == 0 { t.curRow = t.curIter.End() @@ -185,7 +185,7 @@ func (t *mergeJoinInnerTable) reallocReaderResult() { // Create a new Chunk and append it to "resourceQueue" if there is no more // available chunk in "resourceQueue". if len(t.resourceQueue) == 0 { - newChunk := t.reader.newFirstChunk() + newChunk := newFirstChunk(t.reader) t.memTracker.Consume(newChunk.MemoryUsage()) t.resourceQueue = append(t.resourceQueue, newChunk) } @@ -201,7 +201,6 @@ func (t *mergeJoinInnerTable) reallocReaderResult() { // Close implements the Executor Close interface. func (e *MergeJoinExec) Close() error { - e.memTracker.Detach() e.childrenResults = nil e.memTracker = nil @@ -222,7 +221,7 @@ func (e *MergeJoinExec) Open(ctx context.Context) error { e.childrenResults = make([]*chunk.Chunk, 0, len(e.children)) for _, child := range e.children { - e.childrenResults = append(e.childrenResults, child.newFirstChunk()) + e.childrenResults = append(e.childrenResults, newFirstChunk(child)) } e.innerTable.memTracker = memory.NewTracker(innerTableLabel, -1) @@ -267,7 +266,7 @@ func (e *MergeJoinExec) prepare(ctx context.Context, requiredRows int) error { } // Next implements the Executor Next interface. -func (e *MergeJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *MergeJoinExec) Next(ctx context.Context, req *chunk.Chunk) error { if e.runtimeStats != nil { start := time.Now() defer func() { e.runtimeStats.Record(time.Since(start), req.NumRows()) }() @@ -280,7 +279,7 @@ func (e *MergeJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) error } for !req.IsFull() { - hasMore, err := e.joinToChunk(ctx, req.Chunk) + hasMore, err := e.joinToChunk(ctx, req) if err != nil || !hasMore { return err } @@ -389,7 +388,7 @@ func (e *MergeJoinExec) fetchNextOuterRows(ctx context.Context, requiredRows int e.outerTable.chk.SetRequiredRows(requiredRows, e.maxChunkSize) } - err = e.outerTable.reader.Next(ctx, chunk.NewRecordBatch(e.outerTable.chk)) + err = Next(ctx, e.outerTable.reader, e.outerTable.chk) if err != nil { return err } diff --git a/executor/metrics_test.go b/executor/metrics_test.go index afb165038d8b6..662858fd9928c 100644 --- a/executor/metrics_test.go +++ b/executor/metrics_test.go @@ -14,6 +14,7 @@ package executor_test import ( + "context" "fmt" . "github.com/pingcap/check" @@ -63,7 +64,7 @@ func (s *testSuite4) TestStmtLabel(c *C) { is := executor.GetInfoSchema(tk.Se) err = plannercore.Preprocess(tk.Se.(sessionctx.Context), stmtNode, is) c.Assert(err, IsNil) - _, err = planner.Optimize(tk.Se, stmtNode, is) + _, err = planner.Optimize(context.TODO(), tk.Se, stmtNode, is) c.Assert(err, IsNil) c.Assert(executor.GetStmtLabel(stmtNode), Equals, tt.label) } diff --git a/executor/opt_rule_blacklist.go b/executor/opt_rule_blacklist.go new file mode 100644 index 0000000000000..a6a4d37ec28ef --- /dev/null +++ b/executor/opt_rule_blacklist.go @@ -0,0 +1,50 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + + plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/set" + "github.com/pingcap/tidb/util/sqlexec" +) + +// ReloadOptRuleBlacklistExec indicates ReloadOptRuleBlacklist executor. +type ReloadOptRuleBlacklistExec struct { + baseExecutor +} + +// Next implements the Executor Next interface. +func (e *ReloadOptRuleBlacklistExec) Next(ctx context.Context, _ *chunk.Chunk) error { + return LoadOptRuleBlacklist(e.ctx) +} + +// LoadOptRuleBlacklist loads the latest data from table mysql.opt_rule_blacklist. +func LoadOptRuleBlacklist(ctx sessionctx.Context) (err error) { + sql := "select HIGH_PRIORITY name from mysql.opt_rule_blacklist" + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(ctx, sql) + if err != nil { + return err + } + newDisabledLogicalRules := set.NewStringSet() + for _, row := range rows { + name := row.GetString(0) + newDisabledLogicalRules.Insert(name) + } + plannercore.DefaultDisabledLogicalRulesList.Store(newDisabledLogicalRules) + return nil +} diff --git a/executor/pkg_test.go b/executor/pkg_test.go index 74a478aadce48..f06896ed7d9f5 100644 --- a/executor/pkg_test.go +++ b/executor/pkg_test.go @@ -3,22 +3,16 @@ package executor import ( "context" "fmt" - "testing" - - "github.com/klauspost/cpuid" . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/stringutil" - "github.com/spaolacci/murmur3" ) var _ = Suite(&pkgTestSuite{}) @@ -33,9 +27,9 @@ type MockExec struct { curRowIdx int } -func (m *MockExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (m *MockExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() - colTypes := m.retTypes() + colTypes := retTypes(m) for ; m.curRowIdx < len(m.Rows) && req.NumRows() < req.Capacity(); m.curRowIdx++ { curRow := m.Rows[m.curRowIdx] for i := 0; i < curRow.Len(); i++ { @@ -88,7 +82,7 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { innerFilter := outerFilter.Clone() otherFilter := expression.NewFunctionInternal(sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), col0, col1) joiner := newJoiner(sctx, plannercore.InnerJoin, false, - make([]types.Datum, innerExec.Schema().Len()), []expression.Expression{otherFilter}, outerExec.retTypes(), innerExec.retTypes()) + make([]types.Datum, innerExec.Schema().Len()), []expression.Expression{otherFilter}, retTypes(outerExec), retTypes(innerExec)) joinSchema := expression.NewSchema(col0, col1) join := &NestedLoopApplyExec{ baseExecutor: newBaseExecutor(sctx, joinSchema, nil), @@ -98,13 +92,13 @@ func (s *pkgTestSuite) TestNestedLoopApply(c *C) { innerFilter: []expression.Expression{innerFilter}, joiner: joiner, } - join.innerList = chunk.NewList(innerExec.retTypes(), innerExec.initCap, innerExec.maxChunkSize) - join.innerChunk = innerExec.newFirstChunk() - join.outerChunk = outerExec.newFirstChunk() - joinChk := join.newFirstChunk() + join.innerList = chunk.NewList(retTypes(innerExec), innerExec.initCap, innerExec.maxChunkSize) + join.innerChunk = newFirstChunk(innerExec) + join.outerChunk = newFirstChunk(outerExec) + joinChk := newFirstChunk(join) it := chunk.NewIterator4Chunk(joinChk) for rowIdx := 1; ; { - err := join.Next(ctx, chunk.NewRecordBatch(joinChk)) + err := join.Next(ctx, joinChk) c.Check(err, IsNil) if joinChk.NumRows() == 0 { break @@ -130,7 +124,7 @@ func prepareOneColChildExec(sctx sessionctx.Context, rowCount int) Executor { return exec } -func buildExec4RadixHashJoin(sctx sessionctx.Context, rowCount int) *RadixHashJoinExec { +func prepare4RadixPartition(sctx sessionctx.Context, rowCount int) *HashJoinExec { childExec0 := prepareOneColChildExec(sctx, rowCount) childExec1 := prepareOneColChildExec(sctx, rowCount) @@ -149,64 +143,7 @@ func buildExec4RadixHashJoin(sctx sessionctx.Context, rowCount int) *RadixHashJo innerExec: childExec0, outerExec: childExec1, } - return &RadixHashJoinExec{HashJoinExec: hashJoinExec} -} - -func (s *pkgTestSuite) TestRadixPartition(c *C) { - sctx := mock.NewContext() - hashJoinExec := buildExec4RadixHashJoin(sctx, 200) - sv := sctx.GetSessionVars() - originL2CacheSize, originEnableRadixJoin, originMaxChunkSize := sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize - sv.L2CacheSize = 100 - sv.EnableRadixJoin = true - // FIXME: use initChunkSize when join support initChunkSize. - sv.MaxChunkSize = 100 - defer func() { - sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize = originL2CacheSize, originEnableRadixJoin, originMaxChunkSize - }() - sv.StmtCtx.MemTracker = memory.NewTracker(stringutil.StringerStr("RootMemTracker"), variable.DefTiDBMemQuotaHashJoin) - - ctx := context.Background() - err := hashJoinExec.Open(ctx) - c.Assert(err, IsNil) - - hashJoinExec.fetchInnerRows(ctx) - c.Assert(hashJoinExec.innerResult.GetMemTracker().BytesConsumed(), Equals, int64(14400)) - - hashJoinExec.evalRadixBit() - // ceil(log_2(14400/(100*3/4))) = 8 - c.Assert(len(hashJoinExec.innerParts), Equals, 256) - radixBits := hashJoinExec.radixBits - c.Assert(radixBits, Equals, uint32(0x00ff)) - - err = hashJoinExec.partitionInnerRows() - c.Assert(err, IsNil) - totalRowCnt := 0 - for i, part := range hashJoinExec.innerParts { - if part == nil { - continue - } - totalRowCnt += part.NumRows() - iter := chunk.NewIterator4Chunk(part) - keyBuf := make([]byte, 0, 64) - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - hasNull, keyBuf, err := hashJoinExec.getJoinKeyFromChkRow(false, row, keyBuf) - c.Assert(err, IsNil) - c.Assert(hasNull, IsFalse) - joinHash := murmur3.Sum32(keyBuf) - c.Assert(joinHash&radixBits, Equals, uint32(i)&radixBits) - } - } - c.Assert(totalRowCnt, Equals, 200) - - for _, ch := range hashJoinExec.outerResultChs { - close(ch) - } - for _, ch := range hashJoinExec.joinChkResourceCh { - close(ch) - } - err = hashJoinExec.Close() - c.Assert(err, IsNil) + return hashJoinExec } func (s *pkgTestSuite) TestMoveInfoSchemaToFront(c *C) { @@ -238,56 +175,3 @@ func (s *pkgTestSuite) TestMoveInfoSchemaToFront(c *C) { } } } - -func BenchmarkPartitionInnerRows(b *testing.B) { - sctx := mock.NewContext() - hashJoinExec := buildExec4RadixHashJoin(sctx, 1500000) - sv := sctx.GetSessionVars() - originL2CacheSize, originEnableRadixJoin, originMaxChunkSize := sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize - sv.L2CacheSize = cpuid.CPU.Cache.L2 - sv.EnableRadixJoin = true - sv.MaxChunkSize = 1024 - defer func() { - sv.L2CacheSize, sv.EnableRadixJoin, sv.MaxChunkSize = originL2CacheSize, originEnableRadixJoin, originMaxChunkSize - }() - sv.StmtCtx.MemTracker = memory.NewTracker(stringutil.StringerStr("RootMemTracker"), variable.DefTiDBMemQuotaHashJoin) - - ctx := context.Background() - hashJoinExec.Open(ctx) - hashJoinExec.fetchInnerRows(ctx) - hashJoinExec.evalRadixBit() - b.ResetTimer() - hashJoinExec.concurrency = 16 - hashJoinExec.maxChunkSize = 1024 - hashJoinExec.initCap = 1024 - for i := 0; i < b.N; i++ { - hashJoinExec.partitionInnerRows() - hashJoinExec.innerRowPrts = hashJoinExec.innerRowPrts[:0] - } -} - -func (s *pkgTestSuite) TestParallelBuildHashTable4RadixJoin(c *C) { - sctx := mock.NewContext() - hashJoinExec := buildExec4RadixHashJoin(sctx, 200) - - sv := sctx.GetSessionVars() - sv.L2CacheSize = 100 - sv.EnableRadixJoin = true - sv.MaxChunkSize = 100 - sv.StmtCtx.MemTracker = memory.NewTracker(stringutil.StringerStr("RootMemTracker"), variable.DefTiDBMemQuotaHashJoin) - - ctx := context.Background() - err := hashJoinExec.Open(ctx) - c.Assert(err, IsNil) - - hashJoinExec.partitionInnerAndBuildHashTables(ctx) - innerParts := hashJoinExec.innerParts - c.Assert(len(hashJoinExec.hashTables), Equals, len(innerParts)) - for i := 0; i < len(innerParts); i++ { - if innerParts[i] == nil { - c.Assert(hashJoinExec.hashTables[i], IsNil) - } else { - c.Assert(hashJoinExec.hashTables[i], NotNil) - } - } -} diff --git a/executor/point_get.go b/executor/point_get.go index bf3ff62de8231..d480f1ac1f913 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -19,10 +19,8 @@ import ( "github.com/pingcap/failpoint" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" plannercore "github.com/pingcap/tidb/planner/core" - "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/tablecodec" @@ -38,23 +36,25 @@ func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) Executor { b.err = err return nil } - return &PointGetExecutor{ - ctx: b.ctx, - schema: p.Schema(), - tblInfo: p.TblInfo, - idxInfo: p.IndexInfo, - idxVals: p.IndexValues, - handle: p.Handle, - startTS: startTS, - done: p.UnsignedHandle && p.Handle < 0, - } + e := &PointGetExecutor{ + baseExecutor: newBaseExecutor(b.ctx, p.Schema(), p.ExplainID()), + tblInfo: p.TblInfo, + idxInfo: p.IndexInfo, + idxVals: p.IndexValues, + handle: p.Handle, + startTS: startTS, + lock: p.Lock, + } + b.isSelectForUpdate = p.IsForUpdate + e.base().initCap = 1 + e.base().maxChunkSize = 1 + return e } // PointGetExecutor executes point select query. type PointGetExecutor struct { - ctx sessionctx.Context - schema *expression.Schema - tps []*types.FieldType + baseExecutor + tblInfo *model.TableInfo handle int64 idxInfo *model.IndexInfo @@ -62,6 +62,7 @@ type PointGetExecutor struct { startTS uint64 snapshot kv.Snapshot done bool + lock bool } // Open implements the Executor interface. @@ -75,14 +76,18 @@ func (e *PointGetExecutor) Close() error { } // Next implements the Executor interface. -func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.done { return nil } e.done = true + snapshotTS := e.startTS + if e.lock { + snapshotTS = e.ctx.GetSessionVars().TxnCtx.GetForUpdateTS() + } var err error - e.snapshot, err = e.ctx.GetStore().GetSnapshot(kv.Version{Ver: e.startTS}) + e.snapshot, err = e.ctx.GetStore().GetSnapshot(kv.Version{Ver: snapshotTS}) if err != nil { return err } @@ -97,7 +102,7 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.RecordBatch) err return err1 } if len(handleVal) == 0 { - return nil + return e.lockKeyIfNeeded(ctx, idxKey) } e.handle, err1 = tables.DecodeHandle(handleVal) if err1 != nil { @@ -124,6 +129,10 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.RecordBatch) err if err != nil && !kv.ErrNotExist.Equal(err) { return err } + err = e.lockKeyIfNeeded(ctx, key) + if err != nil { + return err + } if len(val) == 0 { if e.idxInfo != nil { return kv.ErrNotExist.GenWithStack("inconsistent extra index %s, handle %d not found in table", @@ -131,7 +140,14 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.RecordBatch) err } return nil } - return e.decodeRowValToChunk(val, req.Chunk) + return e.decodeRowValToChunk(val, req) +} + +func (e *PointGetExecutor) lockKeyIfNeeded(ctx context.Context, key []byte) error { + if e.lock { + return doLockKeys(ctx, e.ctx, key) + } + return nil } func (e *PointGetExecutor) encodeIndexKey() (_ []byte, err error) { @@ -176,60 +192,43 @@ func (e *PointGetExecutor) get(key kv.Key) (val []byte, err error) { } func (e *PointGetExecutor) decodeRowValToChunk(rowVal []byte, chk *chunk.Chunk) error { - // One column could be filled for multi-times in the schema. e.g. select b, b, c, c from t where a = 1. - // We need to set the positions in the schema for the same column. - colID2DecodedPos := make(map[int64]int, e.schema.Len()) - decodedPos2SchemaPos := make([][]int, 0, e.schema.Len()) - for schemaPos, col := range e.schema.Columns { - if decodedPos, ok := colID2DecodedPos[col.ID]; !ok { - colID2DecodedPos[col.ID] = len(colID2DecodedPos) - decodedPos2SchemaPos = append(decodedPos2SchemaPos, []int{schemaPos}) - } else { - decodedPos2SchemaPos[decodedPos] = append(decodedPos2SchemaPos[decodedPos], schemaPos) + colID2CutPos := make(map[int64]int, e.schema.Len()) + for _, col := range e.schema.Columns { + if _, ok := colID2CutPos[col.ID]; !ok { + colID2CutPos[col.ID] = len(colID2CutPos) } } - decodedVals, err := tablecodec.CutRowNew(rowVal, colID2DecodedPos) + cutVals, err := tablecodec.CutRowNew(rowVal, colID2CutPos) if err != nil { return err } - if decodedVals == nil { - decodedVals = make([][]byte, len(colID2DecodedPos)) + if cutVals == nil { + cutVals = make([][]byte, len(colID2CutPos)) } decoder := codec.NewDecoder(chk, e.ctx.GetSessionVars().Location()) - for id, decodedPos := range colID2DecodedPos { - schemaPoses := decodedPos2SchemaPos[decodedPos] - firstPos := schemaPoses[0] - if e.tblInfo.PKIsHandle && mysql.HasPriKeyFlag(e.schema.Columns[firstPos].RetType.Flag) { - chk.AppendInt64(firstPos, e.handle) - // Fill other positions. - for i := 1; i < len(schemaPoses); i++ { - chk.MakeRef(firstPos, schemaPoses[i]) - } + for i, col := range e.schema.Columns { + if e.tblInfo.PKIsHandle && mysql.HasPriKeyFlag(col.RetType.Flag) { + chk.AppendInt64(i, e.handle) continue } - // ExtraHandleID is added when building plan, we can make sure that there's only one column's ID is this. - if id == model.ExtraHandleID { - chk.AppendInt64(firstPos, e.handle) + if col.ID == model.ExtraHandleID { + chk.AppendInt64(i, e.handle) continue } - if len(decodedVals[decodedPos]) == 0 { - // This branch only entered for updating and deleting. It won't have one column in multiple positions. - colInfo := getColInfoByID(e.tblInfo, id) + cutPos := colID2CutPos[col.ID] + if len(cutVals[cutPos]) == 0 { + colInfo := getColInfoByID(e.tblInfo, col.ID) d, err1 := table.GetColOriginDefaultValue(e.ctx, colInfo) if err1 != nil { return err1 } - chk.AppendDatum(firstPos, &d) + chk.AppendDatum(i, &d) continue } - _, err = decoder.DecodeOne(decodedVals[decodedPos], firstPos, e.schema.Columns[firstPos].RetType) + _, err = decoder.DecodeOne(cutVals[cutPos], i, col.RetType) if err != nil { return err } - // Fill other positions. - for i := 1; i < len(schemaPoses); i++ { - chk.MakeRef(firstPos, schemaPoses[i]) - } } return nil } @@ -242,22 +241,3 @@ func getColInfoByID(tbl *model.TableInfo, colID int64) *model.ColumnInfo { } return nil } - -// Schema implements the Executor interface. -func (e *PointGetExecutor) Schema() *expression.Schema { - return e.schema -} - -func (e *PointGetExecutor) retTypes() []*types.FieldType { - if e.tps == nil { - e.tps = make([]*types.FieldType, e.schema.Len()) - for i := range e.schema.Columns { - e.tps[i] = e.schema.Columns[i].RetType - } - } - return e.tps -} - -func (e *PointGetExecutor) newFirstChunk() *chunk.Chunk { - return chunk.New(e.retTypes(), 1, 1) -} diff --git a/executor/point_get_test.go b/executor/point_get_test.go index bab2bb1e33fde..d888c7417513f 100644 --- a/executor/point_get_test.go +++ b/executor/point_get_test.go @@ -101,6 +101,8 @@ func (s *testPointGetSuite) TestPointGet(c *C) { tk.MustQuery(`select a, a, b, a, b, c, b, c, c from t where a = 5;`).Check(testkit.Rows( `5 5 6 5 6 7 6 7 7`, )) + tk.MustQuery(`select b, b from t where a = 1`).Check(testkit.Rows( + " ")) } func (s *testPointGetSuite) TestPointGetCharPK(c *C) { @@ -160,7 +162,7 @@ func (s *testPointGetSuite) TestPointGetCharPK(c *C) { tk.MustPointGet(`select * from t where a = " ";`).Check(testkit.Rows(` `)) } -func (s *testPointGetSuite) TestIndexLookupCharPK(c *C) { +func (s *testPointGetSuite) TestPointGetAliasTableCharPK(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test;`) tk.MustExec(`drop table if exists t;`) @@ -169,46 +171,150 @@ func (s *testPointGetSuite) TestIndexLookupCharPK(c *C) { // Test truncate without sql mode `PAD_CHAR_TO_FULL_LENGTH`. tk.MustExec(`set @@sql_mode="";`) - tk.MustIndexLookup(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) - tk.MustIndexLookup(`select * from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) + tk.MustPointGet(`select * from t tmp where a = "aab";`).Check(testkit.Rows()) // Test truncate with sql mode `PAD_CHAR_TO_FULL_LENGTH`. tk.MustExec(`set @@sql_mode="PAD_CHAR_TO_FULL_LENGTH";`) + tk.MustPointGet(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) + tk.MustPointGet(`select * from t tmp where a = "aab";`).Check(testkit.Rows()) + + tk.MustExec(`truncate table t;`) + tk.MustExec(`insert into t values("a ", "b ");`) + + // Test trailing spaces without sql mode `PAD_CHAR_TO_FULL_LENGTH`. + tk.MustExec(`set @@sql_mode="";`) + tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + + // Test trailing spaces with sql mode `PAD_CHAR_TO_FULL_LENGTH`. + tk.MustExec(`set @@sql_mode="PAD_CHAR_TO_FULL_LENGTH";`) + tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + + // Test CHAR BINARY. + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a char(2) binary primary key, b char(2));`) + tk.MustExec(`insert into t values(" ", " ");`) + tk.MustExec(`insert into t values("a ", "b ");`) + + // Test trailing spaces without sql mode `PAD_CHAR_TO_FULL_LENGTH`. + tk.MustExec(`set @@sql_mode="";`) + tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows(` `)) + tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows(` `)) + tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows(` `)) + + // Test trailing spaces with sql mode `PAD_CHAR_TO_FULL_LENGTH`. + tk.MustExec(`set @@sql_mode="PAD_CHAR_TO_FULL_LENGTH";`) + tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows(` `)) + tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows(` `)) + tk.MustPointGet(`select * from t tmp where a = " ";`).Check(testkit.Rows(` `)) + + // Test both wildcard and column name exist in select field list + tk.MustExec(`set @@sql_mode="";`) + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a char(2) primary key, b char(2));`) + tk.MustExec(`insert into t values("aa", "bb");`) + tk.MustPointGet(`select *, a from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb aa`)) + + // Test using table alias in field list + tk.MustPointGet(`select tmp.* from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) + tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) + tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb aa bb`)) + tk.MustPointGet(`select tmp.* from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustPointGet(`select tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) + tk.MustPointGet(`select tmp.*, tmp.a, tmp.b from t tmp where a = "aab";`).Check(testkit.Rows()) + + // Test using table alias in where clause + tk.MustPointGet(`select * from t tmp where tmp.a = "aa";`).Check(testkit.Rows(`aa bb`)) + tk.MustPointGet(`select a, b from t tmp where tmp.a = "aa";`).Check(testkit.Rows(`aa bb`)) + tk.MustPointGet(`select *, a, b from t tmp where tmp.a = "aa";`).Check(testkit.Rows(`aa bb aa bb`)) + + // Unknown table name in where clause and field list + err := tk.ExecToErr(`select a from t where xxxxx.a = "aa"`) + c.Assert(err, ErrorMatches, ".*Unknown column 'xxxxx.a' in 'where clause'") + err = tk.ExecToErr(`select xxxxx.a from t where a = "aa"`) + c.Assert(err, ErrorMatches, ".*Unknown column 'xxxxx.a' in 'field list'") + + // When an alias is provided, it completely hides the actual name of the table. + err = tk.ExecToErr(`select a from t tmp where t.a = "aa"`) + c.Assert(err, ErrorMatches, ".*Unknown column 't.a' in 'where clause'") + err = tk.ExecToErr(`select t.a from t tmp where a = "aa"`) + c.Assert(err, ErrorMatches, ".*Unknown column 't.a' in 'field list'") + err = tk.ExecToErr(`select t.* from t tmp where a = "aa"`) + c.Assert(err, ErrorMatches, ".*Unknown table 't'") +} + +func (s *testPointGetSuite) TestIndexLookupChar(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test;`) + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a char(2), b char(2), index idx_1(a));`) + tk.MustExec(`insert into t values("aa", "bb");`) + + // Test truncate without sql mode `PAD_CHAR_TO_FULL_LENGTH`. + tk.MustExec(`set @@sql_mode="";`) + tk.MustIndexLookup(`select * from t where a = "aa";`).Check(testkit.Rows(`aa bb`)) + tk.MustIndexLookup(`select * from t where a = "aab";`).Check(testkit.Rows()) + + // Test query with table alias tk.MustIndexLookup(`select * from t tmp where a = "aa";`).Check(testkit.Rows(`aa bb`)) tk.MustIndexLookup(`select * from t tmp where a = "aab";`).Check(testkit.Rows()) + // Test truncate with sql mode `PAD_CHAR_TO_FULL_LENGTH`. + tk.MustExec(`set @@sql_mode="PAD_CHAR_TO_FULL_LENGTH";`) + tk.MustIndexLookup(`select * from t where a = "aa";`).Check(testkit.Rows(`aa bb`)) + tk.MustTableDual(`select * from t where a = "aab";`).Check(testkit.Rows()) + tk.MustExec(`truncate table t;`) tk.MustExec(`insert into t values("a ", "b ");`) // Test trailing spaces without sql mode `PAD_CHAR_TO_FULL_LENGTH`. tk.MustExec(`set @@sql_mode="";`) - tk.MustIndexLookup(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a";`).Check(testkit.Rows(`a b`)) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows()) // Test trailing spaces with sql mode `PAD_CHAR_TO_FULL_LENGTH`. tk.MustExec(`set @@sql_mode="PAD_CHAR_TO_FULL_LENGTH";`) - tk.MustIndexLookup(`select * from t tmp where a = "a";`).Check(testkit.Rows()) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustTableDual(`select * from t where a = "a";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustTableDual(`select * from t where a = "a ";`).Check(testkit.Rows()) // Test CHAR BINARY. tk.MustExec(`drop table if exists t;`) - tk.MustExec(`create table t(a char(2) binary primary key, b char(2));`) + tk.MustExec(`create table t(a char(2) binary, b char(2), index idx_1(a));`) tk.MustExec(`insert into t values(" ", " ");`) tk.MustExec(`insert into t values("a ", "b ");`) // Test trailing spaces without sql mode `PAD_CHAR_TO_FULL_LENGTH`. tk.MustExec(`set @@sql_mode="";`) - tk.MustIndexLookup(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) - tk.MustIndexLookup(`select * from t tmp where a = " ";`).Check(testkit.Rows(` `)) - tk.MustIndexLookup(`select * from t tmp where a = " ";`).Check(testkit.Rows(` `)) - tk.MustIndexLookup(`select * from t tmp where a = " ";`).Check(testkit.Rows(` `)) + tk.MustIndexLookup(`select * from t where a = "a";`).Check(testkit.Rows(`a b`)) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustIndexLookup(`select * from t where a = " ";`).Check(testkit.Rows(` `)) + tk.MustIndexLookup(`select * from t where a = " ";`).Check(testkit.Rows(` `)) + tk.MustIndexLookup(`select * from t where a = " ";`).Check(testkit.Rows(` `)) // Test trailing spaces with sql mode `PAD_CHAR_TO_FULL_LENGTH`. tk.MustExec(`set @@sql_mode="PAD_CHAR_TO_FULL_LENGTH";`) + tk.MustIndexLookup(`select * from t where a = "a";`).Check(testkit.Rows(`a b`)) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows(`a b`)) + tk.MustIndexLookup(`select * from t where a = " ";`).Check(testkit.Rows(` `)) + tk.MustIndexLookup(`select * from t where a = " ";`).Check(testkit.Rows(` `)) + tk.MustIndexLookup(`select * from t where a = " ";`).Check(testkit.Rows(` `)) + + // Test query with table alias in `PAD_CHAR_TO_FULL_LENGTH` mode + tk.MustExec(`set @@sql_mode="PAD_CHAR_TO_FULL_LENGTH";`) tk.MustIndexLookup(`select * from t tmp where a = "a";`).Check(testkit.Rows(`a b`)) tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b`)) @@ -311,7 +417,7 @@ func (s *testPointGetSuite) TestPointGetBinaryPK(c *C) { tk.MustPointGet(`select * from t where a = "a ";`).Check(testkit.Rows()) } -func (s *testPointGetSuite) TestIndexLookupBinaryPK(c *C) { +func (s *testPointGetSuite) TestPointGetAliasTableBinaryPK(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec(`use test;`) tk.MustExec(`drop table if exists t;`) @@ -319,25 +425,142 @@ func (s *testPointGetSuite) TestIndexLookupBinaryPK(c *C) { tk.MustExec(`insert into t values("a", "b");`) tk.MustExec(`set @@sql_mode="";`) - tk.MustIndexLookup(`select * from t tmp where a = "a";`).Check(testkit.Rows()) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) - tk.MustIndexLookup(`select * from t tmp where a = "a\0";`).Check(testkit.Rows("a\x00 b\x00")) + tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a\0";`).Check(testkit.Rows("a\x00 b\x00")) // `PAD_CHAR_TO_FULL_LENGTH` should not affect the result. tk.MustExec(`set @@sql_mode="PAD_CHAR_TO_FULL_LENGTH";`) + tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a\0";`).Check(testkit.Rows("a\x00 b\x00")) + + tk.MustExec(`insert into t values("a ", "b ");`) + tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b `)) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + + // `PAD_CHAR_TO_FULL_LENGTH` should not affect the result. + tk.MustPointGet(`select * from t tmp where a = "a";`).Check(testkit.Rows()) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b `)) + tk.MustPointGet(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) +} + +func (s *testPointGetSuite) TestIndexLookupBinary(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test;`) + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a binary(2), b binary(2), index idx_1(a));`) + tk.MustExec(`insert into t values("a", "b");`) + + tk.MustExec(`set @@sql_mode="";`) + tk.MustIndexLookup(`select * from t where a = "a";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a\0";`).Check(testkit.Rows("a\x00 b\x00")) + + // Test query with table alias + tk.MustExec(`set @@sql_mode="";`) tk.MustIndexLookup(`select * from t tmp where a = "a";`).Check(testkit.Rows()) tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) tk.MustIndexLookup(`select * from t tmp where a = "a\0";`).Check(testkit.Rows("a\x00 b\x00")) + // `PAD_CHAR_TO_FULL_LENGTH` should not affect the result. + tk.MustExec(`set @@sql_mode="PAD_CHAR_TO_FULL_LENGTH";`) + tk.MustIndexLookup(`select * from t where a = "a";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a\0";`).Check(testkit.Rows("a\x00 b\x00")) + tk.MustExec(`insert into t values("a ", "b ");`) - tk.MustIndexLookup(`select * from t tmp where a = "a";`).Check(testkit.Rows()) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b `)) - tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `)) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows()) // `PAD_CHAR_TO_FULL_LENGTH` should not affect the result. + tk.MustIndexLookup(`select * from t where a = "a";`).Check(testkit.Rows()) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows(`a b `)) + tk.MustIndexLookup(`select * from t where a = "a ";`).Check(testkit.Rows()) + + // Test query with table alias in `PAD_CHAR_TO_FULL_LENGTH` mode tk.MustIndexLookup(`select * from t tmp where a = "a";`).Check(testkit.Rows()) tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows(`a b `)) tk.MustIndexLookup(`select * from t tmp where a = "a ";`).Check(testkit.Rows()) } + +func (s *testPointGetSuite) TestIssue10448(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(pk int1 primary key)") + tk.MustExec("insert into t values(125)") + tk.MustQuery("desc select * from t where pk = 9223372036854775807").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 18446744073709551616").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 9223372036854775808").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 18446744073709551615").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 128").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(pk int8 primary key)") + tk.MustExec("insert into t values(9223372036854775807)") + tk.MustQuery("select * from t where pk = 9223372036854775807").Check(testkit.Rows("9223372036854775807")) + tk.MustQuery("desc select * from t where pk = 9223372036854775807").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:9223372036854775807")) + tk.MustQuery("desc select * from t where pk = 18446744073709551616").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 9223372036854775808").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 18446744073709551615").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(pk int1 unsigned primary key)") + tk.MustExec("insert into t values(255)") + tk.MustQuery("select * from t where pk = 255").Check(testkit.Rows("255")) + tk.MustQuery("desc select * from t where pk = 256").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 9223372036854775807").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 18446744073709551616").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 9223372036854775808").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 18446744073709551615").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(pk int8 unsigned primary key)") + tk.MustExec("insert into t value(18446744073709551615)") + tk.MustQuery("desc select * from t where pk = 18446744073709551615").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:18446744073709551615")) + tk.MustQuery("select * from t where pk = 18446744073709551615").Check(testkit.Rows("18446744073709551615")) + tk.MustQuery("desc select * from t where pk = 9223372036854775807").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:9223372036854775807")) + tk.MustQuery("desc select * from t where pk = 18446744073709551616").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("desc select * from t where pk = 9223372036854775808").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:9223372036854775808")) +} + +func (s *testPointGetSuite) TestIssue10677(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(pk int1 primary key)") + tk.MustExec("insert into t values(1)") + tk.MustQuery("desc select * from t where pk = 1.1").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("select * from t where pk = 1.1").Check(testkit.Rows()) + tk.MustQuery("desc select * from t where pk = '1.1'").Check(testkit.Rows("TableDual_2 0.00 root rows:0")) + tk.MustQuery("select * from t where pk = '1.1'").Check(testkit.Rows()) + tk.MustQuery("desc select * from t where pk = 1").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:1")) + tk.MustQuery("select * from t where pk = 1").Check(testkit.Rows("1")) + tk.MustQuery("desc select * from t where pk = '1'").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:1")) + tk.MustQuery("select * from t where pk = '1'").Check(testkit.Rows("1")) + tk.MustQuery("desc select * from t where pk = '1.0'").Check(testkit.Rows("Point_Get_1 1.00 root table:t, handle:1")) + tk.MustQuery("select * from t where pk = '1.0'").Check(testkit.Rows("1")) +} + +func (s *testPointGetSuite) TestForUpdateRetry(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.Exec("drop table if exists t") + tk.MustExec("create table t(pk int primary key, c int)") + tk.MustExec("insert into t values (1, 1), (2, 2)") + tk.MustExec("set @@tidb_disable_txn_auto_retry = 0") + tk.MustExec("begin") + tk.MustQuery("select * from t where pk = 1 for update") + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk2.MustExec("update t set c = c + 1 where pk = 1") + tk.MustExec("update t set c = c + 1 where pk = 2") + _, err := tk.Exec("commit") + c.Assert(session.ErrForUpdateCantRetry.Equal(err), IsTrue) +} diff --git a/executor/prepared.go b/executor/prepared.go index 73dbd092e7d46..fd98b47c33c03 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -17,6 +17,7 @@ import ( "context" "math" "sort" + "time" "github.com/pingcap/errors" "github.com/pingcap/log" @@ -100,7 +101,7 @@ func NewPrepareExec(ctx sessionctx.Context, is infoschema.InfoSchema, sqlTxt str } // Next implements the Executor Next interface. -func (e *PrepareExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error { vars := e.ctx.GetSessionVars() if e.ID != 0 { // Must be the case when we retry a prepare. @@ -167,6 +168,7 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.RecordBatch) error { } prepared := &ast.Prepared{ Stmt: stmt, + StmtType: GetStmtLabel(stmt), Params: sorter.markers, SchemaVersion: e.is.SchemaMetaVersion(), } @@ -179,7 +181,7 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.RecordBatch) error { param.InExecute = false } var p plannercore.Plan - p, err = plannercore.BuildLogicalPlan(e.ctx, stmt, e.is) + p, err = plannercore.BuildLogicalPlan(ctx, e.ctx, stmt, e.is) if err != nil { return err } @@ -201,17 +203,18 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.RecordBatch) error { type ExecuteExec struct { baseExecutor - is infoschema.InfoSchema - name string - usingVars []expression.Expression - id uint32 - stmtExec Executor - stmt ast.StmtNode - plan plannercore.Plan + is infoschema.InfoSchema + name string + usingVars []expression.Expression + id uint32 + stmtExec Executor + stmt ast.StmtNode + plan plannercore.Plan + lowerPriority bool } // Next implements the Executor Next interface. -func (e *ExecuteExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ExecuteExec) Next(ctx context.Context, req *chunk.Chunk) error { return nil } @@ -235,7 +238,7 @@ func (e *ExecuteExec) Build(b *executorBuilder) error { } e.stmtExec = stmtExec CountStmtNode(e.stmt, e.ctx.GetSessionVars().InRestrictedSQL) - logExpensiveQuery(e.stmt, e.plan) + e.lowerPriority = needLowerPriority(e.plan) return nil } @@ -247,7 +250,7 @@ type DeallocateExec struct { } // Next implements the Executor Next interface. -func (e *DeallocateExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *DeallocateExec) Next(ctx context.Context, req *chunk.Chunk) error { vars := e.ctx.GetSessionVars() id, ok := vars.PreparedStmtNameToID[e.Name] if !ok { @@ -264,17 +267,21 @@ func (e *DeallocateExec) Next(ctx context.Context, req *chunk.RecordBatch) error } // CompileExecutePreparedStmt compiles a session Execute command to a stmt.Statement. -func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...interface{}) (sqlexec.Statement, error) { +func CompileExecutePreparedStmt(ctx context.Context, sctx sessionctx.Context, ID uint32, args ...interface{}) (sqlexec.Statement, error) { + startTime := time.Now() + defer func() { + sctx.GetSessionVars().DurationCompile = time.Since(startTime) + }() execStmt := &ast.ExecuteStmt{ExecID: ID} - if err := ResetContextOfStmt(ctx, execStmt); err != nil { + if err := ResetContextOfStmt(sctx, execStmt); err != nil { return nil, err } execStmt.UsingVars = make([]ast.ExprNode, len(args)) for i, val := range args { execStmt.UsingVars[i] = ast.NewValueExpr(val) } - is := GetInfoSchema(ctx) - execPlan, err := planner.Optimize(ctx, execStmt, is) + is := GetInfoSchema(sctx) + execPlan, err := planner.Optimize(ctx, sctx, execStmt, is) if err != nil { return nil, err } @@ -283,11 +290,11 @@ func CompileExecutePreparedStmt(ctx sessionctx.Context, ID uint32, args ...inter InfoSchema: is, Plan: execPlan, StmtNode: execStmt, - Ctx: ctx, + Ctx: sctx, } - if prepared, ok := ctx.GetSessionVars().PreparedStmts[ID]; ok { + if prepared, ok := sctx.GetSessionVars().PreparedStmts[ID]; ok { stmt.Text = prepared.Stmt.Text() - ctx.GetSessionVars().StmtCtx.OriginalSQL = stmt.Text + sctx.GetSessionVars().StmtCtx.OriginalSQL = stmt.Text } return stmt, nil } diff --git a/executor/projection.go b/executor/projection.go index 01d3aff16b7b9..83da1a5857d1e 100644 --- a/executor/projection.go +++ b/executor/projection.go @@ -91,7 +91,7 @@ func (e *ProjectionExec) Open(ctx context.Context) error { } if e.isUnparallelExec() { - e.childResult = e.children[0].newFirstChunk() + e.childResult = newFirstChunk(e.children[0]) } return nil @@ -154,7 +154,7 @@ func (e *ProjectionExec) Open(ctx context.Context) error { // | | | | // +------------------------------+ +----------------------+ // -func (e *ProjectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ProjectionExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("projection.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -166,9 +166,9 @@ func (e *ProjectionExec) Next(ctx context.Context, req *chunk.RecordBatch) error } req.GrowAndReset(e.maxChunkSize) if e.isUnparallelExec() { - return e.unParallelExecute(ctx, req.Chunk) + return e.unParallelExecute(ctx, req) } - return e.parallelExecute(ctx, req.Chunk) + return e.parallelExecute(ctx, req) } @@ -179,7 +179,7 @@ func (e *ProjectionExec) isUnparallelExec() bool { func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk) error { // transmit the requiredRows e.childResult.SetRequiredRows(chk.RequiredRows(), e.maxChunkSize) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(e.childResult)) + err := Next(ctx, e.children[0], e.childResult) if err != nil { return err } @@ -236,11 +236,11 @@ func (e *ProjectionExec) prepare(ctx context.Context) { }) e.fetcher.inputCh <- &projectionInput{ - chk: e.children[0].newFirstChunk(), + chk: newFirstChunk(e.children[0]), targetWorker: e.workers[i], } e.fetcher.outputCh <- &projectionOutput{ - chk: e.newFirstChunk(), + chk: newFirstChunk(e), done: make(chan error, 1), } } @@ -312,7 +312,7 @@ func (f *projectionInputFetcher) run(ctx context.Context) { requiredRows := atomic.LoadInt64(&f.proj.parentReqRows) input.chk.SetRequiredRows(int(requiredRows), f.proj.maxChunkSize) - err := f.child.Next(ctx, chunk.NewRecordBatch(input.chk)) + err := Next(ctx, f.child, input.chk) if err != nil || input.chk.NumRows() == 0 { output.done <- err return diff --git a/executor/radix_hash_join.go b/executor/radix_hash_join.go deleted file mode 100644 index cc0633f391bb8..0000000000000 --- a/executor/radix_hash_join.go +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "math" - "sync" - "time" - "unsafe" - - "github.com/pingcap/errors" - "github.com/pingcap/tidb/util" - "github.com/pingcap/tidb/util/chunk" - "github.com/pingcap/tidb/util/logutil" - "github.com/pingcap/tidb/util/mvmap" - "github.com/spaolacci/murmur3" - "go.uber.org/zap" -) - -var ( - _ Executor = &RadixHashJoinExec{} -) - -// RadixHashJoinExec implements the radix partition-based hash join algorithm. -// It will partition the input relations into small pairs of partitions where -// one of the partitions typically fits into one of the caches. The overall goal -// of this method is to minimize the number of cache misses when building and -// probing hash tables. -type RadixHashJoinExec struct { - *HashJoinExec - - // radixBits indicates the bits using for radix partitioning. Inner relation - // will be split to 2^radixBitsNumber sub-relations before building the hash - // tables. If the complete inner relation can be hold in L2Cache in which - // case radixBits will be 1, we can skip the partition phase. - // Note: We actually check whether `size of sub inner relation < 3/4 * L2 - // cache size` to make sure one inner sub-relation, hash table, one outer - // sub-relation and join result of the sub-relations can be totally loaded - // in L2 cache size. `3/4` is a magic number, we may adjust it after - // benchmark. - radixBits uint32 - innerParts []partition - numNonEmptyPart int - // innerRowPrts indicates the position in corresponding partition of every - // row in innerResult. - innerRowPrts [][]partRowPtr - // hashTables stores the hash tables built from the inner relation, if there - // is no partition phase, a global hash table will be stored in - // hashTables[0]. - hashTables []*mvmap.MVMap -} - -// partition stores the sub-relations of inner relation and outer relation after -// partition phase. Every partition can be fully stored in L2 cache thus can -// reduce the cache miss ratio when building and probing the hash table. -type partition = *chunk.Chunk - -// partRowPtr stores the actual index in `innerParts` or `outerParts`. -type partRowPtr struct { - partitionIdx uint32 - rowIdx uint32 -} - -// partPtr4NullKey indicates a partition pointer which points to a row with null-join-key. -var partPtr4NullKey = partRowPtr{math.MaxUint32, math.MaxUint32} - -// Next implements the Executor Next interface. -// radix hash join constructs the result following these steps: -// step 1. fetch data from inner child -// step 2. parallel partition the inner relation into sub-relations and build an -// individual hash table for every partition -// step 3. fetch data from outer child in a background goroutine and partition -// it into sub-relations -// step 4. probe the corresponded sub-hash-table for every sub-outer-relation in -// multiple join workers -func (e *RadixHashJoinExec) Next(ctx context.Context, req *chunk.RecordBatch) (err error) { - if e.runtimeStats != nil { - start := time.Now() - defer func() { e.runtimeStats.Record(time.Now().Sub(start), req.NumRows()) }() - } - if !e.prepared { - e.innerFinished = make(chan error, 1) - go util.WithRecovery(func() { e.partitionInnerAndBuildHashTables(ctx) }, e.handleFetchInnerAndBuildHashTablePanic) - // TODO: parallel fetch outer rows, partition them and do parallel join - e.prepared = true - } - return <-e.innerFinished -} - -// partitionInnerRows re-order e.innerResults into sub-relations. -func (e *RadixHashJoinExec) partitionInnerRows() error { - e.evalRadixBit() - if err := e.preAlloc4InnerParts(); err != nil { - return err - } - - wg := sync.WaitGroup{} - defer wg.Wait() - wg.Add(int(e.concurrency)) - for i := 0; i < int(e.concurrency); i++ { - workerID := i - go util.WithRecovery(func() { - defer wg.Done() - e.doInnerPartition(workerID) - }, e.handlePartitionPanic) - } - return nil -} - -func (e *RadixHashJoinExec) handlePartitionPanic(r interface{}) { - if r != nil { - e.joinResultCh <- &hashjoinWorkerResult{err: errors.Errorf("%v", r)} - } -} - -// doInnerPartition runs concurrently, partitions and copies the inner relation -// to several pre-allocated data partitions. The input inner Chunk idx for each -// partitioner is workerId + x*numPartitioners. -func (e *RadixHashJoinExec) doInnerPartition(workerID int) { - chkIdx, chkNum := workerID, e.innerResult.NumChunks() - for ; chkIdx < chkNum; chkIdx += int(e.concurrency) { - if e.finished.Load().(bool) { - return - } - chk := e.innerResult.GetChunk(chkIdx) - for srcRowIdx, partPtr := range e.innerRowPrts[chkIdx] { - if partPtr == partPtr4NullKey { - continue - } - partIdx, destRowIdx := partPtr.partitionIdx, partPtr.rowIdx - part := e.innerParts[partIdx] - part.Insert(int(destRowIdx), chk.GetRow(srcRowIdx)) - } - } -} - -// preAlloc4InnerParts evaluates partRowPtr and pre-alloc the memory space -// for every inner row to help re-order the inner relation. -// TODO: we need to evaluate the skewness for the partitions size, if the -// skewness exceeds a threshold, we do not use partition phase. -func (e *RadixHashJoinExec) preAlloc4InnerParts() (err error) { - var hasNull bool - keyBuf := make([]byte, 0, 64) - for chkIdx, chkNum := 0, e.innerResult.NumChunks(); chkIdx < chkNum; chkIdx++ { - chk := e.innerResult.GetChunk(chkIdx) - partPtrs := make([]partRowPtr, chk.NumRows()) - for rowIdx := 0; rowIdx < chk.NumRows(); rowIdx++ { - row := chk.GetRow(rowIdx) - hasNull, keyBuf, err = e.getJoinKeyFromChkRow(false, row, keyBuf) - if err != nil { - return err - } - if hasNull { - partPtrs[rowIdx] = partPtr4NullKey - continue - } - joinHash := murmur3.Sum32(keyBuf) - partIdx := e.radixBits & joinHash - partPtrs[rowIdx].partitionIdx = partIdx - partPtrs[rowIdx].rowIdx = e.getPartition(partIdx).PreAlloc(row) - } - e.innerRowPrts = append(e.innerRowPrts, partPtrs) - } - if e.numNonEmptyPart < len(e.innerParts) { - numTotalPart := len(e.innerParts) - numEmptyPart := numTotalPart - e.numNonEmptyPart - logutil.Logger(context.Background()).Debug("empty partition in radix hash join", zap.Uint64("txnStartTS", e.ctx.GetSessionVars().TxnCtx.StartTS), - zap.Int("numEmptyParts", numEmptyPart), zap.Int("numTotalParts", numTotalPart), - zap.Float64("emptyRatio", float64(numEmptyPart)/float64(numTotalPart))) - } - return -} - -func (e *RadixHashJoinExec) getPartition(idx uint32) partition { - if e.innerParts[idx] == nil { - e.numNonEmptyPart++ - e.innerParts[idx] = chunk.New(e.innerExec.retTypes(), e.initCap, e.maxChunkSize) - } - return e.innerParts[idx] -} - -// evalRadixBit evaluates the radix bit numbers. -// To ensure that one partition of inner relation, one hash table, one partition -// of outer relation and the join result of these two partitions fit into the L2 -// cache when the input data obeys the uniform distribution, we suppose every -// sub-partition of inner relation using three quarters of the L2 cache size. -func (e *RadixHashJoinExec) evalRadixBit() { - sv := e.ctx.GetSessionVars() - innerResultSize := float64(e.innerResult.GetMemTracker().BytesConsumed()) - l2CacheSize := float64(sv.L2CacheSize) * 3 / 4 - radixBitsNum := math.Ceil(math.Log2(innerResultSize / l2CacheSize)) - if radixBitsNum <= 0 { - radixBitsNum = 1 - } - // Take the rightmost radixBitsNum bits as the bitmask. - e.radixBits = ^(math.MaxUint32 << uint(radixBitsNum)) - e.innerParts = make([]partition, 1<>": ast.RightShift, + "or": ast.LogicOr, + ">=": ast.GE, + "<=": ast.LE, + "=": ast.EQ, + "!=": ast.NE, + "<>": ast.NE, + "<": ast.LT, + ">": ast.GT, + "+": ast.Plus, + "-": ast.Minus, + "&&": ast.And, + "||": ast.Or, + "%": ast.Mod, + "xor_bit": ast.Xor, + "/": ast.Div, + "*": ast.Mul, + "!": ast.UnaryNot, + "~": ast.BitNeg, + "div": ast.IntDiv, + "xor_logic": ast.LogicXor, // Avoid name conflict with "xor_bit"., + "<=>": ast.NullEQ, + "+_unary": ast.UnaryPlus, // Avoid name conflict with `plus`., + "-_unary": ast.UnaryMinus, + "in": ast.In, + "like": ast.Like, + "case": ast.Case, + "regexp": ast.Regexp, + "is null": ast.IsNull, + "is true": ast.IsTruth, + "is false": ast.IsFalsity, + "values": ast.Values, + "bit_count": ast.BitCount, + "coalesce": ast.Coalesce, + "greatest": ast.Greatest, + "least": ast.Least, + "interval": ast.Interval, + "abs": ast.Abs, + "acos": ast.Acos, + "asin": ast.Asin, + "atan": ast.Atan, + "atan2": ast.Atan2, + "ceil": ast.Ceil, + "ceiling": ast.Ceiling, + "conv": ast.Conv, + "cos": ast.Cos, + "cot": ast.Cot, + "crc32": ast.CRC32, + "degrees": ast.Degrees, + "exp": ast.Exp, + "floor": ast.Floor, + "ln": ast.Ln, + "log": ast.Log, + "log2": ast.Log2, + "log10": ast.Log10, + "pi": ast.PI, + "pow": ast.Pow, + "power": ast.Power, + "radians": ast.Radians, + "rand": ast.Rand, + "round": ast.Round, + "sign": ast.Sign, + "sin": ast.Sin, + "sqrt": ast.Sqrt, + "tan": ast.Tan, + "truncate": ast.Truncate, + "adddate": ast.AddDate, + "addtime": ast.AddTime, + "convert_tz": ast.ConvertTz, + "curdate": ast.Curdate, + "current_date": ast.CurrentDate, + "current_time": ast.CurrentTime, + "current_timestamp": ast.CurrentTimestamp, + "curtime": ast.Curtime, + "date": ast.Date, + "date_add": ast.DateAdd, + "date_format": ast.DateFormat, + "date_sub": ast.DateSub, + "datediff": ast.DateDiff, + "day": ast.Day, + "dayname": ast.DayName, + "dayofmonth": ast.DayOfMonth, + "dayofweek": ast.DayOfWeek, + "dayofyear": ast.DayOfYear, + "extract": ast.Extract, + "from_days": ast.FromDays, + "from_unixtime": ast.FromUnixTime, + "get_format": ast.GetFormat, + "hour": ast.Hour, + "localtime": ast.LocalTime, + "localtimestamp": ast.LocalTimestamp, + "makedate": ast.MakeDate, + "maketime": ast.MakeTime, + "microsecond": ast.MicroSecond, + "minute": ast.Minute, + "month": ast.Month, + "monthname": ast.MonthName, + "now": ast.Now, + "period_add": ast.PeriodAdd, + "period_diff": ast.PeriodDiff, + "quarter": ast.Quarter, + "sec_to_time": ast.SecToTime, + "second": ast.Second, + "str_to_date": ast.StrToDate, + "subdate": ast.SubDate, + "subtime": ast.SubTime, + "sysdate": ast.Sysdate, + "time": ast.Time, + "time_format": ast.TimeFormat, + "time_to_sec": ast.TimeToSec, + "timediff": ast.TimeDiff, + "timestamp": ast.Timestamp, + "timestampadd": ast.TimestampAdd, + "timestampdiff": ast.TimestampDiff, + "to_days": ast.ToDays, + "to_seconds": ast.ToSeconds, + "unix_timestamp": ast.UnixTimestamp, + "utc_date": ast.UTCDate, + "utc_time": ast.UTCTime, + "utc_timestamp": ast.UTCTimestamp, + "week": ast.Week, + "weekday": ast.Weekday, + "weekofyear": ast.WeekOfYear, + "year": ast.Year, + "yearweek": ast.YearWeek, + "last_day": ast.LastDay, + "ascii": ast.ASCII, + "bin": ast.Bin, + "concat": ast.Concat, + "concat_ws": ast.ConcatWS, + "convert": ast.Convert, + "elt": ast.Elt, + "export_set": ast.ExportSet, + "field": ast.Field, + "format": ast.Format, + "from_base64": ast.FromBase64, + "insert_func": ast.InsertFunc, + "instr": ast.Instr, + "lcase": ast.Lcase, + "left": ast.Left, + "length": ast.Length, + "load_file": ast.LoadFile, + "locate": ast.Locate, + "lower": ast.Lower, + "lpad": ast.Lpad, + "ltrim": ast.LTrim, + "make_set": ast.MakeSet, + "mid": ast.Mid, + "oct": ast.Oct, + "ord": ast.Ord, + "position": ast.Position, + "quote": ast.Quote, + "repeat": ast.Repeat, + "replace": ast.Replace, + "reverse": ast.Reverse, + "right": ast.Right, + "rtrim": ast.RTrim, + "space": ast.Space, + "strcmp": ast.Strcmp, + "substring": ast.Substring, + "substr": ast.Substr, + "substring_index": ast.SubstringIndex, + "to_base64": ast.ToBase64, + "trim": ast.Trim, + "upper": ast.Upper, + "ucase": ast.Ucase, + "hex": ast.Hex, + "unhex": ast.Unhex, + "rpad": ast.Rpad, + "bit_length": ast.BitLength, + "char_func": ast.CharFunc, + "char_length": ast.CharLength, + "character_length": ast.CharacterLength, + "find_in_set": ast.FindInSet, + "benchmark": ast.Benchmark, + "charset": ast.Charset, + "coercibility": ast.Coercibility, + "collation": ast.Collation, + "connection_id": ast.ConnectionID, + "current_user": ast.CurrentUser, + "current_role": ast.CurrentRole, + "database": ast.Database, + "found_rows": ast.FoundRows, + "last_insert_id": ast.LastInsertId, + "row_count": ast.RowCount, + "schema": ast.Schema, + "session_user": ast.SessionUser, + "system_user": ast.SystemUser, + "user": ast.User, + "if": ast.If, + "ifnull": ast.Ifnull, + "nullif": ast.Nullif, + "any_value": ast.AnyValue, + "default_func": ast.DefaultFunc, + "inet_aton": ast.InetAton, + "inet_ntoa": ast.InetNtoa, + "inet6_aton": ast.Inet6Aton, + "inet6_ntoa": ast.Inet6Ntoa, + "is_free_lock": ast.IsFreeLock, + "is_ipv4": ast.IsIPv4, + "is_ipv4_compat": ast.IsIPv4Compat, + "is_ipv4_mapped": ast.IsIPv4Mapped, + "is_ipv6": ast.IsIPv6, + "is_used_lock": ast.IsUsedLock, + "master_pos_wait": ast.MasterPosWait, + "name_const": ast.NameConst, + "release_all_locks": ast.ReleaseAllLocks, + "sleep": ast.Sleep, + "uuid": ast.UUID, + "uuid_short": ast.UUIDShort, + "get_lock": ast.GetLock, + "release_lock": ast.ReleaseLock, + "aes_decrypt": ast.AesDecrypt, + "aes_encrypt": ast.AesEncrypt, + "compress": ast.Compress, + "decode": ast.Decode, + "des_decrypt": ast.DesDecrypt, + "des_encrypt": ast.DesEncrypt, + "encode": ast.Encode, + "encrypt": ast.Encrypt, + "md5": ast.MD5, + "old_password": ast.OldPassword, + "password_func": ast.PasswordFunc, + "random_bytes": ast.RandomBytes, + "sha1": ast.SHA1, + "sha": ast.SHA, + "sha2": ast.SHA2, + "uncompress": ast.Uncompress, + "uncompressed_length": ast.UncompressedLength, + "validate_password_strength": ast.ValidatePasswordStrength, + "json_type": ast.JSONType, + "json_extract": ast.JSONExtract, + "json_unquote": ast.JSONUnquote, + "json_array": ast.JSONArray, + "json_object": ast.JSONObject, + "json_merge": ast.JSONMerge, + "json_set": ast.JSONSet, + "json_insert": ast.JSONInsert, + "json_replace": ast.JSONReplace, + "json_remove": ast.JSONRemove, + "json_contains": ast.JSONContains, + "json_contains_path": ast.JSONContainsPath, + "json_valid": ast.JSONValid, + "json_array_append": ast.JSONArrayAppend, + "json_array_insert": ast.JSONArrayInsert, + "json_merge_patch": ast.JSONMergePatch, + "json_merge_preserve": ast.JSONMergePreserve, + "json_pretty": ast.JSONPretty, + "json_quote": ast.JSONQuote, + "json_search": ast.JSONSearch, + "json_storage_size": ast.JSONStorageSize, + "json_depth": ast.JSONDepth, + "json_keys": ast.JSONKeys, + "json_length": ast.JSONLength, +} diff --git a/executor/replace.go b/executor/replace.go index 26d3875eddb4e..352c6b85156ec 100644 --- a/executor/replace.go +++ b/executor/replace.go @@ -182,7 +182,7 @@ func (e *ReplaceExec) exec(ctx context.Context, newRows [][]types.Datum) error { } // Next implements the Executor Next interface. -func (e *ReplaceExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ReplaceExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("replace.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() diff --git a/executor/revoke.go b/executor/revoke.go index ada1a06f3dbeb..4a5c5f97ef76f 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -51,14 +51,14 @@ type RevokeExec struct { } // Next implements the Executor Next interface. -func (e *RevokeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { if e.done { return nil } e.done = true // Revoke for each user. - for _, user := range e.Users { + for idx, user := range e.Users { // Check if user exists. exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname) if err != nil { @@ -68,6 +68,13 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.RecordBatch) error { return errors.Errorf("Unknown user: %s", user.User) } + if idx == 0 { + // Commit the old transaction, like DDL. + if err := e.ctx.NewTxn(ctx); err != nil { + return err + } + defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() + } err = e.revokeOneUser(user.User.Username, user.User.Hostname) if err != nil { return err diff --git a/executor/seqtest/prepared_test.go b/executor/seqtest/prepared_test.go index 46f4f64a5a41d..f6140b074d09f 100644 --- a/executor/seqtest/prepared_test.go +++ b/executor/seqtest/prepared_test.go @@ -130,17 +130,17 @@ func (s *seqTestSuite) TestPrepared(c *C) { tk.ResultSetToResult(rs, Commentf("%v", rs)).Check(testkit.Rows()) // Check that ast.Statement created by executor.CompileExecutePreparedStmt has query text. - stmt, err := executor.CompileExecutePreparedStmt(tk.Se, stmtID, 1) + stmt, err := executor.CompileExecutePreparedStmt(context.TODO(), tk.Se, stmtID, 1) c.Assert(err, IsNil) c.Assert(stmt.OriginText(), Equals, query) // Check that rebuild plan works. tk.Se.PrepareTxnCtx(ctx) - _, err = stmt.RebuildPlan() + _, err = stmt.RebuildPlan(ctx) c.Assert(err, IsNil) rs, err = stmt.Exec(ctx) c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(ctx, req) c.Assert(err, IsNil) c.Assert(rs.Close(), IsNil) diff --git a/executor/seqtest/seq_executor_test.go b/executor/seqtest/seq_executor_test.go index d739b9657d2b7..631c5ecea3f19 100644 --- a/executor/seqtest/seq_executor_test.go +++ b/executor/seqtest/seq_executor_test.go @@ -25,6 +25,7 @@ import ( "runtime/pprof" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -35,6 +36,7 @@ import ( pb "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/parser" "github.com/pingcap/parser/model" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/executor" @@ -62,6 +64,7 @@ func TestT(t *testing.T) { } var _ = Suite(&seqTestSuite{}) +var _ = Suite(&seqTestSuite1{}) type seqTestSuite struct { cluster *mocktikv.Cluster @@ -89,7 +92,7 @@ func (s *seqTestSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() } d, err := session.BootstrapSession(s.store) c.Assert(err, IsNil) @@ -129,7 +132,7 @@ func (s *seqTestSuite) TestEarlyClose(c *C) { rss, err1 := tk.Se.Execute(ctx, "select * from earlyclose order by id") c.Assert(err1, IsNil) rs := rss[0] - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(ctx, req) c.Assert(err, IsNil) rs.Close() @@ -143,7 +146,7 @@ func (s *seqTestSuite) TestEarlyClose(c *C) { rss, err := tk.Se.Execute(ctx, "select * from earlyclose") c.Assert(err, IsNil) rs := rss[0] - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(ctx, req) c.Assert(err, NotNil) rs.Close() @@ -585,7 +588,7 @@ func (s *seqTestSuite) TestShow(c *C) { "c4|varchar(6)|YES||1|", "c5|varchar(6)|YES||'C6'|", "c6|enum('s','m','l','xl')|YES||xl|", - "c7|set('a','b','c','d')|YES||a,c,c|", + "c7|set('a','b','c','d')|YES||a,c|", "c8|datetime|YES||CURRENT_TIMESTAMP|DEFAULT_GENERATED on update CURRENT_TIMESTAMP", "c9|year(4)|YES||2014|", )) @@ -639,7 +642,7 @@ func (s *seqTestSuite) TestIndexDoubleReadClose(c *C) { rs, err := tk.Exec("select * from dist where c_idx between 0 and 100") c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(context.Background(), req) c.Assert(err, IsNil) c.Assert(err, IsNil) @@ -671,7 +674,7 @@ func (s *seqTestSuite) TestParallelHashAggClose(c *C) { rss, err := tk.Se.Execute(ctx, "select sum(a) from (select cast(t.a as signed) as a, b from t) t group by b;") c.Assert(err, IsNil) rs := rss[0] - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(ctx, req) c.Assert(err.Error(), Equals, "HashAggExec.parallelExec error") } @@ -692,7 +695,7 @@ func (s *seqTestSuite) TestUnparallelHashAggClose(c *C) { rss, err := tk.Se.Execute(ctx, "select sum(distinct a) from (select cast(t.a as signed) as a, b from t) t group by b;") c.Assert(err, IsNil) rs := rss[0] - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(ctx, req) c.Assert(err.Error(), Equals, "HashAggExec.unparallelExec error") } @@ -706,6 +709,10 @@ func checkGoroutineExists(keyword string) bool { } func (s *seqTestSuite) TestAdminShowNextID(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange"), IsNil) + }() step := int64(10) autoIDStep := autoid.GetStep() autoid.SetStep(step) @@ -787,9 +794,130 @@ func (s *seqTestSuite) TestCartesianProduct(c *C) { plannercore.AllowCartesianProduct.Store(true) } +func (s *seqTestSuite) TestBatchInsertDelete(c *C) { + originLimit := atomic.LoadUint64(&kv.TxnEntryCountLimit) + defer func() { + atomic.StoreUint64(&kv.TxnEntryCountLimit, originLimit) + }() + // Set the limitation to a small value, make it easier to reach the limitation. + atomic.StoreUint64(&kv.TxnEntryCountLimit, 100) + + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists batch_insert") + tk.MustExec("create table batch_insert (c int)") + tk.MustExec("drop table if exists batch_insert_on_duplicate") + tk.MustExec("create table batch_insert_on_duplicate (id int primary key, c int)") + // Insert 10 rows. + tk.MustExec("insert into batch_insert values (1),(1),(1),(1),(1),(1),(1),(1),(1),(1)") + r := tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("10")) + // Insert 10 rows. + tk.MustExec("insert into batch_insert (c) select * from batch_insert;") + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("20")) + // Insert 20 rows. + tk.MustExec("insert into batch_insert (c) select * from batch_insert;") + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("40")) + // Insert 40 rows. + tk.MustExec("insert into batch_insert (c) select * from batch_insert;") + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("80")) + // Insert 80 rows. + tk.MustExec("insert into batch_insert (c) select * from batch_insert;") + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("160")) + // for on duplicate key + for i := 0; i < 160; i++ { + tk.MustExec(fmt.Sprintf("insert into batch_insert_on_duplicate values(%d, %d);", i, i)) + } + r = tk.MustQuery("select count(*) from batch_insert_on_duplicate;") + r.Check(testkit.Rows("160")) + + // This will meet txn too large error. + _, err := tk.Exec("insert into batch_insert (c) select * from batch_insert;") + c.Assert(err, NotNil) + c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue) + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("160")) + + // for on duplicate key + _, err = tk.Exec(`insert into batch_insert_on_duplicate select * from batch_insert_on_duplicate as tt + on duplicate key update batch_insert_on_duplicate.id=batch_insert_on_duplicate.id+1000;`) + c.Assert(err, NotNil) + c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue, Commentf("%v", err)) + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("160")) + + // Change to batch inset mode and batch size to 50. + tk.MustExec("set @@session.tidb_batch_insert=1;") + tk.MustExec("set @@session.tidb_dml_batch_size=50;") + tk.MustExec("insert into batch_insert (c) select * from batch_insert;") + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("320")) + + // Enlarge the batch size to 150 which is larger than the txn limitation (100). + // So the insert will meet error. + tk.MustExec("set @@session.tidb_dml_batch_size=150;") + _, err = tk.Exec("insert into batch_insert (c) select * from batch_insert;") + c.Assert(err, NotNil) + c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue) + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("320")) + // Set it back to 50. + tk.MustExec("set @@session.tidb_dml_batch_size=50;") + + // for on duplicate key + _, err = tk.Exec(`insert into batch_insert_on_duplicate select * from batch_insert_on_duplicate as tt + on duplicate key update batch_insert_on_duplicate.id=batch_insert_on_duplicate.id+1000;`) + c.Assert(err, IsNil) + r = tk.MustQuery("select count(*) from batch_insert_on_duplicate;") + r.Check(testkit.Rows("160")) + + // Disable BachInsert mode in transition. + tk.MustExec("begin;") + _, err = tk.Exec("insert into batch_insert (c) select * from batch_insert;") + c.Assert(err, NotNil) + c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue) + tk.MustExec("rollback;") + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("320")) + + tk.MustExec("drop table if exists com_batch_insert") + tk.MustExec("create table com_batch_insert (c int)") + sql := "insert into com_batch_insert values " + values := make([]string, 0, 200) + for i := 0; i < 200; i++ { + values = append(values, "(1)") + } + sql = sql + strings.Join(values, ",") + tk.MustExec(sql) + tk.MustQuery("select count(*) from com_batch_insert;").Check(testkit.Rows("200")) + + // Test case for batch delete. + // This will meet txn too large error. + _, err = tk.Exec("delete from batch_insert;") + c.Assert(err, NotNil) + c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue) + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("320")) + // Enable batch delete and set batch size to 50. + tk.MustExec("set @@session.tidb_batch_delete=on;") + tk.MustExec("set @@session.tidb_dml_batch_size=50;") + tk.MustExec("delete from batch_insert;") + // Make sure that all rows are gone. + r = tk.MustQuery("select count(*) from batch_insert;") + r.Check(testkit.Rows("0")) +} + type checkPrioClient struct { tikv.Client priority pb.CommandPri + mu struct { + sync.RWMutex + checkPrio bool + } } func (c *checkPrioClient) setCheckPriority(priority pb.CommandPri) { @@ -802,10 +930,16 @@ func (c *checkPrioClient) getCheckPriority() pb.CommandPri { func (c *checkPrioClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { resp, err := c.Client.SendRequest(ctx, addr, req, timeout) - switch req.Type { - case tikvrpc.CmdCop: - if c.getCheckPriority() != req.Priority { - return nil, errors.New("fail to set priority") + c.mu.RLock() + defer func() { + c.mu.RUnlock() + }() + if c.mu.checkPrio { + switch req.Type { + case tikvrpc.CmdCop: + if c.getCheckPriority() != req.Priority { + return nil, errors.New("fail to set priority") + } } } return resp, err @@ -854,6 +988,10 @@ func (s *seqTestSuite1) TestCoprocessorPriority(c *C) { } cli := s.cli + cli.mu.Lock() + cli.mu.checkPrio = true + cli.mu.Unlock() + cli.setCheckPriority(pb.CommandPri_High) tk.MustQuery("select id from t where id = 1") tk.MustQuery("select * from t1 where id = 1") @@ -889,4 +1027,55 @@ func (s *seqTestSuite1) TestCoprocessorPriority(c *C) { cli.setCheckPriority(pb.CommandPri_Low) tk.MustQuery("select LOW_PRIORITY id from t where id = 1") + + cli.mu.Lock() + cli.mu.checkPrio = false + cli.mu.Unlock() +} + +func (s *seqTestSuite) TestAutoIDInRetry(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t (id int not null auto_increment primary key)") + + tk.MustExec("set @@tidb_disable_txn_auto_retry = 0") + tk.MustExec("begin") + tk.MustExec("insert into t values ()") + tk.MustExec("insert into t values (),()") + tk.MustExec("insert into t values ()") + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/session/mockCommitRetryForAutoID", `return(true)`), IsNil) + tk.MustExec("commit") + c.Assert(failpoint.Disable("github.com/pingcap/tidb/session/mockCommitRetryForAutoID"), IsNil) + + tk.MustExec("insert into t values ()") + tk.MustQuery(`select * from t`).Check(testkit.Rows("1", "2", "3", "4", "5")) +} + +func (s *seqTestSuite) TestMaxDeltaSchemaCount(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + c.Assert(variable.GetMaxDeltaSchemaCount(), Equals, int64(variable.DefTiDBMaxDeltaSchemaCount)) + gvc := domain.GetDomain(tk.Se).GetGlobalVarsCache() + gvc.Disable() + + tk.MustExec("set @@global.tidb_max_delta_schema_count= -1") + tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect tidb_max_delta_schema_count value: '-1'")) + // Make sure a new session will load global variables. + tk.Se = nil + tk.MustExec("use test") + c.Assert(variable.GetMaxDeltaSchemaCount(), Equals, int64(100)) + tk.MustExec(fmt.Sprintf("set @@global.tidb_max_delta_schema_count= %v", uint64(math.MaxInt64))) + tk.MustQuery("show warnings;").Check(testkit.Rows(fmt.Sprintf("Warning 1292 Truncated incorrect tidb_max_delta_schema_count value: '%d'", uint64(math.MaxInt64)))) + tk.Se = nil + tk.MustExec("use test") + c.Assert(variable.GetMaxDeltaSchemaCount(), Equals, int64(16384)) + _, err := tk.Exec("set @@global.tidb_max_delta_schema_count= invalid_val") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err)) + + tk.MustExec("set @@global.tidb_max_delta_schema_count= 2048") + tk.Se = nil + tk.MustExec("use test") + c.Assert(variable.GetMaxDeltaSchemaCount(), Equals, int64(2048)) + tk.MustQuery("select @@global.tidb_max_delta_schema_count").Check(testkit.Rows("2048")) } diff --git a/executor/set.go b/executor/set.go index 1867b2b8d09a2..3400e62257a9d 100644 --- a/executor/set.go +++ b/executor/set.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/gcutil" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/stmtsummary" "go.uber.org/zap" ) @@ -42,7 +43,7 @@ type SetExecutor struct { } // Next implements the Executor Next interface. -func (e *SetExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *SetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.done { return nil @@ -119,6 +120,7 @@ func (e *SetExecutor) setSysVariable(name string, v *expression.VarAssignment) e if sysVar.Scope == variable.ScopeNone { return errors.Errorf("Variable '%s' is a read only variable", name) } + var valStr string if v.IsGlobal { // Set global scope system variable. if sysVar.Scope&variable.ScopeGlobal == 0 { @@ -131,18 +133,18 @@ func (e *SetExecutor) setSysVariable(name string, v *expression.VarAssignment) e if value.IsNull() { value.SetString("") } - svalue, err := value.ToString() + valStr, err = value.ToString() if err != nil { return err } - err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(name, svalue) + err = sessionVars.GlobalVarsAccessor.SetGlobalSysVar(name, valStr) if err != nil { return err } err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { auditPlugin := plugin.DeclareAuditManifest(p.Manifest) if auditPlugin.OnGlobalVariableEvent != nil { - auditPlugin.OnGlobalVariableEvent(context.Background(), e.ctx.GetSessionVars(), name, svalue) + auditPlugin.OnGlobalVariableEvent(context.Background(), e.ctx.GetSessionVars(), name, valStr) } return nil }) @@ -179,7 +181,6 @@ func (e *SetExecutor) setSysVariable(name string, v *expression.VarAssignment) e sessionVars.SnapshotTS = oldSnapshotTS return err } - var valStr string if value.IsNull() { valStr = "NULL" } else { @@ -187,7 +188,17 @@ func (e *SetExecutor) setSysVariable(name string, v *expression.VarAssignment) e valStr, err = value.ToString() terror.Log(err) } - logutil.Logger(context.Background()).Info("set session var", zap.Uint64("conn", sessionVars.ConnectionID), zap.String("name", name), zap.String("val", valStr)) + if name != variable.AutoCommit { + logutil.Logger(context.Background()).Info("set session var", zap.Uint64("conn", sessionVars.ConnectionID), zap.String("name", name), zap.String("val", valStr)) + } else { + // Some applications will set `autocommit` variable before query. + // This will print too many unnecessary log info. + logutil.Logger(context.Background()).Debug("set session var", zap.Uint64("conn", sessionVars.ConnectionID), zap.String("name", name), zap.String("val", valStr)) + } + } + + if name == variable.TiDBEnableStmtSummary { + stmtsummary.StmtSummaryByDigestMap.SetEnabled(valStr, !v.IsGlobal) } return nil diff --git a/executor/set_test.go b/executor/set_test.go index 4f55bab2bd553..8f0f1d78cc03a 100644 --- a/executor/set_test.go +++ b/executor/set_test.go @@ -15,6 +15,7 @@ package executor_test import ( "context" + "strconv" . "github.com/pingcap/check" "github.com/pingcap/parser/terror" @@ -340,23 +341,51 @@ func (s *testSuite2) TestSetVar(c *C) { _, err = tk.Exec("set global read_only = abc") c.Assert(err, NotNil) - // test for tidb_wait_table_split_finish - tk.MustQuery(`select @@session.tidb_wait_table_split_finish;`).Check(testkit.Rows("0")) - tk.MustExec("set tidb_wait_table_split_finish = 1") - tk.MustQuery(`select @@session.tidb_wait_table_split_finish;`).Check(testkit.Rows("1")) - tk.MustExec("set tidb_wait_table_split_finish = 0") - tk.MustQuery(`select @@session.tidb_wait_table_split_finish;`).Check(testkit.Rows("0")) - - tk.MustExec("set session tidb_back_off_weight = 3") - tk.MustQuery("select @@session.tidb_back_off_weight;").Check(testkit.Rows("3")) - tk.MustExec("set session tidb_back_off_weight = 20") - tk.MustQuery("select @@session.tidb_back_off_weight;").Check(testkit.Rows("20")) - _, err = tk.Exec("set session tidb_back_off_weight = -1") + // test for tidb_wait_split_region_finish + tk.MustQuery(`select @@session.tidb_wait_split_region_finish;`).Check(testkit.Rows("1")) + tk.MustExec("set tidb_wait_split_region_finish = 1") + tk.MustQuery(`select @@session.tidb_wait_split_region_finish;`).Check(testkit.Rows("1")) + tk.MustExec("set tidb_wait_split_region_finish = 0") + tk.MustQuery(`select @@session.tidb_wait_split_region_finish;`).Check(testkit.Rows("0")) + + // test for tidb_scatter_region + tk.MustQuery(`select @@global.tidb_scatter_region;`).Check(testkit.Rows("0")) + tk.MustExec("set global tidb_scatter_region = 1") + tk.MustQuery(`select @@global.tidb_scatter_region;`).Check(testkit.Rows("1")) + tk.MustExec("set global tidb_scatter_region = 0") + tk.MustQuery(`select @@global.tidb_scatter_region;`).Check(testkit.Rows("0")) + _, err = tk.Exec("set session tidb_scatter_region = 0") c.Assert(err, NotNil) - _, err = tk.Exec("set global tidb_back_off_weight = 0") + _, err = tk.Exec(`select @@session.tidb_scatter_region;`) c.Assert(err, NotNil) - tk.MustExec("set global tidb_back_off_weight = 10") - tk.MustQuery("select @@global.tidb_back_off_weight;").Check(testkit.Rows("10")) + + // test for tidb_wait_split_region_timeout + tk.MustQuery(`select @@session.tidb_wait_split_region_timeout;`).Check(testkit.Rows(strconv.Itoa(variable.DefWaitSplitRegionTimeout))) + tk.MustExec("set tidb_wait_split_region_timeout = 1") + tk.MustQuery(`select @@session.tidb_wait_split_region_timeout;`).Check(testkit.Rows("1")) + _, err = tk.Exec("set tidb_wait_split_region_timeout = 0") + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "tidb_wait_split_region_timeout(0) cannot be smaller than 1") + tk.MustQuery(`select @@session.tidb_wait_split_region_timeout;`).Check(testkit.Rows("1")) + + tk.MustExec("set session tidb_backoff_weight = 3") + tk.MustQuery("select @@session.tidb_backoff_weight;").Check(testkit.Rows("3")) + tk.MustExec("set session tidb_backoff_weight = 20") + tk.MustQuery("select @@session.tidb_backoff_weight;").Check(testkit.Rows("20")) + _, err = tk.Exec("set session tidb_backoff_weight = -1") + c.Assert(err, NotNil) + _, err = tk.Exec("set global tidb_backoff_weight = 0") + c.Assert(err, NotNil) + tk.MustExec("set global tidb_backoff_weight = 10") + tk.MustQuery("select @@global.tidb_backoff_weight;").Check(testkit.Rows("10")) + + tk.MustExec("set @@tidb_expensive_query_time_threshold=70") + tk.MustQuery("select @@tidb_expensive_query_time_threshold;").Check(testkit.Rows("70")) + + tk.MustExec("set @@tidb_record_plan_in_slow_log = 1") + tk.MustQuery("select @@tidb_record_plan_in_slow_log;").Check(testkit.Rows("1")) + tk.MustExec("set @@tidb_record_plan_in_slow_log = 0") + tk.MustQuery("select @@tidb_record_plan_in_slow_log;").Check(testkit.Rows("0")) } func (s *testSuite2) TestSetCharset(c *C) { @@ -460,6 +489,19 @@ func (s *testSuite2) TestValidateSetVar(c *C) { _, err = tk.Exec("set @@global.max_connections='hello'") c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + tk.MustExec("set @@global.thread_pool_size=65") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect thread_pool_size value: '65'")) + result = tk.MustQuery("select @@global.thread_pool_size;") + result.Check(testkit.Rows("64")) + + tk.MustExec("set @@global.thread_pool_size=-1") + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect thread_pool_size value: '-1'")) + result = tk.MustQuery("select @@global.thread_pool_size;") + result.Check(testkit.Rows("1")) + + _, err = tk.Exec("set @@global.thread_pool_size='hello'") + c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue) + tk.MustExec("set @@global.max_allowed_packet=-1") tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1292|Truncated incorrect max_allowed_packet value: '-1'")) result = tk.MustQuery("select @@global.max_allowed_packet;") diff --git a/executor/show.go b/executor/show.go index e6e01a2fbcbfc..fb725af202d94 100644 --- a/executor/show.go +++ b/executor/show.go @@ -37,11 +37,13 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/store/tikv" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" @@ -60,9 +62,10 @@ type ShowExec struct { DBName model.CIStr Table *ast.TableName // Used for showing columns. Column *ast.ColumnName // Used for `desc table column`. + IndexName model.CIStr // Used for show table regions. Flag int // Some flag parsed from sql, such as FULL. Full bool - User *auth.UserIdentity // Used for show grants. + User *auth.UserIdentity // Used by show grants, show create user. Roles []*auth.RoleIdentity // Used for show grants. IfNotExists bool // Used for `show create database if not exists` @@ -76,11 +79,11 @@ type ShowExec struct { } // Next implements the Executor Next interface. -func (e *ShowExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *ShowExec) Next(ctx context.Context, req *chunk.Chunk) error { req.GrowAndReset(e.maxChunkSize) if e.result == nil { - e.result = e.newFirstChunk() - err := e.fetchAll() + e.result = newFirstChunk(e) + err := e.fetchAll(ctx) if err != nil { return errors.Trace(err) } @@ -106,14 +109,14 @@ func (e *ShowExec) Next(ctx context.Context, req *chunk.RecordBatch) error { return nil } -func (e *ShowExec) fetchAll() error { +func (e *ShowExec) fetchAll(ctx context.Context) error { switch e.Tp { case ast.ShowCharset: return e.fetchShowCharset() case ast.ShowCollation: return e.fetchShowCollation() case ast.ShowColumns: - return e.fetchShowColumns() + return e.fetchShowColumns(ctx) case ast.ShowCreateTable: return e.fetchShowCreateTable() case ast.ShowCreateUser: @@ -178,6 +181,8 @@ func (e *ShowExec) fetchAll() error { case ast.ShowAnalyzeStatus: e.fetchShowAnalyzeStatus() return nil + case ast.ShowRegions: + return e.fetchShowTableRegions() } return nil } @@ -269,7 +274,7 @@ func (e *ShowExec) fetchShowProcessList() error { if !hasProcessPriv && pi.User != loginUser.Username { continue } - row := pi.ToRow(e.Full) + row := pi.ToRowForShow(e.Full) e.appendRow(row) } return nil @@ -362,7 +367,7 @@ func createOptions(tb *model.TableInfo) string { return "" } -func (e *ShowExec) fetchShowColumns() error { +func (e *ShowExec) fetchShowColumns(ctx context.Context) error { tb, err := e.getTable() if err != nil { @@ -379,7 +384,7 @@ func (e *ShowExec) fetchShowColumns() error { // Because view's undertable's column could change or recreate, so view's column type may change overtime. // To avoid this situation we need to generate a logical plan and extract current column types from Schema. planBuilder := plannercore.NewPlanBuilder(e.ctx, e.is) - viewLogicalPlan, err := planBuilder.BuildDataSourceFromView(e.DBName, tb.Meta()) + viewLogicalPlan, err := planBuilder.BuildDataSourceFromView(ctx, e.DBName, tb.Meta()) if err != nil { return err } @@ -402,7 +407,7 @@ func (e *ShowExec) fetchShowColumns() error { // SHOW COLUMNS result expects string value defaultValStr := fmt.Sprintf("%v", desc.DefaultValue) // If column is timestamp, and default value is not current_timestamp, should convert the default value to the current session time zone. - if col.Tp == mysql.TypeTimestamp && defaultValStr != types.ZeroDatetimeStr && strings.ToUpper(defaultValStr) != strings.ToUpper(ast.CurrentTimestamp) { + if col.Tp == mysql.TypeTimestamp && defaultValStr != types.ZeroDatetimeStr && !strings.HasPrefix(strings.ToUpper(defaultValStr), strings.ToUpper(ast.CurrentTimestamp)) { timeValue, err := table.GetColDefaultValue(e.ctx, col.ToInfo()) if err != nil { return errors.Trace(err) @@ -659,8 +664,16 @@ func (e *ShowExec) fetchShowCreateTable() error { for i, col := range tb.Cols() { fmt.Fprintf(&buf, " %s %s", escape(col.Name, sqlMode), col.GetTypeDesc()) if col.Charset != "binary" { - if col.Charset != tblCharset || col.Collate != tblCollate { - fmt.Fprintf(&buf, " CHARACTER SET %s COLLATE %s", col.Charset, col.Collate) + if col.Charset != tblCharset { + fmt.Fprintf(&buf, " CHARACTER SET %s", col.Charset) + } + if col.Collate != tblCollate { + fmt.Fprintf(&buf, " COLLATE %s", col.Collate) + } else { + defcol, err := charset.GetDefaultCollation(col.Charset) + if err == nil && defcol != col.Collate { + fmt.Fprintf(&buf, " COLLATE %s", col.Collate) + } } } if col.IsGenerated() { @@ -692,6 +705,9 @@ func (e *ShowExec) fetchShowCreateTable() error { } case "CURRENT_TIMESTAMP": buf.WriteString(" DEFAULT CURRENT_TIMESTAMP") + if col.Decimal > 0 { + buf.WriteString(fmt.Sprintf("(%d)", col.Decimal)) + } default: defaultValStr := fmt.Sprintf("%v", defaultValue) // If column is timestamp, and default value is not current_timestamp, should convert the default value to the current session time zone. @@ -713,6 +729,7 @@ func (e *ShowExec) fetchShowCreateTable() error { } if mysql.HasOnUpdateNowFlag(col.Flag) { buf.WriteString(" ON UPDATE CURRENT_TIMESTAMP") + buf.WriteString(table.OptionalFsp(&col.FieldType)) } } if len(col.Comment) > 0 { @@ -784,9 +801,6 @@ func (e *ShowExec) fetchShowCreateTable() error { fmt.Fprintf(&buf, " COMPRESSION='%s'", tb.Meta().Compression) } - // add partition info here. - appendPartitionInfo(tb.Meta().Partition, &buf) - if hasAutoIncID { autoIncID, err := tb.Allocator(e.ctx).NextGlobalAutoID(tb.Meta().ID) if err != nil { @@ -809,6 +823,9 @@ func (e *ShowExec) fetchShowCreateTable() error { if len(tb.Meta().Comment) > 0 { fmt.Fprintf(&buf, " COMMENT='%s'", format.OutputFormat(tb.Meta().Comment)) } + // add partition info here. + appendPartitionInfo(tb.Meta().Partition, &buf) + e.appendRow([]interface{}{tb.Meta().Name.O, buf.String()}) return nil } @@ -938,8 +955,23 @@ func (e *ShowExec) fetchShowCreateUser() error { if checker == nil { return errors.New("miss privilege checker") } + + userName, hostName := e.User.Username, e.User.Hostname + sessVars := e.ctx.GetSessionVars() + if e.User.CurrentUser { + userName = sessVars.User.AuthUsername + hostName = sessVars.User.AuthHostname + } else { + // Show create user requires the SELECT privilege on mysql.user. + // Ref https://dev.mysql.com/doc/refman/5.7/en/show-create-user.html + activeRoles := sessVars.ActiveRoles + if !checker.RequestVerification(activeRoles, mysql.SystemDB, mysql.UserTable, "", mysql.SelectPriv) { + return e.tableAccessDenied("SELECT", mysql.UserTable) + } + } + sql := fmt.Sprintf(`SELECT * FROM %s.%s WHERE User='%s' AND Host='%s';`, - mysql.SystemDB, mysql.UserTable, e.User.Username, e.User.Hostname) + mysql.SystemDB, mysql.UserTable, userName, hostName) rows, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql) if err != nil { return errors.Trace(err) @@ -948,9 +980,8 @@ func (e *ShowExec) fetchShowCreateUser() error { return ErrCannotUser.GenWithStackByArgs("SHOW CREATE USER", fmt.Sprintf("'%s'@'%s'", e.User.Username, e.User.Hostname)) } - showStr := fmt.Sprintf("CREATE USER '%s'@'%s' IDENTIFIED WITH 'mysql_native_password' AS '%s' %s", - e.User.Username, e.User.Hostname, checker.GetEncodedPassword(e.User.Username, e.User.Hostname), - "REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK") + showStr := fmt.Sprintf("CREATE USER '%s'@'%s' IDENTIFIED WITH 'mysql_native_password' AS '%s' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK", + e.User.Username, e.User.Hostname, checker.GetEncodedPassword(e.User.Username, e.User.Hostname)) e.appendRow([]interface{}{showStr}) return nil } @@ -1027,7 +1058,7 @@ func (e *ShowExec) fetchShowPlugins() error { tiPlugins := plugin.GetAll() for _, ps := range tiPlugins { for _, p := range ps { - e.appendRow([]interface{}{p.Name, p.State.String(), p.Kind.String(), p.Path, p.License, strconv.Itoa(int(p.Version))}) + e.appendRow([]interface{}{p.Name, p.StateValue(), p.Kind.String(), p.Path, p.License, strconv.Itoa(int(p.Version))}) } } return nil @@ -1163,3 +1194,113 @@ func (e *ShowExec) appendRow(row []interface{}) { } } } + +func (e *ShowExec) fetchShowTableRegions() error { + store := e.ctx.GetStore() + tikvStore, ok := store.(tikv.Storage) + if !ok { + return nil + } + splitStore, ok := store.(kv.SplitableStore) + if !ok { + return nil + } + + tb, err := e.getTable() + if err != nil { + return errors.Trace(err) + } + + // Get table regions from from pd, not from regionCache, because the region cache maybe outdated. + var regions []regionMeta + if len(e.IndexName.L) != 0 { + indexInfo := tb.Meta().FindIndexByName(e.IndexName.L) + if indexInfo == nil { + return plannercore.ErrKeyDoesNotExist.GenWithStackByArgs(e.IndexName, tb.Meta().Name) + } + regions, err = getTableIndexRegions(tb, indexInfo, tikvStore, splitStore) + } else { + regions, err = getTableRegions(tb, tikvStore, splitStore) + } + + if err != nil { + return err + } + e.fillRegionsToChunk(regions) + return nil +} + +func getTableRegions(tb table.Table, tikvStore tikv.Storage, splitStore kv.SplitableStore) ([]regionMeta, error) { + if info := tb.Meta().GetPartitionInfo(); info != nil { + return getPartitionTableRegions(info, tb.(table.PartitionedTable), tikvStore, splitStore) + } + return getPhysicalTableRegions(tb.Meta().ID, tb.Meta(), tikvStore, splitStore, nil) +} + +func getTableIndexRegions(tb table.Table, indexInfo *model.IndexInfo, tikvStore tikv.Storage, splitStore kv.SplitableStore) ([]regionMeta, error) { + if info := tb.Meta().GetPartitionInfo(); info != nil { + return getPartitionIndexRegions(info, tb.(table.PartitionedTable), indexInfo, tikvStore, splitStore) + } + return getPhysicalIndexRegions(tb.Meta().ID, indexInfo, tikvStore, splitStore, nil) +} + +func getPartitionTableRegions(info *model.PartitionInfo, tbl table.PartitionedTable, tikvStore tikv.Storage, splitStore kv.SplitableStore) ([]regionMeta, error) { + regions := make([]regionMeta, 0, len(info.Definitions)) + uniqueRegionMap := make(map[uint64]struct{}) + for _, def := range info.Definitions { + pid := def.ID + partition := tbl.GetPartition(pid) + partition.GetPhysicalID() + partitionRegions, err := getPhysicalTableRegions(partition.GetPhysicalID(), tbl.Meta(), tikvStore, splitStore, uniqueRegionMap) + if err != nil { + return nil, err + } + regions = append(regions, partitionRegions...) + } + return regions, nil +} + +func getPartitionIndexRegions(info *model.PartitionInfo, tbl table.PartitionedTable, indexInfo *model.IndexInfo, tikvStore tikv.Storage, splitStore kv.SplitableStore) ([]regionMeta, error) { + var regions []regionMeta + uniqueRegionMap := make(map[uint64]struct{}) + for _, def := range info.Definitions { + pid := def.ID + partition := tbl.GetPartition(pid) + partition.GetPhysicalID() + partitionRegions, err := getPhysicalIndexRegions(partition.GetPhysicalID(), indexInfo, tikvStore, splitStore, uniqueRegionMap) + if err != nil { + return nil, err + } + regions = append(regions, partitionRegions...) + } + return regions, nil +} + +func (e *ShowExec) fillRegionsToChunk(regions []regionMeta) { + for i := range regions { + e.result.AppendUint64(0, regions[i].region.Id) + e.result.AppendString(1, regions[i].start) + e.result.AppendString(2, regions[i].end) + e.result.AppendUint64(3, regions[i].leaderID) + e.result.AppendUint64(4, regions[i].storeID) + + peers := "" + for i, peer := range regions[i].region.Peers { + if i > 0 { + peers += ", " + } + peers += strconv.FormatUint(peer.Id, 10) + } + e.result.AppendString(5, peers) + if regions[i].scattering { + e.result.AppendInt64(6, 1) + } else { + e.result.AppendInt64(6, 0) + } + + e.result.AppendInt64(7, regions[i].writtenBytes) + e.result.AppendInt64(8, regions[i].readBytes) + e.result.AppendInt64(9, regions[i].approximateSize) + e.result.AppendInt64(10, regions[i].approximateKeys) + } +} diff --git a/executor/show_stats_test.go b/executor/show_stats_test.go index fcd2e1beb0a2b..529671f385dc8 100644 --- a/executor/show_stats_test.go +++ b/executor/show_stats_test.go @@ -19,7 +19,11 @@ import ( "github.com/pingcap/tidb/util/testkit" ) -func (s *testSuite1) TestShowStatsMeta(c *C) { +type testShowStatsSuite struct { + testSuite +} + +func (s *testShowStatsSuite) TestShowStatsMeta(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t, t1") @@ -35,7 +39,7 @@ func (s *testSuite1) TestShowStatsMeta(c *C) { c.Assert(result.Rows()[0][1], Equals, "t") } -func (s *testSuite1) TestShowStatsHistograms(c *C) { +func (s *testShowStatsSuite) TestShowStatsHistograms(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -63,7 +67,7 @@ func (s *testSuite1) TestShowStatsHistograms(c *C) { c.Assert(len(res.Rows()), Equals, 1) } -func (s *testSuite1) TestShowStatsBuckets(c *C) { +func (s *testShowStatsSuite) TestShowStatsBuckets(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("drop table if exists t") @@ -77,7 +81,7 @@ func (s *testSuite1) TestShowStatsBuckets(c *C) { result.Check(testkit.Rows("test t idx 1 0 1 1 (1, 1) (1, 1)")) } -func (s *testSuite1) TestShowStatsHasNullValue(c *C) { +func (s *testShowStatsSuite) TestShowStatsHasNullValue(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") tk.MustExec("create table t (a int, index idx(a))") @@ -137,7 +141,7 @@ func (s *testSuite1) TestShowStatsHasNullValue(c *C) { c.Assert(res.Rows()[4][7], Equals, "0") } -func (s *testSuite1) TestShowPartitionStats(c *C) { +func (s *testShowStatsSuite) TestShowPartitionStats(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("set @@session.tidb_enable_table_partition=1") tk.MustExec("use test") @@ -170,7 +174,7 @@ func (s *testSuite1) TestShowPartitionStats(c *C) { result.Check(testkit.Rows("test t p0 100")) } -func (s *testSuite1) TestShowAnalyzeStatus(c *C) { +func (s *testShowStatsSuite) TestShowAnalyzeStatus(c *C) { tk := testkit.NewTestKit(c, s.store) statistics.ClearHistoryJobs() tk.MustExec("use test") diff --git a/executor/show_test.go b/executor/show_test.go index 0e3ce9a481991..559a7f1286e17 100644 --- a/executor/show_test.go +++ b/executor/show_test.go @@ -97,16 +97,6 @@ func (s *testSuite2) TestShowDatabasesInfoSchemaFirst(c *C) { tk.MustExec(`drop database BBBB`) } -// mockSessionManager is a mocked session manager that wraps one session -// it returns only this session's current process info as processlist for test. -type mockSessionManager struct { - session.Session -} - -// Kill implements the SessionManager.Kill interface. -func (msm *mockSessionManager) Kill(cid uint64, query bool) { -} - func (s *testSuite2) TestShowWarnings(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -157,6 +147,33 @@ func (s *testSuite2) TestIssue3641(c *C) { c.Assert(err.Error(), Equals, plannercore.ErrNoDB.Error()) } +func (s *testSuite2) TestIssue10549(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("CREATE DATABASE newdb;") + tk.MustExec("CREATE ROLE 'app_developer';") + tk.MustExec("GRANT ALL ON newdb.* TO 'app_developer';") + tk.MustExec("CREATE USER 'dev';") + tk.MustExec("GRANT 'app_developer' TO 'dev';") + tk.MustExec("SET DEFAULT ROLE app_developer TO 'dev';") + + c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "dev", Hostname: "localhost", AuthUsername: "dev", AuthHostname: "localhost"}, nil, nil), IsTrue) + tk.MustQuery("SHOW DATABASES;").Check(testkit.Rows("INFORMATION_SCHEMA", "newdb")) + tk.MustQuery("SHOW GRANTS;").Check(testkit.Rows("GRANT USAGE ON *.* TO 'dev'@'%'", "GRANT ALL PRIVILEGES ON newdb.* TO 'dev'@'%'", "GRANT 'app_developer'@'%' TO 'dev'@'%'")) + tk.MustQuery("SHOW GRANTS FOR CURRENT_USER").Check(testkit.Rows("GRANT USAGE ON *.* TO 'dev'@'%'", "GRANT 'app_developer'@'%' TO 'dev'@'%'")) +} + +func (s *testSuite3) TestIssue11165(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("CREATE ROLE 'r_manager';") + tk.MustExec("CREATE USER 'manager'@'localhost';") + tk.MustExec("GRANT 'r_manager' TO 'manager'@'localhost';") + + c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "manager", Hostname: "localhost", AuthUsername: "manager", AuthHostname: "localhost"}, nil, nil), IsTrue) + tk.MustExec("SET DEFAULT ROLE ALL TO 'manager'@'localhost';") + tk.MustExec("SET DEFAULT ROLE NONE TO 'manager'@'localhost';") + tk.MustExec("SET DEFAULT ROLE 'r_manager' TO 'manager'@'localhost';") +} + // TestShow2 is moved from session_test func (s *testSuite2) TestShow2(c *C) { tk := testkit.NewTestKit(c, s.store) @@ -181,13 +198,19 @@ func (s *testSuite2) TestShow2(c *C) { c_nchar national char(1) charset ascii collate ascii_bin, c_binary binary, c_varchar varchar(1) charset ascii collate ascii_bin, + c_varchar_default varchar(20) charset ascii collate ascii_bin default 'cUrrent_tImestamp', c_nvarchar national varchar(1) charset ascii collate ascii_bin, c_varbinary varbinary(1), c_year year, c_date date, c_time time, c_datetime datetime, + c_datetime_default datetime default current_timestamp, + c_datetime_default_2 datetime(2) default current_timestamp(2), c_timestamp timestamp, + c_timestamp_default timestamp default current_timestamp, + c_timestamp_default_3 timestamp(3) default current_timestamp(3), + c_timestamp_default_4 timestamp(3) default current_timestamp(3) on update current_timestamp(3), c_blob blob, c_tinyblob tinyblob, c_mediumblob mediumblob, @@ -211,13 +234,19 @@ func (s *testSuite2) TestShow2(c *C) { "[c_nchar char(1) ascii_bin YES select,insert,update,references ]\n" + "[c_binary binary(1) YES select,insert,update,references ]\n" + "[c_varchar varchar(1) ascii_bin YES select,insert,update,references ]\n" + + "[c_varchar_default varchar(20) ascii_bin YES cUrrent_tImestamp select,insert,update,references ]\n" + "[c_nvarchar varchar(1) ascii_bin YES select,insert,update,references ]\n" + "[c_varbinary varbinary(1) YES select,insert,update,references ]\n" + "[c_year year(4) YES select,insert,update,references ]\n" + "[c_date date YES select,insert,update,references ]\n" + "[c_time time YES select,insert,update,references ]\n" + "[c_datetime datetime YES select,insert,update,references ]\n" + + "[c_datetime_default datetime YES CURRENT_TIMESTAMP select,insert,update,references ]\n" + + "[c_datetime_default_2 datetime(2) YES CURRENT_TIMESTAMP(2) select,insert,update,references ]\n" + "[c_timestamp timestamp YES select,insert,update,references ]\n" + + "[c_timestamp_default timestamp YES CURRENT_TIMESTAMP select,insert,update,references ]\n" + + "[c_timestamp_default_3 timestamp(3) YES CURRENT_TIMESTAMP(3) select,insert,update,references ]\n" + + "[c_timestamp_default_4 timestamp(3) YES CURRENT_TIMESTAMP(3) DEFAULT_GENERATED on update CURRENT_TIMESTAMP(3) select,insert,update,references ]\n" + "[c_blob blob YES select,insert,update,references ]\n" + "[c_tinyblob tinyblob YES select,insert,update,references ]\n" + "[c_mediumblob mediumblob YES select,insert,update,references ]\n" + @@ -252,7 +281,7 @@ func (s *testSuite2) TestShow2(c *C) { tk.Se.Auth(&auth.UserIdentity{Username: "root", Hostname: "192.168.0.1", AuthUsername: "root", AuthHostname: "%"}, nil, []byte("012345678901234567890")) r := tk.MustQuery("show table status from test like 't'") - r.Check(testkit.Rows(fmt.Sprintf("t InnoDB 10 Compact 0 0 0 0 0 0 0 %s utf8mb4_bin 注释", createTime))) + r.Check(testkit.Rows(fmt.Sprintf("t InnoDB 10 Compact 0 0 0 0 0 0 %s utf8mb4_bin 注释", createTime))) tk.MustQuery("show databases like 'test'").Check(testkit.Rows("test")) @@ -263,7 +292,7 @@ func (s *testSuite2) TestShow2(c *C) { tk.MustQuery("show grants for current_user").Check(testkit.Rows(`GRANT ALL PRIVILEGES ON *.* TO 'root'@'%'`)) } -func (s *testSuite2) TestShow3(c *C) { +func (s *testSuite2) TestShowCreateUser(c *C) { tk := testkit.NewTestKit(c, s.store) // Create a new user. tk.MustExec(`CREATE USER 'test_show_create_user'@'%' IDENTIFIED BY 'root';`) @@ -281,6 +310,27 @@ func (s *testSuite2) TestShow3(c *C) { // Case: a user that doesn't exist err = tk.QueryToErr("show create user 'aaa'@'localhost';") c.Assert(err.Error(), Equals, executor.ErrCannotUser.GenWithStackByArgs("SHOW CREATE USER", "'aaa'@'localhost'").Error()) + + tk.Se.Auth(&auth.UserIdentity{Username: "root", Hostname: "127.0.0.1", AuthUsername: "root", AuthHostname: "%"}, nil, nil) + rows := tk.MustQuery("show create user current_user") + rows.Check(testkit.Rows("CREATE USER 'root'@'127.0.0.1' IDENTIFIED WITH 'mysql_native_password' AS '' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK")) + + rows = tk.MustQuery("show create user current_user()") + rows.Check(testkit.Rows("CREATE USER 'root'@'127.0.0.1' IDENTIFIED WITH 'mysql_native_password' AS '' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK")) + + tk.MustExec("create user 'check_priv'") + + // "show create user" for other user requires the SELECT privilege on mysql database. + tk1 := testkit.NewTestKit(c, s.store) + tk1.MustExec("use mysql") + succ := tk1.Se.Auth(&auth.UserIdentity{Username: "check_priv", Hostname: "127.0.0.1", AuthUsername: "test_show", AuthHostname: "asdf"}, nil, nil) + c.Assert(succ, IsTrue) + err = tk1.QueryToErr("show create user 'root'@'%'") + c.Assert(err, NotNil) + + // "show create user" for current user doesn't check privileges. + rows = tk1.MustQuery("show create user current_user") + rows.Check(testkit.Rows("CREATE USER 'check_priv'@'127.0.0.1' IDENTIFIED WITH 'mysql_native_password' AS '' REQUIRE NONE PASSWORD EXPIRE DEFAULT ACCOUNT UNLOCK")) } func (s *testSuite2) TestUnprivilegedShow(c *C) { @@ -310,7 +360,7 @@ func (s *testSuite2) TestUnprivilegedShow(c *C) { c.Assert(err, IsNil) createTime := model.TSConvert2Time(tblInfo.Meta().UpdateTS).Format("2006-01-02 15:04:05") - tk.MustQuery("show table status from testshow").Check(testkit.Rows(fmt.Sprintf("t1 InnoDB 10 Compact 0 0 0 0 0 0 0 %s utf8mb4_bin ", createTime))) + tk.MustQuery("show table status from testshow").Check(testkit.Rows(fmt.Sprintf("t1 InnoDB 10 Compact 0 0 0 0 0 0 %s utf8mb4_bin ", createTime))) } @@ -396,16 +446,26 @@ func (s *testSuite2) TestShowCreateTable(c *C) { tk.MustExec("create table t1(a int,b int)") tk.MustExec("drop view if exists v1") tk.MustExec("create or replace definer=`root`@`127.0.0.1` view v1 as select * from t1") - tk.MustQuery("show create table v1").Check(testutil.RowsWithSep("|", "v1|CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`127.0.0.1` SQL SECURITY DEFINER VIEW `v1` (`a`, `b`) AS select * from t1 ")) - tk.MustQuery("show create view v1").Check(testutil.RowsWithSep("|", "v1|CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`127.0.0.1` SQL SECURITY DEFINER VIEW `v1` (`a`, `b`) AS select * from t1 ")) + tk.MustQuery("show create table v1").Check(testutil.RowsWithSep("|", "v1|CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`127.0.0.1` SQL SECURITY DEFINER VIEW `v1` (`a`, `b`) AS SELECT `test`.`t1`.`a`,`test`.`t1`.`b` FROM `test`.`t1` ")) + tk.MustQuery("show create view v1").Check(testutil.RowsWithSep("|", "v1|CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`127.0.0.1` SQL SECURITY DEFINER VIEW `v1` (`a`, `b`) AS SELECT `test`.`t1`.`a`,`test`.`t1`.`b` FROM `test`.`t1` ")) tk.MustExec("drop view v1") tk.MustExec("drop table t1") tk.MustExec("drop view if exists v") tk.MustExec("create or replace definer=`root`@`127.0.0.1` view v as select JSON_MERGE('{}', '{}') as col;") - tk.MustQuery("show create view v").Check(testutil.RowsWithSep("|", "v|CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`127.0.0.1` SQL SECURITY DEFINER VIEW `v` (`col`) AS select JSON_MERGE('{}', '{}') as col; ")) + tk.MustQuery("show create view v").Check(testutil.RowsWithSep("|", "v|CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`127.0.0.1` SQL SECURITY DEFINER VIEW `v` (`col`) AS SELECT JSON_MERGE('{}', '{}') AS `col` ")) tk.MustExec("drop view if exists v") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(a int,b int)") + tk.MustExec("create or replace definer=`root`@`127.0.0.1` view v1 as select avg(a),t1.* from t1 group by a") + tk.MustQuery("show create view v1").Check(testutil.RowsWithSep("|", "v1|CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`127.0.0.1` SQL SECURITY DEFINER VIEW `v1` (`avg(a)`, `a`, `b`) AS SELECT AVG(`a`),`test`.`t1`.`a`,`test`.`t1`.`b` FROM `test`.`t1` GROUP BY `a` ")) + tk.MustExec("drop view v1") + tk.MustExec("create or replace definer=`root`@`127.0.0.1` view v1 as select a+b, t1.* , a as c from t1") + tk.MustQuery("show create view v1").Check(testutil.RowsWithSep("|", "v1|CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`127.0.0.1` SQL SECURITY DEFINER VIEW `v1` (`a+b`, `a`, `b`, `c`) AS SELECT `a`+`b`,`test`.`t1`.`a`,`test`.`t1`.`b`,`a` AS `c` FROM `test`.`t1` ")) + tk.MustExec("drop table t1") + tk.MustExec("drop view v1") + // For issue #9211 tk.MustExec("create table t(c int, b int as (c + 1))ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;") tk.MustQuery("show create table `t`").Check(testutil.RowsWithSep("|", @@ -446,6 +506,29 @@ func (s *testSuite2) TestShowCreateTable(c *C) { ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", )) + tk.MustExec("drop table if exists t") + tk.MustExec("create table `t` (\n" + + "`a` timestamp not null default current_timestamp,\n" + + "`b` timestamp(3) default current_timestamp(3),\n" + + "`c` datetime default current_timestamp,\n" + + "`d` datetime(4) default current_timestamp(4),\n" + + "`e` varchar(20) default 'cUrrent_tImestamp',\n" + + "`f` datetime(2) default current_timestamp(2) on update current_timestamp(2),\n" + + "`g` timestamp(2) default current_timestamp(2) on update current_timestamp(2))") + tk.MustQuery("show create table `t`").Check(testutil.RowsWithSep("|", + ""+ + "t CREATE TABLE `t` (\n"+ + " `a` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,\n"+ + " `b` timestamp(3) DEFAULT CURRENT_TIMESTAMP(3),\n"+ + " `c` datetime DEFAULT CURRENT_TIMESTAMP,\n"+ + " `d` datetime(4) DEFAULT CURRENT_TIMESTAMP(4),\n"+ + " `e` varchar(20) DEFAULT 'cUrrent_tImestamp',\n"+ + " `f` datetime(2) DEFAULT CURRENT_TIMESTAMP(2) ON UPDATE CURRENT_TIMESTAMP(2),\n"+ + " `g` timestamp(2) DEFAULT CURRENT_TIMESTAMP(2) ON UPDATE CURRENT_TIMESTAMP(2)\n"+ + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", + )) + tk.MustExec("drop table t") + tk.MustExec("create table t (a int, b int) shard_row_id_bits = 4 pre_split_regions=3;") tk.MustQuery("show create table `t`").Check(testutil.RowsWithSep("|", ""+ @@ -455,6 +538,76 @@ func (s *testSuite2) TestShowCreateTable(c *C) { ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin/*!90000 SHARD_ROW_ID_BITS=4 PRE_SPLIT_REGIONS=3 */", )) tk.MustExec("drop table t") + + tk.MustExec("CREATE TABLE `log` (" + + "`LOG_ID` bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT," + + "`ROUND_ID` bigint(20) UNSIGNED NOT NULL," + + "`USER_ID` int(10) UNSIGNED NOT NULL," + + "`USER_IP` int(10) UNSIGNED DEFAULT NULL," + + "`END_TIME` datetime NOT NULL," + + "`USER_TYPE` int(11) DEFAULT NULL," + + "`APP_ID` int(11) DEFAULT NULL," + + "PRIMARY KEY (`LOG_ID`,`END_TIME`)," + + "KEY `IDX_EndTime` (`END_TIME`)," + + "KEY `IDX_RoundId` (`ROUND_ID`)," + + "KEY `IDX_UserId_EndTime` (`USER_ID`,`END_TIME`)" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin AUTO_INCREMENT=505488 " + + "PARTITION BY RANGE ( month(`end_time`) ) (" + + "PARTITION p1 VALUES LESS THAN (2)," + + "PARTITION p2 VALUES LESS THAN (3)," + + "PARTITION p3 VALUES LESS THAN (4)," + + "PARTITION p4 VALUES LESS THAN (5)," + + "PARTITION p5 VALUES LESS THAN (6)," + + "PARTITION p6 VALUES LESS THAN (7)," + + "PARTITION p7 VALUES LESS THAN (8)," + + "PARTITION p8 VALUES LESS THAN (9)," + + "PARTITION p9 VALUES LESS THAN (10)," + + "PARTITION p10 VALUES LESS THAN (11)," + + "PARTITION p11 VALUES LESS THAN (12)," + + "PARTITION p12 VALUES LESS THAN (MAXVALUE))") + tk.MustQuery("show create table log").Check(testutil.RowsWithSep("|", + "log CREATE TABLE `log` (\n"+ + " `LOG_ID` bigint(20) unsigned NOT NULL AUTO_INCREMENT,\n"+ + " `ROUND_ID` bigint(20) unsigned NOT NULL,\n"+ + " `USER_ID` int(10) unsigned NOT NULL,\n"+ + " `USER_IP` int(10) unsigned DEFAULT NULL,\n"+ + " `END_TIME` datetime NOT NULL,\n"+ + " `USER_TYPE` int(11) DEFAULT NULL,\n"+ + " `APP_ID` int(11) DEFAULT NULL,\n"+ + " PRIMARY KEY (`LOG_ID`,`END_TIME`),\n"+ + " KEY `IDX_EndTime` (`END_TIME`),\n"+ + " KEY `IDX_RoundId` (`ROUND_ID`),\n"+ + " KEY `IDX_UserId_EndTime` (`USER_ID`,`END_TIME`)\n"+ + ") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin AUTO_INCREMENT=505488\n"+ + "PARTITION BY RANGE ( month(`end_time`) ) (\n"+ + " PARTITION p1 VALUES LESS THAN (2),\n"+ + " PARTITION p2 VALUES LESS THAN (3),\n"+ + " PARTITION p3 VALUES LESS THAN (4),\n"+ + " PARTITION p4 VALUES LESS THAN (5),\n"+ + " PARTITION p5 VALUES LESS THAN (6),\n"+ + " PARTITION p6 VALUES LESS THAN (7),\n"+ + " PARTITION p7 VALUES LESS THAN (8),\n"+ + " PARTITION p8 VALUES LESS THAN (9),\n"+ + " PARTITION p9 VALUES LESS THAN (10),\n"+ + " PARTITION p10 VALUES LESS THAN (11),\n"+ + " PARTITION p11 VALUES LESS THAN (12),\n"+ + " PARTITION p12 VALUES LESS THAN (MAXVALUE)\n"+ + ")")) + //for issue #11831 + tk.MustExec("create table ttt4(a varchar(123) default null collate utf8mb4_unicode_ci)engine=innodb default charset=utf8mb4 collate=utf8mb4_unicode_ci;") + tk.MustQuery("show create table `ttt4`").Check(testutil.RowsWithSep("|", + ""+ + "ttt4 CREATE TABLE `ttt4` (\n"+ + " `a` varchar(123) COLLATE utf8mb4_unicode_ci DEFAULT NULL\n"+ + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci", + )) + tk.MustExec("create table ttt5(a varchar(123) default null)engine=innodb default charset=utf8mb4 collate=utf8mb4_bin;") + tk.MustQuery("show create table `ttt5`").Check(testutil.RowsWithSep("|", + ""+ + "ttt5 CREATE TABLE `ttt5` (\n"+ + " `a` varchar(123) DEFAULT NULL\n"+ + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin", + )) } func (s *testSuite2) TestShowEscape(c *C) { diff --git a/executor/simple.go b/executor/simple.go index 4ee51342d3838..4f6ed1528969f 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -18,6 +18,7 @@ import ( "fmt" "strings" + "github.com/ngaut/pools" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" @@ -28,6 +29,7 @@ import ( "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" @@ -51,11 +53,38 @@ type SimpleExec struct { is infoschema.InfoSchema } +func (e *SimpleExec) getSysSession() (sessionctx.Context, error) { + dom := domain.GetDomain(e.ctx) + sysSessionPool := dom.SysSessionPool() + ctx, err := sysSessionPool.Get() + if err != nil { + return nil, err + } + restrictedCtx := ctx.(sessionctx.Context) + restrictedCtx.GetSessionVars().InRestrictedSQL = true + return restrictedCtx, nil +} + +func (e *SimpleExec) releaseSysSession(ctx sessionctx.Context) { + dom := domain.GetDomain(e.ctx) + sysSessionPool := dom.SysSessionPool() + sysSessionPool.Put(ctx.(pools.Resource)) +} + // Next implements the Executor Next interface. -func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err error) { +func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { if e.done { return nil } + + if e.autoNewTxn() { + // Commit the old transaction, like DDL. + if err := e.ctx.NewTxn(ctx); err != nil { + return err + } + defer func() { e.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusInTrans, false) }() + } + switch x := e.Statement.(type) { case *ast.GrantRoleStmt: err = e.executeGrantRole(x) @@ -70,7 +99,7 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.RecordBatch) (err erro case *ast.RollbackStmt: err = e.executeRollback(x) case *ast.CreateUserStmt: - err = e.executeCreateUser(x) + err = e.executeCreateUser(ctx, x) case *ast.AlterUserStmt: err = e.executeAlterUser(x) case *ast.DropUserStmt: @@ -221,17 +250,102 @@ func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error { return nil } -func (e *SimpleExec) executeSetDefaultRole(s *ast.SetDefaultRoleStmt) error { +func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (err error) { + checker := privilege.GetPrivilegeManager(e.ctx) + user, sql := s.UserList[0], "" + if user.Hostname == "" { + user.Hostname = "%" + } switch s.SetRoleOpt { + case ast.SetRoleNone: + sql = fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) case ast.SetRoleAll: - return e.setDefaultRoleAll(s) + sql = fmt.Sprintf("INSERT IGNORE INTO mysql.default_roles(HOST,USER,DEFAULT_ROLE_HOST,DEFAULT_ROLE_USER) "+ + "SELECT TO_HOST,TO_USER,FROM_HOST,FROM_USER FROM mysql.role_edges WHERE TO_HOST='%s' AND TO_USER='%s';", user.Hostname, user.Username) + case ast.SetRoleRegular: + sql = "INSERT IGNORE INTO mysql.default_roles values" + for i, role := range s.RoleList { + ok := checker.FindEdge(e.ctx, role, user) + if !ok { + return ErrRoleNotGranted.GenWithStackByArgs(role.String(), user.String()) + } + sql += fmt.Sprintf("('%s', '%s', '%s', '%s')", user.Hostname, user.Username, role.Hostname, role.Username) + if i != len(s.RoleList)-1 { + sql += "," + } + } + } + + restrictedCtx, err := e.getSysSession() + if err != nil { + return err + } + defer e.releaseSysSession(restrictedCtx) + + sqlExecutor := restrictedCtx.(sqlexec.SQLExecutor) + + if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { + return err + } + + deleteSQL := fmt.Sprintf("DELETE IGNORE FROM mysql.default_roles WHERE USER='%s' AND HOST='%s';", user.Username, user.Hostname) + if _, err := sqlExecutor.Execute(context.Background(), deleteSQL); err != nil { + logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { + logutil.Logger(context.Background()).Error(fmt.Sprintf("Error occur when executing %s", sql)) + if _, rollbackErr := sqlExecutor.Execute(context.Background(), "rollback"); rollbackErr != nil { + return rollbackErr + } + return err + } + if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { + return err + } + return nil +} + +func (e *SimpleExec) executeSetDefaultRole(s *ast.SetDefaultRoleStmt) (err error) { + sessionVars := e.ctx.GetSessionVars() + checker := privilege.GetPrivilegeManager(e.ctx) + if checker == nil { + return errors.New("miss privilege checker") + } + + if len(s.UserList) == 1 && sessionVars.User != nil { + u, h := s.UserList[0].Username, s.UserList[0].Hostname + if u == sessionVars.User.Username && h == sessionVars.User.AuthHostname { + err = e.setDefaultRoleForCurrentUser(s) + domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) + return + } + } + + activeRoles := sessionVars.ActiveRoles + if !checker.RequestVerification(activeRoles, mysql.SystemDB, mysql.DefaultRoleTable, "", mysql.UpdatePriv) { + if !checker.RequestVerification(activeRoles, "", "", "", mysql.CreateUserPriv) { + return core.ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") + } + } + + switch s.SetRoleOpt { + case ast.SetRoleAll: + err = e.setDefaultRoleAll(s) case ast.SetRoleNone: - return e.setDefaultRoleNone(s) + err = e.setDefaultRoleNone(s) case ast.SetRoleRegular: - return e.setDefaultRoleRegular(s) + err = e.setDefaultRoleRegular(s) } - err := domain.GetDomain(e.ctx).PrivilegeHandle().Update(e.ctx.(sessionctx.Context)) - return err + if err != nil { + return + } + domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) + return } func (e *SimpleExec) setRoleRegular(s *ast.SetRoleStmt) error { @@ -404,9 +518,6 @@ func (e *SimpleExec) executeBegin(ctx context.Context, s *ast.BeginStmt) error { txnMode := s.Mode if txnMode == "" { txnMode = e.ctx.GetSessionVars().TxnMode - if txnMode == "" && pTxnConf.Default { - txnMode = ast.Pessimistic - } } if txnMode == ast.Pessimistic { e.ctx.GetSessionVars().TxnCtx.IsPessimistic = true @@ -459,13 +570,20 @@ func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { } return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) } + sql = fmt.Sprintf(`DELETE IGNORE FROM %s.%s WHERE DEFAULT_ROLE_HOST='%s' and DEFAULT_ROLE_USER='%s' and HOST='%s' and USER='%s'`, mysql.SystemDB, mysql.DefaultRoleTable, role.Hostname, role.Username, user.Hostname, user.Username) + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + return errors.Trace(err) + } + return ErrCannotUser.GenWithStackByArgs("REVOKE ROLE", role.String()) + } } } if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "commit"); err != nil { return err } - err := domain.GetDomain(e.ctx).PrivilegeHandle().Update(e.ctx.(sessionctx.Context)) - return errors.Trace(err) + domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) + return nil } func (e *SimpleExec) executeCommit(s *ast.CommitStmt) { @@ -487,7 +605,24 @@ func (e *SimpleExec) executeRollback(s *ast.RollbackStmt) error { return nil } -func (e *SimpleExec) executeCreateUser(s *ast.CreateUserStmt) error { +func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStmt) error { + // Check `CREATE USER` privilege. + if !config.GetGlobalConfig().Security.SkipGrantTable { + checker := privilege.GetPrivilegeManager(e.ctx) + if checker == nil { + return errors.New("miss privilege checker") + } + activeRoles := e.ctx.GetSessionVars().ActiveRoles + if !checker.RequestVerification(activeRoles, mysql.SystemDB, mysql.UserTable, "", mysql.InsertPriv) { + if s.IsCreateRole && !checker.RequestVerification(activeRoles, "", "", "", mysql.CreateRolePriv) { + return core.ErrSpecificAccessDenied.GenWithStackByArgs("CREATE ROLE") + } + if !s.IsCreateRole && !checker.RequestVerification(activeRoles, "", "", "", mysql.CreateUserPriv) { + return core.ErrSpecificAccessDenied.GenWithStackByArgs("CREATE User") + } + } + } + users := make([]string, 0, len(s.Specs)) for _, spec := range s.Specs { exists, err1 := userExists(e.ctx, spec.User.Username, spec.User.Hostname) @@ -513,11 +648,12 @@ func (e *SimpleExec) executeCreateUser(s *ast.CreateUserStmt) error { if len(users) == 0 { return nil } + sql := fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) if s.IsCreateRole { sql = fmt.Sprintf(`INSERT INTO %s.%s (Host, User, Password, Account_locked) VALUES %s;`, mysql.SystemDB, mysql.UserTable, strings.Join(users, ", ")) } - _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql) + _, _, err := e.ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedSQL(e.ctx, sql) if err != nil { return err } @@ -584,6 +720,14 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { failedUsers := make([]string, 0, len(s.Users)) + sessionVars := e.ctx.GetSessionVars() + for i, user := range s.Users { + if user.CurrentUser { + s.Users[i].Username = sessionVars.User.AuthUsername + s.Users[i].Hostname = sessionVars.User.AuthHostname + } + } + for _, role := range s.Roles { exists, err := userExists(e.ctx, role.Username, role.Hostname) if err != nil { @@ -629,8 +773,32 @@ func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { } func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { + // Check privileges. + // Check `CREATE USER` privilege. + if !config.GetGlobalConfig().Security.SkipGrantTable { + checker := privilege.GetPrivilegeManager(e.ctx) + if checker == nil { + return errors.New("miss privilege checker") + } + activeRoles := e.ctx.GetSessionVars().ActiveRoles + if !checker.RequestVerification(activeRoles, mysql.SystemDB, mysql.UserTable, "", mysql.DeletePriv) { + if s.IsDropRole && !checker.RequestVerification(activeRoles, "", "", "", mysql.DropRolePriv) { + return core.ErrSpecificAccessDenied.GenWithStackByArgs("DROP ROLE") + } + if !s.IsDropRole && !checker.RequestVerification(activeRoles, "", "", "", mysql.CreateUserPriv) { + return core.ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") + } + } + } + failedUsers := make([]string, 0, len(s.UserList)) notExistUsers := make([]string, 0, len(s.UserList)) + sysSession, err := e.getSysSession() + defer e.releaseSysSession(sysSession) + if err != nil { + return err + } + sqlExecutor := sysSession.(sqlexec.SQLExecutor) for _, user := range s.UserList { exists, err := userExists(e.ctx, user.Username, user.Hostname) @@ -643,13 +811,13 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { } // begin a transaction to delete a user. - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "begin"); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), "begin"); err != nil { return err } sql := fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.UserTable, user.Hostname, user.Username) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { return err } continue @@ -657,9 +825,9 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { // delete privileges from mysql.db sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.DBTable, user.Hostname, user.Username) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { return err } continue @@ -667,9 +835,9 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { // delete privileges from mysql.tables_priv sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE Host = '%s' and User = '%s';`, mysql.SystemDB, mysql.TablePrivTable, user.Hostname, user.Username) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { return err } continue @@ -677,18 +845,18 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { // delete relationship from mysql.role_edges sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE TO_HOST = '%s' and TO_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { return err } continue } sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE FROM_HOST = '%s' and FROM_USER = '%s';`, mysql.SystemDB, mysql.RoleEdgeTable, user.Hostname, user.Username) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { return err } continue @@ -696,25 +864,25 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { // delete relationship from mysql.default_roles sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE DEFAULT_ROLE_HOST = '%s' and DEFAULT_ROLE_USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { return err } continue } sql = fmt.Sprintf(`DELETE FROM %s.%s WHERE HOST = '%s' and USER = '%s';`, mysql.SystemDB, mysql.DefaultRoleTable, user.Hostname, user.Username) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), sql); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), sql); err != nil { failedUsers = append(failedUsers, user.String()) - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "rollback"); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), "rollback"); err != nil { return err } continue } //TODO: need delete columns_priv once we implement columns_priv functionality. - if _, err := e.ctx.(sqlexec.SQLExecutor).Execute(context.Background(), "commit"); err != nil { + if _, err := sqlExecutor.Execute(context.Background(), "commit"); err != nil { failedUsers = append(failedUsers, user.String()) } } @@ -799,6 +967,13 @@ func (e *SimpleExec) executeFlush(s *ast.FlushStmt) error { return errors.New("FLUSH TABLES WITH READ LOCK is not supported. Please use @@tidb_snapshot") } case ast.FlushPrivileges: + // If skip-grant-table is configured, do not flush privileges. + // Because LoadPrivilegeLoop does not run and the privilege Handle is nil, + // Call dom.PrivilegeHandle().Update would panic. + if config.GetGlobalConfig().Security.SkipGrantTable { + return nil + } + dom := domain.GetDomain(e.ctx) sysSessionPool := dom.SysSessionPool() ctx, err := sysSessionPool.Get() @@ -828,3 +1003,11 @@ func (e *SimpleExec) executeDropStats(s *ast.DropStatsStmt) error { } return h.Update(GetInfoSchema(e.ctx)) } + +func (e *SimpleExec) autoNewTxn() bool { + switch e.Statement.(type) { + case *ast.CreateUserStmt, *ast.AlterUserStmt, *ast.DropUserStmt: + return true + } + return false +} diff --git a/executor/simple_test.go b/executor/simple_test.go index 4ffd55a97a0e9..23542f1b6b951 100644 --- a/executor/simple_test.go +++ b/executor/simple_test.go @@ -16,17 +16,19 @@ package executor_test import ( "context" - "github.com/pingcap/tidb/planner/core" - . "github.com/pingcap/check" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/executor" + "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/store/mockstore/mocktikv" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testutil" ) @@ -134,6 +136,15 @@ func (s *testSuite3) TestRole(c *C) { grantRoleSQL = `GRANT 'r_1'@'localhost' TO 'r_3'@'localhost', 'r_4'@'localhost';` _, err = tk.Exec(grantRoleSQL) c.Check(err, NotNil) + + // Test grant role for current_user(); + sessionVars := tk.Se.GetSessionVars() + originUser := sessionVars.User + sessionVars.User = &auth.UserIdentity{Username: "root", Hostname: "localhost", AuthUsername: "root", AuthHostname: "%"} + tk.MustExec("grant 'r_1'@'localhost' to current_user();") + tk.MustExec("revoke 'r_1'@'localhost' from 'root'@'%';") + sessionVars.User = originUser + result = tk.MustQuery(`SELECT FROM_USER FROM mysql.role_edges WHERE TO_USER="r_3" and TO_HOST="localhost"`) result.Check(nil) @@ -150,14 +161,20 @@ func (s *testSuite3) TestRole(c *C) { tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('localhost','test','%','root')") tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_1','%','root')") tk.MustExec("insert into mysql.role_edges (FROM_HOST,FROM_USER,TO_HOST,TO_USER) values ('%','r_2','%','root')") + tk.MustExec("flush privileges") + tk.MustExec("SET DEFAULT ROLE r_1, r_2 TO root") _, err = tk.Exec("revoke test@localhost, r_1 from root;") c.Check(err, IsNil) _, err = tk.Exec("revoke `r_2`@`%` from root, u_2;") c.Check(err, NotNil) _, err = tk.Exec("revoke `r_2`@`%` from root;") c.Check(err, IsNil) + _, err = tk.Exec("revoke `r_1`@`%` from root;") + c.Check(err, IsNil) result = tk.MustQuery(`SELECT * FROM mysql.default_roles WHERE DEFAULT_ROLE_USER="test" and DEFAULT_ROLE_HOST="localhost"`) result.Check(nil) + result = tk.MustQuery(`SELECT * FROM mysql.default_roles WHERE USER="root" and HOST="%"`) + result.Check(nil) dropRoleSQL = `DROP ROLE 'test'@'localhost', r_1, r_2;` tk.MustExec(dropRoleSQL) } @@ -386,6 +403,36 @@ func (s *testSuite3) TestFlushPrivileges(c *C) { // After flush. _, err = se.Execute(ctx, `SELECT Password FROM mysql.User WHERE User="testflush" and Host="localhost"`) c.Check(err, IsNil) + +} + +type testFlushSuite struct{} + +func (s *testFlushSuite) TestFlushPrivilegesPanic(c *C) { + // Run in a separate suite because this test need to set SkipGrantTable config. + cluster := mocktikv.NewCluster() + mocktikv.BootstrapWithSingleStore(cluster) + mvccStore := mocktikv.MustNewMVCCStore() + store, err := mockstore.NewMockTikvStore( + mockstore.WithCluster(cluster), + mockstore.WithMVCCStore(mvccStore), + ) + c.Assert(err, IsNil) + defer store.Close() + + saveConf := config.GetGlobalConfig() + conf := config.NewConfig() + conf.Security.SkipGrantTable = true + config.StoreGlobalConfig(conf) + + dom, err := session.BootstrapSession(store) + c.Assert(err, IsNil) + defer dom.Close() + + tk := testkit.NewTestKit(c, store) + tk.MustExec("FLUSH PRIVILEGES") + + config.StoreGlobalConfig(saveConf) } func (s *testSuite3) TestDropStats(c *C) { @@ -404,7 +451,7 @@ func (s *testSuite3) TestDropStats(c *C) { c.Assert(statsTbl.Pseudo, IsFalse) testKit.MustExec("drop stats t") - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl = h.GetTableStats(tableInfo) c.Assert(statsTbl.Pseudo, IsTrue) @@ -412,12 +459,12 @@ func (s *testSuite3) TestDropStats(c *C) { statsTbl = h.GetTableStats(tableInfo) c.Assert(statsTbl.Pseudo, IsFalse) - h.Lease = 1 + h.SetLease(1) testKit.MustExec("drop stats t") - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl = h.GetTableStats(tableInfo) c.Assert(statsTbl.Pseudo, IsTrue) - h.Lease = 0 + h.SetLease(0) } func (s *testSuite3) TestFlushTables(c *C) { @@ -439,3 +486,69 @@ func (s *testSuite3) TestUseDB(c *C) { _, err = tk.Exec("USE ``") c.Assert(terror.ErrorEqual(core.ErrNoDB, err), IsTrue, Commentf("err %v", err)) } + +func (s *testSuite3) TestStmtAutoNewTxn(c *C) { + // Some statements are like DDL, they commit the previous txn automically. + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + // Fix issue https://github.com/pingcap/tidb/issues/10705 + tk.MustExec("begin") + tk.MustExec("create user 'xxx'@'%';") + tk.MustExec("grant all privileges on *.* to 'xxx'@'%';") + + tk.MustExec("create table auto_new (id int)") + tk.MustExec("begin") + tk.MustExec("insert into auto_new values (1)") + tk.MustExec("revoke all privileges on *.* from 'xxx'@'%'") + tk.MustExec("rollback") // insert statement has already committed + tk.MustQuery("select * from auto_new").Check(testkit.Rows("1")) + + // Test the behavior when autocommit is false. + tk.MustExec("set autocommit = 0") + tk.MustExec("insert into auto_new values (2)") + tk.MustExec("create user 'yyy'@'%'") + tk.MustExec("rollback") + tk.MustQuery("select * from auto_new").Check(testkit.Rows("1", "2")) + + tk.MustExec("drop user 'yyy'@'%'") + tk.MustExec("insert into auto_new values (3)") + tk.MustExec("rollback") + tk.MustQuery("select * from auto_new").Check(testkit.Rows("1", "2")) +} + +func (s *testSuite3) TestIssue9111(c *C) { + // CREATE USER / DROP USER fails if admin doesn't have insert privilege on `mysql.user` table. + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("create user 'user_admin'@'localhost';") + tk.MustExec("grant create user on *.* to 'user_admin'@'localhost';") + + // Create a new session. + se, err := session.CreateSession4Test(s.store) + c.Check(err, IsNil) + defer se.Close() + c.Assert(se.Auth(&auth.UserIdentity{Username: "user_admin", Hostname: "localhost"}, nil, nil), IsTrue) + + ctx := context.Background() + _, err = se.Execute(ctx, `create user test_create_user`) + c.Check(err, IsNil) + _, err = se.Execute(ctx, `drop user test_create_user`) + c.Check(err, IsNil) + + tk.MustExec("revoke create user on *.* from 'user_admin'@'localhost';") + tk.MustExec("grant insert, delete on mysql.User to 'user_admin'@'localhost';") + + _, err = se.Execute(ctx, `flush privileges`) + c.Check(err, IsNil) + _, err = se.Execute(ctx, `create user test_create_user`) + c.Check(err, IsNil) + _, err = se.Execute(ctx, `drop user test_create_user`) + c.Check(err, IsNil) + + _, err = se.Execute(ctx, `create role test_create_user`) + c.Check(err, IsNil) + _, err = se.Execute(ctx, `drop role test_create_user`) + c.Check(err, IsNil) + + tk.MustExec("drop user 'user_admin'@'localhost';") +} diff --git a/executor/sort.go b/executor/sort.go index 8e2e221a6828a..fefac4cda4c1a 100644 --- a/executor/sort.go +++ b/executor/sort.go @@ -56,7 +56,6 @@ type SortExec struct { // Close implements the Executor Close interface. func (e *SortExec) Close() error { - e.memTracker.Detach() e.memTracker = nil return e.children[0].Close() } @@ -75,7 +74,7 @@ func (e *SortExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *SortExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *SortExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("sort.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -105,13 +104,13 @@ func (e *SortExec) Next(ctx context.Context, req *chunk.RecordBatch) error { } func (e *SortExec) fetchRowChunks(ctx context.Context) error { - fields := e.retTypes() + fields := retTypes(e) e.rowChunks = chunk.NewList(fields, e.initCap, e.maxChunkSize) e.rowChunks.GetMemTracker().AttachTo(e.memTracker) e.rowChunks.GetMemTracker().SetLabel(rowChunksLabel) for { - chk := e.children[0].newFirstChunk() - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + chk := newFirstChunk(e.children[0]) + err := Next(ctx, e.children[0], chk) if err != nil { return err } @@ -239,7 +238,7 @@ func (e *TopNExec) Open(ctx context.Context) error { } // Next implements the Executor Next interface. -func (e *TopNExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *TopNExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("topN.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -275,14 +274,14 @@ func (e *TopNExec) Next(ctx context.Context, req *chunk.RecordBatch) error { func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error { e.chkHeap = &topNChunkHeap{e} - e.rowChunks = chunk.NewList(e.retTypes(), e.initCap, e.maxChunkSize) + e.rowChunks = chunk.NewList(retTypes(e), e.initCap, e.maxChunkSize) e.rowChunks.GetMemTracker().AttachTo(e.memTracker) e.rowChunks.GetMemTracker().SetLabel(rowChunksLabel) for uint64(e.rowChunks.Len()) < e.totalLimit { - srcChk := e.children[0].newFirstChunk() + srcChk := newFirstChunk(e.children[0]) // adjust required rows by total limit srcChk.SetRequiredRows(int(e.totalLimit-uint64(e.rowChunks.Len())), e.maxChunkSize) - err := e.children[0].Next(ctx, chunk.NewRecordBatch(srcChk)) + err := Next(ctx, e.children[0], srcChk) if err != nil { return err } @@ -305,9 +304,9 @@ func (e *TopNExec) executeTopN(ctx context.Context) error { // The number of rows we loaded may exceeds total limit, remove greatest rows by Pop. heap.Pop(e.chkHeap) } - childRowChk := e.children[0].newFirstChunk() + childRowChk := newFirstChunk(e.children[0]) for { - err := e.children[0].Next(ctx, chunk.NewRecordBatch(childRowChk)) + err := Next(ctx, e.children[0], childRowChk) if err != nil { return err } @@ -349,7 +348,7 @@ func (e *TopNExec) processChildChk(childRowChk *chunk.Chunk) error { // but we want descending top N, then we will keep all data in memory. // But if data is distributed randomly, this function will be called log(n) times. func (e *TopNExec) doCompaction() error { - newRowChunks := chunk.NewList(e.retTypes(), e.initCap, e.maxChunkSize) + newRowChunks := chunk.NewList(retTypes(e), e.initCap, e.maxChunkSize) newRowPtrs := make([]chunk.RowPtr, 0, e.rowChunks.Len()) for _, rowPtr := range e.rowPtrs { newRowPtr := newRowChunks.AppendRow(e.rowChunks.GetRow(rowPtr)) diff --git a/executor/split.go b/executor/split.go old mode 100644 new mode 100755 index dc1b75e8018f3..ce1afaa4288b3 --- a/executor/split.go +++ b/executor/split.go @@ -14,15 +14,27 @@ package executor import ( + "bytes" "context" + "encoding/binary" + "fmt" "math" + "time" + "github.com/cznic/mathutil" + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/store/helper" + "github.com/pingcap/tidb/store/tikv" "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -31,54 +43,578 @@ import ( type SplitIndexRegionExec struct { baseExecutor - table table.Table - indexInfo *model.IndexInfo - valueLists [][]types.Datum + tableInfo *model.TableInfo + indexInfo *model.IndexInfo + lower []types.Datum + upper []types.Datum + num int + valueLists [][]types.Datum + splitIdxKeys [][]byte + + done bool + splitRegionResult +} + +type splitRegionResult struct { + splitRegions int + finishScatterNum int } -type splitableStore interface { - SplitRegionAndScatter(splitKey kv.Key) (uint64, error) - WaitScatterRegionFinish(regionID uint64) error +// Open implements the Executor Open interface. +func (e *SplitIndexRegionExec) Open(ctx context.Context) (err error) { + e.splitIdxKeys, err = e.getSplitIdxKeys() + return err } // Next implements the Executor Next interface. -func (e *SplitIndexRegionExec) Next(ctx context.Context, _ *chunk.RecordBatch) error { +func (e *SplitIndexRegionExec) Next(ctx context.Context, chk *chunk.Chunk) error { + chk.Reset() + if e.done { + return nil + } + e.done = true + if err := e.splitIndexRegion(ctx); err != nil { + return err + } + + appendSplitRegionResultToChunk(chk, e.splitRegions, e.finishScatterNum) + return nil +} + +// checkScatterRegionFinishBackOff is the back off time that used to check if a region has finished scattering before split region timeout. +const checkScatterRegionFinishBackOff = 50 + +// splitIndexRegion is used to split index regions. +func (e *SplitIndexRegionExec) splitIndexRegion(ctx context.Context) error { store := e.ctx.GetStore() - s, ok := store.(splitableStore) + s, ok := store.(kv.SplitableStore) if !ok { return nil } - regionIDs := make([]uint64, 0, len(e.valueLists)) - index := tables.NewIndex(e.table.Meta().ID, e.table.Meta(), e.indexInfo) - for _, values := range e.valueLists { - idxKey, _, err := index.GenIndexKey(e.ctx.GetSessionVars().StmtCtx, values, math.MinInt64, nil) - if err != nil { - return err + + start := time.Now() + ctxWithTimeout, cancel := context.WithTimeout(ctx, e.ctx.GetSessionVars().GetSplitRegionTimeout()) + defer cancel() + regionIDs, err := s.SplitRegions(context.Background(), e.splitIdxKeys, true) + if err != nil { + logutil.Logger(context.Background()).Warn("split table index region failed", + zap.String("table", e.tableInfo.Name.L), + zap.String("index", e.indexInfo.Name.L), + zap.Error(err)) + } + e.splitRegions = len(regionIDs) + if e.splitRegions == 0 { + return nil + } + + if !e.ctx.GetSessionVars().WaitSplitRegionFinish { + return nil + } + e.finishScatterNum = waitScatterRegionFinish(ctxWithTimeout, e.ctx, start, s, regionIDs, e.tableInfo.Name.L, e.indexInfo.Name.L) + return nil +} + +func (e *SplitIndexRegionExec) getSplitIdxKeys() ([][]byte, error) { + var idxKeys [][]byte + if e.num > 0 { + idxKeys = make([][]byte, 0, e.num) + } else { + idxKeys = make([][]byte, 0, len(e.valueLists)+1) + } + // Split in the start of the index key. + startIdxKey := tablecodec.EncodeTableIndexPrefix(e.tableInfo.ID, e.indexInfo.ID) + idxKeys = append(idxKeys, startIdxKey) + + // Split in the end for the other index key. + for _, idx := range e.tableInfo.Indices { + if idx.ID <= e.indexInfo.ID { + continue + } + endIdxKey := tablecodec.EncodeTableIndexPrefix(e.tableInfo.ID, idx.ID) + idxKeys = append(idxKeys, endIdxKey) + break + } + + index := tables.NewIndex(e.tableInfo.ID, e.tableInfo, e.indexInfo) + // Split index regions by user specified value lists. + if len(e.valueLists) > 0 { + for _, v := range e.valueLists { + idxKey, _, err := index.GenIndexKey(e.ctx.GetSessionVars().StmtCtx, v, math.MinInt64, nil) + if err != nil { + return nil, err + } + idxKeys = append(idxKeys, idxKey) + } + return idxKeys, nil + } + // Split index regions by lower, upper value and calculate the step by (upper - lower)/num. + lowerIdxKey, _, err := index.GenIndexKey(e.ctx.GetSessionVars().StmtCtx, e.lower, math.MinInt64, nil) + if err != nil { + return nil, err + } + // Use math.MinInt64 as handle_id for the upper index key to avoid affecting calculate split point. + // If use math.MaxInt64 here, test of `TestSplitIndex` will report error. + upperIdxKey, _, err := index.GenIndexKey(e.ctx.GetSessionVars().StmtCtx, e.upper, math.MinInt64, nil) + if err != nil { + return nil, err + } + if bytes.Compare(lowerIdxKey, upperIdxKey) >= 0 { + lowerStr, err1 := datumSliceToString(e.lower) + upperStr, err2 := datumSliceToString(e.upper) + if err1 != nil || err2 != nil { + return nil, errors.Errorf("Split index `%v` region lower value %v should less than the upper value %v", e.indexInfo.Name, e.lower, e.upper) + } + return nil, errors.Errorf("Split index `%v` region lower value %v should less than the upper value %v", e.indexInfo.Name, lowerStr, upperStr) + } + return getValuesList(lowerIdxKey, upperIdxKey, e.num, idxKeys), nil +} + +// getValuesList is used to get `num` values between lower and upper value. +// To Simplify the explain, suppose lower and upper value type is int64, and lower=0, upper=100, num=10, +// then calculate the step=(upper-lower)/num=10, then the function should return 0+10, 10+10, 20+10... all together 9 (num-1) values. +// Then the function will return [10,20,30,40,50,60,70,80,90]. +// The difference is the value type of upper,lower is []byte, So I use getUint64FromBytes to convert []byte to uint64. +func getValuesList(lower, upper []byte, num int, valuesList [][]byte) [][]byte { + commonPrefixIdx := longestCommonPrefixLen(lower, upper) + step := getStepValue(lower[commonPrefixIdx:], upper[commonPrefixIdx:], num) + startV := getUint64FromBytes(lower[commonPrefixIdx:], 0) + // To get `num` regions, only need to split `num-1` idx keys. + buf := make([]byte, 8) + for i := 0; i < num-1; i++ { + value := make([]byte, 0, commonPrefixIdx+8) + value = append(value, lower[:commonPrefixIdx]...) + startV += step + binary.BigEndian.PutUint64(buf, startV) + value = append(value, buf...) + valuesList = append(valuesList, value) + } + return valuesList +} + +// longestCommonPrefixLen gets the longest common prefix byte length. +func longestCommonPrefixLen(s1, s2 []byte) int { + l := mathutil.Min(len(s1), len(s2)) + i := 0 + for ; i < l; i++ { + if s1[i] != s2[i] { + break } + } + return i +} + +// getStepValue gets the step of between the lower and upper value. step = (upper-lower)/num. +// Convert byte slice to uint64 first. +func getStepValue(lower, upper []byte, num int) uint64 { + lowerUint := getUint64FromBytes(lower, 0) + upperUint := getUint64FromBytes(upper, 0xff) + return (upperUint - lowerUint) / uint64(num) +} - regionID, err := s.SplitRegionAndScatter(idxKey) +// getUint64FromBytes gets a uint64 from the `bs` byte slice. +// If len(bs) < 8, then padding with `pad`. +func getUint64FromBytes(bs []byte, pad byte) uint64 { + buf := bs + if len(buf) < 8 { + buf = make([]byte, 0, 8) + buf = append(buf, bs...) + for i := len(buf); i < 8; i++ { + buf = append(buf, pad) + } + } + return binary.BigEndian.Uint64(buf) +} + +func datumSliceToString(ds []types.Datum) (string, error) { + str := "(" + for i, d := range ds { + s, err := d.ToString() if err != nil { - logutil.Logger(context.Background()).Warn("split table index region failed", - zap.String("table", e.table.Meta().Name.L), - zap.String("index", e.indexInfo.Name.L), - zap.Error(err)) - continue + return str, err } - regionIDs = append(regionIDs, regionID) + if i > 0 { + str += "," + } + str += s + } + str += ")" + return str, nil +} + +// SplitTableRegionExec represents a split table regions executor. +type SplitTableRegionExec struct { + baseExecutor + + tableInfo *model.TableInfo + lower types.Datum + upper types.Datum + num int + valueLists [][]types.Datum + splitKeys [][]byte + done bool + splitRegionResult +} + +// Open implements the Executor Open interface. +func (e *SplitTableRegionExec) Open(ctx context.Context) (err error) { + e.splitKeys, err = e.getSplitTableKeys() + return err +} + +// Next implements the Executor Next interface. +func (e *SplitTableRegionExec) Next(ctx context.Context, chk *chunk.Chunk) error { + chk.Reset() + if e.done { + return nil + } + e.done = true + + if err := e.splitTableRegion(ctx); err != nil { + return err + } + appendSplitRegionResultToChunk(chk, e.splitRegions, e.finishScatterNum) + return nil +} + +func (e *SplitTableRegionExec) splitTableRegion(ctx context.Context) error { + store := e.ctx.GetStore() + s, ok := store.(kv.SplitableStore) + if !ok { + return nil } - if !e.ctx.GetSessionVars().WaitTableSplitFinish { + + start := time.Now() + ctxWithTimeout, cancel := context.WithTimeout(ctx, e.ctx.GetSessionVars().GetSplitRegionTimeout()) + defer cancel() + + regionIDs, err := s.SplitRegions(ctxWithTimeout, e.splitKeys, true) + if err != nil { + logutil.Logger(context.Background()).Warn("split table region failed", + zap.String("table", e.tableInfo.Name.L), + zap.Error(err)) + } + e.splitRegions = len(regionIDs) + if e.splitRegions == 0 { + return nil + } + + if !e.ctx.GetSessionVars().WaitSplitRegionFinish { return nil } + + e.finishScatterNum = waitScatterRegionFinish(ctxWithTimeout, e.ctx, start, s, regionIDs, e.tableInfo.Name.L, "") + return nil +} + +func waitScatterRegionFinish(ctxWithTimeout context.Context, sctx sessionctx.Context, startTime time.Time, store kv.SplitableStore, regionIDs []uint64, tableName, indexName string) int { + remainMillisecond := 0 + finishScatterNum := 0 for _, regionID := range regionIDs { - err := s.WaitScatterRegionFinish(regionID) + if isCtxDone(ctxWithTimeout) { + // Do not break here for checking remain regions scatter finished with a very short backoff time. + // Consider this situation - Regions 1, 2, and 3 are to be split. + // Region 1 times out before scattering finishes, while Region 2 and Region 3 have finished scattering. + // In this case, we should return 2 Regions, instead of 0, have finished scattering. + remainMillisecond = checkScatterRegionFinishBackOff + } else { + remainMillisecond = int((sctx.GetSessionVars().GetSplitRegionTimeout().Seconds() - time.Since(startTime).Seconds()) * 1000) + } + + err := store.WaitScatterRegionFinish(regionID, remainMillisecond) + if err == nil { + finishScatterNum++ + } else { + if len(indexName) == 0 { + logutil.Logger(context.Background()).Warn("wait scatter region failed", + zap.Uint64("regionID", regionID), + zap.String("table", tableName), + zap.Error(err)) + } else { + logutil.Logger(context.Background()).Warn("wait scatter region failed", + zap.Uint64("regionID", regionID), + zap.String("table", tableName), + zap.String("index", indexName), + zap.Error(err)) + } + } + } + return finishScatterNum +} + +func appendSplitRegionResultToChunk(chk *chunk.Chunk, totalRegions, finishScatterNum int) { + chk.AppendInt64(0, int64(totalRegions)) + if finishScatterNum > 0 && totalRegions > 0 { + chk.AppendFloat64(1, float64(finishScatterNum)/float64(totalRegions)) + } else { + chk.AppendFloat64(1, float64(0)) + } +} + +func isCtxDone(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + return false + } +} + +var minRegionStepValue = uint64(1000) + +func (e *SplitTableRegionExec) getSplitTableKeys() ([][]byte, error) { + var keys [][]byte + if e.num > 0 { + keys = make([][]byte, 0, e.num) + } else { + keys = make([][]byte, 0, len(e.valueLists)) + } + recordPrefix := tablecodec.GenTableRecordPrefix(e.tableInfo.ID) + if len(e.valueLists) > 0 { + for _, v := range e.valueLists { + key := tablecodec.EncodeRecordKey(recordPrefix, v[0].GetInt64()) + keys = append(keys, key) + } + return keys, nil + } + isUnsigned := false + if e.tableInfo.PKIsHandle { + if pkCol := e.tableInfo.GetPkColInfo(); pkCol != nil { + isUnsigned = mysql.HasUnsignedFlag(pkCol.Flag) + } + } + var step uint64 + var lowerValue int64 + if isUnsigned { + lowerRecordID := e.lower.GetUint64() + upperRecordID := e.upper.GetUint64() + if upperRecordID <= lowerRecordID { + return nil, errors.Errorf("Split table `%s` region lower value %v should less than the upper value %v", e.tableInfo.Name, lowerRecordID, upperRecordID) + } + step = (upperRecordID - lowerRecordID) / uint64(e.num) + lowerValue = int64(lowerRecordID) + } else { + lowerRecordID := e.lower.GetInt64() + upperRecordID := e.upper.GetInt64() + if upperRecordID <= lowerRecordID { + return nil, errors.Errorf("Split table `%s` region lower value %v should less than the upper value %v", e.tableInfo.Name, lowerRecordID, upperRecordID) + } + step = uint64(upperRecordID-lowerRecordID) / uint64(e.num) + lowerValue = lowerRecordID + } + if step < minRegionStepValue { + return nil, errors.Errorf("Split table `%s` region step value should more than %v, step %v is invalid", e.tableInfo.Name, minRegionStepValue, step) + } + + // Split a separate region for index. + if len(e.tableInfo.Indices) > 0 { + keys = append(keys, recordPrefix) + } + recordID := lowerValue + for i := 1; i < e.num; i++ { + recordID += int64(step) + key := tablecodec.EncodeRecordKey(recordPrefix, recordID) + keys = append(keys, key) + } + return keys, nil +} + +// RegionMeta contains a region's peer detail +type regionMeta struct { + region *metapb.Region + leaderID uint64 + storeID uint64 // storeID is the store ID of the leader region. + start string + end string + scattering bool + writtenBytes int64 + readBytes int64 + approximateSize int64 + approximateKeys int64 +} + +func getPhysicalTableRegions(physicalTableID int64, tableInfo *model.TableInfo, tikvStore tikv.Storage, s kv.SplitableStore, uniqueRegionMap map[uint64]struct{}) ([]regionMeta, error) { + if uniqueRegionMap == nil { + uniqueRegionMap = make(map[uint64]struct{}) + } + // for record + startKey, endKey := tablecodec.GetTableHandleKeyRange(physicalTableID) + regionCache := tikvStore.GetRegionCache() + recordRegionMetas, err := regionCache.LoadRegionsInKeyRange(tikv.NewBackoffer(context.Background(), 20000), startKey, endKey) + if err != nil { + return nil, err + } + recordPrefix := tablecodec.GenTableRecordPrefix(physicalTableID) + tablePrefix := tablecodec.GenTablePrefix(physicalTableID) + recordRegions, err := getRegionMeta(tikvStore, recordRegionMetas, uniqueRegionMap, tablePrefix, recordPrefix, nil, physicalTableID, 0) + if err != nil { + return nil, err + } + + regions := recordRegions + // for indices + for _, index := range tableInfo.Indices { + if index.State != model.StatePublic { + continue + } + startKey, endKey := tablecodec.GetTableIndexKeyRange(physicalTableID, index.ID) + regionMetas, err := regionCache.LoadRegionsInKeyRange(tikv.NewBackoffer(context.Background(), 20000), startKey, endKey) + if err != nil { + return nil, err + } + indexPrefix := tablecodec.EncodeTableIndexPrefix(physicalTableID, index.ID) + indexRegions, err := getRegionMeta(tikvStore, regionMetas, uniqueRegionMap, tablePrefix, recordPrefix, indexPrefix, physicalTableID, index.ID) + if err != nil { + return nil, err + } + regions = append(regions, indexRegions...) + } + err = checkRegionsStatus(s, regions) + if err != nil { + return nil, err + } + return regions, nil +} + +func getPhysicalIndexRegions(physicalTableID int64, indexInfo *model.IndexInfo, tikvStore tikv.Storage, s kv.SplitableStore, uniqueRegionMap map[uint64]struct{}) ([]regionMeta, error) { + if uniqueRegionMap == nil { + uniqueRegionMap = make(map[uint64]struct{}) + } + + startKey, endKey := tablecodec.GetTableIndexKeyRange(physicalTableID, indexInfo.ID) + regionCache := tikvStore.GetRegionCache() + regions, err := regionCache.LoadRegionsInKeyRange(tikv.NewBackoffer(context.Background(), 20000), startKey, endKey) + if err != nil { + return nil, err + } + recordPrefix := tablecodec.GenTableRecordPrefix(physicalTableID) + tablePrefix := tablecodec.GenTablePrefix(physicalTableID) + indexPrefix := tablecodec.EncodeTableIndexPrefix(physicalTableID, indexInfo.ID) + indexRegions, err := getRegionMeta(tikvStore, regions, uniqueRegionMap, tablePrefix, recordPrefix, indexPrefix, physicalTableID, indexInfo.ID) + if err != nil { + return nil, err + } + err = checkRegionsStatus(s, indexRegions) + if err != nil { + return nil, err + } + return indexRegions, nil +} + +func checkRegionsStatus(store kv.SplitableStore, regions []regionMeta) error { + for i := range regions { + scattering, err := store.CheckRegionInScattering(regions[i].region.Id) if err != nil { - logutil.Logger(context.Background()).Warn("wait scatter region failed", - zap.Uint64("regionID", regionID), - zap.String("table", e.table.Meta().Name.L), - zap.String("index", e.indexInfo.Name.L), - zap.Error(err)) + return err } + regions[i].scattering = scattering } return nil } + +func decodeRegionsKey(regions []regionMeta, tablePrefix, recordPrefix, indexPrefix []byte, physicalTableID, indexID int64) { + d := ®ionKeyDecoder{ + physicalTableID: physicalTableID, + tablePrefix: tablePrefix, + recordPrefix: recordPrefix, + indexPrefix: indexPrefix, + indexID: indexID, + } + for i := range regions { + regions[i].start = d.decodeRegionKey(regions[i].region.StartKey) + regions[i].end = d.decodeRegionKey(regions[i].region.EndKey) + } +} + +type regionKeyDecoder struct { + physicalTableID int64 + tablePrefix []byte + recordPrefix []byte + indexPrefix []byte + indexID int64 +} + +func (d *regionKeyDecoder) decodeRegionKey(key []byte) string { + if len(d.indexPrefix) > 0 && bytes.HasPrefix(key, d.indexPrefix) { + return fmt.Sprintf("t_%d_i_%d_%x", d.physicalTableID, d.indexID, key[len(d.indexPrefix):]) + } else if len(d.recordPrefix) > 0 && bytes.HasPrefix(key, d.recordPrefix) { + if len(d.recordPrefix) == len(key) { + return fmt.Sprintf("t_%d_r", d.physicalTableID) + } + _, handle, err := codec.DecodeInt(key[len(d.recordPrefix):]) + if err == nil { + return fmt.Sprintf("t_%d_r_%d", d.physicalTableID, handle) + } + } + if len(d.tablePrefix) > 0 && bytes.HasPrefix(key, d.tablePrefix) { + key = key[len(d.tablePrefix):] + // Has index prefix. + if !bytes.HasPrefix(key, []byte("_i")) { + return fmt.Sprintf("t_%d_%x", d.physicalTableID, key) + } + key = key[2:] + // try to decode index ID. + if _, indexID, err := codec.DecodeInt(key); err == nil { + return fmt.Sprintf("t_%d_i_%d_%x", d.physicalTableID, indexID, key[8:]) + } + return fmt.Sprintf("t_%d_i__%x", d.physicalTableID, key) + } + // Has table prefix. + if bytes.HasPrefix(key, []byte("t")) { + key = key[1:] + // try to decode table ID. + if _, tableID, err := codec.DecodeInt(key); err == nil { + return fmt.Sprintf("t_%d_%x", tableID, key[8:]) + } + return fmt.Sprintf("t_%x", key) + } + return fmt.Sprintf("%x", key) +} + +func getRegionMeta(tikvStore tikv.Storage, regionMetas []*tikv.Region, uniqueRegionMap map[uint64]struct{}, tablePrefix, recordPrefix, indexPrefix []byte, physicalTableID, indexID int64) ([]regionMeta, error) { + regions := make([]regionMeta, 0, len(regionMetas)) + for _, r := range regionMetas { + if _, ok := uniqueRegionMap[r.GetID()]; ok { + continue + } + uniqueRegionMap[r.GetID()] = struct{}{} + regions = append(regions, regionMeta{ + region: r.GetMeta(), + leaderID: r.GetLeaderID(), + storeID: r.GetLeaderStoreID(), + }) + } + regions, err := getRegionInfo(tikvStore, regions) + if err != nil { + return regions, err + } + decodeRegionsKey(regions, tablePrefix, recordPrefix, indexPrefix, physicalTableID, indexID) + return regions, nil +} + +func getRegionInfo(store tikv.Storage, regions []regionMeta) ([]regionMeta, error) { + // check pd server exists. + etcd, ok := store.(tikv.EtcdBackend) + if !ok { + return regions, nil + } + pdHosts := etcd.EtcdAddrs() + if len(pdHosts) == 0 { + return regions, nil + } + tikvHelper := &helper.Helper{ + Store: store, + RegionCache: store.GetRegionCache(), + } + for i := range regions { + regionInfo, err := tikvHelper.GetRegionInfoByID(regions[i].region.Id) + if err != nil { + return nil, err + } + regions[i].writtenBytes = regionInfo.WrittenBytes + regions[i].readBytes = regionInfo.ReadBytes + regions[i].approximateSize = regionInfo.ApproximateSize + regions[i].approximateKeys = regionInfo.ApproximateKeys + } + return regions, nil +} diff --git a/executor/split_test.go b/executor/split_test.go new file mode 100644 index 0000000000000..c090faf20d3c7 --- /dev/null +++ b/executor/split_test.go @@ -0,0 +1,374 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bytes" + "encoding/binary" + "math" + "math/rand" + + . "github.com/pingcap/check" + "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/table/tables" + "github.com/pingcap/tidb/tablecodec" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/mock" +) + +var _ = Suite(&testSplitIndex{}) + +type testSplitIndex struct { +} + +func (s *testSplitIndex) SetUpSuite(c *C) { +} + +func (s *testSplitIndex) TearDownSuite(c *C) { +} + +func (s *testSplitIndex) TestLongestCommonPrefixLen(c *C) { + cases := []struct { + s1 string + s2 string + l int + }{ + {"", "", 0}, + {"", "a", 0}, + {"a", "", 0}, + {"a", "a", 1}, + {"ab", "a", 1}, + {"a", "ab", 1}, + {"b", "ab", 0}, + {"ba", "ab", 0}, + } + + for _, ca := range cases { + re := longestCommonPrefixLen([]byte(ca.s1), []byte(ca.s2)) + c.Assert(re, Equals, ca.l) + } +} + +func (s *testSplitIndex) TestgetStepValue(c *C) { + cases := []struct { + lower []byte + upper []byte + l int + v uint64 + }{ + {[]byte{}, []byte{}, 0, math.MaxUint64}, + {[]byte{0}, []byte{128}, 0, binary.BigEndian.Uint64([]byte{128, 255, 255, 255, 255, 255, 255, 255})}, + {[]byte{'a'}, []byte{'z'}, 0, binary.BigEndian.Uint64([]byte{'z' - 'a', 255, 255, 255, 255, 255, 255, 255})}, + {[]byte("abc"), []byte{'z'}, 0, binary.BigEndian.Uint64([]byte{'z' - 'a', 255 - 'b', 255 - 'c', 255, 255, 255, 255, 255})}, + {[]byte("abc"), []byte("xyz"), 0, binary.BigEndian.Uint64([]byte{'x' - 'a', 'y' - 'b', 'z' - 'c', 255, 255, 255, 255, 255})}, + {[]byte("abc"), []byte("axyz"), 1, binary.BigEndian.Uint64([]byte{'x' - 'b', 'y' - 'c', 'z', 255, 255, 255, 255, 255})}, + {[]byte("abc0123456"), []byte("xyz01234"), 0, binary.BigEndian.Uint64([]byte{'x' - 'a', 'y' - 'b', 'z' - 'c', 0, 0, 0, 0, 0})}, + } + + for _, ca := range cases { + l := longestCommonPrefixLen(ca.lower, ca.upper) + c.Assert(l, Equals, ca.l) + v0 := getStepValue(ca.lower[l:], ca.upper[l:], 1) + c.Assert(v0, Equals, ca.v) + } +} + +func (s *testSplitIndex) TestSplitIndex(c *C) { + tbInfo := &model.TableInfo{ + Name: model.NewCIStr("t1"), + ID: rand.Int63(), + Columns: []*model.ColumnInfo{ + { + Name: model.NewCIStr("c0"), + ID: 1, + Offset: 1, + DefaultValue: 0, + State: model.StatePublic, + FieldType: *types.NewFieldType(mysql.TypeLong), + }, + }, + } + idxCols := []*model.IndexColumn{{Name: tbInfo.Columns[0].Name, Offset: 0, Length: types.UnspecifiedLength}} + idxInfo := &model.IndexInfo{ + ID: 1, + Name: model.NewCIStr("idx1"), + Table: model.NewCIStr("t1"), + Columns: idxCols, + State: model.StatePublic, + } + + // Test for int index. + // range is 0 ~ 100, and split into 10 region. + // So 10 regions range is like below, left close right open interval: + // region1: [-inf ~ 10) + // region2: [10 ~ 20) + // region3: [20 ~ 30) + // region4: [30 ~ 40) + // region5: [40 ~ 50) + // region6: [50 ~ 60) + // region7: [60 ~ 70) + // region8: [70 ~ 80) + // region9: [80 ~ 90) + // region10: [90 ~ +inf) + ctx := mock.NewContext() + e := &SplitIndexRegionExec{ + baseExecutor: newBaseExecutor(ctx, nil, nil), + tableInfo: tbInfo, + indexInfo: idxInfo, + lower: []types.Datum{types.NewDatum(0)}, + upper: []types.Datum{types.NewDatum(100)}, + num: 10, + } + valueList, err := e.getSplitIdxKeys() + c.Assert(err, IsNil) + c.Assert(len(valueList), Equals, e.num) + + cases := []struct { + value int + lessEqualIdx int + }{ + {-1, 0}, + {0, 0}, + {1, 0}, + {10, 1}, + {11, 1}, + {20, 2}, + {21, 2}, + {31, 3}, + {41, 4}, + {51, 5}, + {61, 6}, + {71, 7}, + {81, 8}, + {91, 9}, + {100, 9}, + {1000, 9}, + } + + index := tables.NewIndex(tbInfo.ID, tbInfo, idxInfo) + for _, ca := range cases { + // test for minInt64 handle + idxValue, _, err := index.GenIndexKey(ctx.GetSessionVars().StmtCtx, []types.Datum{types.NewDatum(ca.value)}, math.MinInt64, nil) + c.Assert(err, IsNil) + idx := searchLessEqualIdx(valueList, idxValue) + c.Assert(idx, Equals, ca.lessEqualIdx, Commentf("%#v", ca)) + + // Test for max int64 handle. + idxValue, _, err = index.GenIndexKey(ctx.GetSessionVars().StmtCtx, []types.Datum{types.NewDatum(ca.value)}, math.MaxInt64, nil) + c.Assert(err, IsNil) + idx = searchLessEqualIdx(valueList, idxValue) + c.Assert(idx, Equals, ca.lessEqualIdx, Commentf("%#v", ca)) + } + // Test for varchar index. + // range is a ~ z, and split into 26 region. + // So 26 regions range is like below: + // region1: [-inf ~ b) + // region2: [b ~ c) + // . + // . + // . + // region26: [y ~ +inf) + e.lower = []types.Datum{types.NewDatum("a")} + e.upper = []types.Datum{types.NewDatum("z")} + e.num = 26 + // change index column type to varchar + tbInfo.Columns[0].FieldType = *types.NewFieldType(mysql.TypeVarchar) + + valueList, err = e.getSplitIdxKeys() + c.Assert(err, IsNil) + c.Assert(len(valueList), Equals, e.num) + + cases2 := []struct { + value string + lessEqualIdx int + }{ + {"", 0}, + {"a", 0}, + {"abcde", 0}, + {"b", 1}, + {"bzzzz", 1}, + {"c", 2}, + {"czzzz", 2}, + {"z", 25}, + {"zabcd", 25}, + } + + for _, ca := range cases2 { + // test for minInt64 handle + idxValue, _, err := index.GenIndexKey(ctx.GetSessionVars().StmtCtx, []types.Datum{types.NewDatum(ca.value)}, math.MinInt64, nil) + c.Assert(err, IsNil) + idx := searchLessEqualIdx(valueList, idxValue) + c.Assert(idx, Equals, ca.lessEqualIdx, Commentf("%#v", ca)) + + // Test for max int64 handle. + idxValue, _, err = index.GenIndexKey(ctx.GetSessionVars().StmtCtx, []types.Datum{types.NewDatum(ca.value)}, math.MaxInt64, nil) + c.Assert(err, IsNil) + idx = searchLessEqualIdx(valueList, idxValue) + c.Assert(idx, Equals, ca.lessEqualIdx, Commentf("%#v", ca)) + } + + // Test for timestamp index. + // range is 2010-01-01 00:00:00 ~ 2020-01-01 00:00:00, and split into 10 region. + // So 10 regions range is like below: + // region1: [-inf ~ 2011-01-01 00:00:00) + // region2: [2011-01-01 00:00:00 ~ 2012-01-01 00:00:00) + // . + // . + // . + // region10: [2019-01-01 00:00:00 ~ +inf) + lowerTime := types.Time{ + Time: types.FromDate(2010, 1, 1, 0, 0, 0, 0), + Type: mysql.TypeTimestamp, + } + upperTime := types.Time{ + Time: types.FromDate(2020, 1, 1, 0, 0, 0, 0), + Type: mysql.TypeTimestamp, + } + e.lower = []types.Datum{types.NewDatum(lowerTime)} + e.upper = []types.Datum{types.NewDatum(upperTime)} + e.num = 10 + + // change index column type to timestamp + tbInfo.Columns[0].FieldType = *types.NewFieldType(mysql.TypeTimestamp) + + valueList, err = e.getSplitIdxKeys() + c.Assert(err, IsNil) + c.Assert(len(valueList), Equals, e.num) + + cases3 := []struct { + value types.MysqlTime + lessEqualIdx int + }{ + {types.FromDate(2009, 11, 20, 12, 50, 59, 0), 0}, + {types.FromDate(2010, 1, 1, 0, 0, 0, 0), 0}, + {types.FromDate(2011, 12, 31, 23, 59, 59, 0), 1}, + {types.FromDate(2011, 2, 1, 0, 0, 0, 0), 1}, + {types.FromDate(2012, 3, 1, 0, 0, 0, 0), 2}, + {types.FromDate(2013, 4, 1, 0, 0, 0, 0), 3}, + {types.FromDate(2014, 5, 1, 0, 0, 0, 0), 4}, + {types.FromDate(2015, 6, 1, 0, 0, 0, 0), 5}, + {types.FromDate(2016, 8, 1, 0, 0, 0, 0), 6}, + {types.FromDate(2017, 9, 1, 0, 0, 0, 0), 7}, + {types.FromDate(2018, 10, 1, 0, 0, 0, 0), 8}, + {types.FromDate(2019, 11, 1, 0, 0, 0, 0), 9}, + {types.FromDate(2020, 12, 1, 0, 0, 0, 0), 9}, + {types.FromDate(2030, 12, 1, 0, 0, 0, 0), 9}, + } + + for _, ca := range cases3 { + value := types.Time{ + Time: ca.value, + Type: mysql.TypeTimestamp, + } + // test for min int64 handle + idxValue, _, err := index.GenIndexKey(ctx.GetSessionVars().StmtCtx, []types.Datum{types.NewDatum(value)}, math.MinInt64, nil) + c.Assert(err, IsNil) + idx := searchLessEqualIdx(valueList, idxValue) + c.Assert(idx, Equals, ca.lessEqualIdx, Commentf("%#v", ca)) + + // Test for max int64 handle. + idxValue, _, err = index.GenIndexKey(ctx.GetSessionVars().StmtCtx, []types.Datum{types.NewDatum(value)}, math.MaxInt64, nil) + c.Assert(err, IsNil) + idx = searchLessEqualIdx(valueList, idxValue) + c.Assert(idx, Equals, ca.lessEqualIdx, Commentf("%#v", ca)) + } +} + +func (s *testSplitIndex) TestSplitTable(c *C) { + tbInfo := &model.TableInfo{ + Name: model.NewCIStr("t1"), + ID: rand.Int63(), + Columns: []*model.ColumnInfo{ + { + Name: model.NewCIStr("c0"), + ID: 1, + Offset: 1, + DefaultValue: 0, + State: model.StatePublic, + FieldType: *types.NewFieldType(mysql.TypeLong), + }, + }, + } + defer func(originValue uint64) { + minRegionStepValue = originValue + }(minRegionStepValue) + minRegionStepValue = 10 + // range is 0 ~ 100, and split into 10 region. + // So 10 regions range is like below: + // region1: [-inf ~ 10) + // region2: [10 ~ 20) + // region3: [20 ~ 30) + // region4: [30 ~ 40) + // region5: [40 ~ 50) + // region6: [50 ~ 60) + // region7: [60 ~ 70) + // region8: [70 ~ 80) + // region9: [80 ~ 90 ) + // region10: [90 ~ +inf) + ctx := mock.NewContext() + e := &SplitTableRegionExec{ + baseExecutor: newBaseExecutor(ctx, nil, nil), + tableInfo: tbInfo, + lower: types.NewDatum(0), + upper: types.NewDatum(100), + num: 10, + } + valueList, err := e.getSplitTableKeys() + c.Assert(err, IsNil) + c.Assert(len(valueList), Equals, e.num-1) + + cases := []struct { + value int + lessEqualIdx int + }{ + {-1, -1}, + {0, -1}, + {1, -1}, + {10, 0}, + {11, 0}, + {20, 1}, + {21, 1}, + {31, 2}, + {41, 3}, + {51, 4}, + {61, 5}, + {71, 6}, + {81, 7}, + {91, 8}, + {100, 8}, + {1000, 8}, + } + + recordPrefix := tablecodec.GenTableRecordPrefix(e.tableInfo.ID) + for _, ca := range cases { + // test for minInt64 handle + key := tablecodec.EncodeRecordKey(recordPrefix, int64(ca.value)) + c.Assert(err, IsNil) + idx := searchLessEqualIdx(valueList, key) + c.Assert(idx, Equals, ca.lessEqualIdx, Commentf("%#v", ca)) + } +} + +func searchLessEqualIdx(valueList [][]byte, value []byte) int { + idx := -1 + for i, v := range valueList { + if bytes.Compare(value, v) >= 0 { + idx = i + continue + } + break + } + return idx +} diff --git a/executor/statement_context_test.go b/executor/statement_context_test.go index a3593105d8fe1..2b3d3f393f63c 100644 --- a/executor/statement_context_test.go +++ b/executor/statement_context_test.go @@ -102,9 +102,12 @@ func (s *testSuite1) TestStatementContext(c *C) { _, err = tk.Exec("insert t1 values (unhex('F0A48BAE'))") c.Assert(err, NotNil) c.Assert(terror.ErrorEqual(err, table.ErrTruncateWrongValue), IsTrue, Commentf("err %v", err)) - config.GetGlobalConfig().CheckMb4ValueInUTF8 = false + conf := config.GetGlobalConfig() + conf.CheckMb4ValueInUTF8 = false + config.StoreGlobalConfig(conf) tk.MustExec("insert t1 values (unhex('f09f8c80'))") - config.GetGlobalConfig().CheckMb4ValueInUTF8 = true + conf.CheckMb4ValueInUTF8 = true + config.StoreGlobalConfig(conf) _, err = tk.Exec("insert t1 values (unhex('F0A48BAE'))") c.Assert(err, NotNil) } diff --git a/executor/table_reader.go b/executor/table_reader.go index 327b148b02de7..72d04a9968e8d 100644 --- a/executor/table_reader.go +++ b/executor/table_reader.go @@ -28,8 +28,8 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/memory" "github.com/pingcap/tidb/util/ranger" - "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tipb/go-tipb" ) @@ -74,11 +74,16 @@ type TableReaderExecutor struct { corColInAccess bool plans []plannercore.PhysicalPlan + memTracker *memory.Tracker + selectResultHook // for testing } // Open initialzes necessary variables for using this executor. func (e *TableReaderExecutor) Open(ctx context.Context) error { + e.memTracker = memory.NewTracker(e.id, e.ctx.GetSessionVars().MemQuotaDistSQL) + e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) + var err error if e.corColInFilter { e.dagPB.Executors, _, err = constructDistExec(e.ctx, e.plans) @@ -131,7 +136,7 @@ func (e *TableReaderExecutor) Open(ctx context.Context) error { // Next fills data into the chunk passed by its caller. // The task was actually done by tableReaderHandler. -func (e *TableReaderExecutor) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *TableReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("tableReader.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -140,7 +145,7 @@ func (e *TableReaderExecutor) Next(ctx context.Context, req *chunk.RecordBatch) start := time.Now() defer func() { e.runtimeStats.Record(time.Since(start), req.NumRows()) }() } - if err := e.resultHandler.nextChunk(ctx, req.Chunk); err != nil { + if err := e.resultHandler.nextChunk(ctx, req); err != nil { e.feedback.Invalidate() return err } @@ -158,8 +163,6 @@ func (e *TableReaderExecutor) Close() error { return err } -var tableReaderDistSQLTrackerLabel fmt.Stringer = stringutil.StringerStr("TableReaderDistSQLTracker") - // buildResp first builds request and sends it to tikv using distsql.Select. It uses SelectResut returned by the callee // to fetch all results. func (e *TableReaderExecutor) buildResp(ctx context.Context, ranges []*ranger.Range) (distsql.SelectResult, error) { @@ -170,12 +173,12 @@ func (e *TableReaderExecutor) buildResp(ctx context.Context, ranges []*ranger.Ra SetKeepOrder(e.keepOrder). SetStreaming(e.streaming). SetFromSessionVars(e.ctx.GetSessionVars()). - SetMemTracker(e.ctx, tableReaderDistSQLTrackerLabel). + SetMemTracker(e.memTracker). Build() if err != nil { return nil, err } - result, err := e.SelectResult(ctx, e.ctx, kvReq, e.retTypes(), e.feedback, getPhysicalPlanIDs(e.plans)) + result, err := e.SelectResult(ctx, e.ctx, kvReq, retTypes(e), e.feedback, getPhysicalPlanIDs(e.plans)) if err != nil { return nil, err } diff --git a/executor/table_readers_required_rows_test.go b/executor/table_readers_required_rows_test.go index 21819329d6a82..7f7e55caa07b6 100644 --- a/executor/table_readers_required_rows_test.go +++ b/executor/table_readers_required_rows_test.go @@ -178,10 +178,10 @@ func (s *testExecSuite) TestTableReaderRequiredRows(c *C) { ctx := mockDistsqlSelectCtxSet(testCase.totalRows, testCase.expectedRowsDS) exec := buildTableReader(sctx) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) } c.Assert(exec.Close(), IsNil) @@ -230,10 +230,10 @@ func (s *testExecSuite) TestIndexReaderRequiredRows(c *C) { ctx := mockDistsqlSelectCtxSet(testCase.totalRows, testCase.expectedRowsDS) exec := buildIndexReader(sctx) c.Assert(exec.Open(ctx), IsNil) - chk := exec.newFirstChunk() + chk := newFirstChunk(exec) for i := range testCase.requiredRows { chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize) - c.Assert(exec.Next(ctx, chunk.NewRecordBatch(chk)), IsNil) + c.Assert(exec.Next(ctx, chk), IsNil) c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i]) } c.Assert(exec.Close(), IsNil) diff --git a/executor/trace.go b/executor/trace.go index 8b75fa2b4564c..f5fa27b5917d5 100644 --- a/executor/trace.go +++ b/executor/trace.go @@ -48,7 +48,7 @@ type TraceExec struct { } // Next executes real query and collects span later. -func (e *TraceExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *TraceExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() if e.exhausted { return nil @@ -92,7 +92,7 @@ func (e *TraceExec) Next(ctx context.Context, req *chunk.RecordBatch) error { } trace := traces[0] sortTraceByStartTime(trace) - dfsTree(trace, "", false, req.Chunk) + dfsTree(trace, "", false, req) e.exhausted = true return nil } @@ -116,18 +116,18 @@ func (e *TraceExec) Next(ctx context.Context, req *chunk.RecordBatch) error { func drainRecordSet(ctx context.Context, sctx sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) { var rows []chunk.Row - req := rs.NewRecordBatch() + req := rs.NewChunk() for { err := rs.Next(ctx, req) if err != nil || req.NumRows() == 0 { return rows, errors.Trace(err) } - iter := chunk.NewIterator4Chunk(req.Chunk) + iter := chunk.NewIterator4Chunk(req) for r := iter.Begin(); r != iter.End(); r = iter.Next() { rows = append(rows, r) } - req.Chunk = chunk.Renew(req.Chunk, sctx.GetSessionVars().MaxChunkSize) + req = chunk.Renew(req, sctx.GetSessionVars().MaxChunkSize) } } @@ -179,6 +179,18 @@ func dfsTree(t *appdash.Trace, prefix string, isLast bool, chk *chunk.Chunk) { chk.AppendString(1, start.Format("15:04:05.000000")) chk.AppendString(2, duration.String()) + // Sort events by their start time + sort.Slice(t.Sub, func(i, j int) bool { + var istart, jstart time.Time + if ievent, err := t.Sub[i].TimespanEvent(); err == nil { + istart = ievent.Start() + } + if jevent, err := t.Sub[j].TimespanEvent(); err == nil { + jstart = jevent.Start() + } + return istart.Before(jstart) + }) + for i, sp := range t.Sub { dfsTree(sp, newPrefix, i == (len(t.Sub))-1 /*last element of array*/, chk) } diff --git a/executor/trace_test.go b/executor/trace_test.go index fe60b58692cd5..d1e19e210dd5a 100644 --- a/executor/trace_test.go +++ b/executor/trace_test.go @@ -14,6 +14,8 @@ package executor_test import ( + "strings" + . "github.com/pingcap/check" "github.com/pingcap/tidb/util/testkit" ) @@ -39,6 +41,29 @@ func (s *testSuite1) TestTraceExec(c *C) { // | └─recordSet.Next | 22:08:38.249340 | 155.317µs | // +---------------------------+-----------------+------------+ rows = tk.MustQuery("trace format='row' select * from trace where id = 0;").Rows() - c.Assert(len(rows) > 1, IsTrue) + c.Assert(rowsOrdered(rows), IsTrue) +} + +func rowsOrdered(rows [][]interface{}) (ordered bool) { + for idx := range rows { + if idx == 0 || !isSibling(rows[idx-1][0].(string), rows[idx][0].(string)) { + continue + } + + if rows[idx-1][1].(string) > rows[idx][1].(string) { + return false + } + } + return true +} + +func isSibling(x string, y string) bool { + indexF := func(c rune) bool { + if (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') { + return false + } + return true + } + return strings.IndexFunc(x, indexF) == strings.IndexFunc(y, indexF) } diff --git a/executor/union_scan.go b/executor/union_scan.go index 9f95c88d075e0..0098fe99797bd 100644 --- a/executor/union_scan.go +++ b/executor/union_scan.go @@ -16,6 +16,7 @@ package executor import ( "context" "sort" + "sync" "time" "github.com/opentracing/opentracing-go" @@ -32,12 +33,17 @@ import ( // DirtyDB stores uncommitted write operations for a transaction. // It is stored and retrieved by context.Value and context.SetValue method. type DirtyDB struct { + sync.Mutex + // tables is a map whose key is tableID. tables map[int64]*DirtyTable } // GetDirtyTable gets the DirtyTable by id from the DirtyDB. func (udb *DirtyDB) GetDirtyTable(tid int64) *DirtyTable { + // The index join access the tables map parallelly. + // But the map throws panic in this case. So it's locked. + udb.Lock() dt, ok := udb.tables[tid] if !ok { dt = &DirtyTable{ @@ -47,6 +53,7 @@ func (udb *DirtyDB) GetDirtyTable(tid int64) *DirtyTable { } udb.tables[tid] = dt } + udb.Unlock() return dt } @@ -117,12 +124,12 @@ func (us *UnionScanExec) Open(ctx context.Context) error { if err := us.baseExecutor.Open(ctx); err != nil { return err } - us.snapshotChunkBuffer = us.newFirstChunk() + us.snapshotChunkBuffer = newFirstChunk(us) return nil } // Next implements the Executor Next interface. -func (us *UnionScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (us *UnionScanExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("unionScan.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -133,7 +140,7 @@ func (us *UnionScanExec) Next(ctx context.Context, req *chunk.RecordBatch) error defer func() { us.runtimeStats.Record(time.Since(start), req.NumRows()) }() } req.GrowAndReset(us.maxChunkSize) - mutableRow := chunk.MutRowFromTypes(us.retTypes()) + mutableRow := chunk.MutRowFromTypes(retTypes(us)) for i, batchSize := 0, req.Capacity(); i < batchSize; i++ { row, err := us.getOneRow(ctx) if err != nil { @@ -199,7 +206,7 @@ func (us *UnionScanExec) getSnapshotRow(ctx context.Context) ([]types.Datum, err us.cursor4SnapshotRows = 0 us.snapshotRows = us.snapshotRows[:0] for len(us.snapshotRows) == 0 { - err = us.children[0].Next(ctx, chunk.NewRecordBatch(us.snapshotChunkBuffer)) + err = Next(ctx, us.children[0], us.snapshotChunkBuffer) if err != nil || us.snapshotChunkBuffer.NumRows() == 0 { return nil, err } @@ -214,7 +221,7 @@ func (us *UnionScanExec) getSnapshotRow(ctx context.Context) ([]types.Datum, err // commit, but for simplicity, we don't handle it here. continue } - us.snapshotRows = append(us.snapshotRows, row.GetDatumRow(us.children[0].retTypes())) + us.snapshotRows = append(us.snapshotRows, row.GetDatumRow(retTypes(us.children[0]))) } } return us.snapshotRows[0], nil @@ -295,7 +302,7 @@ func (us *UnionScanExec) rowWithColsInTxn(t table.Table, h int64, cols []*table. func (us *UnionScanExec) buildAndSortAddedRows(t table.Table) error { us.addedRows = make([][]types.Datum, 0, len(us.dirty.addedRows)) - mutableRow := chunk.MutRowFromTypes(us.retTypes()) + mutableRow := chunk.MutRowFromTypes(retTypes(us)) cols := t.WritableCols() for h := range us.dirty.addedRows { newData := make([]types.Datum, 0, us.schema.Len()) diff --git a/executor/union_scan_test.go b/executor/union_scan_test.go index 7fc90083cf144..bdb56413a85c4 100644 --- a/executor/union_scan_test.go +++ b/executor/union_scan_test.go @@ -28,7 +28,7 @@ func (s *testSuite4) TestDirtyTransaction(c *C) { tk.MustQuery("select * from t").Check(testkit.Rows("2 3", "4 8", "6 8")) tk.MustExec("insert t values (1, 5), (3, 4), (7, 6)") tk.MustQuery("select * from information_schema.columns") - tk.MustQuery("select * from t").Check(testkit.Rows("1 5", "2 3", "3 4", "4 8", "6 8", "7 6")) + tk.MustQuery("select * from t").Check(testkit.Rows("2 3", "3 4", "1 5", "7 6", "4 8", "6 8")) tk.MustQuery("select * from t where a = 1").Check(testkit.Rows("1 5")) tk.MustQuery("select * from t order by a desc").Check(testkit.Rows("7 6", "6 8", "4 8", "3 4", "2 3", "1 5")) tk.MustQuery("select * from t order by b, a").Check(testkit.Rows("2 3", "3 4", "1 5", "7 6", "4 8", "6 8")) @@ -36,13 +36,13 @@ func (s *testSuite4) TestDirtyTransaction(c *C) { tk.MustQuery("select b from t where b = 8 order by b desc").Check(testkit.Rows("8", "8")) // Delete a snapshot row and a dirty row. tk.MustExec("delete from t where a = 2 or a = 3") - tk.MustQuery("select * from t").Check(testkit.Rows("1 5", "4 8", "6 8", "7 6")) + tk.MustQuery("select * from t").Check(testkit.Rows("1 5", "7 6", "4 8", "6 8")) tk.MustQuery("select * from t order by a desc").Check(testkit.Rows("7 6", "6 8", "4 8", "1 5")) tk.MustQuery("select * from t order by b, a").Check(testkit.Rows("1 5", "7 6", "4 8", "6 8")) tk.MustQuery("select * from t order by b desc, a desc").Check(testkit.Rows("6 8", "4 8", "7 6", "1 5")) // Add deleted row back. tk.MustExec("insert t values (2, 3), (3, 4)") - tk.MustQuery("select * from t").Check(testkit.Rows("1 5", "2 3", "3 4", "4 8", "6 8", "7 6")) + tk.MustQuery("select * from t").Check(testkit.Rows("2 3", "3 4", "1 5", "7 6", "4 8", "6 8")) tk.MustQuery("select * from t order by a desc").Check(testkit.Rows("7 6", "6 8", "4 8", "3 4", "2 3", "1 5")) tk.MustQuery("select * from t order by b, a").Check(testkit.Rows("2 3", "3 4", "1 5", "7 6", "4 8", "6 8")) tk.MustQuery("select * from t order by b desc, a desc").Check(testkit.Rows("6 8", "4 8", "7 6", "1 5", "3 4", "2 3")) diff --git a/executor/update.go b/executor/update.go index c2cdbaf9f9d55..e32b46170b812 100644 --- a/executor/update.go +++ b/executor/update.go @@ -132,7 +132,7 @@ func (e *UpdateExec) canNotUpdate(handle types.Datum) bool { } // Next implements the Executor Next interface. -func (e *UpdateExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *UpdateExec) Next(ctx context.Context, req *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("update.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -165,7 +165,7 @@ func (e *UpdateExec) Next(ctx context.Context, req *chunk.RecordBatch) error { } func (e *UpdateExec) fetchChunkRows(ctx context.Context) error { - fields := e.children[0].retTypes() + fields := retTypes(e.children[0]) schema := e.children[0].Schema() colsInfo := make([]*table.Column, len(fields)) for id, cols := range schema.TblID2Handle { @@ -178,10 +178,10 @@ func (e *UpdateExec) fetchChunkRows(ctx context.Context) error { } } globalRowIdx := 0 - chk := e.children[0].newFirstChunk() + chk := newFirstChunk(e.children[0]) e.evalBuffer = chunk.MutRowFromTypes(fields) for { - err := e.children[0].Next(ctx, chunk.NewRecordBatch(chk)) + err := Next(ctx, e.children[0], chk) if err != nil { return err } diff --git a/executor/update_test.go b/executor/update_test.go index cc51e8dc02f08..7d33c2357e2ad 100644 --- a/executor/update_test.go +++ b/executor/update_test.go @@ -52,7 +52,7 @@ func (s *testUpdateSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() } d, err := session.BootstrapSession(s.store) c.Assert(err, IsNil) @@ -87,3 +87,139 @@ func (s *testUpdateSuite) TestUpdateGenColInTxn(c *C) { tk.MustQuery(`select * from t;`).Check(testkit.Rows( `1 2`)) } + +func (s *testUpdateSuite) TestUpdateWithAutoidSchema(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec(`use test`) + tk.MustExec(`create table t1(id int primary key auto_increment, n int);`) + tk.MustExec(`create table t2(id int primary key, n float auto_increment, key I_n(n));`) + tk.MustExec(`create table t3(id int primary key, n double auto_increment, key I_n(n));`) + + tests := []struct { + exec string + query string + result [][]interface{} + }{ + { + `insert into t1 set n = 1`, + `select * from t1 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `update t1 set id = id+1`, + `select * from t1 where id = 2`, + testkit.Rows(`2 1`), + }, + { + `insert into t1 set n = 2`, + `select * from t1 where id = 3`, + testkit.Rows(`3 2`), + }, + { + `update t1 set id = id + '1.1' where id = 3`, + `select * from t1 where id = 4`, + testkit.Rows(`4 2`), + }, + { + `insert into t1 set n = 3`, + `select * from t1 where id = 5`, + testkit.Rows(`5 3`), + }, + { + `update t1 set id = id + '0.5' where id = 5`, + `select * from t1 where id = 6`, + testkit.Rows(`6 3`), + }, + { + `insert into t1 set n = 4`, + `select * from t1 where id = 7`, + testkit.Rows(`7 4`), + }, + { + `insert into t2 set id = 1`, + `select * from t2 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `update t2 set n = n+1`, + `select * from t2 where id = 1`, + testkit.Rows(`1 2`), + }, + { + `insert into t2 set id = 2`, + `select * from t2 where id = 2`, + testkit.Rows(`2 3`), + }, + { + `update t2 set n = n + '2.2'`, + `select * from t2 where id = 2`, + testkit.Rows(`2 5.2`), + }, + { + `insert into t2 set id = 3`, + `select * from t2 where id = 3`, + testkit.Rows(`3 6`), + }, + { + `update t2 set n = n + '0.5' where id = 3`, + `select * from t2 where id = 3`, + testkit.Rows(`3 6.5`), + }, + { + `insert into t2 set id = 4`, + `select * from t2 where id = 4`, + testkit.Rows(`4 7`), + }, + { + `insert into t3 set id = 1`, + `select * from t3 where id = 1`, + testkit.Rows(`1 1`), + }, + { + `update t3 set n = n+1`, + `select * from t3 where id = 1`, + testkit.Rows(`1 2`), + }, + { + `insert into t3 set id = 2`, + `select * from t3 where id = 2`, + testkit.Rows(`2 3`), + }, + { + `update t3 set n = n + '3.3'`, + `select * from t3 where id = 2`, + testkit.Rows(`2 6.3`), + }, + { + `insert into t3 set id = 3`, + `select * from t3 where id = 3`, + testkit.Rows(`3 7`), + }, + { + `update t3 set n = n + '0.5' where id = 3`, + `select * from t3 where id = 3`, + testkit.Rows(`3 7.5`), + }, + { + `insert into t3 set id = 4`, + `select * from t3 where id = 4`, + testkit.Rows(`4 8`), + }, + } + + for _, tt := range tests { + tk.MustExec(tt.exec) + tk.MustQuery(tt.query).Check(tt.result) + } +} + +func (s *testUpdateSuite) TestUpdateWithSubquery(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t1(id varchar(30) not null, status varchar(1) not null default 'N', id2 varchar(30))") + tk.MustExec("create table t2(id varchar(30) not null, field varchar(4) not null)") + tk.MustExec("insert into t1 values('abc', 'F', 'abc')") + tk.MustExec("insert into t2 values('abc', 'MAIN')") + tk.MustExec("update t1 set status = 'N' where status = 'F' and (id in (select id from t2 where field = 'MAIN') or id2 in (select id from t2 where field = 'main'))") + tk.MustQuery("select * from t1").Check(testkit.Rows("abc N abc")) +} diff --git a/executor/window.go b/executor/window.go index bf4e5a2dab0b1..145ee5606d280 100644 --- a/executor/window.go +++ b/executor/window.go @@ -32,27 +32,27 @@ import ( type WindowExec struct { baseExecutor - groupChecker *groupChecker - inputIter *chunk.Iterator4Chunk - inputRow chunk.Row - groupRows []chunk.Row - childResults []*chunk.Chunk - executed bool - meetNewGroup bool - remainingRowsInGroup int - remainingRowsInChunk int - numWindowFuncs int - processor windowProcessor + groupChecker *groupChecker + // inputIter is the iterator of child chunks + inputIter *chunk.Iterator4Chunk + // executed indicates the child executor is drained or something unexpected happened. + executed bool + // resultChunks stores the chunks to return + resultChunks []*chunk.Chunk + // remainingRowsInChunk indicates how many rows the resultChunks[i] is not prepared. + remainingRowsInChunk []int + + numWindowFuncs int + processor windowProcessor } // Close implements the Executor Close interface. func (e *WindowExec) Close() error { - e.childResults = nil return errors.Trace(e.baseExecutor.Close()) } // Next implements the Executor Next interface. -func (e *WindowExec) Next(ctx context.Context, chk *chunk.RecordBatch) error { +func (e *WindowExec) Next(ctx context.Context, chk *chunk.Chunk) error { if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("windowExec.Next", opentracing.ChildOf(span.Context())) defer span1.Finish() @@ -62,120 +62,114 @@ func (e *WindowExec) Next(ctx context.Context, chk *chunk.RecordBatch) error { defer func() { e.runtimeStats.Record(time.Now().Sub(start), chk.NumRows()) }() } chk.Reset() - if e.meetNewGroup && e.remainingRowsInGroup > 0 { - err := e.appendResult2Chunk(chk.Chunk) + for !e.executed && !e.preparedChunkAvailable() { + err := e.consumeOneGroup(ctx) if err != nil { + e.executed = true return err } } - for !e.executed && (chk.NumRows() == 0 || e.remainingRowsInChunk > 0) { - err := e.consumeOneGroup(ctx, chk.Chunk) - if err != nil { - e.executed = true - return errors.Trace(err) - } + if len(e.resultChunks) > 0 { + chk.SwapColumns(e.resultChunks[0]) + e.resultChunks[0] = nil // GC it. TODO: Reuse it. + e.resultChunks = e.resultChunks[1:] + e.remainingRowsInChunk = e.remainingRowsInChunk[1:] } return nil } -func (e *WindowExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) error { - var err error - if err = e.fetchChildIfNecessary(ctx, chk); err != nil { - return errors.Trace(err) - } - for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { - e.meetNewGroup, err = e.groupChecker.meetNewGroup(e.inputRow) +func (e *WindowExec) preparedChunkAvailable() bool { + return len(e.resultChunks) > 0 && e.remainingRowsInChunk[0] == 0 +} + +func (e *WindowExec) consumeOneGroup(ctx context.Context) error { + var groupRows []chunk.Row + for { + eof, err := e.fetchChildIfNecessary(ctx) if err != nil { return errors.Trace(err) } - if e.meetNewGroup { - err := e.consumeGroupRows() + if eof { + e.executed = true + return e.consumeGroupRows(groupRows) + } + for inputRow := e.inputIter.Current(); inputRow != e.inputIter.End(); inputRow = e.inputIter.Next() { + meetNewGroup, err := e.groupChecker.meetNewGroup(inputRow) if err != nil { return errors.Trace(err) } - err = e.appendResult2Chunk(chk) - if err != nil { - return errors.Trace(err) + if meetNewGroup { + return e.consumeGroupRows(groupRows) } - } - e.remainingRowsInGroup++ - e.groupRows = append(e.groupRows, e.inputRow) - if e.meetNewGroup { - e.inputRow = e.inputIter.Next() - return nil + groupRows = append(groupRows, inputRow) } } - return nil } -func (e *WindowExec) consumeGroupRows() (err error) { - if len(e.groupRows) == 0 { +func (e *WindowExec) consumeGroupRows(groupRows []chunk.Row) (err error) { + remainingRowsInGroup := len(groupRows) + if remainingRowsInGroup == 0 { return nil } - e.groupRows, err = e.processor.consumeGroupRows(e.ctx, e.groupRows) - if err != nil { - return errors.Trace(err) + for i := 0; i < len(e.resultChunks); i++ { + remained := mathutil.Min(e.remainingRowsInChunk[i], remainingRowsInGroup) + e.remainingRowsInChunk[i] -= remained + remainingRowsInGroup -= remained + + // TODO: Combine these three methods. + // The old implementation needs the processor has these three methods + // but now it does not have to. + groupRows, err = e.processor.consumeGroupRows(e.ctx, groupRows) + if err != nil { + return errors.Trace(err) + } + _, err = e.processor.appendResult2Chunk(e.ctx, groupRows, e.resultChunks[i], remained) + if err != nil { + return errors.Trace(err) + } + if remainingRowsInGroup == 0 { + e.processor.resetPartialResult() + break + } } return nil } -func (e *WindowExec) fetchChildIfNecessary(ctx context.Context, chk *chunk.Chunk) (err error) { - if e.inputIter != nil && e.inputRow != e.inputIter.End() { - return nil +func (e *WindowExec) fetchChildIfNecessary(ctx context.Context) (EOF bool, err error) { + if e.inputIter != nil && e.inputIter.Current() != e.inputIter.End() { + return false, nil } - // Before fetching a new batch of input, we should consume the last group rows. - err = e.consumeGroupRows() + childResult := newFirstChunk(e.children[0]) + err = Next(ctx, e.children[0], childResult) if err != nil { - return errors.Trace(err) + return false, errors.Trace(err) } - - childResult := e.children[0].newFirstChunk() - err = e.children[0].Next(ctx, &chunk.RecordBatch{Chunk: childResult}) - if err != nil { - return errors.Trace(err) - } - e.childResults = append(e.childResults, childResult) // No more data. - if childResult.NumRows() == 0 { - e.executed = true - err = e.appendResult2Chunk(chk) - return errors.Trace(err) + numRows := childResult.NumRows() + if numRows == 0 { + return true, nil } - e.inputIter = chunk.NewIterator4Chunk(childResult) - e.inputRow = e.inputIter.Begin() - return nil -} - -// appendResult2Chunk appends result of the window function to the result chunk. -func (e *WindowExec) appendResult2Chunk(chk *chunk.Chunk) (err error) { - e.copyChk(chk) - remained := mathutil.Min(e.remainingRowsInChunk, e.remainingRowsInGroup) - e.groupRows, err = e.processor.appendResult2Chunk(e.ctx, e.groupRows, chk, remained) + resultChk := chunk.New(e.retFieldTypes, 0, numRows) + err = e.copyChk(childResult, resultChk) if err != nil { - return err + return false, err } - e.remainingRowsInGroup -= remained - e.remainingRowsInChunk -= remained - if e.remainingRowsInGroup == 0 { - e.processor.resetPartialResult() - e.groupRows = e.groupRows[:0] - } - return nil + e.resultChunks = append(e.resultChunks, resultChk) + e.remainingRowsInChunk = append(e.remainingRowsInChunk, numRows) + + e.inputIter = chunk.NewIterator4Chunk(childResult) + e.inputIter.Begin() + return false, nil } -func (e *WindowExec) copyChk(chk *chunk.Chunk) { - if len(e.childResults) == 0 || chk.NumRows() > 0 { - return - } - childResult := e.childResults[0] - e.childResults = e.childResults[1:] - e.remainingRowsInChunk = childResult.NumRows() +func (e *WindowExec) copyChk(src, dst *chunk.Chunk) error { columns := e.Schema().Columns[:len(e.Schema().Columns)-e.numWindowFuncs] for i, col := range columns { - chk.MakeRefTo(i, childResult, col.Index) + dst.MakeRefTo(i, src, col.Index) } + return nil } // windowProcessor is the interface for processing different kinds of windows. diff --git a/executor/window_test.go b/executor/window_test.go index 357c6ff2c785e..6f9a460d9e36d 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -20,6 +20,7 @@ import ( func (s *testSuite4) TestWindowFunctions(c *C) { tk := testkit.NewTestKit(c, s.store) + var result *testkit.Result tk.MustExec("use test") tk.MustExec("drop table if exists t") tk.MustExec("create table t (a int, b int, c int)") @@ -28,7 +29,7 @@ func (s *testSuite4) TestWindowFunctions(c *C) { tk.MustExec("set @@tidb_enable_window_function = 0") }() tk.MustExec("insert into t values (1,2,3),(4,3,2),(2,3,4)") - result := tk.MustQuery("select count(a) over () from t") + result = tk.MustQuery("select count(a) over () from t") result.Check(testkit.Rows("3", "3", "3")) result = tk.MustQuery("select sum(a) over () + count(a) over () from t") result.Check(testkit.Rows("10", "10", "10")) @@ -160,10 +161,43 @@ func (s *testSuite4) TestWindowFunctions(c *C) { ), ) + tk.MustExec("CREATE TABLE td_dec (id DECIMAL(10,2), sex CHAR(1));") + tk.MustExec("insert into td_dec value (2.0, 'F'), (NULL, 'F'), (1.0, 'F')") + tk.MustQuery("SELECT id, FIRST_VALUE(id) OVER w FROM td_dec WINDOW w AS (ORDER BY id);").Check( + testkit.Rows(" ", "1.00 ", "2.00 "), + ) + result = tk.MustQuery("select sum(a) over w, sum(b) over w from t window w as (order by a)") result.Check(testkit.Rows("2 3", "2 3", "6 6", "6 6")) result = tk.MustQuery("select row_number() over w, sum(b) over w from t window w as (order by a)") result.Check(testkit.Rows("1 3", "2 3", "3 6", "4 6")) result = tk.MustQuery("select row_number() over w, sum(b) over w from t window w as (rows between 1 preceding and 1 following)") result.Check(testkit.Rows("1 3", "2 4", "3 5", "4 3")) + + tk.Se.GetSessionVars().MaxChunkSize = 1 + result = tk.MustQuery("select a, row_number() over (partition by a) from t") + result.Check(testkit.Rows("1 1", "1 2", "2 1", "2 2")) +} + +func (s *testSuite4) TestWindowFunctionsDataReference(c *C) { + // see https://github.com/pingcap/tidb/issues/11614 + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("insert into t values (2,1),(2,2),(2,3)") + + tk.Se.GetSessionVars().MaxChunkSize = 2 + result := tk.MustQuery("select a, b, rank() over (partition by a order by b) from t") + result.Check(testkit.Rows("2 1 1", "2 2 2", "2 3 3")) + result = tk.MustQuery("select a, b, PERCENT_RANK() over (partition by a order by b) from t") + result.Check(testkit.Rows("2 1 0", "2 2 0.5", "2 3 1")) + result = tk.MustQuery("select a, b, CUME_DIST() over (partition by a order by b) from t") + result.Check(testkit.Rows("2 1 0.3333333333333333", "2 2 0.6666666666666666", "2 3 1")) + + // see https://github.com/pingcap/tidb/issues/12415 + result = tk.MustQuery("select b, first_value(b) over (order by b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) from t") + result.Check(testkit.Rows("1 1", "2 1", "3 1")) + result = tk.MustQuery("select b, first_value(b) over (order by b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) from t") + result.Check(testkit.Rows("1 1", "2 1", "3 1")) } diff --git a/executor/write.go b/executor/write.go index f93c302b479b1..2a531eeff63f0 100644 --- a/executor/write.go +++ b/executor/write.go @@ -88,7 +88,11 @@ func updateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datu modified[i] = true // Rebase auto increment id if the field is changed. if mysql.HasAutoIncrementFlag(col.Flag) { - if err = t.RebaseAutoID(ctx, newData[i].GetInt64(), true); err != nil { + recordID, err := getAutoRecordID(newData[i], &col.FieldType, false) + if err != nil { + return false, false, 0, err + } + if err = t.RebaseAutoID(ctx, recordID, true); err != nil { return false, false, 0, err } } diff --git a/executor/write_test.go b/executor/write_test.go index 8749290563a10..e664fe6a2dcb9 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -17,12 +17,10 @@ import ( "context" "errors" "fmt" - "strings" - "sync/atomic" . "github.com/pingcap/check" - "github.com/pingcap/failpoint" "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/core" @@ -253,7 +251,7 @@ func (s *testSuite4) TestInsert(c *C) { tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) tk.MustExec("insert into t select cast(-1 as unsigned);") tk.MustExec("insert into t value (-1.111);") - tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1 overflows bigint")) + tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 constant -1.111 overflows bigint")) tk.MustExec("insert into t value ('-1.111');") tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1690 BIGINT UNSIGNED value is out of range in '-1'")) r = tk.MustQuery("select * from t;") @@ -1777,6 +1775,43 @@ func (s *testSuite4) TestQualifiedDelete(c *C) { c.Assert(err, NotNil) } +func (s *testSuite4) TestLoadDataMissingColumn(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + createSQL := `create table load_data_missing (id int, t timestamp not null)` + tk.MustExec(createSQL) + tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_missing") + ctx := tk.Se.(sessionctx.Context) + ld, ok := ctx.Value(executor.LoadDataVarKey).(*executor.LoadDataInfo) + c.Assert(ok, IsTrue) + defer ctx.SetValue(executor.LoadDataVarKey, nil) + c.Assert(ld, NotNil) + + deleteSQL := "delete from load_data_missing" + selectSQL := "select * from load_data_missing;" + _, reachLimit, err := ld.InsertData(context.Background(), nil, nil) + c.Assert(err, IsNil) + c.Assert(reachLimit, IsFalse) + r := tk.MustQuery(selectSQL) + r.Check(nil) + + curTime := types.CurrentTime(mysql.TypeTimestamp) + timeStr := curTime.String() + tests := []testCase{ + {nil, []byte("12\n"), []string{fmt.Sprintf("12|%v", timeStr)}, nil, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"}, + } + checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) + + tk.MustExec("alter table load_data_missing add column t2 timestamp null") + curTime = types.CurrentTime(mysql.TypeTimestamp) + timeStr = curTime.String() + tests = []testCase{ + {nil, []byte("12\n"), []string{fmt.Sprintf("12|%v|", timeStr)}, nil, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"}, + } + checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) + +} + func (s *testSuite4) TestLoadData(c *C) { trivialMsg := "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0" tk := testkit.NewTestKit(c, s.store) @@ -1802,7 +1837,7 @@ func (s *testSuite4) TestLoadData(c *C) { // data1 = nil, data2 = nil, fields and lines is default ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true - _, reachLimit, err := ld.InsertData(nil, nil) + _, reachLimit, err := ld.InsertData(context.Background(), nil, nil) c.Assert(err, IsNil) c.Assert(reachLimit, IsFalse) r := tk.MustQuery(selectSQL) @@ -2040,123 +2075,6 @@ func (s *testSuite4) TestLoadDataOverflowBigintUnsigned(c *C) { checkCases(tests, ld, c, tk, ctx, selectSQL, deleteSQL) } -func (s *testSuite4) TestBatchInsertDelete(c *C) { - originLimit := atomic.LoadUint64(&kv.TxnEntryCountLimit) - defer func() { - atomic.StoreUint64(&kv.TxnEntryCountLimit, originLimit) - }() - // Set the limitation to a small value, make it easier to reach the limitation. - atomic.StoreUint64(&kv.TxnEntryCountLimit, 100) - - tk := testkit.NewTestKit(c, s.store) - tk.MustExec("use test") - tk.MustExec("drop table if exists batch_insert") - tk.MustExec("create table batch_insert (c int)") - tk.MustExec("drop table if exists batch_insert_on_duplicate") - tk.MustExec("create table batch_insert_on_duplicate (id int primary key, c int)") - // Insert 10 rows. - tk.MustExec("insert into batch_insert values (1),(1),(1),(1),(1),(1),(1),(1),(1),(1)") - r := tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("10")) - // Insert 10 rows. - tk.MustExec("insert into batch_insert (c) select * from batch_insert;") - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("20")) - // Insert 20 rows. - tk.MustExec("insert into batch_insert (c) select * from batch_insert;") - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("40")) - // Insert 40 rows. - tk.MustExec("insert into batch_insert (c) select * from batch_insert;") - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("80")) - // Insert 80 rows. - tk.MustExec("insert into batch_insert (c) select * from batch_insert;") - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("160")) - // for on duplicate key - for i := 0; i < 160; i++ { - tk.MustExec(fmt.Sprintf("insert into batch_insert_on_duplicate values(%d, %d);", i, i)) - } - r = tk.MustQuery("select count(*) from batch_insert_on_duplicate;") - r.Check(testkit.Rows("160")) - - // This will meet txn too large error. - _, err := tk.Exec("insert into batch_insert (c) select * from batch_insert;") - c.Assert(err, NotNil) - c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue) - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("160")) - - // for on duplicate key - _, err = tk.Exec(`insert into batch_insert_on_duplicate select * from batch_insert_on_duplicate as tt - on duplicate key update batch_insert_on_duplicate.id=batch_insert_on_duplicate.id+1000;`) - c.Assert(err, NotNil) - c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue, Commentf("%v", err)) - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("160")) - - // Change to batch inset mode and batch size to 50. - tk.MustExec("set @@session.tidb_batch_insert=1;") - tk.MustExec("set @@session.tidb_dml_batch_size=50;") - tk.MustExec("insert into batch_insert (c) select * from batch_insert;") - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("320")) - - // Enlarge the batch size to 150 which is larger than the txn limitation (100). - // So the insert will meet error. - tk.MustExec("set @@session.tidb_dml_batch_size=150;") - _, err = tk.Exec("insert into batch_insert (c) select * from batch_insert;") - c.Assert(err, NotNil) - c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue) - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("320")) - // Set it back to 50. - tk.MustExec("set @@session.tidb_dml_batch_size=50;") - - // for on duplicate key - _, err = tk.Exec(`insert into batch_insert_on_duplicate select * from batch_insert_on_duplicate as tt - on duplicate key update batch_insert_on_duplicate.id=batch_insert_on_duplicate.id+1000;`) - c.Assert(err, IsNil) - r = tk.MustQuery("select count(*) from batch_insert_on_duplicate;") - r.Check(testkit.Rows("160")) - - // Disable BachInsert mode in transition. - tk.MustExec("begin;") - _, err = tk.Exec("insert into batch_insert (c) select * from batch_insert;") - c.Assert(err, NotNil) - c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue) - tk.MustExec("rollback;") - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("320")) - - tk.MustExec("drop table if exists com_batch_insert") - tk.MustExec("create table com_batch_insert (c int)") - sql := "insert into com_batch_insert values " - values := make([]string, 0, 200) - for i := 0; i < 200; i++ { - values = append(values, "(1)") - } - sql = sql + strings.Join(values, ",") - tk.MustExec(sql) - tk.MustQuery("select count(*) from com_batch_insert;").Check(testkit.Rows("200")) - - // Test case for batch delete. - // This will meet txn too large error. - _, err = tk.Exec("delete from batch_insert;") - c.Assert(err, NotNil) - c.Assert(kv.ErrTxnTooLarge.Equal(err), IsTrue) - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("320")) - // Enable batch delete and set batch size to 50. - tk.MustExec("set @@session.tidb_batch_delete=on;") - tk.MustExec("set @@session.tidb_dml_batch_size=50;") - tk.MustExec("delete from batch_insert;") - // Make sure that all rows are gone. - r = tk.MustQuery("select count(*) from batch_insert;") - r.Check(testkit.Rows("0")) -} - func (s *testSuite4) TestNullDefault(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test; drop table if exists test_null_default;") @@ -2518,20 +2436,14 @@ func (s *testSuite4) TestDefEnumInsert(c *C) { tk.MustQuery("select prescription_type from test").Check(testkit.Rows("a")) } -func (s *testSuite4) TestAutoIDInRetry(c *C) { +func (s *testSuite4) TestSetWithCurrentTimestampAndNow(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) - tk.MustExec("create table t (id int not null auto_increment primary key)") - - tk.MustExec("set @@tidb_disable_txn_auto_retry = 0") - tk.MustExec("begin") - tk.MustExec("insert into t values ()") - tk.MustExec("insert into t values (),()") - tk.MustExec("insert into t values ()") - - c.Assert(failpoint.Enable("github.com/pingcap/tidb/session/mockCommitRetryForAutoID", `return(true)`), IsNil) - tk.MustExec("commit") - c.Assert(failpoint.Disable("github.com/pingcap/tidb/session/mockCommitRetryForAutoID"), IsNil) - - tk.MustExec("insert into t values ()") - tk.MustQuery(`select * from t`).Check(testkit.Rows("1", "2", "3", "4", "5")) + tk.MustExec("use test") + tk.MustExec(`drop table if exists tbl;`) + tk.MustExec(`create table t1(c1 timestamp default current_timestamp, c2 int, c3 timestamp default current_timestamp);`) + //c1 insert using now() function result, c3 using default value calculation, should be same + tk.MustExec(`insert into t1 set c1 = current_timestamp, c2 = sleep(2);`) + tk.MustQuery("select c1 = c3 from t1").Check(testkit.Rows("1")) + tk.MustExec(`insert into t1 set c1 = current_timestamp, c2 = sleep(1);`) + tk.MustQuery("select c1 = c3 from t1").Check(testkit.Rows("1", "1")) } diff --git a/expression/aggregation/agg_to_pb_test.go b/expression/aggregation/agg_to_pb_test.go index 7e80f3d2c7a90..115bb6bb81ba3 100644 --- a/expression/aggregation/agg_to_pb_test.go +++ b/expression/aggregation/agg_to_pb_test.go @@ -76,7 +76,8 @@ func (s *testEvaluatorSuite) TestAggFunc2Pb(c *C) { } for _, funcName := range funcNames { args := []expression.Expression{dg.genColumn(mysql.TypeDouble, 1)} - aggFunc := NewAggFuncDesc(s.ctx, funcName, args, true) + aggFunc, err := NewAggFuncDesc(s.ctx, funcName, args, true) + c.Assert(err, IsNil) pbExpr := AggFuncToPBExpr(sc, client, aggFunc) js, err := json.Marshal(pbExpr) c.Assert(err, IsNil) @@ -94,7 +95,8 @@ func (s *testEvaluatorSuite) TestAggFunc2Pb(c *C) { } for i, funcName := range funcNames { args := []expression.Expression{dg.genColumn(mysql.TypeDouble, 1)} - aggFunc := NewAggFuncDesc(s.ctx, funcName, args, false) + aggFunc, err := NewAggFuncDesc(s.ctx, funcName, args, false) + c.Assert(err, IsNil) aggFunc.RetTp = funcTypes[i] pbExpr := AggFuncToPBExpr(sc, client, aggFunc) js, err := json.Marshal(pbExpr) diff --git a/expression/aggregation/aggregation_test.go b/expression/aggregation/aggregation_test.go index 0ebbe8a330eef..307382e08f46c 100644 --- a/expression/aggregation/aggregation_test.go +++ b/expression/aggregation/aggregation_test.go @@ -58,7 +58,9 @@ func (s *testAggFuncSuit) TestAvg(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - avgFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, false) + c.Assert(err, IsNil) + avgFunc := desc.GetAggFunc(ctx) evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := avgFunc.GetResult(evalCtx) @@ -71,12 +73,14 @@ func (s *testAggFuncSuit) TestAvg(c *C) { result = avgFunc.GetResult(evalCtx) needed := types.NewDecFromStringForTest("67.000000000000000000000000000000") c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) - err := avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) + err = avgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) c.Assert(err, IsNil) result = avgFunc.GetResult(evalCtx) c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) - distinctAvgFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{col}, true) + c.Assert(err, IsNil) + distinctAvgFunc := desc.GetAggFunc(ctx) evalCtx = distinctAvgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { err := distinctAvgFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) @@ -105,7 +109,8 @@ func (s *testAggFuncSuit) TestAvgFinalMode(c *C) { Index: 1, RetType: types.NewFieldType(mysql.TypeNewDecimal), } - aggFunc := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{cntCol, sumCol}, false) + aggFunc, err := NewAggFuncDesc(s.ctx, ast.AggFuncAvg, []expression.Expression{cntCol, sumCol}, false) + c.Assert(err, IsNil) aggFunc.Mode = FinalMode avgFunc := aggFunc.GetAggFunc(ctx) evalCtx := avgFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) @@ -125,7 +130,9 @@ func (s *testAggFuncSuit) TestSum(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - sumFunc := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, false) + c.Assert(err, IsNil) + sumFunc := desc.GetAggFunc(ctx) evalCtx := sumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := sumFunc.GetResult(evalCtx) @@ -138,14 +145,16 @@ func (s *testAggFuncSuit) TestSum(c *C) { result = sumFunc.GetResult(evalCtx) needed := types.NewDecFromStringForTest("338350") c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) - err := sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) + err = sumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) c.Assert(err, IsNil) result = sumFunc.GetResult(evalCtx) c.Assert(result.GetMysqlDecimal().Compare(needed) == 0, IsTrue) partialResult := sumFunc.GetPartialResult(evalCtx) c.Assert(partialResult[0].GetMysqlDecimal().Compare(needed) == 0, IsTrue) - distinctSumFunc := NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncSum, []expression.Expression{col}, true) + c.Assert(err, IsNil) + distinctSumFunc := desc.GetAggFunc(ctx) evalCtx = distinctSumFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { err := distinctSumFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) @@ -162,14 +171,16 @@ func (s *testAggFuncSuit) TestBitAnd(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - bitAndFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitAnd, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitAnd, []expression.Expression{col}, false) + c.Assert(err, IsNil) + bitAndFunc := desc.GetAggFunc(ctx) evalCtx := bitAndFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitAndFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(math.MaxUint64)) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = bitAndFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result = bitAndFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -238,14 +249,16 @@ func (s *testAggFuncSuit) TestBitOr(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - bitOrFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitOr, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitOr, []expression.Expression{col}, false) + c.Assert(err, IsNil) + bitOrFunc := desc.GetAggFunc(ctx) evalCtx := bitOrFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitOrFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(0)) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = bitOrFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result = bitOrFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -322,14 +335,16 @@ func (s *testAggFuncSuit) TestBitXor(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - bitXorFunc := NewAggFuncDesc(s.ctx, ast.AggFuncBitXor, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncBitXor, []expression.Expression{col}, false) + c.Assert(err, IsNil) + bitXorFunc := desc.GetAggFunc(ctx) evalCtx := bitXorFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := bitXorFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(0)) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = bitXorFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result = bitXorFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -398,7 +413,9 @@ func (s *testAggFuncSuit) TestCount(c *C) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - countFunc := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, false) + c.Assert(err, IsNil) + countFunc := desc.GetAggFunc(ctx) evalCtx := countFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := countFunc.GetResult(evalCtx) @@ -410,14 +427,16 @@ func (s *testAggFuncSuit) TestCount(c *C) { } result = countFunc.GetResult(evalCtx) c.Assert(result.GetInt64(), Equals, int64(5050)) - err := countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) + err = countFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, s.nullRow) c.Assert(err, IsNil) result = countFunc.GetResult(evalCtx) c.Assert(result.GetInt64(), Equals, int64(5050)) partialResult := countFunc.GetPartialResult(evalCtx) c.Assert(partialResult[0].GetInt64(), Equals, int64(5050)) - distinctCountFunc := NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncCount, []expression.Expression{col}, true) + c.Assert(err, IsNil) + distinctCountFunc := desc.GetAggFunc(ctx) evalCtx = distinctCountFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) for _, row := range s.rows { @@ -438,14 +457,16 @@ func (s *testAggFuncSuit) TestConcat(c *C) { RetType: types.NewFieldType(mysql.TypeVarchar), } ctx := mock.NewContext() - concatFunc := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, false) + c.Assert(err, IsNil) + concatFunc := desc.GetAggFunc(ctx) evalCtx := concatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) result := concatFunc.GetResult(evalCtx) c.Assert(result.IsNull(), IsTrue) row := chunk.MutRowFromDatums(types.MakeDatums(1, "x")) - err := concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) + err = concatFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) c.Assert(err, IsNil) result = concatFunc.GetResult(evalCtx) c.Assert(result.GetString(), Equals, "1") @@ -464,7 +485,9 @@ func (s *testAggFuncSuit) TestConcat(c *C) { partialResult := concatFunc.GetPartialResult(evalCtx) c.Assert(partialResult[0].GetString(), Equals, "1x2") - distinctConcatFunc := NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, true).GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncGroupConcat, []expression.Expression{col, sep}, true) + c.Assert(err, IsNil) + distinctConcatFunc := desc.GetAggFunc(ctx) evalCtx = distinctConcatFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) row.SetDatum(0, types.NewIntDatum(1)) @@ -487,11 +510,13 @@ func (s *testAggFuncSuit) TestFirstRow(c *C) { } ctx := mock.NewContext() - firstRowFunc := NewAggFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + c.Assert(err, IsNil) + firstRowFunc := desc.GetAggFunc(ctx) evalCtx := firstRowFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) row := chunk.MutRowFromDatums(types.MakeDatums(1)).ToRow() - err := firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) + err = firstRowFunc.Update(evalCtx, s.ctx.GetSessionVars().StmtCtx, row) c.Assert(err, IsNil) result := firstRowFunc.GetResult(evalCtx) c.Assert(result.GetUint64(), Equals, uint64(1)) @@ -512,8 +537,12 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) { } ctx := mock.NewContext() - maxFunc := NewAggFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}, false).GetAggFunc(ctx) - minFunc := NewAggFuncDesc(s.ctx, ast.AggFuncMin, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}, false) + c.Assert(err, IsNil) + maxFunc := desc.GetAggFunc(ctx) + desc, err = NewAggFuncDesc(s.ctx, ast.AggFuncMin, []expression.Expression{col}, false) + c.Assert(err, IsNil) + minFunc := desc.GetAggFunc(ctx) maxEvalCtx := maxFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) minEvalCtx := minFunc.CreateContext(s.ctx.GetSessionVars().StmtCtx) @@ -523,7 +552,7 @@ func (s *testAggFuncSuit) TestMaxMin(c *C) { c.Assert(result.IsNull(), IsTrue) row := chunk.MutRowFromDatums(types.MakeDatums(2)) - err := maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) + err = maxFunc.Update(maxEvalCtx, s.ctx.GetSessionVars().StmtCtx, row.ToRow()) c.Assert(err, IsNil) result = maxFunc.GetResult(maxEvalCtx) c.Assert(result.GetInt64(), Equals, int64(2)) diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index ba0e716853476..6706eeea99d9d 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -19,6 +19,7 @@ import ( "strings" "github.com/cznic/mathutil" + "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" @@ -37,10 +38,10 @@ type baseFuncDesc struct { RetTp *types.FieldType } -func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) baseFuncDesc { +func newBaseFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (baseFuncDesc, error) { b := baseFuncDesc{Name: strings.ToLower(name), Args: args} - b.typeInfer(ctx) - return b + err := b.typeInfer(ctx) + return b, err } func (a *baseFuncDesc) equal(ctx sessionctx.Context, other *baseFuncDesc) bool { @@ -81,7 +82,7 @@ func (a *baseFuncDesc) String() string { } // typeInfer infers the arguments and return types of an function. -func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { +func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error { switch a.Name { case ast.AggFuncCount: a.typeInfer4Count(ctx) @@ -107,8 +108,9 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) { case ast.WindowFuncLead, ast.WindowFuncLag: a.typeInfer4LeadLag(ctx) default: - panic("unsupported agg function: " + a.Name) + return errors.Errorf("unsupported agg function: %s", a.Name) } + return nil } func (a *baseFuncDesc) typeInfer4Count(ctx sessionctx.Context) { @@ -182,6 +184,10 @@ func (a *baseFuncDesc) typeInfer4MaxMin(ctx sessionctx.Context) { a.Args[0] = expression.BuildCastFunction(ctx, a.Args[0], tp) } a.RetTp = a.Args[0].GetType() + if (a.Name == ast.AggFuncMax || a.Name == ast.AggFuncMin) && a.RetTp.Tp != mysql.TypeBit { + a.RetTp = a.Args[0].GetType().Clone() + a.RetTp.Flag &^= mysql.NotNullFlag + } if a.RetTp.Tp == mysql.TypeEnum || a.RetTp.Tp == mysql.TypeSet { a.RetTp = &types.FieldType{Tp: mysql.TypeString, Flen: mysql.MaxFieldCharLength} } diff --git a/expression/aggregation/base_func_test.go b/expression/aggregation/base_func_test.go index 3002400c6ce85..bf6e96364fe39 100644 --- a/expression/aggregation/base_func_test.go +++ b/expression/aggregation/base_func_test.go @@ -25,7 +25,8 @@ func (s *testBaseFuncSuite) TestClone(c *check.C) { UniqueID: 0, RetType: types.NewFieldType(mysql.TypeLonglong), } - desc := newBaseFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}) + desc, err := newBaseFuncDesc(s.ctx, ast.AggFuncFirstRow, []expression.Expression{col}) + c.Assert(err, check.IsNil) cloned := desc.clone() c.Assert(desc.equal(s.ctx, cloned), check.IsTrue) @@ -38,3 +39,14 @@ func (s *testBaseFuncSuite) TestClone(c *check.C) { c.Assert(desc.Args[0], check.Equals, col) c.Assert(desc.equal(s.ctx, cloned), check.IsFalse) } + +func (s *testBaseFuncSuite) TestMaxMin(c *check.C) { + col := &expression.Column{ + UniqueID: 0, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + col.RetType.Flag |= mysql.NotNullFlag + desc, err := newBaseFuncDesc(s.ctx, ast.AggFuncMax, []expression.Expression{col}) + c.Assert(err, check.IsNil) + c.Assert(mysql.HasNotNullFlag(desc.RetTp.Flag), check.IsFalse) +} diff --git a/expression/aggregation/bench_test.go b/expression/aggregation/bench_test.go index e49deebe00da3..c3f72695709cd 100644 --- a/expression/aggregation/bench_test.go +++ b/expression/aggregation/bench_test.go @@ -29,7 +29,11 @@ func BenchmarkCreateContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) b.StartTimer() for i := 0; i < b.N; i++ { fun.CreateContext(ctx.GetSessionVars().StmtCtx) @@ -43,7 +47,11 @@ func BenchmarkResetContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, false) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx) b.StartTimer() for i := 0; i < b.N; i++ { @@ -58,7 +66,11 @@ func BenchmarkCreateDistinctContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) b.StartTimer() for i := 0; i < b.N; i++ { fun.CreateContext(ctx.GetSessionVars().StmtCtx) @@ -72,7 +84,11 @@ func BenchmarkResetDistinctContext(b *testing.B) { RetType: types.NewFieldType(mysql.TypeLonglong), } ctx := mock.NewContext() - fun := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true).GetAggFunc(ctx) + desc, err := NewAggFuncDesc(ctx, ast.AggFuncAvg, []expression.Expression{col}, true) + if err != nil { + b.Fatal(err) + } + fun := desc.GetAggFunc(ctx) evalCtx := fun.CreateContext(ctx.GetSessionVars().StmtCtx) b.StartTimer() for i := 0; i < b.N; i++ { diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index d8e544171d8cd..66f5f3346c805 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -37,9 +37,12 @@ type AggFuncDesc struct { } // NewAggFuncDesc creates an aggregation function signature descriptor. -func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) *AggFuncDesc { - b := newBaseFuncDesc(ctx, name, args) - return &AggFuncDesc{baseFuncDesc: b, HasDistinct: hasDistinct} +func NewAggFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression, hasDistinct bool) (*AggFuncDesc, error) { + b, err := newBaseFuncDesc(ctx, name, args) + if err != nil { + return nil, err + } + return &AggFuncDesc{baseFuncDesc: b, HasDistinct: hasDistinct}, nil } // Equal checks whether two aggregation function signatures are equal. diff --git a/expression/aggregation/window_func.go b/expression/aggregation/window_func.go index 28ccfed44e98d..8f963480dde16 100644 --- a/expression/aggregation/window_func.go +++ b/expression/aggregation/window_func.go @@ -27,19 +27,19 @@ type WindowFuncDesc struct { } // NewWindowFuncDesc creates a window function signature descriptor. -func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) *WindowFuncDesc { +func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Expression) (*WindowFuncDesc, error) { switch strings.ToLower(name) { case ast.WindowFuncNthValue: val, isNull, ok := expression.GetUint64FromConstant(args[1]) // nth_value does not allow `0`, but allows `null`. if !ok || (val == 0 && !isNull) { - return nil + return nil, nil } case ast.WindowFuncNtile: val, isNull, ok := expression.GetUint64FromConstant(args[0]) // ntile does not allow `0`, but allows `null`. if !ok || (val == 0 && !isNull) { - return nil + return nil, nil } case ast.WindowFuncLead, ast.WindowFuncLag: if len(args) < 2 { @@ -47,10 +47,14 @@ func NewWindowFuncDesc(ctx sessionctx.Context, name string, args []expression.Ex } _, isNull, ok := expression.GetUint64FromConstant(args[1]) if !ok || isNull { - return nil + return nil, nil } } - return &WindowFuncDesc{newBaseFuncDesc(ctx, name, args)} + base, err := newBaseFuncDesc(ctx, name, args) + if err != nil { + return nil, err + } + return &WindowFuncDesc{base}, nil } // noFrameWindowFuncs is the functions that operate on the entire partition, diff --git a/expression/builtin.go b/expression/builtin.go index f779687d863f0..da5f0a6509a0b 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -415,7 +415,6 @@ var funcs = map[string]functionClass{ ast.Year: &yearFunctionClass{baseFunctionClass{ast.Year, 1, 1}}, ast.YearWeek: &yearWeekFunctionClass{baseFunctionClass{ast.YearWeek, 1, 2}}, ast.LastDay: &lastDayFunctionClass{baseFunctionClass{ast.LastDay, 1, 1}}, - ast.TiDBParseTso: &tidbParseTsoFunctionClass{baseFunctionClass{ast.TiDBParseTso, 1, 1}}, // string functions ast.ASCII: &asciiFunctionClass{baseFunctionClass{ast.ASCII, 1, 1}}, @@ -486,9 +485,6 @@ var funcs = map[string]functionClass{ ast.RowCount: &rowCountFunctionClass{baseFunctionClass{ast.RowCount, 0, 0}}, ast.SessionUser: &userFunctionClass{baseFunctionClass{ast.SessionUser, 0, 0}}, ast.SystemUser: &userFunctionClass{baseFunctionClass{ast.SystemUser, 0, 0}}, - // This function is used to show tidb-server version info. - ast.TiDBVersion: &tidbVersionFunctionClass{baseFunctionClass{ast.TiDBVersion, 0, 0}}, - ast.TiDBIsDDLOwner: &tidbIsDDLOwnerFunctionClass{baseFunctionClass{ast.TiDBIsDDLOwner, 0, 0}}, // control functions ast.If: &ifFunctionClass{baseFunctionClass{ast.If, 3, 3}}, @@ -544,8 +540,8 @@ var funcs = map[string]functionClass{ ast.Xor: &bitXorFunctionClass{baseFunctionClass{ast.Xor, 2, 2}}, ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}}, ast.In: &inFunctionClass{baseFunctionClass{ast.In, 2, -1}}, - ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth}, - ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity}, + ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, false}, + ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity, false}, ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 3, 3}}, ast.Regexp: ®expFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}}, ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}}, @@ -600,4 +596,11 @@ var funcs = map[string]functionClass{ ast.JSONDepth: &jsonDepthFunctionClass{baseFunctionClass{ast.JSONDepth, 1, 1}}, ast.JSONKeys: &jsonKeysFunctionClass{baseFunctionClass{ast.JSONKeys, 1, 2}}, ast.JSONLength: &jsonLengthFunctionClass{baseFunctionClass{ast.JSONLength, 1, 2}}, + + // TiDB internal function. + // This function is used to show tidb-server version info. + ast.TiDBVersion: &tidbVersionFunctionClass{baseFunctionClass{ast.TiDBVersion, 0, 0}}, + ast.TiDBIsDDLOwner: &tidbIsDDLOwnerFunctionClass{baseFunctionClass{ast.TiDBIsDDLOwner, 0, 0}}, + ast.TiDBParseTso: &tidbParseTsoFunctionClass{baseFunctionClass{ast.TiDBParseTso, 1, 1}}, + ast.TiDBDecodePlan: &tidbDecodePlanFunctionClass{baseFunctionClass{ast.TiDBDecodePlan, 1, 1}}, } diff --git a/expression/builtin_arithmetic.go b/expression/builtin_arithmetic.go index 85b1d757767d8..20bcf0c7f5261 100644 --- a/expression/builtin_arithmetic.go +++ b/expression/builtin_arithmetic.go @@ -90,9 +90,12 @@ func setFlenDecimal4Int(retTp, a, b *types.FieldType) { // setFlenDecimal4RealOrDecimal is called to set proper `Flen` and `Decimal` of return // type according to the two input parameter's types. -func setFlenDecimal4RealOrDecimal(retTp, a, b *types.FieldType, isReal bool) { +func setFlenDecimal4RealOrDecimal(retTp, a, b *types.FieldType, isReal bool, isMultiply bool) { if a.Decimal != types.UnspecifiedLength && b.Decimal != types.UnspecifiedLength { retTp.Decimal = a.Decimal + b.Decimal + if !isMultiply { + retTp.Decimal = mathutil.Max(a.Decimal, b.Decimal) + } if !isReal && retTp.Decimal > mysql.MaxDecimalScale { retTp.Decimal = mysql.MaxDecimalScale } @@ -101,6 +104,9 @@ func setFlenDecimal4RealOrDecimal(retTp, a, b *types.FieldType, isReal bool) { return } digitsInt := mathutil.Max(a.Flen-a.Decimal, b.Flen-b.Decimal) + if isMultiply { + digitsInt = a.Flen - a.Decimal + b.Flen - b.Decimal + } retTp.Flen = digitsInt + retTp.Decimal + 3 if isReal { retTp.Flen = mathutil.Min(retTp.Flen, mysql.MaxRealWidth) @@ -155,13 +161,13 @@ func (c *arithmeticPlusFunctionClass) getFunction(ctx sessionctx.Context, args [ lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETReal, types.ETReal, types.ETReal) - setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true) + setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true, false) sig := &builtinArithmeticPlusRealSig{bf} sig.setPbCode(tipb.ScalarFuncSig_PlusReal) return sig, nil } else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal { bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal) - setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false) + setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false, false) sig := &builtinArithmeticPlusDecimalSig{bf} sig.setPbCode(tipb.ScalarFuncSig_PlusDecimal) return sig, nil @@ -293,13 +299,13 @@ func (c *arithmeticMinusFunctionClass) getFunction(ctx sessionctx.Context, args lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETReal, types.ETReal, types.ETReal) - setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true) + setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true, false) sig := &builtinArithmeticMinusRealSig{bf} sig.setPbCode(tipb.ScalarFuncSig_MinusReal) return sig, nil } else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal { bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal) - setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false) + setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false, false) sig := &builtinArithmeticMinusDecimalSig{bf} sig.setPbCode(tipb.ScalarFuncSig_MinusDecimal) return sig, nil @@ -439,13 +445,13 @@ func (c *arithmeticMultiplyFunctionClass) getFunction(ctx sessionctx.Context, ar lhsEvalTp, rhsEvalTp := numericContextResultType(lhsTp), numericContextResultType(rhsTp) if lhsEvalTp == types.ETReal || rhsEvalTp == types.ETReal { bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETReal, types.ETReal, types.ETReal) - setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true) + setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), true, true) sig := &builtinArithmeticMultiplyRealSig{bf} sig.setPbCode(tipb.ScalarFuncSig_MultiplyReal) return sig, nil } else if lhsEvalTp == types.ETDecimal || rhsEvalTp == types.ETDecimal { bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDecimal, types.ETDecimal, types.ETDecimal) - setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false) + setFlenDecimal4RealOrDecimal(bf.tp, args[0].GetType(), args[1].GetType(), false, true) sig := &builtinArithmeticMultiplyDecimalSig{bf} sig.setPbCode(tipb.ScalarFuncSig_MultiplyDecimal) return sig, nil @@ -762,11 +768,30 @@ func (s *builtinArithmeticIntDivideDecimalSig) evalInt(row chunk.Row) (ret int64 return 0, true, err } - ret, err = c.ToInt() - // err returned by ToInt may be ErrTruncated or ErrOverflow, only handle ErrOverflow, ignore ErrTruncated. - if err == types.ErrOverflow { - return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s DIV %s)", s.args[0].String(), s.args[1].String())) + isLHSUnsigned := mysql.HasUnsignedFlag(s.args[0].GetType().Flag) + isRHSUnsigned := mysql.HasUnsignedFlag(s.args[1].GetType().Flag) + + if isLHSUnsigned || isRHSUnsigned { + val, err := c.ToUint() + // err returned by ToUint may be ErrTruncated or ErrOverflow, only handle ErrOverflow, ignore ErrTruncated. + if err == types.ErrOverflow { + v, err := c.ToInt() + // when the final result is at (-1, 0], it should be return 0 instead of the error + if v == 0 && err == types.ErrTruncated { + ret = int64(0) + return ret, false, nil + } + return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", fmt.Sprintf("(%s DIV %s)", s.args[0].String(), s.args[1].String())) + } + ret = int64(val) + } else { + ret, err = c.ToInt() + // err returned by ToInt may be ErrTruncated or ErrOverflow, only handle ErrOverflow, ignore ErrTruncated. + if err == types.ErrOverflow { + return 0, true, types.ErrOverflow.GenWithStackByArgs("BIGINT", fmt.Sprintf("(%s DIV %s)", s.args[0].String(), s.args[1].String())) + } } + return ret, false, nil } diff --git a/expression/builtin_arithmetic_test.go b/expression/builtin_arithmetic_test.go index 61a47bbc5d334..a1e5afc63180d 100644 --- a/expression/builtin_arithmetic_test.go +++ b/expression/builtin_arithmetic_test.go @@ -37,25 +37,56 @@ func (s *testEvaluatorSuite) TestSetFlenDecimal4RealOrDecimal(c *C) { Decimal: 0, Flen: 2, } - setFlenDecimal4RealOrDecimal(ret, a, b, true) + setFlenDecimal4RealOrDecimal(ret, a, b, true, false) c.Assert(ret.Decimal, Equals, 1) c.Assert(ret.Flen, Equals, 6) b.Flen = 65 - setFlenDecimal4RealOrDecimal(ret, a, b, true) + setFlenDecimal4RealOrDecimal(ret, a, b, true, false) c.Assert(ret.Decimal, Equals, 1) c.Assert(ret.Flen, Equals, mysql.MaxRealWidth) - setFlenDecimal4RealOrDecimal(ret, a, b, false) + setFlenDecimal4RealOrDecimal(ret, a, b, false, false) c.Assert(ret.Decimal, Equals, 1) c.Assert(ret.Flen, Equals, mysql.MaxDecimalWidth) b.Flen = types.UnspecifiedLength - setFlenDecimal4RealOrDecimal(ret, a, b, true) + setFlenDecimal4RealOrDecimal(ret, a, b, true, false) c.Assert(ret.Decimal, Equals, 1) c.Assert(ret.Flen, Equals, types.UnspecifiedLength) b.Decimal = types.UnspecifiedLength - setFlenDecimal4RealOrDecimal(ret, a, b, true) + setFlenDecimal4RealOrDecimal(ret, a, b, true, false) + c.Assert(ret.Decimal, Equals, types.UnspecifiedLength) + c.Assert(ret.Flen, Equals, types.UnspecifiedLength) + + ret = &types.FieldType{} + a = &types.FieldType{ + Decimal: 1, + Flen: 3, + } + b = &types.FieldType{ + Decimal: 0, + Flen: 2, + } + setFlenDecimal4RealOrDecimal(ret, a, b, true, true) + c.Assert(ret.Decimal, Equals, 1) + c.Assert(ret.Flen, Equals, 8) + + b.Flen = 65 + setFlenDecimal4RealOrDecimal(ret, a, b, true, true) + c.Assert(ret.Decimal, Equals, 1) + c.Assert(ret.Flen, Equals, mysql.MaxRealWidth) + setFlenDecimal4RealOrDecimal(ret, a, b, false, true) + c.Assert(ret.Decimal, Equals, 1) + c.Assert(ret.Flen, Equals, mysql.MaxDecimalWidth) + + b.Flen = types.UnspecifiedLength + setFlenDecimal4RealOrDecimal(ret, a, b, true, true) + c.Assert(ret.Decimal, Equals, 1) + c.Assert(ret.Flen, Equals, types.UnspecifiedLength) + + b.Decimal = types.UnspecifiedLength + setFlenDecimal4RealOrDecimal(ret, a, b, true, true) c.Assert(ret.Decimal, Equals, types.UnspecifiedLength) c.Assert(ret.Flen, Equals, types.UnspecifiedLength) } @@ -441,6 +472,14 @@ func (s *testEvaluatorSuite) TestArithmeticIntDivide(c *C) { args: []interface{}{int64(-9223372036854775808), float64(-1)}, expect: []interface{}{nil, "*BIGINT value is out of range in '\\(-9223372036854775808 DIV -1\\)'"}, }, + { + args: []interface{}{uint64(1), float64(-2)}, + expect: []interface{}{0, nil}, + }, + { + args: []interface{}{uint64(1), float64(-1)}, + expect: []interface{}{nil, "*BIGINT UNSIGNED value is out of range in '\\(1 DIV -1\\)'"}, + }, } for _, tc := range testCases { diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 1b8205c009118..ccdbd626e6242 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -437,6 +437,9 @@ func (b *builtinCastIntAsIntSig) Clone() builtinFunc { func (b *builtinCastIntAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { res, isNull, err = b.args[0].EvalInt(b.ctx, row) + if isNull || err != nil { + return + } if b.inUnion && mysql.HasUnsignedFlag(b.tp.Flag) && res < 0 { res = 0 } @@ -463,10 +466,8 @@ func (b *builtinCastIntAsRealSig) evalReal(row chunk.Row) (res float64, isNull b } else if b.inUnion && val < 0 { res = 0 } else { - var uVal uint64 - sc := b.ctx.GetSessionVars().StmtCtx - uVal, err = types.ConvertIntToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) - res = float64(uVal) + // recall that, int to float is different from uint to float + res = float64(uint64(val)) } return res, false, err } @@ -491,13 +492,7 @@ func (b *builtinCastIntAsDecimalSig) evalDecimal(row chunk.Row) (res *types.MyDe } else if b.inUnion && val < 0 { res = &types.MyDecimal{} } else { - var uVal uint64 - sc := b.ctx.GetSessionVars().StmtCtx - uVal, err = types.ConvertIntToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) - if err != nil { - return res, false, err - } - res = types.NewDecFromUint(uVal) + res = types.NewDecFromUint(uint64(val)) } res, err = types.ProduceDecWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx) return res, isNull, err @@ -521,13 +516,7 @@ func (b *builtinCastIntAsStringSig) evalString(row chunk.Row) (res string, isNul if !mysql.HasUnsignedFlag(b.args[0].GetType().Flag) { res = strconv.FormatInt(val, 10) } else { - var uVal uint64 - sc := b.ctx.GetSessionVars().StmtCtx - uVal, err = types.ConvertIntToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) - if err != nil { - return res, false, err - } - res = strconv.FormatUint(uVal, 10) + res = strconv.FormatUint(uint64(val), 10) } res, err = types.ProduceStrWithSpecifiedTp(res, b.tp, b.ctx.GetSessionVars().StmtCtx, false) if err != nil { @@ -748,13 +737,13 @@ func (b *builtinCastRealAsIntSig) evalInt(row chunk.Row) (res int64, isNull bool return res, isNull, err } if !mysql.HasUnsignedFlag(b.tp.Flag) { - res, err = types.ConvertFloatToInt(val, types.IntergerSignedLowerBound(mysql.TypeLonglong), types.IntergerSignedUpperBound(mysql.TypeLonglong), mysql.TypeDouble) + res, err = types.ConvertFloatToInt(val, types.IntergerSignedLowerBound(mysql.TypeLonglong), types.IntergerSignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) } else if b.inUnion && val < 0 { res = 0 } else { var uintVal uint64 sc := b.ctx.GetSessionVars().StmtCtx - uintVal, err = types.ConvertFloatToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeDouble) + uintVal, err = types.ConvertFloatToUint(sc, val, types.IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) res = int64(uintVal) } return res, isNull, err @@ -1786,7 +1775,10 @@ func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression { return expr } tp := types.NewFieldType(mysql.TypeNewDecimal) - tp.Flen, tp.Decimal = expr.GetType().Flen, types.UnspecifiedLength + tp.Flen, tp.Decimal = expr.GetType().Flen, expr.GetType().Decimal + if expr.GetType().EvalType() == types.ETInt { + tp.Flen = mysql.MaxIntWidth + } types.SetBinChsClnFlag(tp) tp.Flag |= expr.GetType().Flag & mysql.UnsignedFlag return BuildCastFunction(ctx, expr, tp) diff --git a/expression/builtin_cast_test.go b/expression/builtin_cast_test.go index 05207acfde41a..c0636e4b34a04 100644 --- a/expression/builtin_cast_test.go +++ b/expression/builtin_cast_test.go @@ -88,7 +88,11 @@ func (s *testEvaluatorSuite) TestCast(c *C) { c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) origSc := sc + oldInSelectStmt := sc.InSelectStmt sc.InSelectStmt = true + defer func() { + sc.InSelectStmt = oldInSelectStmt + }() sc.OverflowAsWarning = true // cast('18446744073709551616' as unsigned); diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index a6e12a6b2474c..6e7cc3a2e4d0b 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -1070,7 +1070,13 @@ func isTemporalColumn(expr Expression) bool { } // tryToConvertConstantInt tries to convert a constant with other type to a int constant. -func tryToConvertConstantInt(ctx sessionctx.Context, isUnsigned bool, con *Constant) (_ *Constant, isAlwaysFalse bool) { +// isExceptional indicates whether the 'int column [cmp] const' might be true/false. +// If isExceptional is true, ExecptionalVal is returned. Or, CorrectVal is returned. +// CorrectVal: The computed result. If the constant can be converted to int without exception, return the val. Else return 'con'(the input). +// ExceptionalVal : It is used to get more information to check whether 'int column [cmp] const' is true/false +// If the op == LT,LE,GT,GE and it gets an Overflow when converting, return inf/-inf. +// If the op == EQ,NullEQ and the constant can never be equal to the int column, return ‘con’(the input, a non-int constant). +func tryToConvertConstantInt(ctx sessionctx.Context, targetFieldType *types.FieldType, con *Constant) (_ *Constant, isExceptional bool) { if con.GetType().EvalType() == types.ETInt { return con, false } @@ -1079,37 +1085,51 @@ func tryToConvertConstantInt(ctx sessionctx.Context, isUnsigned bool, con *Const return con, false } sc := ctx.GetSessionVars().StmtCtx - fieldType := types.NewFieldType(mysql.TypeLonglong) - if isUnsigned { - fieldType.Flag |= mysql.UnsignedFlag - } - dt, err = dt.ConvertTo(sc, fieldType) + + dt, err = dt.ConvertTo(sc, targetFieldType) if err != nil { - return con, terror.ErrorEqual(err, types.ErrOverflow) + if terror.ErrorEqual(err, types.ErrOverflow) { + return &Constant{ + Value: dt, + RetType: targetFieldType, + }, true + } + return con, false } return &Constant{ Value: dt, - RetType: fieldType, + RetType: targetFieldType, DeferredExpr: con.DeferredExpr, }, false } -// RefineComparedConstant changes an non-integer constant argument to its ceiling or floor result by the given op. -// isAlwaysFalse indicates whether the int column "con" is false. -func RefineComparedConstant(ctx sessionctx.Context, isUnsigned bool, con *Constant, op opcode.Op) (_ *Constant, isAlwaysFalse bool) { +// RefineComparedConstant changes a non-integer constant argument to its ceiling or floor result by the given op. +// isExceptional indicates whether the 'int column [cmp] const' might be true/false. +// If isExceptional is true, ExecptionalVal is returned. Or, CorrectVal is returned. +// CorrectVal: The computed result. If the constant can be converted to int without exception, return the val. Else return 'con'(the input). +// ExceptionalVal : It is used to get more information to check whether 'int column [cmp] const' is true/false +// If the op == LT,LE,GT,GE and it gets an Overflow when converting, return inf/-inf. +// If the op == EQ,NullEQ and the constant can never be equal to the int column, return ‘con’(the input, a non-int constant). +func RefineComparedConstant(ctx sessionctx.Context, targetFieldType types.FieldType, con *Constant, op opcode.Op) (_ *Constant, isExceptional bool) { dt, err := con.Eval(chunk.Row{}) if err != nil { return con, false } sc := ctx.GetSessionVars().StmtCtx - intFieldType := types.NewFieldType(mysql.TypeLonglong) - if isUnsigned { - intFieldType.Flag |= mysql.UnsignedFlag + + if targetFieldType.Tp == mysql.TypeBit { + targetFieldType = *types.NewFieldType(mysql.TypeLonglong) } var intDatum types.Datum - intDatum, err = dt.ConvertTo(sc, intFieldType) + intDatum, err = dt.ConvertTo(sc, &targetFieldType) if err != nil { - return con, terror.ErrorEqual(err, types.ErrOverflow) + if terror.ErrorEqual(err, types.ErrOverflow) { + return &Constant{ + Value: intDatum, + RetType: &targetFieldType, + }, true + } + return con, false } c, err := intDatum.CompareDatum(sc, &con.Value) if err != nil { @@ -1118,7 +1138,7 @@ func RefineComparedConstant(ctx sessionctx.Context, isUnsigned bool, con *Consta if c == 0 { return &Constant{ Value: intDatum, - RetType: intFieldType, + RetType: &targetFieldType, DeferredExpr: con.DeferredExpr, }, false } @@ -1126,12 +1146,12 @@ func RefineComparedConstant(ctx sessionctx.Context, isUnsigned bool, con *Consta case opcode.LT, opcode.GE: resultExpr := NewFunctionInternal(ctx, ast.Ceil, types.NewFieldType(mysql.TypeUnspecified), con) if resultCon, ok := resultExpr.(*Constant); ok { - return tryToConvertConstantInt(ctx, isUnsigned, resultCon) + return tryToConvertConstantInt(ctx, &targetFieldType, resultCon) } case opcode.LE, opcode.GT: resultExpr := NewFunctionInternal(ctx, ast.Floor, types.NewFieldType(mysql.TypeUnspecified), con) if resultCon, ok := resultExpr.(*Constant); ok { - return tryToConvertConstantInt(ctx, isUnsigned, resultCon) + return tryToConvertConstantInt(ctx, &targetFieldType, resultCon) } case opcode.NullEQ, opcode.EQ: switch con.RetType.EvalType() { @@ -1159,7 +1179,7 @@ func RefineComparedConstant(ctx sessionctx.Context, isUnsigned bool, con *Consta } return &Constant{ Value: intDatum, - RetType: intFieldType, + RetType: &targetFieldType, DeferredExpr: con.DeferredExpr, }, false } @@ -1175,27 +1195,60 @@ func (c *compareFunctionClass) refineArgs(ctx sessionctx.Context, args []Express arg1IsInt := arg1Type.EvalType() == types.ETInt arg0, arg0IsCon := args[0].(*Constant) arg1, arg1IsCon := args[1].(*Constant) - isAlways, finalArg0, finalArg1 := false, args[0], args[1] + isExceptional, finalArg0, finalArg1 := false, args[0], args[1] + isPositiveInfinite, isNegativeInfinite := false, false // int non-constant [cmp] non-int constant if arg0IsInt && !arg0IsCon && !arg1IsInt && arg1IsCon { - finalArg1, isAlways = RefineComparedConstant(ctx, mysql.HasUnsignedFlag(arg0Type.Flag), arg1, c.op) + arg1, isExceptional = RefineComparedConstant(ctx, *arg0Type, arg1, c.op) + finalArg1 = arg1 + if isExceptional && arg1.RetType.EvalType() == types.ETInt { + // Judge it is inf or -inf + // For int: + // inf: 01111111 & 1 == 1 + // -inf: 10000000 & 1 == 0 + // For uint: + // inf: 11111111 & 1 == 1 + // -inf: 00000000 & 0 == 0 + if arg1.Value.GetInt64()&1 == 1 { + isPositiveInfinite = true + } else { + isNegativeInfinite = true + } + } } // non-int constant [cmp] int non-constant if arg1IsInt && !arg1IsCon && !arg0IsInt && arg0IsCon { - finalArg0, isAlways = RefineComparedConstant(ctx, mysql.HasUnsignedFlag(arg1Type.Flag), arg0, symmetricOp[c.op]) + arg0, isExceptional = RefineComparedConstant(ctx, *arg1Type, arg0, symmetricOp[c.op]) + finalArg0 = arg0 + if isExceptional && arg0.RetType.EvalType() == types.ETInt { + if arg0.Value.GetInt64()&1 == 1 { + isNegativeInfinite = true + } else { + isPositiveInfinite = true + } + } } - if !isAlways { - return []Expression{finalArg0, finalArg1} + + if isExceptional && (c.op == opcode.EQ || c.op == opcode.NullEQ) { + // This will always be false. + return []Expression{Zero.Clone(), One.Clone()} } - switch c.op { - case opcode.LT, opcode.LE: + if isPositiveInfinite { + // If the op is opcode.LT, opcode.LE // This will always be true. + // If the op is opcode.GT, opcode.GE + // This will always be false. return []Expression{Zero.Clone(), One.Clone()} - case opcode.EQ, opcode.NullEQ, opcode.GT, opcode.GE: + } + if isNegativeInfinite { + // If the op is opcode.GT, opcode.GE + // This will always be true. + // If the op is opcode.LT, opcode.LE // This will always be false. return []Expression{One.Clone(), Zero.Clone()} } - return args + + return []Expression{finalArg0, finalArg1} } // getFunction sets compare built-in function signatures for various types. diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index 3aee49ddaec77..ad09320fde511 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -63,6 +63,10 @@ func (s *testEvaluatorSuite) TestCompareFunctionWithRefine(c *C) { {"'1.1' != a", "ne(1.1, cast(a))"}, {"'123456789123456711111189' = a", "0"}, {"123456789123456789.12345 = a", "0"}, + {"123456789123456789123456789.12345 > a", "1"}, + {"-123456789123456789123456789.12345 > a", "0"}, + {"123456789123456789123456789.12345 < a", "0"}, + {"-123456789123456789123456789.12345 < a", "1"}, // This cast can not be eliminated, // since converting "aaaa" to an int will cause DataTruncate error. {"'aaaa'=a", "eq(cast(aaaa), cast(a))"}, diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 015f9705f2844..80e8d5e233e14 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -67,9 +67,8 @@ func InferType4ControlFuncs(lhs, rhs *types.FieldType) *types.FieldType { } else if rhs.Tp == mysql.TypeNull { *resultFieldType = *lhs } else { - var unsignedFlag uint - evalType := types.AggregateEvalType([]*types.FieldType{lhs, rhs}, &unsignedFlag) resultFieldType = types.AggFieldType([]*types.FieldType{lhs, rhs}) + evalType := types.AggregateEvalType([]*types.FieldType{lhs, rhs}, &resultFieldType.Flag) if evalType == types.ETInt { resultFieldType.Decimal = 0 } else { diff --git a/expression/builtin_info.go b/expression/builtin_info.go index 0653afef87847..fa945845d3e82 100644 --- a/expression/builtin_info.go +++ b/expression/builtin_info.go @@ -18,11 +18,14 @@ package expression import ( + "sort" + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/plancodec" "github.com/pingcap/tidb/util/printer" ) @@ -42,6 +45,7 @@ var ( _ functionClass = &rowCountFunctionClass{} _ functionClass = &tidbVersionFunctionClass{} _ functionClass = &tidbIsDDLOwnerFunctionClass{} + _ functionClass = &tidbDecodePlanFunctionClass{} ) var ( @@ -192,8 +196,13 @@ func (b *builtinCurrentRoleSig) evalString(row chunk.Row) (string, bool, error) return "", false, nil } res := "" - for i, r := range data.ActiveRoles { - res += r.String() + sortedRes := make([]string, 0, 10) + for _, r := range data.ActiveRoles { + sortedRes = append(sortedRes, r.String()) + } + sort.Strings(sortedRes) + for i, r := range sortedRes { + res += r if i != len(data.ActiveRoles)-1 { res += "," } @@ -582,3 +591,35 @@ func (b *builtinRowCountSig) evalInt(_ chunk.Row) (res int64, isNull bool, err e res = int64(b.ctx.GetSessionVars().StmtCtx.PrevAffectedRows) return res, false, nil } + +type tidbDecodePlanFunctionClass struct { + baseFunctionClass +} + +func (c *tidbDecodePlanFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString) + sig := &builtinTiDBDecodePlanSig{bf} + return sig, nil +} + +type builtinTiDBDecodePlanSig struct { + baseBuiltinFunc +} + +func (b *builtinTiDBDecodePlanSig) Clone() builtinFunc { + newSig := &builtinTiDBDecodePlanSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinTiDBDecodePlanSig) evalString(row chunk.Row) (string, bool, error) { + planString, isNull, err := b.args[0].EvalString(b.ctx, row) + if isNull || err != nil { + return "", isNull, err + } + planTree, err := plancodec.DecodePlan(planString) + return planTree, false, err +} diff --git a/expression/builtin_info_test.go b/expression/builtin_info_test.go index 0f82dfef2c510..2b50a59f9ffd5 100644 --- a/expression/builtin_info_test.go +++ b/expression/builtin_info_test.go @@ -42,6 +42,7 @@ func (s *testEvaluatorSuite) TestDatabase(c *C) { d, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(d.GetString(), Equals, "test") + c.Assert(f.Clone().PbCode(), Equals, f.PbCode()) // Test case for schema(). fc = funcs[ast.Schema] @@ -51,6 +52,7 @@ func (s *testEvaluatorSuite) TestDatabase(c *C) { d, err = evalBuiltinFunc(f, chunk.MutRowFromDatums(types.MakeDatums()).ToRow()) c.Assert(err, IsNil) c.Assert(d.GetString(), Equals, "test") + c.Assert(f.Clone().PbCode(), Equals, f.PbCode()) } func (s *testEvaluatorSuite) TestFoundRows(c *C) { @@ -79,6 +81,7 @@ func (s *testEvaluatorSuite) TestUser(c *C) { d, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(d.GetString(), Equals, "root@localhost") + c.Assert(f.Clone().PbCode(), Equals, f.PbCode()) } func (s *testEvaluatorSuite) TestCurrentUser(c *C) { @@ -93,6 +96,7 @@ func (s *testEvaluatorSuite) TestCurrentUser(c *C) { d, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(d.GetString(), Equals, "root@localhost") + c.Assert(f.Clone().PbCode(), Equals, f.PbCode()) } func (s *testEvaluatorSuite) TestCurrentRole(c *C) { @@ -109,6 +113,7 @@ func (s *testEvaluatorSuite) TestCurrentRole(c *C) { d, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(d.GetString(), Equals, "`r_1`@`%`,`r_2`@`localhost`") + c.Assert(f.Clone().PbCode(), Equals, f.PbCode()) } func (s *testEvaluatorSuite) TestConnectionID(c *C) { @@ -123,6 +128,7 @@ func (s *testEvaluatorSuite) TestConnectionID(c *C) { d, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(d.GetUint64(), Equals, uint64(1)) + c.Assert(f.Clone().PbCode(), Equals, f.PbCode()) } func (s *testEvaluatorSuite) TestVersion(c *C) { @@ -133,6 +139,7 @@ func (s *testEvaluatorSuite) TestVersion(c *C) { v, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(v.GetString(), Equals, mysql.ServerVersion) + c.Assert(f.Clone().PbCode(), Equals, f.PbCode()) } func (s *testEvaluatorSuite) TestBenchMark(c *C) { @@ -213,6 +220,7 @@ func (s *testEvaluatorSuite) TestRowCount(c *C) { c.Assert(err, IsNil) c.Assert(isNull, IsFalse) c.Assert(intResult, Equals, int64(10)) + c.Assert(f.Clone().PbCode(), Equals, f.PbCode()) } // TestTiDBVersion for tidb_server(). diff --git a/expression/builtin_json.go b/expression/builtin_json.go index 462ea88ed0d6a..32385b06b0210 100644 --- a/expression/builtin_json.go +++ b/expression/builtin_json.go @@ -14,6 +14,7 @@ package expression import ( + json2 "encoding/json" "strings" "github.com/pingcap/errors" @@ -22,6 +23,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tipb/go-tipb" ) @@ -71,6 +73,9 @@ var ( _ builtinFunc = &builtinJSONKeysSig{} _ builtinFunc = &builtinJSONKeys2ArgsSig{} _ builtinFunc = &builtinJSONLengthSig{} + _ builtinFunc = &builtinJSONValidJSONSig{} + _ builtinFunc = &builtinJSONValidStringSig{} + _ builtinFunc = &builtinJSONValidOthersSig{} ) type jsonTypeFunctionClass struct { @@ -715,7 +720,87 @@ type jsonValidFunctionClass struct { } func (c *jsonValidFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", "JSON_VALID") + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + var sig builtinFunc + argType := args[0].GetType().EvalType() + switch argType { + case types.ETJson: + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETJson) + sig = &builtinJSONValidJSONSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonValidJsonSig) + case types.ETString: + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETString) + sig = &builtinJSONValidStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonValidStringSig) + default: + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, argType) + sig = &builtinJSONValidOthersSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonValidOthersSig) + } + return sig, nil +} + +type builtinJSONValidJSONSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONValidJSONSig) Clone() builtinFunc { + newSig := &builtinJSONValidJSONSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinJSONValidJSONSig. +// See https://dev.mysql.com/doc/refman/5.7/en/json-attribute-functions.html#function_json-valid +func (b *builtinJSONValidJSONSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { + _, isNull, err = b.args[0].EvalJSON(b.ctx, row) + return 1, isNull, err +} + +type builtinJSONValidStringSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONValidStringSig) Clone() builtinFunc { + newSig := &builtinJSONValidStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinJSONValidStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/json-attribute-functions.html#function_json-valid +func (b *builtinJSONValidStringSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { + val, isNull, err := b.args[0].EvalString(b.ctx, row) + if err != nil || isNull { + return 0, isNull, err + } + + data := hack.Slice(val) + if json2.Valid(data) { + res = 1 + } else { + res = 0 + } + return res, false, nil +} + +type builtinJSONValidOthersSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONValidOthersSig) Clone() builtinFunc { + newSig := &builtinJSONValidOthersSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinJSONValidOthersSig. +// See https://dev.mysql.com/doc/refman/5.7/en/json-attribute-functions.html#function_json-valid +func (b *builtinJSONValidOthersSig) evalInt(row chunk.Row) (res int64, isNull bool, err error) { + return 0, false, nil } type jsonArrayAppendFunctionClass struct { diff --git a/expression/builtin_json_test.go b/expression/builtin_json_test.go index dba46f94b3afb..ab294b6b1d4f8 100644 --- a/expression/builtin_json_test.go +++ b/expression/builtin_json_test.go @@ -95,8 +95,10 @@ func (s *testEvaluatorSuite) TestJSONUnquote(c *C) { {`{"a": "b"}`, `{"a": "b"}`}, {`"hello,\"quoted string\",world"`, `hello,"quoted string",world`}, {`"hello,\"宽字符\",world"`, `hello,"宽字符",world`}, - {`Invalid Json string\tis OK`, `Invalid Json string is OK`}, + {`Invalid Json string\tis OK`, `Invalid Json string\tis OK`}, {`"1\\u2232\\u22322"`, `1\u2232\u22322`}, + {`"[{\"x\":\"{\\\"y\\\":12}\"}]"`, `[{"x":"{\"y\":12}"}]`}, + {`[{\"x\":\"{\\\"y\\\":12}\"}]`, `[{\"x\":\"{\\\"y\\\":12}\"}]`}, } dtbl := tblToDtbl(tbl) for _, t := range dtbl { @@ -865,3 +867,34 @@ func (s *testEvaluatorSuite) TestJSONSearch(c *C) { } } } + +func (s *testEvaluatorSuite) TestJSONValid(c *C) { + defer testleak.AfterTest(c)() + fc := funcs[ast.JSONValid] + tbl := []struct { + Input interface{} + Expected interface{} + }{ + {`{"a":1}`, 1}, + {`hello`, 0}, + {`"hello"`, 1}, + {`null`, 1}, + {`{}`, 1}, + {`[]`, 1}, + {`2`, 1}, + {`2.5`, 1}, + {`2019-8-19`, 0}, + {`"2019-8-19"`, 1}, + {2, 0}, + {2.5, 0}, + {nil, nil}, + } + dtbl := tblToDtbl(tbl) + for _, t := range dtbl { + f, err := fc.getFunction(s.ctx, s.datumsToConstants(t["Input"])) + c.Assert(err, IsNil) + d, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(d, testutil.DatumEquals, t["Expected"][0]) + } +} diff --git a/expression/builtin_like.go b/expression/builtin_like.go index f64df725d6129..093dfaf2abd97 100644 --- a/expression/builtin_like.go +++ b/expression/builtin_like.go @@ -95,7 +95,7 @@ func (c *regexpFunctionClass) getFunction(ctx sessionctx.Context, args []Express bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETString, types.ETString) bf.tp.Flen = 1 var sig builtinFunc - if types.IsBinaryStr(args[0].GetType()) { + if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[1].GetType()) { sig = &builtinRegexpBinarySig{bf} } else { sig = &builtinRegexpSig{bf} diff --git a/expression/builtin_math.go b/expression/builtin_math.go index 4ec7a4accb4c5..88d7c9c171170 100644 --- a/expression/builtin_math.go +++ b/expression/builtin_math.go @@ -24,6 +24,7 @@ import ( "math/rand" "strconv" "strings" + "sync" "time" "github.com/cznic/mathutil" @@ -966,7 +967,7 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio bt := bf if len(args) == 0 { seed := time.Now().UnixNano() - sig = &builtinRandSig{bt, rand.New(rand.NewSource(seed))} + sig = &builtinRandSig{bt, &sync.Mutex{}, rand.New(rand.NewSource(seed))} } else if _, isConstant := args[0].(*Constant); isConstant { // According to MySQL manual: // If an integer argument N is specified, it is used as the seed value: @@ -979,7 +980,7 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio if isNull { seed = time.Now().UnixNano() } - sig = &builtinRandSig{bt, rand.New(rand.NewSource(seed))} + sig = &builtinRandSig{bt, &sync.Mutex{}, rand.New(rand.NewSource(seed))} } else { sig = &builtinRandWithSeedSig{bt} } @@ -988,11 +989,12 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio type builtinRandSig struct { baseBuiltinFunc + mu *sync.Mutex randGen *rand.Rand } func (b *builtinRandSig) Clone() builtinFunc { - newSig := &builtinRandSig{randGen: b.randGen} + newSig := &builtinRandSig{randGen: b.randGen, mu: b.mu} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -1000,7 +1002,10 @@ func (b *builtinRandSig) Clone() builtinFunc { // evalReal evals RAND(). // See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_rand func (b *builtinRandSig) evalReal(row chunk.Row) (float64, bool, error) { - return b.randGen.Float64(), false, nil + b.mu.Lock() + res := b.randGen.Float64() + b.mu.Unlock() + return res, false, nil } type builtinRandWithSeedSig struct { diff --git a/expression/builtin_miscellaneous.go b/expression/builtin_miscellaneous.go index 12c896df433ea..70c52503b52b0 100644 --- a/expression/builtin_miscellaneous.go +++ b/expression/builtin_miscellaneous.go @@ -19,14 +19,15 @@ import ( "math" "net" "strings" + "sync/atomic" "time" + "github.com/google/uuid" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" - "github.com/twinj/uuid" ) var ( @@ -113,6 +114,7 @@ func (b *builtinSleepSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, isNull, err } + sessVars := b.ctx.GetSessionVars() if isNull { if sessVars.StrictSQLMode { @@ -131,12 +133,25 @@ func (b *builtinSleepSig) evalInt(row chunk.Row) (int64, bool, error) { if val > math.MaxFloat64/float64(time.Second.Nanoseconds()) { return 0, false, errIncorrectArgs.GenWithStackByArgs("sleep") } + dur := time.Duration(val * float64(time.Second.Nanoseconds())) - select { - case <-time.After(dur): - // TODO: Handle Ctrl-C is pressed in `mysql` client. - // return 1 when SLEEP() is KILLed + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + start := time.Now() + finish := false + for !finish { + select { + case now := <-ticker.C: + if now.Sub(start) > dur { + finish = true + } + default: + if atomic.CompareAndSwapUint32(&sessVars.Killed, 1, 0) { + return 1, false, nil + } + } } + return 0, false, nil } @@ -980,7 +995,13 @@ func (b *builtinUUIDSig) Clone() builtinFunc { // evalString evals a builtinUUIDSig. // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_uuid func (b *builtinUUIDSig) evalString(_ chunk.Row) (d string, isNull bool, err error) { - return uuid.NewV1().String(), false, nil + var id uuid.UUID + id, err = uuid.NewUUID() + if err != nil { + return + } + d = id.String() + return } type uuidShortFunctionClass struct { diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 3b27a09aa7853..cf6dc0a71a857 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -17,6 +17,7 @@ import ( "fmt" "math" + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" "github.com/pingcap/tidb/sessionctx" @@ -64,6 +65,15 @@ func (c *logicAndFunctionClass) getFunction(ctx sessionctx.Context, args []Expre if err != nil { return nil, err } + args[0], err = wrapWithIsTrue(ctx, true, args[0]) + if err != nil { + return nil, errors.Trace(err) + } + args[1], err = wrapWithIsTrue(ctx, true, args[1]) + if err != nil { + return nil, errors.Trace(err) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) sig := &builtinLogicAndSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LogicalAnd) @@ -105,6 +115,15 @@ func (c *logicOrFunctionClass) getFunction(ctx sessionctx.Context, args []Expres if err != nil { return nil, err } + args[0], err = wrapWithIsTrue(ctx, true, args[0]) + if err != nil { + return nil, errors.Trace(err) + } + args[1], err = wrapWithIsTrue(ctx, true, args[1]) + if err != nil { + return nil, errors.Trace(err) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) bf.tp.Flen = 1 sig := &builtinLogicOrSig{bf} @@ -152,6 +171,7 @@ func (c *logicXorFunctionClass) getFunction(ctx sessionctx.Context, args []Expre if err != nil { return nil, err } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) sig := &builtinLogicXorSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LogicalXor) @@ -375,6 +395,11 @@ func (b *builtinRightShiftSig) evalInt(row chunk.Row) (int64, bool, error) { type isTrueOrFalseFunctionClass struct { baseFunctionClass op opcode.Op + + // keepNull indicates how this function treats a null input parameter. + // If keepNull is true and the input parameter is null, the function will return null. + // If keepNull is false, the null input parameter will be cast to 0. + keepNull bool } func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { @@ -395,25 +420,25 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] case opcode.IsTruth: switch argTp { case types.ETReal: - sig = &builtinRealIsTrueSig{bf} + sig = &builtinRealIsTrueSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_RealIsTrue) case types.ETDecimal: - sig = &builtinDecimalIsTrueSig{bf} + sig = &builtinDecimalIsTrueSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_DecimalIsTrue) case types.ETInt: - sig = &builtinIntIsTrueSig{bf} + sig = &builtinIntIsTrueSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue) } case opcode.IsFalsity: switch argTp { case types.ETReal: - sig = &builtinRealIsFalseSig{bf} + sig = &builtinRealIsFalseSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_RealIsFalse) case types.ETDecimal: - sig = &builtinDecimalIsFalseSig{bf} + sig = &builtinDecimalIsFalseSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_DecimalIsFalse) case types.ETInt: - sig = &builtinIntIsFalseSig{bf} + sig = &builtinIntIsFalseSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse) } } @@ -422,10 +447,11 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] type builtinRealIsTrueSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinRealIsTrueSig) Clone() builtinFunc { - newSig := &builtinRealIsTrueSig{} + newSig := &builtinRealIsTrueSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -435,6 +461,9 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input == 0 { return 0, false, nil } @@ -443,10 +472,11 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinDecimalIsTrueSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinDecimalIsTrueSig) Clone() builtinFunc { - newSig := &builtinDecimalIsTrueSig{} + newSig := &builtinDecimalIsTrueSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -456,6 +486,9 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input.IsZero() { return 0, false, nil } @@ -464,10 +497,11 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinIntIsTrueSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinIntIsTrueSig) Clone() builtinFunc { - newSig := &builtinIntIsTrueSig{} + newSig := &builtinIntIsTrueSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -477,6 +511,9 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input == 0 { return 0, false, nil } @@ -485,10 +522,11 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinRealIsFalseSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinRealIsFalseSig) Clone() builtinFunc { - newSig := &builtinRealIsFalseSig{} + newSig := &builtinRealIsFalseSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -498,6 +536,9 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input != 0 { return 0, false, nil } @@ -506,10 +547,11 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinDecimalIsFalseSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinDecimalIsFalseSig) Clone() builtinFunc { - newSig := &builtinDecimalIsFalseSig{} + newSig := &builtinDecimalIsFalseSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -519,6 +561,9 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || !input.IsZero() { return 0, false, nil } @@ -527,10 +572,11 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinIntIsFalseSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinIntIsFalseSig) Clone() builtinFunc { - newSig := &builtinIntIsFalseSig{} + newSig := &builtinIntIsFalseSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -540,6 +586,9 @@ func (b *builtinIntIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input != 0 { return 0, false, nil } @@ -642,16 +691,15 @@ func (c *unaryMinusFunctionClass) handleIntOverflow(arg *Constant) (overflow boo // typeInfer infers unaryMinus function return type. when the arg is an int constant and overflow, // typerInfer will infers the return type as types.ETDecimal, not types.ETInt. -func (c *unaryMinusFunctionClass) typeInfer(ctx sessionctx.Context, argExpr Expression) (types.EvalType, bool) { +func (c *unaryMinusFunctionClass) typeInfer(argExpr Expression) (types.EvalType, bool) { tp := argExpr.GetType().EvalType() if tp != types.ETInt && tp != types.ETDecimal { tp = types.ETReal } - sc := ctx.GetSessionVars().StmtCtx overflow := false // TODO: Handle float overflow. - if arg, ok := argExpr.(*Constant); sc.InSelectStmt && ok && tp == types.ETInt { + if arg, ok := argExpr.(*Constant); ok && tp == types.ETInt { overflow = c.handleIntOverflow(arg) if overflow { tp = types.ETDecimal @@ -666,7 +714,7 @@ func (c *unaryMinusFunctionClass) getFunction(ctx sessionctx.Context, args []Exp } argExpr, argExprTp := args[0], args[0].GetType() - _, intOverflow := c.typeInfer(ctx, argExpr) + _, intOverflow := c.typeInfer(argExpr) var bf baseBuiltinFunc switch argExprTp.EvalType() { diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index b2f700cb7ca57..a45d488dfff95 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -86,11 +86,21 @@ func (s *testEvaluatorSuite) TestLogicAnd(c *C) { {[]interface{}{0, 1}, 0, false, false}, {[]interface{}{0, 0}, 0, false, false}, {[]interface{}{2, -1}, 1, false, false}, + {[]interface{}{"a", "0"}, 0, false, false}, {[]interface{}{"a", "1"}, 0, false, false}, + {[]interface{}{"1a", "0"}, 0, false, false}, {[]interface{}{"1a", "1"}, 1, false, false}, {[]interface{}{0, nil}, 0, false, false}, {[]interface{}{nil, 0}, 0, false, false}, {[]interface{}{nil, 1}, 0, true, false}, + {[]interface{}{0.001, 0}, 0, false, false}, + {[]interface{}{0.001, 1}, 1, false, false}, + {[]interface{}{nil, 0.000}, 0, false, false}, + {[]interface{}{nil, 0.001}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false}, {[]interface{}{errors.New("must error"), 1}, 0, false, true}, } @@ -300,11 +310,26 @@ func (s *testEvaluatorSuite) TestLogicOr(c *C) { {[]interface{}{0, 1}, 1, false, false}, {[]interface{}{0, 0}, 0, false, false}, {[]interface{}{2, -1}, 1, false, false}, + {[]interface{}{"a", "0"}, 0, false, false}, {[]interface{}{"a", "1"}, 1, false, false}, + {[]interface{}{"1a", "0"}, 1, false, false}, {[]interface{}{"1a", "1"}, 1, false, false}, + // casting string to real depends on #10498, which will not be cherry-picked. + // {[]interface{}{"0.0a", 0}, 0, false, false}, + // {[]interface{}{"0.0001a", 0}, 1, false, false}, {[]interface{}{1, nil}, 1, false, false}, {[]interface{}{nil, 1}, 1, false, false}, {[]interface{}{nil, 0}, 0, true, false}, + {[]interface{}{0.000, 0}, 0, false, false}, + {[]interface{}{0.001, 0}, 1, false, false}, + {[]interface{}{nil, 0.000}, 0, true, false}, + {[]interface{}{nil, 0.001}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), 0}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 1, false, false}, {[]interface{}{errors.New("must error"), 1}, 0, false, true}, } @@ -541,3 +566,68 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) { c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse)) } } + +func (s *testEvaluatorSuite) TestLogicXor(c *C) { + defer testleak.AfterTest(c)() + + sc := s.ctx.GetSessionVars().StmtCtx + origin := sc.IgnoreTruncate + defer func() { + sc.IgnoreTruncate = origin + }() + sc.IgnoreTruncate = true + + cases := []struct { + args []interface{} + expected int64 + isNil bool + getErr bool + }{ + {[]interface{}{1, 1}, 0, false, false}, + {[]interface{}{1, 0}, 1, false, false}, + {[]interface{}{0, 1}, 1, false, false}, + {[]interface{}{0, 0}, 0, false, false}, + {[]interface{}{2, -1}, 0, false, false}, + {[]interface{}{"a", "0"}, 0, false, false}, + {[]interface{}{"a", "1"}, 1, false, false}, + {[]interface{}{"1a", "0"}, 1, false, false}, + {[]interface{}{"1a", "1"}, 0, false, false}, + {[]interface{}{0, nil}, 0, true, false}, + {[]interface{}{nil, 0}, 0, true, false}, + {[]interface{}{nil, 1}, 0, true, false}, + {[]interface{}{0.5000, 0.4999}, 1, false, false}, + {[]interface{}{0.5000, 1.0}, 0, false, false}, + {[]interface{}{0.4999, 1.0}, 1, false, false}, + {[]interface{}{nil, 0.000}, 0, true, false}, + {[]interface{}{nil, 0.001}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 0.00001}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false}, + + {[]interface{}{errors.New("must error"), 1}, 0, false, true}, + } + + for _, t := range cases { + f, err := newFunctionForTest(s.ctx, ast.LogicXor, s.primitiveValsToConstants(t.args)...) + c.Assert(err, IsNil) + d, err := f.Eval(chunk.Row{}) + if t.getErr { + c.Assert(err, NotNil) + } else { + c.Assert(err, IsNil) + if t.isNil { + c.Assert(d.Kind(), Equals, types.KindNull) + } else { + c.Assert(d.GetInt64(), Equals, t.expected) + } + } + } + + // Test incorrect parameter count. + _, err := newFunctionForTest(s.ctx, ast.LogicXor, Zero) + c.Assert(err, NotNil) + + _, err = funcs[ast.LogicXor].getFunction(s.ctx, []Expression{Zero, Zero}) + c.Assert(err, IsNil) +} diff --git a/expression/builtin_other.go b/expression/builtin_other.go index 4b8ff8d79c343..2be110ea0b3a8 100644 --- a/expression/builtin_other.go +++ b/expression/builtin_other.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tipb/go-tipb" ) @@ -420,7 +421,7 @@ func (b *builtinSetVarSig) evalString(row chunk.Row) (res string, isNull bool, e } varName = strings.ToLower(varName) sessionVars.UsersLock.Lock() - sessionVars.Users[varName] = res + sessionVars.Users[varName] = stringutil.Copy(res) sessionVars.UsersLock.Unlock() return res, false, nil } diff --git a/expression/builtin_other_test.go b/expression/builtin_other_test.go index 713965ce81f03..cd9660e13a7d5 100644 --- a/expression/builtin_other_test.go +++ b/expression/builtin_other_test.go @@ -234,3 +234,46 @@ func (s *testEvaluatorSuite) TestValues(c *C) { c.Assert(err, IsNil) c.Assert(cmp, Equals, 0) } + +func (s *testEvaluatorSuite) TestSetVarFromColumn(c *C) { + defer testleak.AfterTest(c)() + + // Construct arguments. + argVarName := &Constant{ + Value: types.NewStringDatum("a"), + RetType: &types.FieldType{Tp: mysql.TypeVarString, Flen: 20}, + } + argCol := &Column{ + RetType: &types.FieldType{Tp: mysql.TypeVarString, Flen: 20}, + Index: 0, + } + + // Construct SetVar function. + funcSetVar, err := NewFunction( + s.ctx, + ast.SetVar, + &types.FieldType{Tp: mysql.TypeVarString, Flen: 20}, + []Expression{argVarName, argCol}..., + ) + c.Assert(err, IsNil) + + // Construct input and output Chunks. + inputChunk := chunk.NewChunkWithCapacity([]*types.FieldType{argCol.RetType}, 1) + inputChunk.AppendString(0, "a") + outputChunk := chunk.NewChunkWithCapacity([]*types.FieldType{argCol.RetType}, 1) + + // Evaluate the SetVar function. + err = evalOneCell(s.ctx, funcSetVar, inputChunk.GetRow(0), outputChunk, 0) + c.Assert(err, IsNil) + c.Assert(outputChunk.GetRow(0).GetString(0), Equals, "a") + + // Change the content of the underlying Chunk. + inputChunk.Reset() + inputChunk.AppendString(0, "b") + + // Check whether the user variable changed. + sessionVars := s.ctx.GetSessionVars() + sessionVars.UsersLock.RLock() + defer sessionVars.UsersLock.RUnlock() + c.Assert(sessionVars.Users["a"], Equals, "a") +} diff --git a/expression/builtin_string.go b/expression/builtin_string.go index c0bcb9dfcd87f..862385f8eaa3d 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -273,17 +273,26 @@ func (c *concatFunctionClass) getFunction(ctx sessionctx.Context, args []Express if bf.tp.Flen >= mysql.MaxBlobWidth { bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinConcatSig{bf} + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, err + } + + sig := &builtinConcatSig{bf, maxAllowedPacket} return sig, nil } type builtinConcatSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinConcatSig) Clone() builtinFunc { newSig := &builtinConcatSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -296,6 +305,10 @@ func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err if isNull || err != nil { return d, isNull, err } + if uint64(len(s)+len(d)) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat", b.maxAllowedPacket)) + return "", true, nil + } s = append(s, []byte(d)...) } return string(s), false, nil @@ -338,17 +351,25 @@ func (c *concatWSFunctionClass) getFunction(ctx sessionctx.Context, args []Expre bf.tp.Flen = mysql.MaxBlobWidth } - sig := &builtinConcatWSSig{bf} + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, err + } + + sig := &builtinConcatWSSig{bf, maxAllowedPacket} return sig, nil } type builtinConcatWSSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinConcatWSSig) Clone() builtinFunc { newSig := &builtinConcatWSSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -358,25 +379,35 @@ func (b *builtinConcatWSSig) evalString(row chunk.Row) (string, bool, error) { args := b.getArgs() strs := make([]string, 0, len(args)) var sep string - for i, arg := range args { - val, isNull, err := arg.EvalString(b.ctx, row) + var targetLength int + + N := len(args) + if N > 0 { + val, isNull, err := args[0].EvalString(b.ctx, row) + if err != nil || isNull { + // If the separator is NULL, the result is NULL. + return val, isNull, err + } + sep = val + } + for i := 1; i < N; i++ { + val, isNull, err := args[i].EvalString(b.ctx, row) if err != nil { return val, isNull, err } - if isNull { - // If the separator is NULL, the result is NULL. - if i == 0 { - return val, isNull, nil - } // CONCAT_WS() does not skip empty strings. However, // it does skip any NULL values after the separator argument. continue } - if i == 0 { - sep = val - continue + targetLength += len(val) + if i > 1 { + targetLength += len(sep) + } + if uint64(targetLength) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenWithStackByArgs("concat_ws", b.maxAllowedPacket)) + return "", true, nil } strs = append(strs, val) } @@ -2665,10 +2696,18 @@ func (b *builtinQuoteSig) Clone() builtinFunc { // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_quote func (b *builtinQuoteSig) evalString(row chunk.Row) (string, bool, error) { str, isNull, err := b.args[0].EvalString(b.ctx, row) - if isNull || err != nil { + if err != nil { return "", true, err + } else if isNull { + // If the argument is NULL, the return value is the word "NULL" without enclosing single quotation marks. see ref. + return "NULL", false, err } + return Quote(str), false, nil +} + +// Quote produce a result that can be used as a properly escaped data value in an SQL statement. +func Quote(str string) string { runes := []rune(str) buffer := bytes.NewBufferString("") buffer.WriteRune('\'') @@ -2689,7 +2728,7 @@ func (b *builtinQuoteSig) evalString(row chunk.Row) (string, bool, error) { } buffer.WriteRune('\'') - return buffer.String(), false, nil + return buffer.String() } type binFunctionClass struct { @@ -3326,15 +3365,11 @@ func (b *builtinInsertBinarySig) evalString(row chunk.Row) (string, bool, error) if isNull || err != nil { return "", true, err } - strLength := int64(len(str)) pos, isNull, err := b.args[1].EvalInt(b.ctx, row) if isNull || err != nil { return "", true, err } - if pos < 1 || pos > strLength { - return str, false, nil - } length, isNull, err := b.args[2].EvalInt(b.ctx, row) if isNull || err != nil { @@ -3346,6 +3381,10 @@ func (b *builtinInsertBinarySig) evalString(row chunk.Row) (string, bool, error) return "", true, err } + strLength := int64(len(str)) + if pos < 1 || pos > strLength { + return str, false, nil + } if length > strLength-pos+1 || length < 0 { length = strLength - pos + 1 } @@ -3377,16 +3416,11 @@ func (b *builtinInsertSig) evalString(row chunk.Row) (string, bool, error) { if isNull || err != nil { return "", true, err } - runes := []rune(str) - runeLength := int64(len(runes)) pos, isNull, err := b.args[1].EvalInt(b.ctx, row) if isNull || err != nil { return "", true, err } - if pos < 1 || pos > runeLength { - return str, false, nil - } length, isNull, err := b.args[2].EvalInt(b.ctx, row) if isNull || err != nil { @@ -3398,6 +3432,11 @@ func (b *builtinInsertSig) evalString(row chunk.Row) (string, bool, error) { return "", true, err } + runes := []rune(str) + runeLength := int64(len(runes)) + if pos < 1 || pos > runeLength { + return str, false, nil + } if length > runeLength-pos+1 || length < 0 { length = runeLength - pos + 1 } diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index c4fdf1cc5e2b6..0272e2b7d8c20 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -171,6 +171,50 @@ func (s *testEvaluatorSuite) TestConcat(c *C) { } } +func (s *testEvaluatorSuite) TestConcatSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeVarchar}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + } + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + concat := &builtinConcatSig{base, 5} + + cases := []struct { + args []interface{} + warnings int + res string + }{ + {[]interface{}{"a", "b"}, 0, "ab"}, + {[]interface{}{"aaa", "bbb"}, 1, ""}, + {[]interface{}{"中", "a"}, 0, "中a"}, + {[]interface{}{"中文", "a"}, 2, ""}, + } + + for _, t := range cases { + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, t.args[0].(string)) + input.AppendString(1, t.args[1].(string)) + + res, isNull, err := concat.evalString(input.GetRow(0)) + c.Assert(res, Equals, t.res) + c.Assert(err, IsNil) + if t.warnings == 0 { + c.Assert(isNull, IsFalse) + } else { + c.Assert(isNull, IsTrue) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(warnings, HasLen, t.warnings) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + } + } +} + func (s *testEvaluatorSuite) TestConcatWS(c *C) { defer testleak.AfterTest(c)() cases := []struct { @@ -246,6 +290,53 @@ func (s *testEvaluatorSuite) TestConcatWS(c *C) { c.Assert(err, IsNil) } +func (s *testEvaluatorSuite) TestConcatWSSig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeVarchar}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 1000} + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + &Column{Index: 2, RetType: colTypes[2]}, + } + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + concat := &builtinConcatWSSig{base, 6} + + cases := []struct { + args []interface{} + warnings int + res string + }{ + {[]interface{}{",", "a", "b"}, 0, "a,b"}, + {[]interface{}{",", "aaa", "bbb"}, 1, ""}, + {[]interface{}{",", "中", "a"}, 0, "中,a"}, + {[]interface{}{",", "中文", "a"}, 2, ""}, + } + + for _, t := range cases { + input := chunk.NewChunkWithCapacity(colTypes, 10) + input.AppendString(0, t.args[0].(string)) + input.AppendString(1, t.args[1].(string)) + input.AppendString(2, t.args[2].(string)) + + res, isNull, err := concat.evalString(input.GetRow(0)) + c.Assert(res, Equals, t.res) + c.Assert(err, IsNil) + if t.warnings == 0 { + c.Assert(isNull, IsFalse) + } else { + c.Assert(isNull, IsTrue) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(warnings, HasLen, t.warnings) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) + } + } +} + func (s *testEvaluatorSuite) TestLeft(c *C) { defer testleak.AfterTest(c)() stmtCtx := s.ctx.GetSessionVars().StmtCtx @@ -1474,12 +1565,32 @@ func (s *testEvaluatorSuite) TestInsertBinarySig(c *C) { input := chunk.NewChunkWithCapacity(colTypes, 2) input.AppendString(0, "abc") input.AppendString(0, "abc") + input.AppendString(0, "abc") + input.AppendNull(0) + input.AppendString(0, "abc") + input.AppendString(0, "abc") + input.AppendString(0, "abc") input.AppendInt64(1, 3) input.AppendInt64(1, 3) + input.AppendInt64(1, 0) + input.AppendInt64(1, 3) + input.AppendNull(1) + input.AppendInt64(1, 3) + input.AppendInt64(1, 3) + input.AppendInt64(2, -1) + input.AppendInt64(2, -1) + input.AppendInt64(2, -1) input.AppendInt64(2, -1) input.AppendInt64(2, -1) + input.AppendNull(2) + input.AppendInt64(2, -1) input.AppendString(3, "d") input.AppendString(3, "de") + input.AppendString(3, "d") + input.AppendString(3, "d") + input.AppendString(3, "d") + input.AppendString(3, "d") + input.AppendNull(3) res, isNull, err := insert.evalString(input.GetRow(0)) c.Assert(res, Equals, "abd") @@ -1491,6 +1602,31 @@ func (s *testEvaluatorSuite) TestInsertBinarySig(c *C) { c.Assert(isNull, IsTrue) c.Assert(err, IsNil) + res, isNull, err = insert.evalString(input.GetRow(2)) + c.Assert(res, Equals, "abc") + c.Assert(isNull, IsFalse) + c.Assert(err, IsNil) + + res, isNull, err = insert.evalString(input.GetRow(3)) + c.Assert(res, Equals, "") + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + + res, isNull, err = insert.evalString(input.GetRow(4)) + c.Assert(res, Equals, "") + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + + res, isNull, err = insert.evalString(input.GetRow(5)) + c.Assert(res, Equals, "") + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + + res, isNull, err = insert.evalString(input.GetRow(6)) + c.Assert(res, Equals, "") + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() c.Assert(len(warnings), Equals, 1) lastWarn := warnings[len(warnings)-1] @@ -1855,6 +1991,8 @@ func (s *testEvaluatorSuite) TestInsert(c *C) { {[]interface{}{"Quadratic", 3, 4, nil}, nil}, {[]interface{}{"Quadratic", 3, -1, "What"}, "QuWhat"}, {[]interface{}{"Quadratic", 3, 1, "What"}, "QuWhatdratic"}, + {[]interface{}{"Quadratic", -1, nil, "What"}, nil}, + {[]interface{}{"Quadratic", -1, 4, nil}, nil}, {[]interface{}{"我叫小雨呀", 3, 2, "王雨叶"}, "我叫王雨叶呀"}, {[]interface{}{"我叫小雨呀", -1, 2, "王雨叶"}, "我叫小雨呀"}, @@ -1865,6 +2003,8 @@ func (s *testEvaluatorSuite) TestInsert(c *C) { {[]interface{}{"我叫小雨呀", 3, 4, nil}, nil}, {[]interface{}{"我叫小雨呀", 3, -1, "王雨叶"}, "我叫王雨叶"}, {[]interface{}{"我叫小雨呀", 3, 1, "王雨叶"}, "我叫王雨叶雨呀"}, + {[]interface{}{"我叫小雨呀", -1, nil, "王雨叶"}, nil}, + {[]interface{}{"我叫小雨呀", -1, 2, nil}, nil}, } fc := funcs[ast.InsertFunc] for _, test := range tests { @@ -2028,7 +2168,7 @@ func (s *testEvaluatorSuite) TestQuote(c *C) { {`萌萌哒(๑•ᴗ•๑)😊`, `'萌萌哒(๑•ᴗ•๑)😊'`}, {`㍿㌍㍑㌫`, `'㍿㌍㍑㌫'`}, {string([]byte{0, 26}), `'\0\Z'`}, - {nil, nil}, + {nil, "NULL"}, } for _, t := range tbl { diff --git a/expression/builtin_time.go b/expression/builtin_time.go index d9d659074dba7..9b88af516f94c 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -1572,26 +1572,30 @@ func (c *fromUnixTimeFunctionClass) getFunction(ctx sessionctx.Context, args []E _, isArg0Con := args[0].(*Constant) isArg0Str := args[0].GetType().EvalType() == types.ETString bf := newBaseBuiltinFuncWithTp(ctx, args, retTp, argTps...) - if len(args) == 1 { - if isArg0Str { - bf.tp.Decimal = types.MaxFsp - } else if isArg0Con { - arg0, _, err1 := args[0].EvalDecimal(ctx, chunk.Row{}) - if err1 != nil { - return sig, err1 - } + + if len(args) > 1 { + bf.tp.Flen = args[1].GetType().Flen + return &builtinFromUnixTime2ArgSig{bf}, nil + } + + // Calculate the time fsp. + switch { + case isArg0Str: + bf.tp.Decimal = int(types.MaxFsp) + case isArg0Con: + arg0, arg0IsNull, err0 := args[0].EvalDecimal(ctx, chunk.Row{}) + if err0 != nil { + return nil, err0 + } + + bf.tp.Decimal = int(types.MaxFsp) + if !arg0IsNull { fsp := int(arg0.GetDigitsFrac()) - if fsp > types.MaxFsp { - fsp = types.MaxFsp - } - bf.tp.Decimal = fsp + bf.tp.Decimal = mathutil.Min(fsp, int(types.MaxFsp)) } - sig = &builtinFromUnixTime1ArgSig{bf} - } else { - bf.tp.Flen = args[1].GetType().Flen - sig = &builtinFromUnixTime2ArgSig{bf} } - return sig, nil + + return &builtinFromUnixTime1ArgSig{bf}, nil } func evalFromUnixTime(ctx sessionctx.Context, fsp int, row chunk.Row, arg Expression) (res types.Time, isNull bool, err error) { @@ -1785,6 +1789,7 @@ func (c *strToDateFunctionClass) getRetTp(ctx sessionctx.Context, arg Expression if err != nil || isNull { return } + isDuration, isDate := types.GetFormatType(format) if isDuration && !isDate { tp = mysql.TypeDuration @@ -2029,9 +2034,9 @@ func (b *builtinCurrentDateSig) Clone() builtinFunc { // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_curdate func (b *builtinCurrentDateSig) evalTime(row chunk.Row) (d types.Time, isNull bool, err error) { tz := b.ctx.GetSessionVars().Location() - var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs - if nowTs.Equal(time.Time{}) { - *nowTs = time.Now() + nowTs, err := getStmtTimestamp(b.ctx) + if err != nil { + return types.Time{}, true, err } year, month, day := nowTs.In(tz).Date() result := types.Time{ @@ -2088,9 +2093,9 @@ func (b *builtinCurrentTime0ArgSig) Clone() builtinFunc { func (b *builtinCurrentTime0ArgSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { tz := b.ctx.GetSessionVars().Location() - var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs - if nowTs.Equal(time.Time{}) { - *nowTs = time.Now() + nowTs, err := getStmtTimestamp(b.ctx) + if err != nil { + return types.Duration{}, true, err } dur := nowTs.In(tz).Format(types.TimeFormat) res, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, dur, types.MinFsp) @@ -2116,9 +2121,9 @@ func (b *builtinCurrentTime1ArgSig) evalDuration(row chunk.Row) (types.Duration, return types.Duration{}, true, err } tz := b.ctx.GetSessionVars().Location() - var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs - if nowTs.Equal(time.Time{}) { - *nowTs = time.Now() + nowTs, err := getStmtTimestamp(b.ctx) + if err != nil { + return types.Duration{}, true, err } dur := nowTs.In(tz).Format(types.TimeFSPFormat) res, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, dur, int(fsp)) @@ -2258,9 +2263,9 @@ func (b *builtinUTCDateSig) Clone() builtinFunc { // evalTime evals UTC_DATE, UTC_DATE(). // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-date func (b *builtinUTCDateSig) evalTime(row chunk.Row) (types.Time, bool, error) { - var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs - if nowTs.Equal(time.Time{}) { - *nowTs = time.Now() + nowTs, err := getStmtTimestamp(b.ctx) + if err != nil { + return types.Time{}, true, err } year, month, day := nowTs.UTC().Date() result := types.Time{ @@ -2319,9 +2324,9 @@ func (c *utcTimestampFunctionClass) getFunction(ctx sessionctx.Context, args []E } func evalUTCTimestampWithFsp(ctx sessionctx.Context, fsp int) (types.Time, bool, error) { - var nowTs = &ctx.GetSessionVars().StmtCtx.NowTs - if nowTs.Equal(time.Time{}) { - *nowTs = time.Now() + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.Time{}, true, err } result, err := convertTimeToMysqlTime(nowTs.UTC(), fsp, types.ModeHalfEven) if err != nil { @@ -2406,13 +2411,9 @@ func (c *nowFunctionClass) getFunction(ctx sessionctx.Context, args []Expression } func evalNowWithFsp(ctx sessionctx.Context, fsp int) (types.Time, bool, error) { - var sysTs = &ctx.GetSessionVars().StmtCtx.SysTs - if sysTs.Equal(time.Time{}) { - var err error - *sysTs, err = getSystemTimestamp(ctx) - if err != nil { - return types.Time{}, true, err - } + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.Time{}, true, err } // In MySQL's implementation, now() will truncate the result instead of rounding it. @@ -2423,7 +2424,7 @@ func evalNowWithFsp(ctx sessionctx.Context, fsp int) (types.Time, bool, error) { // +----------------------------+-------------------------+---------------------+ // | 2019-03-25 15:57:56.612966 | 2019-03-25 15:57:56.612 | 2019-03-25 15:57:56 | // +----------------------------+-------------------------+---------------------+ - result, err := convertTimeToMysqlTime(*sysTs, fsp, types.ModeTruncate) + result, err := convertTimeToMysqlTime(nowTs, fsp, types.ModeTruncate) if err != nil { return types.Time{}, true, err } @@ -2657,32 +2658,45 @@ func (du *baseDateArithmitical) getIntervalFromString(ctx sessionctx.Context, ar } func (du *baseDateArithmitical) getIntervalFromDecimal(ctx sessionctx.Context, args []Expression, row chunk.Row, unit string) (string, bool, error) { - interval, isNull, err := args[1].EvalString(ctx, row) + decimal, isNull, err := args[1].EvalDecimal(ctx, row) if isNull || err != nil { return "", true, err } + interval := decimal.String() switch strings.ToUpper(unit) { - case "HOUR_MINUTE", "MINUTE_SECOND": - interval = strings.Replace(interval, ".", ":", -1) - case "YEAR_MONTH": - interval = strings.Replace(interval, ".", "-", -1) - case "DAY_HOUR": - interval = strings.Replace(interval, ".", " ", -1) - case "DAY_MINUTE": - interval = "0 " + strings.Replace(interval, ".", ":", -1) - case "DAY_SECOND": - interval = "0 00:" + strings.Replace(interval, ".", ":", -1) - case "DAY_MICROSECOND": - interval = "0 00:00:" + interval - case "HOUR_MICROSECOND": - interval = "00:00:" + interval - case "HOUR_SECOND": - interval = "00:" + strings.Replace(interval, ".", ":", -1) - case "MINUTE_MICROSECOND": - interval = "00:" + interval - case "SECOND_MICROSECOND": - /* keep interval as original decimal */ + case "HOUR_MINUTE", "MINUTE_SECOND", "YEAR_MONTH", "DAY_HOUR", "DAY_MINUTE", + "DAY_SECOND", "DAY_MICROSECOND", "HOUR_MICROSECOND", "HOUR_SECOND", "MINUTE_MICROSECOND", "SECOND_MICROSECOND": + neg := false + if interval != "" && interval[0] == '-' { + neg = true + interval = interval[1:] + } + switch strings.ToUpper(unit) { + case "HOUR_MINUTE", "MINUTE_SECOND": + interval = strings.Replace(interval, ".", ":", -1) + case "YEAR_MONTH": + interval = strings.Replace(interval, ".", "-", -1) + case "DAY_HOUR": + interval = strings.Replace(interval, ".", " ", -1) + case "DAY_MINUTE": + interval = "0 " + strings.Replace(interval, ".", ":", -1) + case "DAY_SECOND": + interval = "0 00:" + strings.Replace(interval, ".", ":", -1) + case "DAY_MICROSECOND": + interval = "0 00:00:" + interval + case "HOUR_MICROSECOND": + interval = "00:00:" + interval + case "HOUR_SECOND": + interval = "00:" + strings.Replace(interval, ".", ":", -1) + case "MINUTE_MICROSECOND": + interval = "00:" + interval + case "SECOND_MICROSECOND": + /* keep interval as original decimal */ + } + if neg { + interval = "-" + interval + } case "SECOND": // Decimal's EvalString is like %f format. interval, isNull, err = args[1].EvalString(ctx, row) @@ -2714,7 +2728,7 @@ func (du *baseDateArithmitical) getIntervalFromReal(ctx sessionctx.Context, args if isNull || err != nil { return "", true, err } - return strconv.FormatFloat(interval, 'f', -1, 64), false, nil + return strconv.FormatFloat(interval, 'f', args[1].GetType().Decimal, 64), false, nil } func (du *baseDateArithmitical) add(ctx sessionctx.Context, date types.Time, interval string, unit string) (types.Time, bool, error) { @@ -2737,6 +2751,10 @@ func (du *baseDateArithmitical) add(ctx sessionctx.Context, date types.Time, int date.Fsp = 6 } + if goTime.Year() < 0 || goTime.Year() > (1<<16-1) { + return types.Time{}, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) + } + date.Time = types.FromGoTime(goTime) overflow, err := types.DateTimeIsOverflow(ctx.GetSessionVars().StmtCtx, date) if err := handleInvalidTimeError(ctx, err); err != nil { @@ -2760,6 +2778,18 @@ func (du *baseDateArithmitical) addDuration(ctx sessionctx.Context, d types.Dura return retDur, false, nil } +func (du *baseDateArithmitical) subDuration(ctx sessionctx.Context, d types.Duration, interval string, unit string) (types.Duration, bool, error) { + dur, err := types.ExtractDurationValue(unit, interval) + if err != nil { + return types.ZeroDuration, true, handleInvalidTimeError(ctx, err) + } + retDur, err := d.Sub(dur) + if err != nil { + return types.ZeroDuration, true, err + } + return retDur, false, nil +} + func (du *baseDateArithmitical) sub(ctx sessionctx.Context, date types.Time, interval string, unit string) (types.Time, bool, error) { year, month, day, nano, err := types.ParseDurationValue(unit, interval) if err := handleInvalidTimeError(ctx, err); err != nil { @@ -2782,6 +2812,10 @@ func (du *baseDateArithmitical) sub(ctx sessionctx.Context, date types.Time, int date.Fsp = 6 } + if goTime.Year() < 0 || goTime.Year() > (1<<16-1) { + return types.Time{}, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) + } + date.Time = types.FromGoTime(goTime) overflow, err := types.DateTimeIsOverflow(ctx.GetSessionVars().StmtCtx, date) if err := handleInvalidTimeError(ctx, err); err != nil { @@ -2914,6 +2948,11 @@ func (c *addDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres baseBuiltinFunc: bf, baseDateArithmitical: newDateArighmeticalUtil(), } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETReal: + sig = &builtinAddDateDurationRealSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETDecimal: sig = &builtinAddDateDurationDecimalSig{ baseBuiltinFunc: bf, @@ -3324,6 +3363,12 @@ type builtinAddDateDurationStringSig struct { baseDateArithmitical } +func (b *builtinAddDateDurationStringSig) Clone() builtinFunc { + newSig := &builtinAddDateDurationStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + func (b *builtinAddDateDurationStringSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { unit, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { @@ -3349,6 +3394,12 @@ type builtinAddDateDurationIntSig struct { baseDateArithmitical } +func (b *builtinAddDateDurationIntSig) Clone() builtinFunc { + newSig := &builtinAddDateDurationIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + func (b *builtinAddDateDurationIntSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { unit, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { @@ -3373,6 +3424,12 @@ type builtinAddDateDurationDecimalSig struct { baseDateArithmitical } +func (b *builtinAddDateDurationDecimalSig) Clone() builtinFunc { + newSig := &builtinAddDateDurationDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + func (b *builtinAddDateDurationDecimalSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { unit, isNull, err := b.args[2].EvalString(b.ctx, row) if isNull || err != nil { @@ -3392,6 +3449,36 @@ func (b *builtinAddDateDurationDecimalSig) evalDuration(row chunk.Row) (types.Du return result, isNull || err != nil, err } +type builtinAddDateDurationRealSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinAddDateDurationRealSig) Clone() builtinFunc { + newSig := &builtinAddDateDurationRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinAddDateDurationRealSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + interval, isNull, err := b.getIntervalFromReal(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.addDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + type subDateFunctionClass struct { baseFunctionClass } @@ -3402,7 +3489,7 @@ func (c *subDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres } dateEvalTp := args[0].GetType().EvalType() - if dateEvalTp != types.ETString && dateEvalTp != types.ETInt { + if dateEvalTp != types.ETString && dateEvalTp != types.ETInt && dateEvalTp != types.ETDuration { dateEvalTp = types.ETDatetime } @@ -3412,8 +3499,35 @@ func (c *subDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres } argTps := []types.EvalType{dateEvalTp, intervalEvalTp, types.ETString} - bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDatetime, argTps...) - bf.tp.Flen, bf.tp.Decimal = mysql.MaxDatetimeFullWidth, types.UnspecifiedLength + var bf baseBuiltinFunc + if dateEvalTp == types.ETDuration { + unit, _, err := args[2].EvalString(ctx, chunk.Row{}) + if err != nil { + return nil, err + } + internalFsp := 0 + switch unit { + // If the unit has micro second, then the fsp must be the MaxFsp. + case "MICROSECOND", "SECOND_MICROSECOND", "MINUTE_MICROSECOND", "HOUR_MICROSECOND", "DAY_MICROSECOND": + internalFsp = types.MaxFsp + // If the unit is second, the fsp is related with the arg[1]'s. + case "SECOND": + internalFsp = types.MaxFsp + if intervalEvalTp != types.ETString { + internalFsp = mathutil.Min(args[1].GetType().Decimal, types.MaxFsp) + } + // Otherwise, the fsp should be 0. + } + bf = newBaseBuiltinFuncWithTp(ctx, args, types.ETDuration, argTps...) + arg0Dec, err := getExpressionFsp(ctx, args[0]) + if err != nil { + return nil, err + } + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDurationWidthWithFsp, mathutil.Max(arg0Dec, internalFsp) + } else { + bf = newBaseBuiltinFuncWithTp(ctx, args, types.ETDatetime, argTps...) + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDatetimeFullWidth, types.UnspecifiedLength + } switch { case dateEvalTp == types.ETString && intervalEvalTp == types.ETString: @@ -3476,6 +3590,26 @@ func (c *subDateFunctionClass) getFunction(ctx sessionctx.Context, args []Expres baseBuiltinFunc: bf, baseDateArithmitical: newDateArighmeticalUtil(), } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETString: + sig = &builtinSubDateDurationStringSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETInt: + sig = &builtinSubDateDurationIntSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETReal: + sig = &builtinSubDateDurationRealSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETDecimal: + sig = &builtinSubDateDurationDecimalSig{ + baseBuiltinFunc: bf, + baseDateArithmitical: newDateArighmeticalUtil(), + } } return sig, nil } @@ -3874,6 +4008,129 @@ func (b *builtinSubDateDatetimeDecimalSig) evalTime(row chunk.Row) (types.Time, return result, isNull || err != nil, err } +type builtinSubDateDurationStringSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinSubDateDurationStringSig) Clone() builtinFunc { + newSig := &builtinSubDateDurationStringSig{baseDateArithmitical: b.baseDateArithmitical} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinSubDateDurationStringSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + interval, isNull, err := b.getIntervalFromString(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.subDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + +type builtinSubDateDurationIntSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinSubDateDurationIntSig) Clone() builtinFunc { + newSig := &builtinSubDateDurationIntSig{baseDateArithmitical: b.baseDateArithmitical} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinSubDateDurationIntSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + interval, isNull, err := b.getIntervalFromInt(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.subDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + +type builtinSubDateDurationDecimalSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinSubDateDurationDecimalSig) Clone() builtinFunc { + newSig := &builtinSubDateDurationDecimalSig{baseDateArithmitical: b.baseDateArithmitical} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinSubDateDurationDecimalSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + interval, isNull, err := b.getIntervalFromDecimal(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.subDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + +type builtinSubDateDurationRealSig struct { + baseBuiltinFunc + baseDateArithmitical +} + +func (b *builtinSubDateDurationRealSig) Clone() builtinFunc { + newSig := &builtinSubDateDurationRealSig{baseDateArithmitical: b.baseDateArithmitical} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinSubDateDurationRealSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(b.ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + interval, isNull, err := b.getIntervalFromReal(b.ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.subDuration(b.ctx, dur, interval, unit) + return result, isNull || err != nil, err +} + type timestampDiffFunctionClass struct { baseFunctionClass } @@ -4032,11 +4289,11 @@ func (b *builtinUnixTimestampCurrentSig) Clone() builtinFunc { // evalInt evals a UNIX_TIMESTAMP(). // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_unix-timestamp func (b *builtinUnixTimestampCurrentSig) evalInt(row chunk.Row) (int64, bool, error) { - var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs - if nowTs.Equal(time.Time{}) { - *nowTs = time.Now() + nowTs, err := getStmtTimestamp(b.ctx) + if err != nil { + return 0, true, err } - dec, err := goTimeToMysqlUnixTimestamp(*nowTs, 1) + dec, err := goTimeToMysqlUnixTimestamp(nowTs, 1) if err != nil { return 0, true, err } @@ -4146,7 +4403,7 @@ func (c *timestampFunctionClass) getFunction(ctx sessionctx.Context, args []Expr isFloat = true } bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETDatetime, evalTps...) - bf.tp.Decimal, bf.tp.Flen = fsp, 19 + bf.tp.Decimal, bf.tp.Flen = -1, 19 if fsp != 0 { bf.tp.Flen += 1 + fsp } @@ -4554,6 +4811,10 @@ func (b *builtinAddDatetimeAndStringSig) evalTime(row chunk.Row) (types.Time, bo sc := b.ctx.GetSessionVars().StmtCtx arg1, err := types.ParseDuration(sc, s, types.GetFsp(s)) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return types.ZeroDatetime, true, nil + } return types.ZeroDatetime, true, err } result, err := arg0.Add(sc, arg1) @@ -4628,8 +4889,13 @@ func (b *builtinAddDurationAndStringSig) evalDuration(row chunk.Row) (types.Dura if !isDuration(s) { return types.ZeroDuration, true, nil } - arg1, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, s, types.GetFsp(s)) + sc := b.ctx.GetSessionVars().StmtCtx + arg1, err := types.ParseDuration(sc, s, types.GetFsp(s)) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return types.ZeroDuration, true, nil + } return types.ZeroDuration, true, err } result, err := arg0.Add(arg1) @@ -4684,6 +4950,10 @@ func (b *builtinAddStringAndDurationSig) evalString(row chunk.Row) (result strin if isDuration(arg0) { result, err = strDurationAddDuration(sc, arg0, arg1) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return "", true, nil + } return "", true, err } return result, false, nil @@ -4724,11 +4994,19 @@ func (b *builtinAddStringAndStringSig) evalString(row chunk.Row) (result string, sc := b.ctx.GetSessionVars().StmtCtx arg1, err = types.ParseDuration(sc, arg1Str, getFsp4TimeAddSub(arg1Str)) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return "", true, nil + } return "", true, err } if isDuration(arg0) { result, err = strDurationAddDuration(sc, arg0, arg1) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return "", true, nil + } return "", true, err } return result, false, nil @@ -4786,8 +5064,13 @@ func (b *builtinAddDateAndStringSig) evalString(row chunk.Row) (string, bool, er if !isDuration(s) { return "", true, nil } - arg1, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, s, getFsp4TimeAddSub(s)) + sc := b.ctx.GetSessionVars().StmtCtx + arg1, err := types.ParseDuration(sc, s, getFsp4TimeAddSub(s)) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return "", true, nil + } return "", true, err } result, err := arg0.Add(arg1) @@ -4854,21 +5137,22 @@ func (b *builtinConvertTzSig) Clone() builtinFunc { } // evalTime evals CONVERT_TZ(dt,from_tz,to_tz). +// `CONVERT_TZ` function returns NULL if the arguments are invalid. // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_convert-tz func (b *builtinConvertTzSig) evalTime(row chunk.Row) (types.Time, bool, error) { dt, isNull, err := b.args[0].EvalTime(b.ctx, row) if isNull || err != nil { - return types.Time{}, true, err + return types.Time{}, true, nil } fromTzStr, isNull, err := b.args[1].EvalString(b.ctx, row) - if isNull || err != nil { - return types.Time{}, true, err + if isNull || err != nil || fromTzStr == "" { + return types.Time{}, true, nil } toTzStr, isNull, err := b.args[2].EvalString(b.ctx, row) - if isNull || err != nil { - return types.Time{}, true, err + if isNull || err != nil || toTzStr == "" { + return types.Time{}, true, nil } fromTzMatched := b.timezoneRegex.MatchString(fromTzStr) @@ -4877,17 +5161,17 @@ func (b *builtinConvertTzSig) evalTime(row chunk.Row) (types.Time, bool, error) if !fromTzMatched && !toTzMatched { fromTz, err := time.LoadLocation(fromTzStr) if err != nil { - return types.Time{}, true, err + return types.Time{}, true, nil } toTz, err := time.LoadLocation(toTzStr) if err != nil { - return types.Time{}, true, err + return types.Time{}, true, nil } t, err := dt.Time.GoTime(fromTz) if err != nil { - return types.Time{}, true, err + return types.Time{}, true, nil } return types.Time{ @@ -4899,7 +5183,7 @@ func (b *builtinConvertTzSig) evalTime(row chunk.Row) (types.Time, bool, error) if fromTzMatched && toTzMatched { t, err := dt.Time.GoTime(time.Local) if err != nil { - return types.Time{}, true, err + return types.Time{}, true, nil } return types.Time{ @@ -5458,6 +5742,10 @@ func (b *builtinSubDatetimeAndStringSig) evalTime(row chunk.Row) (types.Time, bo sc := b.ctx.GetSessionVars().StmtCtx arg1, err := types.ParseDuration(sc, s, types.GetFsp(s)) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return types.ZeroDatetime, true, nil + } return types.ZeroDatetime, true, err } arg1time, err := arg1.ConvertToTime(sc, mysql.TypeDatetime) @@ -5514,6 +5802,10 @@ func (b *builtinSubStringAndDurationSig) evalString(row chunk.Row) (result strin if isDuration(arg0) { result, err = strDurationSubDuration(sc, arg0, arg1) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return "", true, nil + } return "", true, err } return result, false, nil @@ -5551,14 +5843,22 @@ func (b *builtinSubStringAndStringSig) evalString(row chunk.Row) (result string, if isNull || err != nil { return "", isNull, err } - arg1, err = types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, s, getFsp4TimeAddSub(s)) + sc := b.ctx.GetSessionVars().StmtCtx + arg1, err = types.ParseDuration(sc, s, getFsp4TimeAddSub(s)) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return "", true, nil + } return "", true, err } - sc := b.ctx.GetSessionVars().StmtCtx if isDuration(arg0) { result, err = strDurationSubDuration(sc, arg0, arg1) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return "", true, nil + } return "", true, err } return result, false, nil @@ -5635,8 +5935,13 @@ func (b *builtinSubDurationAndStringSig) evalDuration(row chunk.Row) (types.Dura if !isDuration(s) { return types.ZeroDuration, true, nil } - arg1, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, s, types.GetFsp(s)) + sc := b.ctx.GetSessionVars().StmtCtx + arg1, err := types.ParseDuration(sc, s, types.GetFsp(s)) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return types.ZeroDuration, true, nil + } return types.ZeroDuration, true, err } result, err := arg0.Sub(arg1) @@ -5708,8 +6013,13 @@ func (b *builtinSubDateAndStringSig) evalString(row chunk.Row) (string, bool, er if !isDuration(s) { return "", true, nil } - arg1, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, s, getFsp4TimeAddSub(s)) + sc := b.ctx.GetSessionVars().StmtCtx + arg1, err := types.ParseDuration(sc, s, getFsp4TimeAddSub(s)) if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + sc.AppendWarning(err) + return "", true, nil + } return "", true, err } result, err := arg0.Sub(arg1) @@ -6041,9 +6351,9 @@ func (b *builtinUTCTimeWithoutArgSig) Clone() builtinFunc { // evalDuration evals a builtinUTCTimeWithoutArgSig. // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-time func (b *builtinUTCTimeWithoutArgSig) evalDuration(row chunk.Row) (types.Duration, bool, error) { - var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs - if nowTs.Equal(time.Time{}) { - *nowTs = time.Now() + nowTs, err := getStmtTimestamp(b.ctx) + if err != nil { + return types.Duration{}, true, err } v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, nowTs.UTC().Format(types.TimeFormat), 0) return v, false, err @@ -6072,9 +6382,9 @@ func (b *builtinUTCTimeWithArgSig) evalDuration(row chunk.Row) (types.Duration, if fsp < int64(types.MinFsp) { return types.Duration{}, true, errors.Errorf("Invalid negative %d specified, must in [0, 6].", fsp) } - var nowTs = &b.ctx.GetSessionVars().StmtCtx.NowTs - if nowTs.Equal(time.Time{}) { - *nowTs = time.Now() + nowTs, err := getStmtTimestamp(b.ctx) + if err != nil { + return types.Duration{}, true, err } v, err := types.ParseDuration(b.ctx.GetSessionVars().StmtCtx, nowTs.UTC().Format(types.TimeFSPFormat), int(fsp)) return v, false, err diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index bb837c1c0f6a4..ed7e784cc231a 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -763,8 +764,7 @@ func (s *testEvaluatorSuite) TestTime(c *C) { } func resetStmtContext(ctx sessionctx.Context) { - ctx.GetSessionVars().StmtCtx.NowTs = time.Time{} - ctx.GetSessionVars().StmtCtx.SysTs = time.Time{} + ctx.GetSessionVars().StmtCtx.ResetNowTs() } func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) { @@ -783,9 +783,9 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) { {funcs[ast.Now], func() time.Time { return time.Now() }}, {funcs[ast.UTCTimestamp], func() time.Time { return time.Now().UTC() }}, } { - resetStmtContext(s.ctx) f, err := x.fc.getFunction(s.ctx, s.datumsToConstants(nil)) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err := evalBuiltinFunc(f, chunk.Row{}) ts := x.now() c.Assert(err, IsNil) @@ -795,9 +795,9 @@ func (s *testEvaluatorSuite) TestNowAndUTCTimestamp(c *C) { c.Assert(strings.Contains(t.String(), "."), IsFalse) c.Assert(ts.Sub(gotime(t, ts.Location())), LessEqual, time.Second) - resetStmtContext(s.ctx) f, err = x.fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(6))) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err = evalBuiltinFunc(f, chunk.Row{}) ts = x.now() c.Assert(err, IsNil) @@ -937,6 +937,33 @@ func (s *testEvaluatorSuite) TestAddTimeSig(c *C) { c.Assert(result, Equals, t.expect) } + tblWarning := []struct { + Input interface{} + InputDuration interface{} + warning *terror.Error + }{ + {"0", "-32073", types.ErrTruncatedWrongVal}, + {"-32073", "0", types.ErrTruncatedWrongVal}, + {types.ZeroDuration, "-32073", types.ErrTruncatedWrongVal}, + {"-32073", types.ZeroDuration, types.ErrTruncatedWrongVal}, + {types.CurrentTime(mysql.TypeTimestamp), "-32073", types.ErrTruncatedWrongVal}, + {types.CurrentTime(mysql.TypeDate), "-32073", types.ErrTruncatedWrongVal}, + {types.CurrentTime(mysql.TypeDatetime), "-32073", types.ErrTruncatedWrongVal}, + } + for i, t := range tblWarning { + tmpInput := types.NewDatum(t.Input) + tmpInputDuration := types.NewDatum(t.InputDuration) + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{tmpInput, tmpInputDuration})) + c.Assert(err, IsNil) + d, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + result, _ := d.ToString() + c.Assert(result, Equals, "") + c.Assert(d.IsNull(), Equals, true) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, i+1) + c.Assert(terror.ErrorEqual(t.warning, warnings[i].Err), IsTrue, Commentf("err %v", warnings[i].Err)) + } } func (s *testEvaluatorSuite) TestSubTimeSig(c *C) { @@ -1002,6 +1029,34 @@ func (s *testEvaluatorSuite) TestSubTimeSig(c *C) { result, _ := d.ToString() c.Assert(result, Equals, t.expect) } + + tblWarning := []struct { + Input interface{} + InputDuration interface{} + warning *terror.Error + }{ + {"0", "-32073", types.ErrTruncatedWrongVal}, + {"-32073", "0", types.ErrTruncatedWrongVal}, + {types.ZeroDuration, "-32073", types.ErrTruncatedWrongVal}, + {"-32073", types.ZeroDuration, types.ErrTruncatedWrongVal}, + {types.CurrentTime(mysql.TypeTimestamp), "-32073", types.ErrTruncatedWrongVal}, + {types.CurrentTime(mysql.TypeDate), "-32073", types.ErrTruncatedWrongVal}, + {types.CurrentTime(mysql.TypeDatetime), "-32073", types.ErrTruncatedWrongVal}, + } + for i, t := range tblWarning { + tmpInput := types.NewDatum(t.Input) + tmpInputDuration := types.NewDatum(t.InputDuration) + f, err := fc.getFunction(s.ctx, s.datumsToConstants([]types.Datum{tmpInput, tmpInputDuration})) + c.Assert(err, IsNil) + d, err := evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + result, _ := d.ToString() + c.Assert(result, Equals, "") + c.Assert(d.IsNull(), Equals, true) + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, i+1) + c.Assert(terror.ErrorEqual(t.warning, warnings[i].Err), IsTrue, Commentf("err %v", warnings[i].Err)) + } } func (s *testEvaluatorSuite) TestSysDate(c *C) { @@ -1016,6 +1071,7 @@ func (s *testEvaluatorSuite) TestSysDate(c *C) { variable.SetSessionSystemVar(ctx.GetSessionVars(), "timestamp", timezone) f, err := fc.getFunction(ctx, s.datumsToConstants(nil)) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err := evalBuiltinFunc(f, chunk.Row{}) last := time.Now() c.Assert(err, IsNil) @@ -1026,6 +1082,7 @@ func (s *testEvaluatorSuite) TestSysDate(c *C) { last := time.Now() f, err := fc.getFunction(ctx, s.datumsToConstants(types.MakeDatums(6))) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) n := v.GetMysqlTime() @@ -1158,6 +1215,7 @@ func (s *testEvaluatorSuite) TestCurrentDate(c *C) { fc := funcs[ast.CurrentDate] f, err := fc.getFunction(mock.NewContext(), s.datumsToConstants(nil)) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) n := v.GetMysqlTime() @@ -1172,6 +1230,7 @@ func (s *testEvaluatorSuite) TestCurrentTime(c *C) { fc := funcs[ast.CurrentTime] f, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(nil))) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) n := v.GetMysqlDuration() @@ -1180,6 +1239,7 @@ func (s *testEvaluatorSuite) TestCurrentTime(c *C) { f, err = fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(3))) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) n = v.GetMysqlDuration() @@ -1188,6 +1248,7 @@ func (s *testEvaluatorSuite) TestCurrentTime(c *C) { f, err = fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(6))) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) n = v.GetMysqlDuration() @@ -1214,9 +1275,9 @@ func (s *testEvaluatorSuite) TestUTCTime(c *C) { }{{0, 8}, {3, 12}, {6, 15}, {-1, 0}, {7, 0}} for _, test := range tests { - resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, s.datumsToConstants(types.MakeDatums(test.param))) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err := evalBuiltinFunc(f, chunk.Row{}) if test.expect > 0 { c.Assert(err, IsNil) @@ -1230,6 +1291,7 @@ func (s *testEvaluatorSuite) TestUTCTime(c *C) { f, err := fc.getFunction(s.ctx, make([]Expression, 0)) c.Assert(err, IsNil) + resetStmtContext(s.ctx) v, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) n := v.GetMysqlDuration() @@ -1241,9 +1303,9 @@ func (s *testEvaluatorSuite) TestUTCDate(c *C) { defer testleak.AfterTest(c)() last := time.Now().UTC() fc := funcs[ast.UTCDate] - resetStmtContext(mock.NewContext()) f, err := fc.getFunction(mock.NewContext(), s.datumsToConstants(nil)) c.Assert(err, IsNil) + resetStmtContext(mock.NewContext()) v, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) n := v.GetMysqlTime() @@ -1580,9 +1642,9 @@ func (s *testEvaluatorSuite) TestTimestampDiff(c *C) { func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) { // Test UNIX_TIMESTAMP(). fc := funcs[ast.UnixTimestamp] - resetStmtContext(s.ctx) f, err := fc.getFunction(s.ctx, nil) c.Assert(err, IsNil) + resetStmtContext(s.ctx) d, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) c.Assert(d.GetInt64()-time.Now().Unix(), GreaterEqual, int64(-1)) @@ -1597,9 +1659,9 @@ func (s *testEvaluatorSuite) TestUnixTimestamp(c *C) { n := types.Datum{} n.SetMysqlTime(now) args := []types.Datum{n} - resetStmtContext(s.ctx) f, err = fc.getFunction(s.ctx, s.datumsToConstants(args)) c.Assert(err, IsNil) + resetStmtContext(s.ctx) d, err = evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) val, _ := d.GetMysqlDecimal().ToInt() @@ -1764,7 +1826,37 @@ func (s *testEvaluatorSuite) TestDateArithFuncs(c *C) { c.Assert(err, IsNil) c.Assert(v.GetMysqlTime().String(), Equals, test.expected) } + + testOverflowYears := []struct { + input string + year int + }{ + {"2008-11-23", -1465647104}, + {"2008-11-23", 1465647104}, + } + + for _, test := range testOverflowYears { + args = types.MakeDatums(test.input, test.year, "YEAR") + f, err = fcAdd.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v.IsNull(), IsTrue) + } + + for _, test := range testOverflowYears { + args = types.MakeDatums(test.input, test.year, "YEAR") + f, err = fcSub.getFunction(s.ctx, s.datumsToConstants(args)) + c.Assert(err, IsNil) + c.Assert(f, NotNil) + v, err = evalBuiltinFunc(f, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(v.IsNull(), IsTrue) + } + testDurations := []struct { + fc functionClass dur string fsp int unit string @@ -1772,19 +1864,100 @@ func (s *testEvaluatorSuite) TestDateArithFuncs(c *C) { expected string }{ { + fc: fcAdd, dur: "00:00:00", fsp: 0, unit: "MICROSECOND", format: "100", expected: "00:00:00.000100", }, + { + fc: fcAdd, + dur: "00:00:00", + fsp: 0, + unit: "MICROSECOND", + format: 100.0, + expected: "00:00:00.000100", + }, + { + fc: fcSub, + dur: "00:00:01", + fsp: 0, + unit: "MICROSECOND", + format: "100", + expected: "00:00:00.999900", + }, + { + fc: fcAdd, + dur: "00:00:00", + fsp: 0, + unit: "DAY", + format: "1", + expected: "24:00:00", + }, + { + fc: fcAdd, + dur: "00:00:00", + fsp: 0, + unit: "SECOND", + format: 1, + expected: "00:00:01", + }, + { + fc: fcAdd, + dur: "00:00:00", + fsp: 0, + unit: "DAY", + format: types.NewDecFromInt(1), + expected: "24:00:00", + }, + { + fc: fcAdd, + dur: "00:00:00", + fsp: 0, + unit: "DAY", + format: 1.0, + expected: "24:00:00", + }, + { + fc: fcSub, + dur: "26:00:00", + fsp: 0, + unit: "DAY", + format: "1", + expected: "02:00:00", + }, + { + fc: fcSub, + dur: "26:00:00", + fsp: 0, + unit: "DAY", + format: 1, + expected: "02:00:00", + }, + { + fc: fcSub, + dur: "26:00:00", + fsp: 0, + unit: "SECOND", + format: types.NewDecFromInt(1), + expected: "25:59:59", + }, + { + fc: fcSub, + dur: "27:00:00", + fsp: 0, + unit: "DAY", + format: 1.0, + expected: "03:00:00", + }, } for _, tt := range testDurations { dur, _, ok, err := types.StrToDuration(nil, tt.dur, tt.fsp) c.Assert(err, IsNil) c.Assert(ok, IsTrue) args = types.MakeDatums(dur, tt.format, tt.unit) - f, err = fcAdd.getFunction(s.ctx, s.datumsToConstants(args)) + f, err = tt.fc.getFunction(s.ctx, s.datumsToConstants(args)) c.Assert(err, IsNil) c.Assert(f, NotNil) v, err = evalBuiltinFunc(f, chunk.Row{}) @@ -2355,8 +2528,8 @@ func (s *testEvaluatorSuite) TestSecToTime(c *C) { func (s *testEvaluatorSuite) TestConvertTz(c *C) { tests := []struct { t interface{} - fromTz string - toTz string + fromTz interface{} + toTz interface{} Success bool expect string }{ @@ -2368,11 +2541,20 @@ func (s *testEvaluatorSuite) TestConvertTz(c *C) { {"2004-01-01 12:00:00", "-00:00", "+13:00", true, "2004-01-02 01:00:00"}, {"2004-01-01 12:00:00", "-00:00", "-13:00", true, ""}, {"2004-01-01 12:00:00", "-00:00", "-12:88", true, ""}, - {"2004-01-01 12:00:00", "+10:82", "GMT", false, ""}, + {"2004-01-01 12:00:00", "+10:82", "GMT", true, ""}, {"2004-01-01 12:00:00", "+00:00", "GMT", true, ""}, {"2004-01-01 12:00:00", "GMT", "+00:00", true, ""}, {20040101, "+00:00", "+10:32", true, "2004-01-01 10:32:00"}, {3.14159, "+00:00", "+10:32", true, ""}, + {"2004-01-01 12:00:00", "", "GMT", true, ""}, + {"2004-01-01 12:00:00", "GMT", "", true, ""}, + {"2004-01-01 12:00:00", "a", "GMT", true, ""}, + {"2004-01-01 12:00:00", "0", "GMT", true, ""}, + {"2004-01-01 12:00:00", "GMT", "a", true, ""}, + {"2004-01-01 12:00:00", "GMT", "0", true, ""}, + {nil, "GMT", "+00:00", true, ""}, + {"2004-01-01 12:00:00", nil, "+00:00", true, ""}, + {"2004-01-01 12:00:00", "GMT", nil, true, ""}, } fc := funcs[ast.ConvertTz] for _, test := range tests { @@ -2380,8 +2562,8 @@ func (s *testEvaluatorSuite) TestConvertTz(c *C) { s.datumsToConstants( []types.Datum{ types.NewDatum(test.t), - types.NewStringDatum(test.fromTz), - types.NewStringDatum(test.toTz)})) + types.NewDatum(test.fromTz), + types.NewDatum(test.toTz)})) c.Assert(err, IsNil) d, err := evalBuiltinFunc(f, chunk.Row{}) if test.Success { @@ -2539,9 +2721,9 @@ func (s *testEvaluatorSuite) TestWithTimeZone(c *C) { for _, t := range tests { now := time.Now().In(sv.TimeZone) - resetStmtContext(s.ctx) f, err := funcs[t.method].getFunction(s.ctx, s.datumsToConstants(t.Input)) c.Assert(err, IsNil) + resetStmtContext(s.ctx) d, err := evalBuiltinFunc(f, chunk.Row{}) c.Assert(err, IsNil) result := t.convertToTime(d, sv.TimeZone) @@ -2585,3 +2767,29 @@ func (s *testEvaluatorSuite) TestTidbParseTso(c *C) { c.Assert(d.IsNull(), IsTrue) } } + +func (s *testEvaluatorSuite) TestGetIntervalFromDecimal(c *C) { + defer testleak.AfterTest(c)() + du := baseDateArithmitical{} + + tests := []struct { + param string + expect string + unit string + }{ + {"1.100", "1:100", "MINUTE_SECOND"}, + {"1.10000", "1-10000", "YEAR_MONTH"}, + {"1.10000", "1 10000", "DAY_HOUR"}, + {"11000", "0 00:00:11000", "DAY_MICROSECOND"}, + {"11000", "00:00:11000", "HOUR_MICROSECOND"}, + {"11.1000", "00:11:1000", "HOUR_SECOND"}, + {"1000", "00:1000", "MINUTE_MICROSECOND"}, + } + + for _, test := range tests { + interval, isNull, err := du.getIntervalFromDecimal(s.ctx, s.datumsToConstants([]types.Datum{types.NewDatum("CURRENT DATE"), types.NewDecimalDatum(newMyDecimal(c, test.param))}), chunk.Row{}, test.unit) + c.Assert(isNull, IsFalse) + c.Assert(err, IsNil) + c.Assert(interval, Equals, test.expect) + } +} diff --git a/expression/column.go b/expression/column.go index 99b94e0ea4414..493c51de01984 100644 --- a/expression/column.go +++ b/expression/column.go @@ -338,6 +338,16 @@ func (col *Column) resolveIndices(schema *Schema) error { return nil } +// ToInfo converts the expression.Column to model.ColumnInfo for casting values, +// beware it doesn't fill all the fields of the model.ColumnInfo. +func (col *Column) ToInfo() *model.ColumnInfo { + return &model.ColumnInfo{ + ID: col.ID, + Name: col.ColName, + FieldType: *col.RetType, + } +} + // Column2Exprs will transfer column slice to expression slice. func Column2Exprs(cols []*Column) []Expression { result := make([]Expression, 0, len(cols)) diff --git a/expression/constant.go b/expression/constant.go index c678bc40b8d78..4d61e92387f46 100644 --- a/expression/constant.go +++ b/expression/constant.go @@ -277,7 +277,6 @@ func (c *Constant) EvalJSON(ctx sessionctx.Context, _ chunk.Row) (json.BinaryJSO if err != nil { return json.BinaryJSON{}, true, err } - fmt.Println("const eval json", val.GetMysqlJSON().String()) c.Value.SetMysqlJSON(val.GetMysqlJSON()) c.GetType().Tp = mysql.TypeJSON } else { diff --git a/expression/constant_fold.go b/expression/constant_fold.go index adec69c78f8a0..9822b9225b8a3 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -29,6 +29,7 @@ func init() { specialFoldHandler = map[string]func(*ScalarFunction) (Expression, bool){ ast.If: ifFoldHandler, ast.Ifnull: ifNullFoldHandler, + ast.Case: caseWhenHandler, } } @@ -80,6 +81,59 @@ func ifNullFoldHandler(expr *ScalarFunction) (Expression, bool) { return expr, isDeferredConst } +func caseWhenHandler(expr *ScalarFunction) (Expression, bool) { + args, l := expr.GetArgs(), len(expr.GetArgs()) + var isDeferred, isDeferredConst, hasNonConstCondition bool + for i := 0; i < l-1; i += 2 { + expr.GetArgs()[i], isDeferred = foldConstant(args[i]) + isDeferredConst = isDeferredConst || isDeferred + if _, isConst := expr.GetArgs()[i].(*Constant); isConst && !hasNonConstCondition { + // If the condition is const and true, and the previous conditions + // has no expr, then the folded execution body is returned, otherwise + // the arguments of the casewhen are folded and replaced. + val, isNull, err := args[i].EvalInt(expr.GetCtx(), chunk.Row{}) + if err != nil { + return expr, false + } + if val != 0 && !isNull { + foldedExpr, isDeferred := foldConstant(args[i+1]) + isDeferredConst = isDeferredConst || isDeferred + if _, isConst := foldedExpr.(*Constant); isConst { + foldedExpr.GetType().Decimal = expr.GetType().Decimal + return foldedExpr, isDeferredConst + } + return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst + } + } else { + hasNonConstCondition = true + } + expr.GetArgs()[i+1], isDeferred = foldConstant(args[i+1]) + isDeferredConst = isDeferredConst || isDeferred + } + + if l%2 == 0 { + return expr, isDeferredConst + } + + // If the number of arguments in casewhen is odd, and the previous conditions + // is const and false, then the folded else execution body is returned. otherwise + // the execution body of the else are folded and replaced. + if !hasNonConstCondition { + foldedExpr, isDeferred := foldConstant(args[l-1]) + isDeferredConst = isDeferredConst || isDeferred + if _, isConst := foldedExpr.(*Constant); isConst { + foldedExpr.GetType().Decimal = expr.GetType().Decimal + return foldedExpr, isDeferredConst + } + return BuildCastFunction(expr.GetCtx(), foldedExpr, foldedExpr.GetType()), isDeferredConst + } + + expr.GetArgs()[l-1], isDeferred = foldConstant(args[l-1]) + isDeferredConst = isDeferredConst || isDeferred + + return expr, isDeferredConst +} + func foldConstant(expr Expression) (Expression, bool) { switch x := expr.(type) { case *ScalarFunction: diff --git a/expression/constant_propagation_test.go b/expression/constant_propagation_test.go index d457e85d9eb17..f8a9ab352dca2 100644 --- a/expression/constant_propagation_test.go +++ b/expression/constant_propagation_test.go @@ -66,7 +66,7 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { // Positive tests. tk.MustQuery("explain select * from t1 left join t2 on t1.a > t2.a and t1.a = 1;").Check(testkit.Rows( - "HashLeftJoin_6 33233333.33 root left outer join, inner:TableReader_11, left cond:[eq(test.t1.a, 1)]", + "HashLeftJoin_6 33233333.33 root CARTESIAN left outer join, inner:TableReader_11, left cond:[eq(test.t1.a, 1)]", "├─TableReader_8 10000.00 root data:TableScan_7", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableReader_11 3323.33 root data:Selection_10", @@ -74,7 +74,7 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 left join t2 on t1.a > t2.a where t1.a = 1;").Check(testkit.Rows( - "HashLeftJoin_7 33233.33 root left outer join, inner:TableReader_13", + "HashLeftJoin_7 33233.33 root CARTESIAN left outer join, inner:TableReader_13", "├─TableReader_10 10.00 root data:Selection_9", "│ └─Selection_9 10.00 cop eq(test.t1.a, 1)", "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", @@ -100,7 +100,7 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { " └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 right join t2 on t1.a > t2.a where t2.a = 1;").Check(testkit.Rows( - "HashRightJoin_7 33333.33 root right outer join, inner:TableReader_10", + "HashRightJoin_7 33333.33 root CARTESIAN right outer join, inner:TableReader_10", "├─TableReader_10 3333.33 root data:Selection_9", "│ └─Selection_9 3333.33 cop gt(test.t1.a, 1)", "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", @@ -126,7 +126,7 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { " └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 right join t2 on t1.a > t2.a and t2.a = 1;").Check(testkit.Rows( - "HashRightJoin_6 33333333.33 root right outer join, inner:TableReader_9, right cond:eq(test.t2.a, 1)", + "HashRightJoin_6 33333333.33 root CARTESIAN right outer join, inner:TableReader_9, right cond:eq(test.t2.a, 1)", "├─TableReader_9 3333.33 root data:Selection_8", "│ └─Selection_8 3333.33 cop gt(test.t1.a, 1)", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", @@ -143,7 +143,7 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 left join t2 on t1.a > t2.a and t2.a = 1;").Check(testkit.Rows( - "HashLeftJoin_6 100000.00 root left outer join, inner:TableReader_11, other cond:gt(test.t1.a, test.t2.a)", + "HashLeftJoin_6 100000.00 root CARTESIAN left outer join, inner:TableReader_11, other cond:gt(test.t1.a, test.t2.a)", "├─TableReader_8 10000.00 root data:TableScan_7", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableReader_11 10.00 root data:Selection_10", @@ -151,7 +151,7 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 right join t2 on t1.a > t2.a and t1.a = 1;").Check(testkit.Rows( - "HashRightJoin_6 100000.00 root right outer join, inner:TableReader_9, other cond:gt(test.t1.a, test.t2.a)", + "HashRightJoin_6 100000.00 root CARTESIAN right outer join, inner:TableReader_9, other cond:gt(test.t1.a, test.t2.a)", "├─TableReader_9 10.00 root data:Selection_8", "│ └─Selection_8 10.00 cop eq(test.t1.a, 1), not(isnull(test.t1.a))", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", @@ -167,14 +167,14 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { " └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 left join t2 on t1.a = t1.b and t1.a > 1;").Check(testkit.Rows( - "HashLeftJoin_6 100000000.00 root left outer join, inner:TableReader_10, left cond:[eq(test.t1.a, test.t1.b) gt(test.t1.a, 1)]", + "HashLeftJoin_6 100000000.00 root CARTESIAN left outer join, inner:TableReader_10, left cond:[eq(test.t1.a, test.t1.b) gt(test.t1.a, 1)]", "├─TableReader_8 10000.00 root data:TableScan_7", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableReader_10 10000.00 root data:TableScan_9", " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 left join t2 on t2.a = t2.b and t2.a > 1;").Check(testkit.Rows( - "HashLeftJoin_6 26666666.67 root left outer join, inner:TableReader_11", + "HashLeftJoin_6 26666666.67 root CARTESIAN left outer join, inner:TableReader_11", "├─TableReader_8 10000.00 root data:TableScan_7", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableReader_11 2666.67 root data:Selection_10", @@ -195,7 +195,7 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { "TableDual_8 0.00 root rows:0", )) tk.MustQuery("explain select * from t1 left join t2 on true where t1.a = 1 and t1.a = 1;").Check(testkit.Rows( - "HashLeftJoin_7 80000.00 root left outer join, inner:TableReader_12", + "HashLeftJoin_7 80000.00 root CARTESIAN left outer join, inner:TableReader_12", "├─TableReader_10 10.00 root data:Selection_9", "│ └─Selection_9 10.00 cop eq(test.t1.a, 1)", "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", @@ -203,32 +203,32 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { " └─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 left join t2 on false;").Check(testkit.Rows( - "HashLeftJoin_6 80000000.00 root left outer join, inner:TableDual_9", + "HashLeftJoin_6 80000000.00 root CARTESIAN left outer join, inner:TableDual_9", "├─TableReader_8 10000.00 root data:TableScan_7", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableDual_9 8000.00 root rows:0", )) tk.MustQuery("explain select * from t1 right join t2 on false;").Check(testkit.Rows( - "HashRightJoin_6 80000000.00 root right outer join, inner:TableDual_7", + "HashRightJoin_6 80000000.00 root CARTESIAN right outer join, inner:TableDual_7", "├─TableDual_7 8000.00 root rows:0", "└─TableReader_9 10000.00 root data:TableScan_8", " └─TableScan_8 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 left join t2 on t1.a = 1 and t1.a = 2;").Check(testkit.Rows( - "HashLeftJoin_6 80000000.00 root left outer join, inner:TableDual_9", + "HashLeftJoin_6 80000000.00 root CARTESIAN left outer join, inner:TableDual_9", "├─TableReader_8 10000.00 root data:TableScan_7", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableDual_9 8000.00 root rows:0", )) tk.MustQuery("explain select * from t1 left join t2 on t1.a =1 where t1.a = 2;").Check(testkit.Rows( - "HashLeftJoin_7 80000.00 root left outer join, inner:TableDual_11", + "HashLeftJoin_7 80000.00 root CARTESIAN left outer join, inner:TableDual_11", "├─TableReader_10 10.00 root data:Selection_9", "│ └─Selection_9 10.00 cop eq(test.t1.a, 2)", "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableDual_11 8000.00 root rows:0", )) tk.MustQuery("explain select * from t1 left join t2 on t2.a = 1 and t2.a = 2;").Check(testkit.Rows( - "HashLeftJoin_6 0.00 root left outer join, inner:TableReader_11", + "HashLeftJoin_6 0.00 root CARTESIAN left outer join, inner:TableReader_11", "├─TableReader_8 10000.00 root data:TableScan_7", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableReader_11 0.00 root data:Selection_10", @@ -237,14 +237,14 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { )) // Constant propagation for DNF in outer join. tk.MustQuery("explain select * from t1 left join t2 on t1.a = 1 or (t1.a = 2 and t1.a = 3);").Check(testkit.Rows( - "HashLeftJoin_6 100000000.00 root left outer join, inner:TableReader_10, left cond:[or(eq(test.t1.a, 1), 0)]", + "HashLeftJoin_6 100000000.00 root CARTESIAN left outer join, inner:TableReader_10, left cond:[or(eq(test.t1.a, 1), 0)]", "├─TableReader_8 10000.00 root data:TableScan_7", "│ └─TableScan_7 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "└─TableReader_10 10000.00 root data:TableScan_9", " └─TableScan_9 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", )) tk.MustQuery("explain select * from t1 left join t2 on true where t1.a = 1 or (t1.a = 2 and t1.a = 3);").Check(testkit.Rows( - "HashLeftJoin_7 80000.00 root left outer join, inner:TableReader_12", + "HashLeftJoin_7 80000.00 root CARTESIAN left outer join, inner:TableReader_12", "├─TableReader_10 10.00 root data:Selection_9", "│ └─Selection_9 10.00 cop or(eq(test.t1.a, 1), 0)", "│ └─TableScan_8 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", @@ -255,7 +255,7 @@ func (s *testSuite) TestOuterJoinPropConst(c *C) { tk.MustQuery("explain select * from t1 where t1.b > 1 or t1.b in (select b from t2);").Check(testkit.Rows( "Projection_7 8000.00 root test.t1.id, test.t1.a, test.t1.b", "└─Selection_8 8000.00 root or(gt(test.t1.b, 1), 5_aux_0)", - " └─HashLeftJoin_9 10000.00 root left outer semi join, inner:TableReader_13, other cond:eq(test.t1.b, test.t2.b)", + " └─HashLeftJoin_9 10000.00 root CARTESIAN left outer semi join, inner:TableReader_13, other cond:eq(test.t1.b, test.t2.b)", " ├─TableReader_11 10000.00 root data:TableScan_10", " │ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", " └─TableReader_13 10000.00 root data:TableScan_12", diff --git a/expression/constant_test.go b/expression/constant_test.go index e0eba43757412..75451b90c1c2d 100644 --- a/expression/constant_test.go +++ b/expression/constant_test.go @@ -23,6 +23,8 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/testleak" ) @@ -53,14 +55,22 @@ func newLonglong(value int64) *Constant { } func newDate(year, month, day int) *Constant { + return newTimeConst(year, month, day, 0, 0, 0, mysql.TypeDate) +} + +func newTimestamp(yy, mm, dd, hh, min, ss int) *Constant { + return newTimeConst(yy, mm, dd, hh, min, ss, mysql.TypeTimestamp) +} + +func newTimeConst(yy, mm, dd, hh, min, ss int, tp uint8) *Constant { var tmp types.Datum tmp.SetMysqlTime(types.Time{ - Time: types.FromDate(year, month, day, 0, 0, 0, 0), - Type: mysql.TypeDate, + Time: types.FromDate(yy, mm, dd, 0, 0, 0, 0), + Type: tp, }) return &Constant{ Value: tmp, - RetType: types.NewFieldType(mysql.TypeDate), + RetType: types.NewFieldType(tp), } } @@ -199,6 +209,7 @@ func (*testExpressionSuite) TestConstantPropagation(c *C) { func (*testExpressionSuite) TestConstraintPropagation(c *C) { defer testleak.AfterTest(c)() col1 := newColumnWithType(1, types.NewFieldType(mysql.TypeDate)) + col2 := newColumnWithType(2, types.NewFieldType(mysql.TypeTimestamp)) tests := []struct { solver constraintSolver conditions []Expression @@ -264,6 +275,15 @@ func (*testExpressionSuite) TestConstraintPropagation(c *C) { }, result: "0", }, + { + solver: newConstraintSolver(ruleColumnOPConst), + // col2 > unixtimestamp('2008-05-01 00:00:00') and unixtimestamp(col2) < unixtimestamp('2008-04-01 00:00:00') => false + conditions: []Expression{ + newFunction(ast.GT, col2, newTimestamp(2008, 5, 1, 0, 0, 0)), + newFunction(ast.LT, newFunction(ast.UnixTimestamp, col2), newLonglong(1206979200)), + }, + result: "0", + }, } for _, tt := range tests { ctx := mock.NewContext() @@ -345,3 +365,79 @@ func (*testExpressionSuite) TestDeferredExprNullConstantFold(c *C) { c.Assert(newConst.DeferredExpr.String(), Equals, tt.deferred, comment) } } + +func (*testExpressionSuite) TestDeferredExprNotNull(c *C) { + defer testleak.AfterTest(c)() + m := &MockExpr{} + ctx := mock.NewContext() + cst := &Constant{DeferredExpr: m, RetType: newIntFieldType()} + m.i, m.err = nil, fmt.Errorf("ERROR") + _, _, err := cst.EvalInt(ctx, chunk.Row{}) + c.Assert(err, NotNil) + _, _, err = cst.EvalReal(ctx, chunk.Row{}) + c.Assert(err, NotNil) + _, _, err = cst.EvalDecimal(ctx, chunk.Row{}) + c.Assert(err, NotNil) + _, _, err = cst.EvalString(ctx, chunk.Row{}) + c.Assert(err, NotNil) + _, _, err = cst.EvalTime(ctx, chunk.Row{}) + c.Assert(err, NotNil) + _, _, err = cst.EvalDuration(ctx, chunk.Row{}) + c.Assert(err, NotNil) + _, _, err = cst.EvalJSON(ctx, chunk.Row{}) + c.Assert(err, NotNil) + + m.i, m.err = nil, nil + _, isNull, err := cst.EvalInt(ctx, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(isNull, IsTrue) + _, isNull, err = cst.EvalReal(ctx, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(isNull, IsTrue) + _, isNull, err = cst.EvalDecimal(ctx, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(isNull, IsTrue) + _, isNull, err = cst.EvalString(ctx, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(isNull, IsTrue) + _, isNull, err = cst.EvalTime(ctx, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(isNull, IsTrue) + _, isNull, err = cst.EvalDuration(ctx, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(isNull, IsTrue) + _, isNull, err = cst.EvalJSON(ctx, chunk.Row{}) + c.Assert(err, IsNil) + c.Assert(isNull, IsTrue) + + m.i = int64(2333) + xInt, _, _ := cst.EvalInt(ctx, chunk.Row{}) + c.Assert(xInt, Equals, int64(2333)) + + m.i = float64(123.45) + xFlo, _, _ := cst.EvalReal(ctx, chunk.Row{}) + c.Assert(xFlo, Equals, float64(123.45)) + + m.i = "abc" + xStr, _, _ := cst.EvalString(ctx, chunk.Row{}) + c.Assert(xStr, Equals, "abc") + + m.i = &types.MyDecimal{} + xDec, _, _ := cst.EvalDecimal(ctx, chunk.Row{}) + c.Assert(xDec.Compare(m.i.(*types.MyDecimal)), Equals, 0) + + m.i = types.Time{} + xTim, _, _ := cst.EvalTime(ctx, chunk.Row{}) + c.Assert(xTim.Compare(m.i.(types.Time)), Equals, 0) + + m.i = types.Duration{} + xDur, _, _ := cst.EvalDuration(ctx, chunk.Row{}) + c.Assert(xDur.Compare(m.i.(types.Duration)), Equals, 0) + + m.i = json.BinaryJSON{} + xJsn, _, _ := cst.EvalJSON(ctx, chunk.Row{}) + c.Assert(m.i.(json.BinaryJSON).String(), Equals, xJsn.String()) + + cln := cst.Clone().(*Constant) + c.Assert(cln.DeferredExpr, Equals, cst.DeferredExpr) +} diff --git a/expression/constraint_propagation.go b/expression/constraint_propagation.go index 269eb45298e13..28993b482c78f 100644 --- a/expression/constraint_propagation.go +++ b/expression/constraint_propagation.go @@ -307,7 +307,8 @@ func negOP(cmp string) string { // monotoneIncFuncs are those functions that for any x y, if x > y => f(x) > f(y) var monotoneIncFuncs = map[string]struct{}{ - ast.ToDays: {}, + ast.ToDays: {}, + ast.UnixTimestamp: {}, } // compareConstant compares two expressions. c1 and c2 should be constant with the same type. diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 127e6baa95c78..9ddf5a7130ad8 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -371,17 +371,17 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti f = &builtinCaseWhenIntSig{base} case tipb.ScalarFuncSig_IntIsFalse: - f = &builtinIntIsFalseSig{base} + f = &builtinIntIsFalseSig{base, false} case tipb.ScalarFuncSig_RealIsFalse: - f = &builtinRealIsFalseSig{base} + f = &builtinRealIsFalseSig{base, false} case tipb.ScalarFuncSig_DecimalIsFalse: - f = &builtinDecimalIsFalseSig{base} + f = &builtinDecimalIsFalseSig{base, false} case tipb.ScalarFuncSig_IntIsTrue: - f = &builtinIntIsTrueSig{base} + f = &builtinIntIsTrueSig{base, false} case tipb.ScalarFuncSig_RealIsTrue: - f = &builtinRealIsTrueSig{base} + f = &builtinRealIsTrueSig{base, false} case tipb.ScalarFuncSig_DecimalIsTrue: - f = &builtinDecimalIsTrueSig{base} + f = &builtinDecimalIsTrueSig{base, false} case tipb.ScalarFuncSig_IfNullReal: f = &builtinIfNullRealSig{base} @@ -446,6 +446,12 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti f = &builtinJSONDepthSig{base} case tipb.ScalarFuncSig_JsonSearchSig: f = &builtinJSONSearchSig{base} + case tipb.ScalarFuncSig_JsonValidJsonSig: + f = &builtinJSONValidJSONSig{base} + case tipb.ScalarFuncSig_JsonValidStringSig: + f = &builtinJSONValidStringSig{base} + case tipb.ScalarFuncSig_JsonValidOthersSig: + f = &builtinJSONValidOthersSig{base} case tipb.ScalarFuncSig_InInt: f = &builtinInIntSig{base} diff --git a/expression/distsql_builtin_test.go b/expression/distsql_builtin_test.go index 840fd5e89c755..a9ab97f2711c1 100644 --- a/expression/distsql_builtin_test.go +++ b/expression/distsql_builtin_test.go @@ -42,8 +42,60 @@ func (s *testEvalSuite) allocColID() int64 { return s.colID } +func (s *testEvalSuite) TestPBToExpr(c *C) { + sc := new(stmtctx.StatementContext) + fieldTps := make([]*types.FieldType, 1) + ds := []types.Datum{types.NewIntDatum(1), types.NewUintDatum(1), types.NewFloat64Datum(1), + types.NewDecimalDatum(newMyDecimal(c, "1")), types.NewDurationDatum(newDuration(time.Second))} + + for _, d := range ds { + expr := datumExpr(c, d) + expr.Val = expr.Val[:len(expr.Val)/2] + _, err := PBToExpr(expr, fieldTps, sc) + c.Assert(err, NotNil) + } + + expr := &tipb.Expr{ + Tp: tipb.ExprType_ScalarFunc, + Children: []*tipb.Expr{ + { + Tp: tipb.ExprType_ValueList, + }, + }, + } + _, err := PBToExpr(expr, fieldTps, sc) + c.Assert(err, IsNil) + + val := make([]byte, 0, 32) + val = codec.EncodeInt(val, 1) + expr = &tipb.Expr{ + Tp: tipb.ExprType_ScalarFunc, + Children: []*tipb.Expr{ + { + Tp: tipb.ExprType_ValueList, + Val: val[:len(val)/2], + }, + }, + } + _, err = PBToExpr(expr, fieldTps, sc) + c.Assert(err, NotNil) + + expr = &tipb.Expr{ + Tp: tipb.ExprType_ScalarFunc, + Children: []*tipb.Expr{ + { + Tp: tipb.ExprType_ValueList, + Val: val, + }, + }, + Sig: tipb.ScalarFuncSig_AbsInt, + FieldType: ToPBFieldType(newIntFieldType()), + } + _, err = PBToExpr(expr, fieldTps, sc) + c.Assert(err, NotNil) +} + // TestEval test expr.Eval(). -// TODO: add more tests. func (s *testEvalSuite) TestEval(c *C) { row := chunk.MutRowFromDatums([]types.Datum{types.NewDatum(100)}).ToRow() fieldTps := make([]*types.FieldType, 1) @@ -120,6 +172,612 @@ func (s *testEvalSuite) TestEval(c *C) { ), newJSONDatum(c, `"$[1][0].k"`), }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastIntAsInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(2333))), + types.NewIntDatum(2333), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastRealAsInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewFloat64Datum(2333))), + types.NewIntDatum(2333), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastStringAsInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewStringDatum("2333"))), + types.NewIntDatum(2333), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastDecimalAsInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2333")))), + types.NewIntDatum(2333), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastIntAsReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewIntDatum(2333))), + types.NewFloat64Datum(2333), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastRealAsReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewFloat64Datum(2333))), + types.NewFloat64Datum(2333), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastStringAsReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewStringDatum("2333"))), + types.NewFloat64Datum(2333), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastDecimalAsReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2333")))), + types.NewFloat64Datum(2333), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastStringAsString, + toPBFieldType(newStringFieldType()), datumExpr(c, types.NewStringDatum("2333"))), + types.NewStringDatum("2333"), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastIntAsString, + toPBFieldType(newStringFieldType()), datumExpr(c, types.NewIntDatum(2333))), + types.NewStringDatum("2333"), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastRealAsString, + toPBFieldType(newStringFieldType()), datumExpr(c, types.NewFloat64Datum(2333))), + types.NewStringDatum("2333"), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastDecimalAsString, + toPBFieldType(newStringFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2333")))), + types.NewStringDatum("2333"), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastDecimalAsDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2333")))), + types.NewDecimalDatum(newMyDecimal(c, "2333")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastIntAsDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewIntDatum(2333))), + types.NewDecimalDatum(newMyDecimal(c, "2333")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastRealAsDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewFloat64Datum(2333))), + types.NewDecimalDatum(newMyDecimal(c, "2333")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastStringAsDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewStringDatum("2333"))), + types.NewDecimalDatum(newMyDecimal(c, "2333")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_GEInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(2)), datumExpr(c, types.NewIntDatum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LEInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(2))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NEInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(2))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NullEQInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewDatum(nil))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_GEReal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewFloat64Datum(2)), datumExpr(c, types.NewFloat64Datum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LEReal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewFloat64Datum(1)), datumExpr(c, types.NewFloat64Datum(2))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LTReal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewFloat64Datum(1)), datumExpr(c, types.NewFloat64Datum(2))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_EQReal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewFloat64Datum(1)), datumExpr(c, types.NewFloat64Datum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NEReal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewFloat64Datum(1)), datumExpr(c, types.NewFloat64Datum(2))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NullEQReal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewDatum(nil))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_GEDecimal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2"))), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LEDecimal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1"))), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2")))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LTDecimal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1"))), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2")))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_EQDecimal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1"))), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NEDecimal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1"))), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2")))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NullEQDecimal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewDatum(nil))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_GEDuration, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDurationDatum(newDuration(time.Second*2))), datumExpr(c, types.NewDurationDatum(newDuration(time.Second)))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_GTDuration, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDurationDatum(newDuration(time.Second*2))), datumExpr(c, types.NewDurationDatum(newDuration(time.Second)))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_EQDuration, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDurationDatum(newDuration(time.Second))), datumExpr(c, types.NewDurationDatum(newDuration(time.Second)))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LEDuration, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDurationDatum(newDuration(time.Second))), datumExpr(c, types.NewDurationDatum(newDuration(time.Second*2)))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NEDuration, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDurationDatum(newDuration(time.Second))), datumExpr(c, types.NewDurationDatum(newDuration(time.Second*2)))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NullEQDuration, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewDatum(nil))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_GEString, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewStringDatum("1")), datumExpr(c, types.NewStringDatum("1"))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LEString, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewStringDatum("1")), datumExpr(c, types.NewStringDatum("1"))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NEString, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewStringDatum("2")), datumExpr(c, types.NewStringDatum("1"))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NullEQString, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewDatum(nil))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_GTJson, + toPBFieldType(newIntFieldType()), jsonDatumExpr(c, "[2]"), jsonDatumExpr(c, "[1]")), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_GEJson, + toPBFieldType(newIntFieldType()), jsonDatumExpr(c, "[2]"), jsonDatumExpr(c, "[1]")), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LTJson, + toPBFieldType(newIntFieldType()), jsonDatumExpr(c, "[1]"), jsonDatumExpr(c, "[2]")), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LEJson, + toPBFieldType(newIntFieldType()), jsonDatumExpr(c, "[1]"), jsonDatumExpr(c, "[2]")), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_EQJson, + toPBFieldType(newIntFieldType()), jsonDatumExpr(c, "[1]"), jsonDatumExpr(c, "[1]")), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NEJson, + toPBFieldType(newIntFieldType()), jsonDatumExpr(c, "[1]"), jsonDatumExpr(c, "[2]")), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_NullEQJson, + toPBFieldType(newIntFieldType()), jsonDatumExpr(c, "[1]"), jsonDatumExpr(c, "[1]")), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_DecimalIsNull, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDatum(nil))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_DurationIsNull, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDatum(nil))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_RealIsNull, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDatum(nil))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_AbsInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(-1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_AbsUInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewUintDatum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_AbsReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewFloat64Datum(-1.23))), + types.NewFloat64Datum(1.23), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_AbsDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "-1.23")))), + types.NewDecimalDatum(newMyDecimal(c, "1.23")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LogicalAnd, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LogicalOr, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(0))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_LogicalXor, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(0))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_BitAndSig, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_BitOrSig, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(0))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_BitXorSig, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(0))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_BitNegSig, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(0))), + types.NewIntDatum(-1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_InReal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewFloat64Datum(1)), datumExpr(c, types.NewFloat64Datum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_InDecimal, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1"))), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_InString, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewStringDatum("1")), datumExpr(c, types.NewStringDatum("1"))), + types.NewIntDatum(1), + }, + //{ + // scalarFunctionExpr(tipb.ScalarFuncSig_InTime, + // toPBFieldType(newIntFieldType()), datumExpr(c, types.NewTimeDatum(types.ZeroDate)), datumExpr(c, types.NewTimeDatum(types.ZeroDate))), + // types.NewIntDatum(1), + //}, + { + scalarFunctionExpr(tipb.ScalarFuncSig_InDuration, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDurationDatum(newDuration(time.Second))), datumExpr(c, types.NewDurationDatum(newDuration(time.Second)))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfNullInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewIntDatum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(2))), + types.NewIntDatum(2), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfNullReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewFloat64Datum(1))), + types.NewFloat64Datum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewFloat64Datum(1)), datumExpr(c, types.NewFloat64Datum(2))), + types.NewFloat64Datum(2), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfNullDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewDecimalDatum(newMyDecimal(c, "1")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2")))), + types.NewDecimalDatum(newMyDecimal(c, "2")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfNullString, + toPBFieldType(newStringFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewStringDatum("1"))), + types.NewStringDatum("1"), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfString, + toPBFieldType(newStringFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewStringDatum("2"))), + types.NewStringDatum("2"), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfNullDuration, + toPBFieldType(newDurFieldType()), datumExpr(c, types.NewDatum(nil)), datumExpr(c, types.NewDurationDatum(newDuration(time.Second)))), + types.NewDurationDatum(newDuration(time.Second)), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_IfDuration, + toPBFieldType(newDurFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewDurationDatum(newDuration(time.Second*2)))), + types.NewDurationDatum(newDuration(time.Second * 2)), + }, + + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastIntAsDuration, + toPBFieldType(newDurFieldType()), datumExpr(c, types.NewIntDatum(1))), + types.NewDurationDatum(newDuration(time.Second * 1)), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastRealAsDuration, + toPBFieldType(newDurFieldType()), datumExpr(c, types.NewFloat64Datum(1))), + types.NewDurationDatum(newDuration(time.Second * 1)), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastDecimalAsDuration, + toPBFieldType(newDurFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewDurationDatum(newDuration(time.Second * 1)), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastDurationAsDuration, + toPBFieldType(newDurFieldType()), datumExpr(c, types.NewDurationDatum(newDuration(time.Second*1)))), + types.NewDurationDatum(newDuration(time.Second * 1)), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastStringAsDuration, + toPBFieldType(newDurFieldType()), datumExpr(c, types.NewStringDatum("1"))), + types.NewDurationDatum(newDuration(time.Second * 1)), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastTimeAsTime, + toPBFieldType(newDateFieldType()), datumExpr(c, types.NewTimeDatum(newDateTime(c, "2000-01-01")))), + types.NewTimeDatum(newDateTime(c, "2000-01-01")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastIntAsTime, + toPBFieldType(newDateFieldType()), datumExpr(c, types.NewIntDatum(20000101))), + types.NewTimeDatum(newDateTime(c, "2000-01-01")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastRealAsTime, + toPBFieldType(newDateFieldType()), datumExpr(c, types.NewFloat64Datum(20000101))), + types.NewTimeDatum(newDateTime(c, "2000-01-01")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastDecimalAsTime, + toPBFieldType(newDateFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "20000101")))), + types.NewTimeDatum(newDateTime(c, "2000-01-01")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CastStringAsTime, + toPBFieldType(newDateFieldType()), datumExpr(c, types.NewStringDatum("20000101"))), + types.NewTimeDatum(newDateTime(c, "2000-01-01")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_PlusInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(2))), + types.NewIntDatum(3), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_PlusDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1"))), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2")))), + types.NewDecimalDatum(newMyDecimal(c, "3")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_PlusReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewFloat64Datum(1)), datumExpr(c, types.NewFloat64Datum(2))), + types.NewFloat64Datum(3), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_MinusInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(2))), + types.NewIntDatum(-1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_MinusDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1"))), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2")))), + types.NewDecimalDatum(newMyDecimal(c, "-1")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_MinusReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewFloat64Datum(1)), datumExpr(c, types.NewFloat64Datum(2))), + types.NewFloat64Datum(-1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_MultiplyInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1)), datumExpr(c, types.NewIntDatum(2))), + types.NewIntDatum(2), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_MultiplyDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1"))), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "2")))), + types.NewDecimalDatum(newMyDecimal(c, "2")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_MultiplyReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewFloat64Datum(1)), datumExpr(c, types.NewFloat64Datum(2))), + types.NewFloat64Datum(2), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CeilIntToInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CeilIntToDec, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewIntDatum(1))), + types.NewDecimalDatum(newMyDecimal(c, "1")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CeilDecToInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CeilReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewFloat64Datum(1))), + types.NewFloat64Datum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_FloorIntToInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_FloorIntToDec, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewIntDatum(1))), + types.NewDecimalDatum(newMyDecimal(c, "1")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_FloorDecToInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_FloorReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewFloat64Datum(1))), + types.NewFloat64Datum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CoalesceInt, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewIntDatum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CoalesceReal, + toPBFieldType(newRealFieldType()), datumExpr(c, types.NewFloat64Datum(1))), + types.NewFloat64Datum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CoalesceDecimal, + toPBFieldType(newDecimalFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewDecimalDatum(newMyDecimal(c, "1")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CoalesceString, + toPBFieldType(newStringFieldType()), datumExpr(c, types.NewStringDatum("1"))), + types.NewStringDatum("1"), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CoalesceDuration, + toPBFieldType(newDurFieldType()), datumExpr(c, types.NewDurationDatum(newDuration(time.Second)))), + types.NewDurationDatum(newDuration(time.Second)), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CoalesceTime, + toPBFieldType(newDateFieldType()), datumExpr(c, types.NewTimeDatum(newDateTime(c, "2000-01-01")))), + types.NewTimeDatum(newDateTime(c, "2000-01-01")), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CaseWhenInt, + toPBFieldType(newIntFieldType())), + types.NewDatum(nil), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CaseWhenReal, + toPBFieldType(newRealFieldType())), + types.NewDatum(nil), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CaseWhenDecimal, + toPBFieldType(newDecimalFieldType())), + types.NewDatum(nil), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CaseWhenDuration, + toPBFieldType(newDurFieldType())), + types.NewDatum(nil), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CaseWhenTime, + toPBFieldType(newDateFieldType())), + types.NewDatum(nil), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_CaseWhenJson, + toPBFieldType(newJSONFieldType())), + types.NewDatum(nil), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_RealIsFalse, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewFloat64Datum(1))), + types.NewIntDatum(0), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_DecimalIsFalse, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewIntDatum(0), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_RealIsTrue, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewFloat64Datum(1))), + types.NewIntDatum(1), + }, + { + scalarFunctionExpr(tipb.ScalarFuncSig_DecimalIsTrue, + toPBFieldType(newIntFieldType()), datumExpr(c, types.NewDecimalDatum(newMyDecimal(c, "1")))), + types.NewIntDatum(1), + }, } sc := new(stmtctx.StatementContext) for _, tt := range tests { @@ -184,6 +842,12 @@ func datumExpr(c *C, d types.Datum) *tipb.Expr { expr.Val = make([]byte, 0, 1024) expr.Val, err = codec.EncodeValue(nil, expr.Val, d) c.Assert(err, IsNil) + case types.KindMysqlTime: + expr.Tp = tipb.ExprType_MysqlTime + var err error + expr.Val, err = codec.EncodeMySQLTime(nil, d, mysql.TypeUnspecified, nil) + c.Assert(err, IsNil) + expr.FieldType = ToPBFieldType(newDateFieldType()) default: expr.Tp = tipb.ExprType_Null } @@ -220,6 +884,31 @@ func toPBFieldType(ft *types.FieldType) *tipb.FieldType { } } +func newMyDecimal(c *C, s string) *types.MyDecimal { + d := new(types.MyDecimal) + c.Assert(d.FromString([]byte(s)), IsNil) + return d +} + +func newDuration(dur time.Duration) types.Duration { + return types.Duration{ + Duration: dur, + Fsp: types.DefaultFsp, + } +} + +func newDateTime(c *C, s string) types.Time { + t, err := types.ParseDate(nil, s) + c.Assert(err, IsNil) + return t +} + +func newDateFieldType() *types.FieldType { + return &types.FieldType{ + Tp: mysql.TypeDate, + } +} + func newIntFieldType() *types.FieldType { return &types.FieldType{ Tp: mysql.TypeLonglong, @@ -229,6 +918,34 @@ func newIntFieldType() *types.FieldType { } } +func newDurFieldType() *types.FieldType { + return &types.FieldType{ + Tp: mysql.TypeDuration, + Flag: types.DefaultFsp, + } +} + +func newStringFieldType() *types.FieldType { + return &types.FieldType{ + Tp: mysql.TypeVarString, + Flen: types.UnspecifiedLength, + } +} + +func newRealFieldType() *types.FieldType { + return &types.FieldType{ + Tp: mysql.TypeFloat, + Flen: types.UnspecifiedLength, + } +} + +func newDecimalFieldType() *types.FieldType { + return &types.FieldType{ + Tp: mysql.TypeNewDecimal, + Flen: types.UnspecifiedLength, + } +} + func newJSONFieldType() *types.FieldType { return &types.FieldType{ Tp: mysql.TypeJSON, diff --git a/expression/evaluator_test.go b/expression/evaluator_test.go index 8c04dd8bb59ac..ef86d53f0b242 100644 --- a/expression/evaluator_test.go +++ b/expression/evaluator_test.go @@ -14,6 +14,7 @@ package expression import ( + "sync/atomic" "testing" "time" @@ -31,7 +32,7 @@ import ( "github.com/pingcap/tidb/util/testutil" ) -var _ = Suite(&testEvaluatorSuite{}) +var _ = SerialSuites(&testEvaluatorSuite{}) func TestT(t *testing.T) { CustomVerboseFlag = true @@ -176,23 +177,21 @@ func (s *testEvaluatorSuite) TestSleep(c *C) { sub := time.Since(start) c.Assert(sub.Nanoseconds(), GreaterEqual, int64(0.5*1e9)) - // quit when context canceled. - // TODO: recover it. - // d[0].SetFloat64(2) - // f, err = fc.getFunction(ctx, s.datumsToConstants(d)) - // c.Assert(err, IsNil) - // start = time.Now() - // go func() { - // time.Sleep(1 * time.Second) - // ctx.Cancel() - // }() - // ret, isNull, err = f.evalInt(chunk.Row{}) - // sub = time.Since(start) - // c.Assert(err, IsNil) - // c.Assert(isNull, IsFalse) - // c.Assert(ret, Equals, int64(1)) - // c.Assert(sub.Nanoseconds(), LessEqual, int64(2*1e9)) - // c.Assert(sub.Nanoseconds(), GreaterEqual, int64(1*1e9)) + d[0].SetFloat64(3) + f, err = fc.getFunction(ctx, s.datumsToConstants(d)) + c.Assert(err, IsNil) + start = time.Now() + go func() { + time.Sleep(1 * time.Second) + atomic.CompareAndSwapUint32(&ctx.GetSessionVars().Killed, 0, 1) + }() + ret, isNull, err = f.evalInt(chunk.Row{}) + sub = time.Since(start) + c.Assert(err, IsNil) + c.Assert(isNull, IsFalse) + c.Assert(ret, Equals, int64(1)) + c.Assert(sub.Nanoseconds(), LessEqual, int64(2*1e9)) + c.Assert(sub.Nanoseconds(), GreaterEqual, int64(1*1e9)) } func (s *testEvaluatorSuite) TestBinopComparison(c *C) { diff --git a/expression/explain.go b/expression/explain.go index 37bcba7f0252d..93eb01cd498f7 100644 --- a/expression/explain.go +++ b/expression/explain.go @@ -37,8 +37,8 @@ func (expr *ScalarFunction) ExplainInfo() string { } // ExplainInfo implements the Expression interface. -func (expr *Column) ExplainInfo() string { - return expr.String() +func (col *Column) ExplainInfo() string { + return col.String() } // ExplainInfo implements the Expression interface. diff --git a/expression/expr_to_pb.go b/expression/expr_to_pb.go index 4416b464c1f4c..a863f28ba47b6 100644 --- a/expression/expr_to_pb.go +++ b/expression/expr_to_pb.go @@ -15,6 +15,7 @@ package expression import ( "context" + "sync/atomic" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/charset" @@ -321,8 +322,16 @@ func (pc PbConverter) canFuncBePushed(sf *ScalarFunction) bool { // date functions. ast.DateFormat: - - return true + _, disallowPushdown := DefaultExprPushdownBlacklist.Load().(map[string]struct{})[sf.FuncName.L] + return true && !disallowPushdown } return false } + +// DefaultExprPushdownBlacklist indicates the expressions which can not be pushed down to TiKV. +var DefaultExprPushdownBlacklist *atomic.Value + +func init() { + DefaultExprPushdownBlacklist = new(atomic.Value) + DefaultExprPushdownBlacklist.Store(make(map[string]struct{})) +} diff --git a/expression/expression.go b/expression/expression.go index 162cacb015474..7f5f65f3f6f58 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/opcode" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -36,7 +37,7 @@ const ( ) // EvalAstExpr evaluates ast expression directly. -var EvalAstExpr func(ctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error) +var EvalAstExpr func(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error) // Expression represents all scalar expression in SQL. type Expression interface { @@ -275,7 +276,7 @@ func EvaluateExprWithNull(ctx sessionctx.Context, schema *Schema, expr Expressio for i, arg := range x.GetArgs() { args[i] = EvaluateExprWithNull(ctx, schema, arg) } - return NewFunctionInternal(ctx, x.FuncName.L, types.NewFieldType(mysql.TypeTiny), args...) + return NewFunctionInternal(ctx, x.FuncName.L, x.RetType, args...) case *Column: if !schema.Contains(x) { return x @@ -376,3 +377,24 @@ func IsBinaryLiteral(expr Expression) bool { con, ok := expr.(*Constant) return ok && con.Value.Kind() == types.KindBinaryLiteral } + +// wrapWithIsTrue wraps `arg` with istrue function if the return type of expr is not +// type int, otherwise, returns `arg` directly. +// The `keepNull` controls what the istrue function will return when `arg` is null: +// 1. keepNull is true and arg is null, the istrue function returns null. +// 2. keepNull is false and arg is null, the istrue function returns 0. +func wrapWithIsTrue(ctx sessionctx.Context, keepNull bool, arg Expression) (Expression, error) { + if arg.GetType().EvalType() == types.ETInt { + return arg, nil + } + fc := &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, keepNull} + f, err := fc.getFunction(ctx, []Expression{arg}) + if err != nil { + return nil, err + } + return &ScalarFunction{ + FuncName: model.NewCIStr(fmt.Sprintf("sig_%T", f)), + Function: f, + RetType: f.getRetTp(), + }, nil +} diff --git a/expression/helper.go b/expression/helper.go index cd150fcb6cc84..f5268a6619107 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -34,12 +34,24 @@ func boolToInt64(v bool) int64 { return 0 } -// IsCurrentTimestampExpr returns whether e is CurrentTimestamp expression. -func IsCurrentTimestampExpr(e ast.ExprNode) bool { - if fn, ok := e.(*ast.FuncCallExpr); ok && fn.FnName.L == ast.CurrentTimestamp { - return true +// IsValidCurrentTimestampExpr returns true if exprNode is a valid CurrentTimestamp expression. +// Here `valid` means it is consistent with the given fieldType's Decimal. +func IsValidCurrentTimestampExpr(exprNode ast.ExprNode, fieldType *types.FieldType) bool { + fn, isFuncCall := exprNode.(*ast.FuncCallExpr) + if !isFuncCall || fn.FnName.L != ast.CurrentTimestamp { + return false } - return false + + containsArg := len(fn.Args) > 0 + // Fsp represents fractional seconds precision. + containsFsp := fieldType != nil && fieldType.Decimal > 0 + var isConsistent bool + if containsArg { + v, ok := fn.Args[0].(*driver.ValueExpr) + isConsistent = ok && fieldType != nil && v.Datum.GetInt64() == int64(fieldType.Decimal) + } + + return (containsArg && isConsistent) || (!containsArg && !containsFsp) } // GetTimeValue gets the time value with type tp. @@ -54,7 +66,7 @@ func GetTimeValue(ctx sessionctx.Context, v interface{}, tp byte, fsp int) (d ty case string: upperX := strings.ToUpper(x) if upperX == strings.ToUpper(ast.CurrentTimestamp) { - defaultTime, err := getSystemTimestamp(ctx) + defaultTime, err := getStmtTimestamp(ctx) if err != nil { return d, err } @@ -120,7 +132,9 @@ func GetTimeValue(ctx sessionctx.Context, v interface{}, tp byte, fsp int) (d ty return d, nil } -func getSystemTimestamp(ctx sessionctx.Context) (time.Time, error) { +// if timestamp session variable set, use session variable as current time, otherwise use cached time +// during one sql statement, the "current_time" should be the same +func getStmtTimestamp(ctx sessionctx.Context) (time.Time, error) { now := time.Now() if ctx == nil { @@ -133,15 +147,16 @@ func getSystemTimestamp(ctx sessionctx.Context) (time.Time, error) { return now, err } - if timestampStr == "" { - return now, nil - } - timestamp, err := types.StrToInt(sessionVars.StmtCtx, timestampStr) - if err != nil { - return time.Time{}, err - } - if timestamp <= 0 { - return now, nil + if timestampStr != "" { + timestamp, err := types.StrToInt(sessionVars.StmtCtx, timestampStr) + if err != nil { + return time.Time{}, err + } + if timestamp <= 0 { + return now, nil + } + return time.Unix(timestamp, 0), nil } - return time.Unix(timestamp, 0), nil + stmtCtx := ctx.GetSessionVars().StmtCtx + return stmtCtx.GetNowTsCached(), nil } diff --git a/expression/helper_test.go b/expression/helper_test.go index e0acecfdf817b..21db9667f1063 100644 --- a/expression/helper_test.go +++ b/expression/helper_test.go @@ -14,6 +14,7 @@ package expression import ( + driver "github.com/pingcap/tidb/types/parser_driver" "strings" "time" @@ -108,11 +109,28 @@ func (s *testExpressionSuite) TestGetTimeValue(c *C) { func (s *testExpressionSuite) TestIsCurrentTimestampExpr(c *C) { defer testleak.AfterTest(c)() - v := IsCurrentTimestampExpr(ast.NewValueExpr("abc")) - c.Assert(v, IsFalse) + buildTimestampFuncCallExpr := func(i int64) *ast.FuncCallExpr { + var args []ast.ExprNode + if i != 0 { + args = []ast.ExprNode{&driver.ValueExpr{Datum: types.NewIntDatum(i)}} + } + return &ast.FuncCallExpr{FnName: model.NewCIStr("CURRENT_TIMESTAMP"), Args: args} + } - v = IsCurrentTimestampExpr(&ast.FuncCallExpr{FnName: model.NewCIStr("CURRENT_TIMESTAMP")}) + v := IsValidCurrentTimestampExpr(ast.NewValueExpr("abc"), nil) + c.Assert(v, IsFalse) + v = IsValidCurrentTimestampExpr(buildTimestampFuncCallExpr(0), nil) + c.Assert(v, IsTrue) + v = IsValidCurrentTimestampExpr(buildTimestampFuncCallExpr(3), &types.FieldType{Decimal: 3}) c.Assert(v, IsTrue) + v = IsValidCurrentTimestampExpr(buildTimestampFuncCallExpr(1), &types.FieldType{Decimal: 3}) + c.Assert(v, IsFalse) + v = IsValidCurrentTimestampExpr(buildTimestampFuncCallExpr(0), &types.FieldType{Decimal: 3}) + c.Assert(v, IsFalse) + v = IsValidCurrentTimestampExpr(buildTimestampFuncCallExpr(2), &types.FieldType{Decimal: 0}) + c.Assert(v, IsFalse) + v = IsValidCurrentTimestampExpr(buildTimestampFuncCallExpr(2), nil) + c.Assert(v, IsFalse) } func (s *testExpressionSuite) TestCurrentTimestampTimeZone(c *C) { diff --git a/expression/integration_test.go b/expression/integration_test.go index dc6b89f0a69a2..b60bbb08428db 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -550,6 +550,8 @@ func (s *testIntegrationSuite) TestMathBuiltin(c *C) { tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int)") tk.MustExec("insert into t values(1),(2),(3)") + tk.Se.GetSessionVars().MaxChunkSize = 1 + tk.MustQuery("select rand(1) from t").Sort().Check(testkit.Rows("0.6046602879796196", "0.6645600532184904", "0.9405090880450124")) tk.MustQuery("select rand(a) from t").Check(testkit.Rows("0.6046602879796196", "0.16729663442585624", "0.7199826688373036")) tk.MustQuery("select rand(1), rand(2), rand(3)").Check(testkit.Rows("0.6046602879796196 0.16729663442585624 0.7199826688373036")) } @@ -898,7 +900,14 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) { result = tk.MustQuery(`select quote("aaaa"), quote(""), quote("\"\""), quote("\n\n");`) result.Check(testkit.Rows("'aaaa' '' '\"\"' '\n\n'")) result = tk.MustQuery(`select quote(0121), quote(0000), quote("中文"), quote(NULL);`) - result.Check(testkit.Rows("'121' '0' '中文' ")) + result.Check(testkit.Rows("'121' '0' '中文' NULL")) + tk.MustQuery(`select quote(null) is NULL;`).Check(testkit.Rows(`0`)) + tk.MustQuery(`select quote(null) is NOT NULL;`).Check(testkit.Rows(`1`)) + tk.MustQuery(`select length(quote(null));`).Check(testkit.Rows(`4`)) + tk.MustQuery(`select quote(null) REGEXP binary 'null'`).Check(testkit.Rows(`0`)) + tk.MustQuery(`select quote(null) REGEXP binary 'NULL'`).Check(testkit.Rows(`1`)) + tk.MustQuery(`select quote(null) REGEXP 'NULL'`).Check(testkit.Rows(`1`)) + tk.MustQuery(`select quote(null) REGEXP 'null'`).Check(testkit.Rows(`1`)) // for convert result = tk.MustQuery(`select convert("123" using "binary"), convert("中文" using "binary"), convert("中文" using "utf8"), convert("中文" using "utf8mb4"), convert(cast("中文" as binary) using "utf8");`) @@ -917,6 +926,10 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) { result.Check(testkit.Rows("aaa文 ba aaa ba")) result = tk.MustQuery(`select insert("bb", NULL, 1, "aa"), insert("bb", 1, NULL, "aa"), insert(NULL, 1, 1, "aaa"), insert("bb", 1, 1, NULL);`) result.Check(testkit.Rows(" ")) + result = tk.MustQuery(`SELECT INSERT("bb", 0, 1, NULL), INSERT("bb", 0, NULL, "aaa");`) + result.Check(testkit.Rows(" ")) + result = tk.MustQuery(`SELECT INSERT("中文", 0, 1, NULL), INSERT("中文", 0, NULL, "aaa");`) + result.Check(testkit.Rows(" ")) // for export_set result = tk.MustQuery(`select export_set(7, "1", "0", ",", 65);`) @@ -1392,6 +1405,14 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { "c), addtime(c,a), addtime(c,b)" + " from t;") result.Check(testkit.Rows(" 2017-01-01 13:31:32 2017-01-01 13:31:32 ")) + result = tk.MustQuery("select addtime('01:01:11', cast('1' as time))") + result.Check(testkit.Rows("01:01:12")) + tk.MustQuery("select addtime(cast(null as char(20)), cast('1' as time))").Check(testkit.Rows("")) + c.Assert(tk.QueryToErr(`select addtime("01:01:11", cast('sdf' as time))`), NotNil) + tk.MustQuery(`select addtime("01:01:11", cast(null as char(20)))`).Check(testkit.Rows("")) + tk.MustQuery(`select addtime(cast(1 as time), cast(1 as time))`).Check(testkit.Rows("00:00:02")) + tk.MustQuery(`select addtime(cast(null as time), cast(1 as time))`).Check(testkit.Rows("")) + tk.MustQuery(`select addtime(cast(1 as time), cast(null as time))`).Check(testkit.Rows("")) // for SUBTIME result = tk.MustQuery("select subtime('01:01:11', '00:00:01.013'), subtime('01:01:11.00', '00:00:01'), subtime" + @@ -1424,7 +1445,6 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { result.Check(testkit.Rows(" ")) result = tk.MustQuery("select addtime('2017-01-01', 1), addtime('2017-01-01 01:01:01', 1), addtime(cast('2017-01-01' as date), 1)") result.Check(testkit.Rows("2017-01-01 00:00:01 2017-01-01 01:01:02 00:00:01")) - result = tk.MustQuery("select subtime(a, e), subtime(b, e), subtime(c, e), subtime(d, e) from t") result.Check(testkit.Rows(" ")) result = tk.MustQuery("select subtime('2017-01-01 01:01:01', 0b1), subtime('2017-01-01', b'1'), subtime('01:01:01', 0b1011)") @@ -1432,6 +1452,40 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { result = tk.MustQuery("select subtime('2017-01-01', 1), subtime('2017-01-01 01:01:01', 1), subtime(cast('2017-01-01' as date), 1)") result.Check(testkit.Rows("2016-12-31 23:59:59 2017-01-01 01:01:00 -00:00:01")) + result = tk.MustQuery("select addtime(-32073, 0), addtime(0, -32073);") + result.Check(testkit.Rows(" ")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", + "Warning|1292|Truncated incorrect time value: '-32073'", + "Warning|1292|Truncated incorrect time value: '-32073'")) + result = tk.MustQuery("select addtime(-32073, c), addtime(c, -32073) from t;") + result.Check(testkit.Rows(" ")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", + "Warning|1292|Truncated incorrect time value: '-32073'", + "Warning|1292|Truncated incorrect time value: '-32073'")) + result = tk.MustQuery("select addtime(a, -32073), addtime(b, -32073), addtime(d, -32073) from t;") + result.Check(testkit.Rows(" ")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", + "Warning|1292|Truncated incorrect time value: '-32073'", + "Warning|1292|Truncated incorrect time value: '-32073'", + "Warning|1292|Truncated incorrect time value: '-32073'")) + + result = tk.MustQuery("select subtime(-32073, 0), subtime(0, -32073);") + result.Check(testkit.Rows(" ")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", + "Warning|1292|Truncated incorrect time value: '-32073'", + "Warning|1292|Truncated incorrect time value: '-32073'")) + result = tk.MustQuery("select subtime(-32073, c), subtime(c, -32073) from t;") + result.Check(testkit.Rows(" ")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", + "Warning|1292|Truncated incorrect time value: '-32073'", + "Warning|1292|Truncated incorrect time value: '-32073'")) + result = tk.MustQuery("select subtime(a, -32073), subtime(b, -32073), subtime(d, -32073) from t;") + result.Check(testkit.Rows(" ")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", + "Warning|1292|Truncated incorrect time value: '-32073'", + "Warning|1292|Truncated incorrect time value: '-32073'", + "Warning|1292|Truncated incorrect time value: '-32073'")) + // fixed issue #3986 tk.MustExec("SET SQL_MODE='NO_ENGINE_SUBSTITUTION';") tk.MustExec("SET TIME_ZONE='+03:00';") @@ -1784,9 +1838,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { // for convert_tz result = tk.MustQuery(`select convert_tz("2004-01-01 12:00:00", "+00:00", "+10:32"), convert_tz("2004-01-01 12:00:00.01", "+00:00", "+10:32"), convert_tz("2004-01-01 12:00:00.01234567", "+00:00", "+10:32");`) result.Check(testkit.Rows("2004-01-01 22:32:00 2004-01-01 22:32:00.01 2004-01-01 22:32:00.012346")) - // TODO: release the following test after fix #4462 - //result = tk.MustQuery(`select convert_tz(20040101, "+00:00", "+10:32"), convert_tz(20040101.01, "+00:00", "+10:32"), convert_tz(20040101.01234567, "+00:00", "+10:32");`) - //result.Check(testkit.Rows("2004-01-01 10:32:00 2004-01-01 10:32:00.00 2004-01-01 10:32:00.000000")) + result = tk.MustQuery(`select convert_tz(20040101, "+00:00", "+10:32"), convert_tz(20040101.01, "+00:00", "+10:32"), convert_tz(20040101.01234567, "+00:00", "+10:32");`) + result.Check(testkit.Rows("2004-01-01 10:32:00 2004-01-01 10:32:00.00 2004-01-01 10:32:00.000000")) + result = tk.MustQuery(`select convert_tz(NULL, "+00:00", "+10:32"), convert_tz("2004-01-01 12:00:00", NULL, "+10:32"), convert_tz("2004-01-01 12:00:00", "+00:00", NULL);`) + result.Check(testkit.Rows(" ")) + result = tk.MustQuery(`select convert_tz("a", "+00:00", "+10:32"), convert_tz("2004-01-01 12:00:00", "a", "+10:32"), convert_tz("2004-01-01 12:00:00", "+00:00", "a");`) + result.Check(testkit.Rows(" ")) + result = tk.MustQuery(`select convert_tz("", "+00:00", "+10:32"), convert_tz("2004-01-01 12:00:00", "", "+10:32"), convert_tz("2004-01-01 12:00:00", "+00:00", "");`) + result.Check(testkit.Rows(" ")) + result = tk.MustQuery(`select convert_tz("0", "+00:00", "+10:32"), convert_tz("2004-01-01 12:00:00", "0", "+10:32"), convert_tz("2004-01-01 12:00:00", "+00:00", "0");`) + result.Check(testkit.Rows(" ")) // for from_unixtime tk.MustExec(`set @@session.time_zone = "+08:00"`) @@ -1927,6 +1988,26 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { result = tk.MustQuery(subDate) result.Check(testkit.Rows(tc.SubResult)) } + tk.MustQuery(`select subdate(cast("2000-02-01" as datetime), cast(1 as decimal))`).Check(testkit.Rows("2000-01-31 00:00:00")) + tk.MustQuery(`select subdate(cast("2000-02-01" as datetime), cast(null as decimal))`).Check(testkit.Rows("")) + tk.MustQuery(`select subdate(cast(null as datetime), cast(1 as decimal))`).Check(testkit.Rows("")) + tk.MustQuery(`select subdate(cast("2000-02-01" as datetime), cast("xxx" as decimal))`).Check(testkit.Rows("2000-02-01 00:00:00")) + tk.MustQuery(`select subdate(cast("xxx" as datetime), cast(1 as decimal))`).Check(testkit.Rows("")) + tk.MustQuery(`select subdate(cast(20000101 as SIGNED), cast("1" as decimal))`).Check(testkit.Rows("1999-12-31")) + tk.MustQuery(`select subdate(cast(20000101 as SIGNED), cast("xxx" as decimal))`).Check(testkit.Rows("2000-01-01")) + tk.MustQuery(`select subdate(cast("abc" as SIGNED), cast("1" as decimal))`).Check(testkit.Rows("")) + tk.MustQuery(`select subdate(cast(null as SIGNED), cast("1" as decimal))`).Check(testkit.Rows("")) + tk.MustQuery(`select subdate(cast(20000101 as SIGNED), cast(null as decimal))`).Check(testkit.Rows("")) + tk.MustQuery(`select adddate(cast("2000-02-01" as datetime), cast(1 as decimal))`).Check(testkit.Rows("2000-02-02 00:00:00")) + tk.MustQuery(`select adddate(cast("2000-02-01" as datetime), cast(null as decimal))`).Check(testkit.Rows("")) + tk.MustQuery(`select adddate(cast(null as datetime), cast(1 as decimal))`).Check(testkit.Rows("")) + tk.MustQuery(`select adddate(cast("2000-02-01" as datetime), cast("xxx" as decimal))`).Check(testkit.Rows("2000-02-01 00:00:00")) + tk.MustQuery(`select adddate(cast("xxx" as datetime), cast(1 as decimal))`).Check(testkit.Rows("")) + tk.MustQuery(`select adddate(cast("2000-02-01" as datetime), cast(1 as SIGNED))`).Check(testkit.Rows("2000-02-02 00:00:00")) + tk.MustQuery(`select adddate(cast("2000-02-01" as datetime), cast(null as SIGNED))`).Check(testkit.Rows("")) + tk.MustQuery(`select adddate(cast(null as datetime), cast(1 as SIGNED))`).Check(testkit.Rows("")) + tk.MustQuery(`select adddate(cast("2000-02-01" as datetime), cast("xxx" as SIGNED))`).Check(testkit.Rows("2000-02-01 00:00:00")) + tk.MustQuery(`select adddate(cast("xxx" as datetime), cast(1 as SIGNED))`).Check(testkit.Rows("")) // for localtime, localtimestamp result = tk.MustQuery(`select localtime() = now(), localtime = now(), localtimestamp() = now(), localtimestamp = now()`) @@ -1948,6 +2029,17 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { result.Check(testkit.Rows("")) result = tk.MustQuery(`select tidb_parse_tso(-1)`) result.Check(testkit.Rows("")) + + // fix issue 10308 + result = tk.MustQuery("select time(\"- -\");") + result.Check(testkit.Rows("00:00:00")) + tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect time value: '- -'")) + result = tk.MustQuery("select time(\"---1\");") + result.Check(testkit.Rows("00:00:00")) + tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect time value: '---1'")) + result = tk.MustQuery("select time(\"-- --1\");") + result.Check(testkit.Rows("00:00:00")) + tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect time value: '-- --1'")) } func (s *testIntegrationSuite) TestOpBuiltin(c *C) { @@ -1988,6 +2080,12 @@ func (s *testIntegrationSuite) TestOpBuiltin(c *C) { // for unaryPlus result = tk.MustQuery(`select +1, +0, +(-9), +(-0.001), +0.999, +null, +"aaa"`) result.Check(testkit.Rows("1 0 -9 -0.001 0.999 aaa")) + // for unaryMinus + tk.MustExec("drop table if exists f") + tk.MustExec("create table f(a decimal(65,0))") + tk.MustExec("insert into f value (-17000000000000000000)") + result = tk.MustQuery("select a from f") + result.Check(testkit.Rows("-17000000000000000000")) } func (s *testIntegrationSuite) TestDatetimeOverflow(c *C) { @@ -2022,6 +2120,12 @@ func (s *testIntegrationSuite) TestDatetimeOverflow(c *C) { rows = append(rows, "") } tk.MustQuery("select * from t1").Check(testkit.Rows(rows...)) + + //Fix ISSUE 11256 + tk.MustQuery(`select DATE_ADD('2000-04-13 07:17:02',INTERVAL -1465647104 YEAR);`).Check(testkit.Rows("")) + tk.MustQuery(`select DATE_ADD('2008-11-23 22:47:31',INTERVAL 266076160 QUARTER);`).Check(testkit.Rows("")) + tk.MustQuery(`select DATE_SUB('2000-04-13 07:17:02',INTERVAL 1465647104 YEAR);`).Check(testkit.Rows("")) + tk.MustQuery(`select DATE_SUB('2008-11-23 22:47:31',INTERVAL -266076160 QUARTER);`).Check(testkit.Rows("")) } func (s *testIntegrationSuite) TestBuiltin(c *C) { @@ -2091,9 +2195,46 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) { result.Check(testkit.Rows("")) result = tk.MustQuery(`select cast(cast('2017-01-01 01:01:11.12' as date) as datetime(2));`) result.Check(testkit.Rows("2017-01-01 00:00:00.00")) - result = tk.MustQuery(`select cast(20170118.999 as datetime);`) result.Check(testkit.Rows("2017-01-18 00:00:00")) + tk.MustQuery(`select convert(a2.a, unsigned int) from (select cast('"9223372036854775808"' as json) as a) as a2;`) + + tk.MustExec(`create table tb5(a bigint(64) unsigned, b double);`) + tk.MustExec(`insert into tb5 (a, b) values (9223372036854776000, 9223372036854776000);`) + tk.MustExec(`insert into tb5 (a, b) select * from (select cast(a as json) as a1, b from tb5) as t where t.a1 = t.b;`) + tk.MustExec(`drop table tb5;`) + + tk.MustExec(`create table tb5(a float(64));`) + tk.MustExec(`insert into tb5(a) values (13835058055282163712);`) + err := tk.QueryToErr(`select convert(t.a1, signed int) from (select convert(a, json) as a1 from tb5) as t`) + msg := strings.Split(err.Error(), " ") + last := msg[len(msg)-1] + c.Assert(last, Equals, "bigint") + tk.MustExec(`drop table tb5;`) + + // test builtinCastIntAsDecimalSig + tk.MustExec(`create table tb5(a bigint(64) unsigned, b decimal(64, 10));`) + tk.MustExec(`insert into tb5 (a, b) values (9223372036854775808, 9223372036854775808);`) + tk.MustExec(`insert into tb5 (select * from tb5 where a = b);`) + result = tk.MustQuery(`select * from tb5;`) + result.Check(testkit.Rows("9223372036854775808 9223372036854775808.0000000000", "9223372036854775808 9223372036854775808.0000000000")) + tk.MustExec(`drop table tb5;`) + + // test builtinCastIntAsRealSig + tk.MustExec(`create table tb5(a bigint(64) unsigned, b double(64, 10));`) + tk.MustExec(`insert into tb5 (a, b) values (13835058000000000000, 13835058000000000000);`) + tk.MustExec(`insert into tb5 (select * from tb5 where a = b);`) + result = tk.MustQuery(`select * from tb5;`) + result.Check(testkit.Rows("13835058000000000000 13835058000000000000", "13835058000000000000 13835058000000000000")) + tk.MustExec(`drop table tb5;`) + + // test builtinCastIntAsStringSig + tk.MustExec(`create table tb5(a bigint(64) unsigned,b varchar(50));`) + tk.MustExec(`insert into tb5(a, b) values (9223372036854775808, '9223372036854775808');`) + tk.MustExec(`insert into tb5(select * from tb5 where a = b);`) + result = tk.MustQuery(`select * from tb5;`) + result.Check(testkit.Rows("9223372036854775808 9223372036854775808", "9223372036854775808 9223372036854775808")) + tk.MustExec(`drop table tb5;`) // Test corner cases of cast string as datetime result = tk.MustQuery(`select cast("170102034" as datetime);`) @@ -2253,7 +2394,7 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) { result.Check(testkit.Rows("99999.99")) result = tk.MustQuery("select cast(s1 as decimal(8, 2)) from t1;") result.Check(testkit.Rows("111111.00")) - _, err := tk.Exec("insert into t1 values(cast('111111.00' as decimal(7, 2)));") + _, err = tk.Exec("insert into t1 values(cast('111111.00' as decimal(7, 2)));") c.Assert(err, NotNil) result = tk.MustQuery(`select CAST(0x8fffffffffffffff as signed) a, @@ -2419,18 +2560,24 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) { result.Check(testkit.Rows("1")) result = tk.MustQuery(`select b regexp 'Xt' from t;`) result.Check(testkit.Rows("1")) + result = tk.MustQuery(`select b regexp binary 'Xt' from t;`) + result.Check(testkit.Rows("0")) result = tk.MustQuery(`select c regexp 'Xt' from t;`) result.Check(testkit.Rows("0")) result = tk.MustQuery(`select d regexp 'Xt' from t;`) result.Check(testkit.Rows("0")) result = tk.MustQuery(`select a rlike 'Xt' from t;`) result.Check(testkit.Rows("1")) + result = tk.MustQuery(`select a rlike binary 'Xt' from t;`) + result.Check(testkit.Rows("0")) result = tk.MustQuery(`select b rlike 'Xt' from t;`) result.Check(testkit.Rows("1")) result = tk.MustQuery(`select c rlike 'Xt' from t;`) result.Check(testkit.Rows("0")) result = tk.MustQuery(`select d rlike 'Xt' from t;`) result.Check(testkit.Rows("0")) + result = tk.MustQuery(`select 'a' regexp 'A', 'a' regexp binary 'A'`) + result.Check(testkit.Rows("1 0")) // testCase is for like and regexp type testCase struct { @@ -2655,6 +2802,11 @@ func (s *testIntegrationSuite) TestControlBuiltin(c *C) { result = tk.MustQuery("select ifnull(null, null)") result.Check(testkit.Rows("")) + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(a bigint not null)") + result = tk.MustQuery("select ifnull(max(a),0) from t1") + result.Check(testkit.Rows("0")) + tk.MustExec("drop table if exists t1") tk.MustExec("drop table if exists t2") tk.MustExec("create table t1(a decimal(20,4))") @@ -2780,7 +2932,10 @@ func (s *testIntegrationSuite) TestArithmeticBuiltin(c *C) { c.Assert(err, NotNil) c.Assert(err.Error(), Equals, "[types:1690]BIGINT UNSIGNED value is out of range in '(18446744073709551615 - -1)'") c.Assert(rs.Close(), IsNil) + tk.MustQuery(`select cast(-3 as unsigned) - cast(-1 as signed);`).Check(testkit.Rows("18446744073709551614")) + tk.MustQuery("select 1.11 - 1.11;").Check(testkit.Rows("0.00")) + // for multiply tk.MustQuery("select 1234567890 * 1234567890").Check(testkit.Rows("1524157875019052100")) rs, err = tk.Exec("select 1234567890 * 12345671890") c.Assert(err, IsNil) @@ -2807,8 +2962,7 @@ func (s *testIntegrationSuite) TestArithmeticBuiltin(c *C) { _, err = session.GetRows4Test(ctx, tk.Se, rs) c.Assert(terror.ErrorEqual(err, types.ErrOverflow), IsTrue) c.Assert(rs.Close(), IsNil) - result = tk.MustQuery(`select cast(-3 as unsigned) - cast(-1 as signed);`) - result.Check(testkit.Rows("18446744073709551614")) + tk.MustQuery("select 0.0 * -1;").Check(testkit.Rows("0.0")) tk.MustExec("DROP TABLE IF EXISTS t;") tk.MustExec("CREATE TABLE t(a DECIMAL(4, 2), b DECIMAL(5, 3));") @@ -2878,6 +3032,7 @@ func (s *testIntegrationSuite) TestArithmeticBuiltin(c *C) { tk.MustExec("INSERT IGNORE INTO t VALUE(12 MOD 0);") tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1365 Division by 0")) tk.MustQuery("select v from t;").Check(testkit.Rows("")) + tk.MustQuery("select 0.000 % 0.11234500000000000000;").Check(testkit.Rows("0.00000000000000000000")) _, err = tk.Exec("INSERT INTO t VALUE(12 MOD 0);") c.Assert(terror.ErrorEqual(err, expression.ErrDivisionByZero), IsTrue) @@ -3419,6 +3574,39 @@ func (s *testIntegrationSuite) TestJSONBuiltin(c *C) { tk.MustExec("CREATE TABLE `my_collection` ( `doc` json DEFAULT NULL, `_id` varchar(32) GENERATED ALWAYS AS (JSON_UNQUOTE(JSON_EXTRACT(doc,'$._id'))) STORED NOT NULL, PRIMARY KEY (`_id`))") _, err := tk.Exec("UPDATE `test`.`my_collection` SET doc=JSON_SET(doc) WHERE (JSON_EXTRACT(doc,'$.name') = 'clare');") c.Assert(err, NotNil) + + r := tk.MustQuery("select json_valid(null);") + r.Check(testkit.Rows("")) + + r = tk.MustQuery(`select json_valid("null");`) + r.Check(testkit.Rows("1")) + + r = tk.MustQuery("select json_valid(0);") + r.Check(testkit.Rows("0")) + + r = tk.MustQuery(`select json_valid("0");`) + r.Check(testkit.Rows("1")) + + r = tk.MustQuery(`select json_valid("hello");`) + r.Check(testkit.Rows("0")) + + r = tk.MustQuery(`select json_valid('"hello"');`) + r.Check(testkit.Rows("1")) + + r = tk.MustQuery(`select json_valid('{"a":1}');`) + r.Check(testkit.Rows("1")) + + r = tk.MustQuery("select json_valid('{}');") + r.Check(testkit.Rows("1")) + + r = tk.MustQuery(`select json_valid('[]');`) + r.Check(testkit.Rows("1")) + + r = tk.MustQuery("select json_valid('2019-8-19');") + r.Check(testkit.Rows("0")) + + r = tk.MustQuery(`select json_valid('"2019-8-19"');`) + r.Check(testkit.Rows("1")) } func (s *testIntegrationSuite) TestTimeLiteral(c *C) { @@ -3463,6 +3651,11 @@ func (s *testIntegrationSuite) TestTimeLiteral(c *C) { _, err = tk.Exec("select time '20171231235959.999999';") c.Assert(err, NotNil) c.Assert(terror.ErrorEqual(err, types.ErrIncorrectDatetimeValue.GenWithStackByArgs("20171231235959.999999")), IsTrue) + + _, err = tk.Exec("select ADDDATE('2008-01-34', -1);") + c.Assert(err, IsNil) + tk.MustQuery("Show warnings;").Check(testutil.RowsWithSep("|", + "Warning|1292|Incorrect datetime value: '2008-1-34'")) } func (s *testIntegrationSuite) TestTimestampLiteral(c *C) { @@ -3811,24 +4004,25 @@ func (s *testIntegrationSuite) TestFilterExtractFromDNF(c *C) { }, } + ctx := context.Background() for _, tt := range tests { sql := "select * from t where " + tt.exprStr - ctx := tk.Se.(sessionctx.Context) - sc := ctx.GetSessionVars().StmtCtx - stmts, err := session.Parse(ctx, sql) + sctx := tk.Se.(sessionctx.Context) + sc := sctx.GetSessionVars().StmtCtx + stmts, err := session.Parse(sctx, sql) c.Assert(err, IsNil, Commentf("error %v, for expr %s", err, tt.exprStr)) c.Assert(stmts, HasLen, 1) - is := domain.GetDomain(ctx).InfoSchema() - err = plannercore.Preprocess(ctx, stmts[0], is) + is := domain.GetDomain(sctx).InfoSchema() + err = plannercore.Preprocess(sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for resolve name, expr %s", err, tt.exprStr)) - p, err := plannercore.BuildLogicalPlan(ctx, stmts[0], is) + p, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for build plan, expr %s", err, tt.exprStr)) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) conds := make([]expression.Expression, 0, len(selection.Conditions)) for _, cond := range selection.Conditions { - conds = append(conds, expression.PushDownNot(ctx, cond, false)) + conds = append(conds, expression.PushDownNot(sctx, cond, false)) } - afterFunc := expression.ExtractFiltersFromDNFs(ctx, conds) + afterFunc := expression.ExtractFiltersFromDNFs(sctx, conds) sort.Slice(afterFunc, func(i, j int) bool { return bytes.Compare(afterFunc[i].HashCode(sc), afterFunc[j].HashCode(sc)) < 0 }) @@ -3849,6 +4043,25 @@ func (s *testIntegrationSuite) testTiDBIsOwnerFunc(c *C) { result.Check(testkit.Rows(fmt.Sprintf("%v", ret))) } +func (s *testIntegrationSuite) TestTiDBDecodePlanFunc(c *C) { + tk := testkit.NewTestKit(c, s.store) + defer s.cleanEnv(c) + tk.MustQuery("select tidb_decode_plan('')").Check(testkit.Rows("")) + tk.MustQuery("select tidb_decode_plan('7APIMAk1XzEzCTAJMQlmdW5jczpjb3VudCgxKQoxCTE3XzE0CTAJMAlpbm5lciBqb2luLCBp" + + "AQyQOlRhYmxlUmVhZGVyXzIxLCBlcXVhbDpbZXEoQ29sdW1uIzEsIA0KCDkpIBkXADIVFywxMCldCjIJMzJfMTgFZXhkYXRhOlNlbGVjdGlvbl" + + "8xNwozCTFfMTcJMQkwCWx0HVlATlVMTCksIG5vdChpc251bGwVHAApUhcAUDIpKQo0CTEwXzE2CTEJMTAwMDAJdAHB2Dp0MSwgcmFuZ2U6Wy1p" + + "bmYsK2luZl0sIGtlZXAgb3JkZXI6ZmFsc2UsIHN0YXRzOnBzZXVkbwoFtgAyAZcEMAk6tgAEMjAFtgQyMDq2AAg5LCBmtgAAMFa3AAA5FbcAO" + + "T63AAAyzrcA')").Check(testkit.Rows("" + + "\tStreamAgg_13 \troot\t1 \tfuncs:count(1)\n" + + "\t└─HashLeftJoin_14 \troot\t0 \tinner join, inner:TableReader_21, equal:[eq(Column#1, Column#9) eq(Column#2, Column#10)]\n" + + "\t ├─TableReader_18 \troot\t0 \tdata:Selection_17\n" + + "\t │ └─Selection_17 \tcop \t0 \tlt(Column#1, NULL), not(isnull(Column#1)), not(isnull(Column#2))\n" + + "\t │ └─TableScan_16\tcop \t10000\ttable:t1, range:[-inf,+inf], keep order:false, stats:pseudo\n" + + "\t └─TableReader_21 \troot\t0 \tdata:Selection_20\n" + + "\t └─Selection_20 \tcop \t0 \tlt(Column#9, NULL), not(isnull(Column#10)), not(isnull(Column#9))\n" + + "\t └─TableScan_19\tcop \t10000\ttable:t2, range:[-inf,+inf], keep order:false, stats:pseudo")) +} + func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) { store, err := mockstore.NewMockTikvStore() if err != nil { @@ -4071,6 +4284,10 @@ func (s *testIntegrationSuite) TestFuncNameConst(c *C) { r.Check(testkit.Rows("2")) r = tk.MustQuery("SELECT concat('hello', name_const('test_string', 'world')) FROM t;") r.Check(testkit.Rows("helloworld")) + r = tk.MustQuery("SELECT NAME_CONST('come', -1);") + r.Check(testkit.Rows("-1")) + r = tk.MustQuery("SELECT NAME_CONST('come', -1.0);") + r.Check(testkit.Rows("-1.0")) err := tk.ExecToErr(`select name_const(a,b) from t;`) c.Assert(err.Error(), Equals, "[planner:1210]Incorrect arguments to NAME_CONST") err = tk.ExecToErr(`select name_const(a,"hello") from t;`) @@ -4353,3 +4570,294 @@ func (s *testIntegrationSuite) TestIssue10181(c *C) { tk.MustExec(`insert into t values(9223372036854775807), (18446744073709551615)`) tk.MustQuery(`select * from t where a > 9223372036854775807-0.5 order by a`).Check(testkit.Rows(`9223372036854775807`, `18446744073709551615`)) } + +func (s *testIntegrationSuite) TestMySQLExtAssignment(c *C) { + tk := testkit.NewTestKit(c, s.store) + defer s.cleanEnv(c) + + tk.MustExec("use test") + tk.MustExec("set @@autocommit := on;") + tk.MustExec("set autocommit := on;") + tk.MustExec("set session autocommit := on;") + tk.MustExec("set global autocommit := on;") + tk.MustExec("set @count := 100;") + tk.MustExec("set @count := @count + 5;") +} + +func (s *testIntegrationSuite) TestExprPushdownBlacklist(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustQuery(`select * from mysql.expr_pushdown_blacklist`).Check(testkit.Rows()) +} + +func (s *testIntegrationSuite) TestIssue10675(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a int);`) + tk.MustExec(`insert into t values(1);`) + tk.MustQuery(`select * from t where a < -184467440737095516167.1;`).Check(testkit.Rows()) + tk.MustQuery(`select * from t where a > -184467440737095516167.1;`).Check( + testkit.Rows("1")) + tk.MustQuery(`select * from t where a < 184467440737095516167.1;`).Check( + testkit.Rows("1")) + tk.MustQuery(`select * from t where a > 184467440737095516167.1;`).Check(testkit.Rows()) + + // issue 11647 + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(b bit(1));`) + tk.MustExec(`insert into t values(b'1');`) + tk.MustQuery(`select count(*) from t where b = 1;`).Check(testkit.Rows("1")) + tk.MustQuery(`select count(*) from t where b = '1';`).Check(testkit.Rows("1")) + tk.MustQuery(`select count(*) from t where b = b'1';`).Check(testkit.Rows("1")) + + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(b bit(63));`) + // Not 64, because the behavior of mysql is amazing. I have no idea to fix it. + tk.MustExec(`insert into t values(b'111111111111111111111111111111111111111111111111111111111111111');`) + tk.MustQuery(`select count(*) from t where b = 9223372036854775807;`).Check(testkit.Rows("1")) + tk.MustQuery(`select count(*) from t where b = '9223372036854775807';`).Check(testkit.Rows("1")) + tk.MustQuery(`select count(*) from t where b = b'111111111111111111111111111111111111111111111111111111111111111';`).Check(testkit.Rows("1")) +} + +func (s *testIntegrationSuite) TestDatetimeMicrosecond(c *C) { + tk := testkit.NewTestKit(c, s.store) + // For int + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2 SECOND_MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2 MINUTE_MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2 HOUR_MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2 DAY_MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.800000")) + + // For Decimal + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 HOUR_MINUTE);`).Check( + testkit.Rows("2007-03-29 00:10:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MINUTE_SECOND);`).Check( + testkit.Rows("2007-03-28 22:10:30")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 YEAR_MONTH);`).Check( + testkit.Rows("2009-05-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_HOUR);`).Check( + testkit.Rows("2007-03-31 00:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_MINUTE);`).Check( + testkit.Rows("2007-03-29 00:10:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_SECOND);`).Check( + testkit.Rows("2007-03-28 22:10:30")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 HOUR_SECOND);`).Check( + testkit.Rows("2007-03-28 22:10:30")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 SECOND);`).Check( + testkit.Rows("2007-03-28 22:08:30.200000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 YEAR);`).Check( + testkit.Rows("2009-03-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 QUARTER);`).Check( + testkit.Rows("2007-09-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MONTH);`).Check( + testkit.Rows("2007-05-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 WEEK);`).Check( + testkit.Rows("2007-04-11 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY);`).Check( + testkit.Rows("2007-03-30 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 HOUR);`).Check( + testkit.Rows("2007-03-29 00:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MINUTE);`).Check( + testkit.Rows("2007-03-28 22:10:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:28.000002")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 HOUR_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MINUTE_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 YEAR_MONTH);`).Check( + testkit.Rows("2005-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_HOUR);`).Check( + testkit.Rows("2007-03-26 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 HOUR_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 SECOND);`).Check( + // testkit.Rows("2007-03-28 22:08:25.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 YEAR);`).Check( + testkit.Rows("2005-03-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 QUARTER);`).Check( + testkit.Rows("2006-09-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MONTH);`).Check( + testkit.Rows("2007-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 WEEK);`).Check( + testkit.Rows("2007-03-14 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY);`).Check( + testkit.Rows("2007-03-26 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 HOUR);`).Check( + testkit.Rows("2007-03-28 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MINUTE);`).Check( + testkit.Rows("2007-03-28 22:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.999998")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" HOUR_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" MINUTE_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" YEAR_MONTH);`).Check( + testkit.Rows("2005-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" DAY_HOUR);`).Check( + testkit.Rows("2007-03-26 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" DAY_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" DAY_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" HOUR_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" SECOND);`).Check( + // testkit.Rows("2007-03-28 22:08:25.800000")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" YEAR);`).Check( + testkit.Rows("2005-03-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" QUARTER);`).Check( + testkit.Rows("2006-09-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" MONTH);`).Check( + testkit.Rows("2007-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" WEEK);`).Check( + testkit.Rows("2007-03-14 22:08:28")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" DAY);`).Check( + // testkit.Rows("2007-03-26 22:08:28")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" HOUR);`).Check( + // testkit.Rows("2007-03-28 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" MINUTE);`).Check( + testkit.Rows("2007-03-28 22:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.2" MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.999998")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" HOUR_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" MINUTE_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" YEAR_MONTH);`).Check( + testkit.Rows("2005-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" DAY_HOUR);`).Check( + testkit.Rows("2007-03-26 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" DAY_MINUTE);`).Check( + testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" DAY_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" HOUR_SECOND);`).Check( + testkit.Rows("2007-03-28 22:06:26")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" SECOND);`).Check( + // testkit.Rows("2007-03-28 22:08:26")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" YEAR);`).Check( + testkit.Rows("2005-03-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" QUARTER);`).Check( + testkit.Rows("2006-09-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" MONTH);`).Check( + testkit.Rows("2007-01-28 22:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" WEEK);`).Check( + testkit.Rows("2007-03-14 22:08:28")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" DAY);`).Check( + // testkit.Rows("2007-03-26 22:08:28")) + // tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" HOUR);`).Check( + // testkit.Rows("2007-03-28 20:08:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" MINUTE);`).Check( + testkit.Rows("2007-03-28 22:06:28")) + tk.MustQuery(`select DATE_ADD('2007-03-28 22:08:28',INTERVAL "-2.-2" MICROSECOND);`).Check( + testkit.Rows("2007-03-28 22:08:27.999998")) +} + +func (s *testIntegrationSuite) TestFuncCaseWithLeftJoin(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + + tk.MustExec("create table kankan1(id int, name text)") + tk.MustExec("insert into kankan1 values(1, 'a')") + tk.MustExec("insert into kankan1 values(2, 'a')") + + tk.MustExec("create table kankan2(id int, h1 text)") + tk.MustExec("insert into kankan2 values(2, 'z')") + + tk.MustQuery("select t1.id from kankan1 t1 left join kankan2 t2 on t1.id = t2.id where (case when t1.name='b' then 'case2' when t1.name='a' then 'case1' else NULL end) = 'case1' order by t1.id").Check(testkit.Rows("1", "2")) +} + +func (s *testIntegrationSuite) TestIssue11594(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`drop table if exists t1;`) + tk.MustExec("CREATE TABLE t1 (v bigint(20) UNSIGNED NOT NULL);") + tk.MustExec("INSERT INTO t1 VALUES (1), (2);") + tk.MustQuery("SELECT SUM(IF(v > 1, v, -v)) FROM t1;").Check(testkit.Rows("1")) + tk.MustQuery("SELECT sum(IFNULL(cast(null+rand() as unsigned), -v)) FROM t1;").Check(testkit.Rows("-3")) + tk.MustQuery("SELECT sum(COALESCE(cast(null+rand() as unsigned), -v)) FROM t1;").Check(testkit.Rows("-3")) + tk.MustQuery("SELECT sum(COALESCE(cast(null+rand() as unsigned), v)) FROM t1;").Check(testkit.Rows("3")) +} + +func (s *testIntegrationSuite) TestIssue11309And11319(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`CREATE TABLE t (a decimal(6,3),b double(6,3),c float(6,3));`) + tk.MustExec(`INSERT INTO t VALUES (1.100,1.100,1.100);`) + tk.MustQuery(`SELECT DATE_ADD('2003-11-18 07:25:13',INTERVAL a MINUTE_SECOND) FROM t`).Check(testkit.Rows(`2003-11-18 07:27:53`)) + tk.MustQuery(`SELECT DATE_ADD('2003-11-18 07:25:13',INTERVAL b MINUTE_SECOND) FROM t`).Check(testkit.Rows(`2003-11-18 07:27:53`)) + tk.MustQuery(`SELECT DATE_ADD('2003-11-18 07:25:13',INTERVAL c MINUTE_SECOND) FROM t`).Check(testkit.Rows(`2003-11-18 07:27:53`)) + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`CREATE TABLE t (a decimal(11,7),b double(11,7),c float(11,7));`) + tk.MustExec(`INSERT INTO t VALUES (123.9999999,123.9999999,123.9999999),(-123.9999999,-123.9999999,-123.9999999);`) + tk.MustQuery(`SELECT DATE_ADD('2003-11-18 07:25:13',INTERVAL a MINUTE_SECOND) FROM t`).Check(testkit.Rows(`2004-03-13 03:14:52`, `2003-07-25 11:35:34`)) + tk.MustQuery(`SELECT DATE_ADD('2003-11-18 07:25:13',INTERVAL b MINUTE_SECOND) FROM t`).Check(testkit.Rows(`2004-03-13 03:14:52`, `2003-07-25 11:35:34`)) + tk.MustQuery(`SELECT DATE_ADD('2003-11-18 07:25:13',INTERVAL c MINUTE_SECOND) FROM t`).Check(testkit.Rows(`2003-11-18 09:29:13`, `2003-11-18 05:21:13`)) + tk.MustExec(`drop table if exists t;`) + + // for https://github.com/pingcap/tidb/issues/11319 + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MINUTE_MICROSECOND)`).Check(testkit.Rows("2007-03-28 22:08:25.800000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 SECOND_MICROSECOND)`).Check(testkit.Rows("2007-03-28 22:08:25.800000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 HOUR_MICROSECOND)`).Check(testkit.Rows("2007-03-28 22:08:25.800000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_MICROSECOND)`).Check(testkit.Rows("2007-03-28 22:08:25.800000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 SECOND)`).Check(testkit.Rows("2007-03-28 22:08:25.800000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 HOUR_SECOND)`).Check(testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_SECOND)`).Check(testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MINUTE_SECOND)`).Check(testkit.Rows("2007-03-28 22:06:26")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 MINUTE)`).Check(testkit.Rows("2007-03-28 22:06:28")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_MINUTE)`).Check(testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 HOUR_MINUTE)`).Check(testkit.Rows("2007-03-28 20:06:28")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 DAY_HOUR)`).Check(testkit.Rows("2007-03-26 20:08:28")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL -2.2 YEAR_MONTH)`).Check(testkit.Rows("2005-01-28 22:08:28")) + + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MINUTE_MICROSECOND)`).Check(testkit.Rows("2007-03-28 22:08:30.200000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 SECOND_MICROSECOND)`).Check(testkit.Rows("2007-03-28 22:08:30.200000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 HOUR_MICROSECOND)`).Check(testkit.Rows("2007-03-28 22:08:30.200000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_MICROSECOND)`).Check(testkit.Rows("2007-03-28 22:08:30.200000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 SECOND)`).Check(testkit.Rows("2007-03-28 22:08:30.200000")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 HOUR_SECOND)`).Check(testkit.Rows("2007-03-28 22:10:30")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_SECOND)`).Check(testkit.Rows("2007-03-28 22:10:30")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MINUTE_SECOND)`).Check(testkit.Rows("2007-03-28 22:10:30")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 MINUTE)`).Check(testkit.Rows("2007-03-28 22:10:28")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_MINUTE)`).Check(testkit.Rows("2007-03-29 00:10:28")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 HOUR_MINUTE)`).Check(testkit.Rows("2007-03-29 00:10:28")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 DAY_HOUR)`).Check(testkit.Rows("2007-03-31 00:08:28")) + tk.MustQuery(`SELECT DATE_ADD('2007-03-28 22:08:28',INTERVAL 2.2 YEAR_MONTH)`).Check(testkit.Rows("2009-05-28 22:08:28")) +} + +func (s *testIntegrationSuite) TestIssue12301(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t (d decimal(19, 0), i bigint(11))") + tk.MustExec("insert into t values (123456789012, 123456789012)") + tk.MustQuery("select * from t where d = i").Check(testkit.Rows("123456789012 123456789012")) +} + +func (s *testIntegrationSuite) TestNotExistFunc(c *C) { + tk := testkit.NewTestKit(c, s.store) + + // current db is empty + _, err := tk.Exec("SELECT xxx(1)") + c.Assert(err.Error(), Equals, "[planner:1046]No database selected") + + _, err = tk.Exec("SELECT yyy()") + c.Assert(err.Error(), Equals, "[planner:1046]No database selected") + + // current db is not empty + tk.MustExec("use test") + _, err = tk.Exec("SELECT xxx(1)") + c.Assert(err.Error(), Equals, "[expression:1305]FUNCTION test.xxx does not exist") + + _, err = tk.Exec("SELECT yyy()") + c.Assert(err.Error(), Equals, "[expression:1305]FUNCTION test.yyy does not exist") + +} diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 77ce863a39406..dff91c56abd95 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -80,7 +80,12 @@ func newFunctionImpl(ctx sessionctx.Context, fold bool, funcName string, retType } fc, ok := funcs[funcName] if !ok { - return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", funcName) + db := ctx.GetSessionVars().CurrentDB + if db == "" { + return nil, terror.ClassOptimizer.New(mysql.ErrNoDB, mysql.MySQLErrName[mysql.ErrNoDB]) + } + + return nil, errFunctionNotExists.GenWithStackByArgs("FUNCTION", db+"."+funcName) } funcArgs := make([]Expression, len(args)) copy(funcArgs, args) diff --git a/expression/simple_rewriter.go b/expression/simple_rewriter.go index 31589e2599185..045d22fe1265f 100644 --- a/expression/simple_rewriter.go +++ b/expression/simple_rewriter.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/parser/opcode" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/types/parser_driver" + driver "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util" ) @@ -468,10 +468,15 @@ func (sr *simpleRewriter) inToExpression(lLen int, not bool, tp *types.FieldType return } leftEt := leftFt.EvalType() + if leftEt == types.ETInt { for i := 0; i < len(elems); i++ { if c, ok := elems[i].(*Constant); ok { - elems[i], _ = RefineComparedConstant(sr.ctx, mysql.HasUnsignedFlag(leftFt.Flag), c, opcode.EQ) + var isExceptional bool + elems[i], isExceptional = RefineComparedConstant(sr.ctx, *leftFt, c, opcode.EQ) + if isExceptional { + elems[i] = c + } } } } diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 9c0f2fe00472c..f36efac0b832f 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -126,8 +126,9 @@ func (s *testInferTypeSuite) TestInferType(c *C) { tests = append(tests, s.createTestCase4JSONFuncs()...) tests = append(tests, s.createTestCase4MiscellaneousFunc()...) + ctx := context.Background() for _, tt := range tests { - ctx := testKit.Se.(sessionctx.Context) + sctx := testKit.Se.(sessionctx.Context) sql := "select " + tt.sql + " from t" comment := Commentf("for %s", sql) stmt, err := s.ParseOneStmt(sql, "", "") @@ -136,10 +137,10 @@ func (s *testInferTypeSuite) TestInferType(c *C) { err = se.NewTxn(context.Background()) c.Assert(err, IsNil) - is := domain.GetDomain(ctx).InfoSchema() - err = plannercore.Preprocess(ctx, stmt, is) + is := domain.GetDomain(sctx).InfoSchema() + err = plannercore.Preprocess(sctx, stmt, is) c.Assert(err, IsNil, comment) - p, err := plannercore.BuildLogicalPlan(ctx, stmt, is) + p, err := plannercore.BuildLogicalPlan(ctx, sctx, stmt, is) c.Assert(err, IsNil, comment) tp := p.Schema().Columns[0].RetType @@ -170,9 +171,9 @@ func (s *testInferTypeSuite) createTestCase4Constants() []typeInferTestCase { {"b'0001'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, {"b'000100001'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, {"b'0000000000010000'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 2, 0}, - {"x'10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 1, 0}, - {"x'ff10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 2, 0}, - {"x'0000000000000000ff10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 10, 0}, + {"x'10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 3, 0}, + {"x'ff10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 6, 0}, + {"x'0000000000000000ff10'", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag | mysql.UnsignedFlag, 30, 0}, } } @@ -692,9 +693,9 @@ func (s *testInferTypeSuite) createTestCase4ArithmeticFuncs() []typeInferTestCas {"c_int_d + c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_int_d + c_time_d", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"c_int_d + c_double_d", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, - {"c_int_d + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, mysql.MaxDecimalScale}, - {"c_datetime + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, mysql.MaxDecimalScale}, - {"c_bigint_d + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, mysql.MaxDecimalScale}, + {"c_int_d + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, + {"c_datetime + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, + {"c_bigint_d + c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, {"c_double_d + c_decimal", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d + c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d + c_enum", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, @@ -704,9 +705,9 @@ func (s *testInferTypeSuite) createTestCase4ArithmeticFuncs() []typeInferTestCas {"c_int_d - c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_int_d - c_time_d", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"c_int_d - c_double_d", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, - {"c_int_d - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, mysql.MaxDecimalScale}, - {"c_datetime - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, mysql.MaxDecimalScale}, - {"c_bigint_d - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, mysql.MaxDecimalScale}, + {"c_int_d - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, + {"c_datetime - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, + {"c_bigint_d - c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 26, 3}, {"c_double_d - c_decimal", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d - c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d - c_enum", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, @@ -716,9 +717,9 @@ func (s *testInferTypeSuite) createTestCase4ArithmeticFuncs() []typeInferTestCas {"c_int_d * c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_int_d * c_time_d", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"c_int_d * c_double_d", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, - {"c_int_d * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, mysql.MaxDecimalScale}, - {"c_datetime * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, mysql.MaxDecimalScale}, - {"c_bigint_d * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDecimalWidth, mysql.MaxDecimalScale}, + {"c_int_d * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 29, 3}, + {"c_datetime * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 31, 5}, + {"c_bigint_d * c_decimal", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 29, 3}, {"c_double_d * c_decimal", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d * c_char", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, {"c_double_d * c_enum", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, types.UnspecifiedLength, types.UnspecifiedLength}, @@ -1266,31 +1267,31 @@ func (s *testInferTypeSuite) createTestCase4TimeFuncs() []typeInferTestCase { {"subtime(c_date, c_timestamp)", mysql.TypeString, charset.CharsetUTF8MB4, 0, 26, types.UnspecifiedLength}, {"subtime(c_date, c_time)", mysql.TypeString, charset.CharsetUTF8MB4, 0, 26, types.UnspecifiedLength}, - {"timestamp(c_int_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, - {"timestamp(c_float_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_double_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_decimal)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 23, 3}, - {"timestamp(c_udecimal)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 23, 3}, - {"timestamp(c_decimal_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, - {"timestamp(c_udecimal_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, - {"timestamp(c_datetime)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 22, 2}, - {"timestamp(c_datetime_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, - {"timestamp(c_timestamp)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 24, 4}, - {"timestamp(c_time)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 23, 3}, - {"timestamp(c_time_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, 0}, - {"timestamp(c_bchar)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_char)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_varchar)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_text_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_btext_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_blob_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_set)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_enum)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - - {"timestamp(c_int_d, c_float_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_datetime, c_timestamp)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 24, 4}, - {"timestamp(c_timestamp, c_char)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 6}, - {"timestamp(c_int_d, c_datetime)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 22, 2}, + {"timestamp(c_int_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, types.UnspecifiedLength}, + {"timestamp(c_float_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_double_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_decimal)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 23, types.UnspecifiedLength}, + {"timestamp(c_udecimal)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 23, types.UnspecifiedLength}, + {"timestamp(c_decimal_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, types.UnspecifiedLength}, + {"timestamp(c_udecimal_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, types.UnspecifiedLength}, + {"timestamp(c_datetime)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 22, types.UnspecifiedLength}, + {"timestamp(c_datetime_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, types.UnspecifiedLength}, + {"timestamp(c_timestamp)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 24, types.UnspecifiedLength}, + {"timestamp(c_time)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 23, types.UnspecifiedLength}, + {"timestamp(c_time_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 19, types.UnspecifiedLength}, + {"timestamp(c_bchar)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_char)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_varchar)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_text_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_btext_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_blob_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_set)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_enum)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + + {"timestamp(c_int_d, c_float_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_datetime, c_timestamp)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 24, types.UnspecifiedLength}, + {"timestamp(c_timestamp, c_char)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, types.UnspecifiedLength}, + {"timestamp(c_int_d, c_datetime)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 22, types.UnspecifiedLength}, {"addtime(c_int_d, c_time_d)", mysql.TypeString, charset.CharsetUTF8MB4, 0, 26, types.UnspecifiedLength}, {"addtime(c_datetime_d, c_time_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, 26, 0}, diff --git a/expression/util.go b/expression/util.go index 368a211d2c472..3e885d67185bb 100644 --- a/expression/util.go +++ b/expression/util.go @@ -542,20 +542,6 @@ func (s *exprStack) len() int { return len(s.stack) } -// ColumnSliceIsIntersect checks whether two column slice is intersected. -func ColumnSliceIsIntersect(s1, s2 []*Column) bool { - intSet := map[int64]struct{}{} - for _, col := range s1 { - intSet[col.UniqueID] = struct{}{} - } - for _, col := range s2 { - if _, ok := intSet[col.UniqueID]; ok { - return true - } - } - return false -} - // DatumToConstant generates a Constant expression from a Datum. func DatumToConstant(d types.Datum, tp byte) *Constant { return &Constant{Value: d, RetType: types.NewFieldType(tp)} diff --git a/expression/util_test.go b/expression/util_test.go index bf6ff7f459cb0..babc840c69f71 100644 --- a/expression/util_test.go +++ b/expression/util_test.go @@ -18,8 +18,13 @@ import ( "github.com/pingcap/check" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/testleak" ) @@ -29,6 +34,54 @@ var _ = check.Suite(&testUtilSuite{}) type testUtilSuite struct { } +func (s *testUtilSuite) TestSetExprColumnInOperand(c *check.C) { + col := &Column{RetType: newIntFieldType()} + c.Assert(setExprColumnInOperand(col).(*Column).InOperand, check.IsTrue) + + f, err := funcs[ast.Abs].getFunction(mock.NewContext(), []Expression{col}) + c.Assert(err, check.IsNil) + fun := &ScalarFunction{Function: f} + setExprColumnInOperand(fun) + c.Assert(f.getArgs()[0].(*Column).InOperand, check.IsTrue) +} + +func (s testUtilSuite) TestPopRowFirstArg(c *check.C) { + c1, c2, c3 := &Column{RetType: newIntFieldType()}, &Column{RetType: newIntFieldType()}, &Column{RetType: newIntFieldType()} + f, err := funcs[ast.RowFunc].getFunction(mock.NewContext(), []Expression{c1, c2, c3}) + c.Assert(err, check.IsNil) + fun := &ScalarFunction{Function: f, FuncName: model.NewCIStr(ast.RowFunc), RetType: newIntFieldType()} + fun2, err := PopRowFirstArg(mock.NewContext(), fun) + c.Assert(err, check.IsNil) + c.Assert(len(fun2.(*ScalarFunction).GetArgs()), check.Equals, 2) +} + +func (s testUtilSuite) TestGetStrIntFromConstant(c *check.C) { + col := &Column{} + _, _, err := GetStringFromConstant(mock.NewContext(), col) + c.Assert(err, check.NotNil) + + con := &Constant{RetType: &types.FieldType{Tp: mysql.TypeNull}} + _, isNull, err := GetStringFromConstant(mock.NewContext(), con) + c.Assert(err, check.IsNil) + c.Assert(isNull, check.IsTrue) + + con = &Constant{RetType: newIntFieldType(), Value: types.NewIntDatum(1)} + ret, _, _ := GetStringFromConstant(mock.NewContext(), con) + c.Assert(ret, check.Equals, "1") + + con = &Constant{RetType: &types.FieldType{Tp: mysql.TypeNull}} + _, isNull, _ = GetIntFromConstant(mock.NewContext(), con) + c.Assert(isNull, check.IsTrue) + + con = &Constant{RetType: newStringFieldType(), Value: types.NewStringDatum("abc")} + _, isNull, _ = GetIntFromConstant(mock.NewContext(), con) + c.Assert(isNull, check.IsTrue) + + con = &Constant{RetType: newStringFieldType(), Value: types.NewStringDatum("123")} + num, _, _ := GetIntFromConstant(mock.NewContext(), con) + c.Assert(num, check.Equals, 123) +} + func (s *testUtilSuite) TestSubstituteCorCol2Constant(c *check.C) { defer testleak.AfterTest(c)() ctx := mock.NewContext() @@ -161,3 +214,66 @@ func BenchmarkExprFromSchema(b *testing.B) { } b.ReportAllocs() } + +// MockExpr is mainly for test. +type MockExpr struct { + err error + t *types.FieldType + i interface{} +} + +func (m *MockExpr) String() string { return "" } +func (m *MockExpr) MarshalJSON() ([]byte, error) { return nil, nil } +func (m *MockExpr) Eval(row chunk.Row) (types.Datum, error) { return types.NewDatum(m.i), m.err } +func (m *MockExpr) EvalInt(ctx sessionctx.Context, row chunk.Row) (val int64, isNull bool, err error) { + if x, ok := m.i.(int64); ok { + return int64(x), false, m.err + } + return 0, m.i == nil, m.err +} +func (m *MockExpr) EvalReal(ctx sessionctx.Context, row chunk.Row) (val float64, isNull bool, err error) { + if x, ok := m.i.(float64); ok { + return float64(x), false, m.err + } + return 0, m.i == nil, m.err +} +func (m *MockExpr) EvalString(ctx sessionctx.Context, row chunk.Row) (val string, isNull bool, err error) { + if x, ok := m.i.(string); ok { + return string(x), false, m.err + } + return "", m.i == nil, m.err +} +func (m *MockExpr) EvalDecimal(ctx sessionctx.Context, row chunk.Row) (val *types.MyDecimal, isNull bool, err error) { + if x, ok := m.i.(*types.MyDecimal); ok { + return x, false, m.err + } + return nil, m.i == nil, m.err +} +func (m *MockExpr) EvalTime(ctx sessionctx.Context, row chunk.Row) (val types.Time, isNull bool, err error) { + if x, ok := m.i.(types.Time); ok { + return x, false, m.err + } + return types.Time{}, m.i == nil, m.err +} +func (m *MockExpr) EvalDuration(ctx sessionctx.Context, row chunk.Row) (val types.Duration, isNull bool, err error) { + if x, ok := m.i.(types.Duration); ok { + return x, false, m.err + } + return types.Duration{}, m.i == nil, m.err +} +func (m *MockExpr) EvalJSON(ctx sessionctx.Context, row chunk.Row) (val json.BinaryJSON, isNull bool, err error) { + if x, ok := m.i.(json.BinaryJSON); ok { + return x, false, m.err + } + return json.BinaryJSON{}, m.i == nil, m.err +} +func (m *MockExpr) GetType() *types.FieldType { return m.t } +func (m *MockExpr) Clone() Expression { return nil } +func (m *MockExpr) Equal(ctx sessionctx.Context, e Expression) bool { return false } +func (m *MockExpr) IsCorrelated() bool { return false } +func (m *MockExpr) ConstItem() bool { return false } +func (m *MockExpr) Decorrelate(schema *Schema) Expression { return m } +func (m *MockExpr) ResolveIndices(schema *Schema) (Expression, error) { return m, nil } +func (m *MockExpr) resolveIndices(schema *Schema) error { return nil } +func (m *MockExpr) ExplainInfo() string { return "" } +func (m *MockExpr) HashCode(sc *stmtctx.StatementContext) []byte { return nil } diff --git a/go.mod b/go.mod index 88a2117c0f59c..ff24d6fa46cfa 100644 --- a/go.mod +++ b/go.mod @@ -3,51 +3,55 @@ module github.com/pingcap/tidb require ( github.com/BurntSushi/toml v0.3.1 github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f // indirect + github.com/beorn7/perks v1.0.0 // indirect github.com/blacktear23/go-proxyprotocol v0.0.0-20180807104634-af7a81e8dd0d - github.com/boltdb/bolt v1.3.1 // indirect github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect - github.com/coreos/bbolt v1.3.0 // indirect - github.com/coreos/etcd v3.3.12+incompatible + github.com/coreos/bbolt v1.3.3 // indirect + github.com/coreos/etcd v3.3.13+incompatible github.com/coreos/go-systemd v0.0.0-20181031085051-9002847aa142 // indirect github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f // indirect github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 github.com/cznic/sortutil v0.0.0-20150617083342-4c7342852e65 + github.com/dgryski/go-farm v0.0.0-20190104051053-3adb47b1fb0f github.com/dustin/go-humanize v1.0.0 // indirect github.com/go-ole/go-ole v1.2.1 // indirect github.com/go-sql-driver/mysql v0.0.0-20170715192408-3955978caca4 - github.com/gogo/protobuf v1.2.0 // indirect + github.com/gogo/protobuf v1.2.0 github.com/golang/protobuf v1.2.0 + github.com/golang/snappy v0.0.1 github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c + github.com/google/pprof v0.0.0-20190930153522-6ce02741cba3 + github.com/google/uuid v1.1.1 github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/mux v1.6.2 github.com/gorilla/websocket v1.4.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.0.1-0.20190118093823-f849b5445de4 github.com/grpc-ecosystem/grpc-gateway v1.5.1 // indirect + github.com/json-iterator/go v1.1.6 // indirect github.com/klauspost/cpuid v0.0.0-20170728055534-ae7887de9fa5 github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.1 // indirect github.com/montanaflynn/stats v0.0.0-20180911141734-db72e6cae808 // indirect - github.com/myesui/uuid v1.0.0 // indirect github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7 github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef - github.com/onsi/ginkgo v1.7.0 // indirect - github.com/onsi/gomega v1.4.3 // indirect github.com/opentracing/basictracer-go v1.0.0 github.com/opentracing/opentracing-go v1.0.2 github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8 github.com/pingcap/errors v0.11.4 - github.com/pingcap/failpoint v0.0.0-20190422094118-d8535965f59b + github.com/pingcap/failpoint v0.0.0-20190512135322-30cc7431d99c github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e - github.com/pingcap/kvproto v0.0.0-20190517030054-ff2e03f6fdfe - github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 - github.com/pingcap/parser v0.0.0-20190522123204-4628ac31ee47 - github.com/pingcap/pd v0.0.0-20190424024702-bd1e2496a669 + github.com/pingcap/kvproto v0.0.0-20190918085321-44e3817e1f18 + github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd + github.com/pingcap/parser v0.0.0-20191018040038-555b97093a2a + github.com/pingcap/pd v1.1.0-beta.0.20190912093418-dc03c839debd github.com/pingcap/tidb-tools v2.1.3-0.20190321065848-1e8b48f5c168+incompatible - github.com/pingcap/tipb v0.0.0-20190428032612-535e1abaa330 + github.com/pingcap/tipb v0.0.0-20191031111650-d14196d52154 github.com/prometheus/client_golang v0.9.0 github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 github.com/prometheus/common v0.0.0-20181020173914-7e9e6cabbd39 // indirect github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d // indirect - github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7 // indirect github.com/shirou/gopsutil v2.18.10+incompatible github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371 // indirect github.com/shurcooL/vfsgen v0.0.0-20181020040650-a97a25d856ca // indirect @@ -55,17 +59,17 @@ require ( github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 github.com/struCoder/pidusage v0.1.2 github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2 - github.com/twinj/uuid v1.0.0 github.com/uber-go/atomic v1.3.2 // indirect github.com/uber/jaeger-client-go v2.15.0+incompatible github.com/uber/jaeger-lib v1.5.0 // indirect github.com/unrolled/render v0.0.0-20180914162206-b9786414de4d // indirect + go.etcd.io/bbolt v1.3.3 // indirect go.uber.org/atomic v1.3.2 go.uber.org/zap v1.9.1 - golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e - golang.org/x/sys v0.0.0-20190109145017-48ac38b7c8cb // indirect + golang.org/x/crypto v0.0.0-20190909091759-094676da4a83 // indirect + golang.org/x/net v0.0.0-20190909003024-a7b16738d86b + golang.org/x/sync v0.0.0-20190423024810-112230192c58 // indirect golang.org/x/text v0.3.0 - golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect golang.org/x/tools v0.0.0-20190130214255-bb1329dc71a0 google.golang.org/genproto v0.0.0-20190108161440-ae2f86662275 // indirect google.golang.org/grpc v1.17.0 @@ -73,3 +77,7 @@ require ( sourcegraph.com/sourcegraph/appdash v0.0.0-20180531100431-4c381bd170b4 sourcegraph.com/sourcegraph/appdash-data v0.0.0-20151005221446-73f23eafcf67 ) + +replace github.com/google/pprof => github.com/lonng/pprof v0.0.0-20191012154247-04dfd648ce8d + +go 1.13 diff --git a/go.sum b/go.sum index ef1d81233c978..6955852681809 100644 --- a/go.sum +++ b/go.sum @@ -4,23 +4,24 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f h1:5ZfJxyXo8KyX8DgGXC5B7ILL8y51fci/qYz2B4j8iLY= github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0 h1:HWo1m869IqiPhD389kmkxeTalrjNbbJTC8LXupb+sl0= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= github.com/blacktear23/go-proxyprotocol v0.0.0-20180807104634-af7a81e8dd0d h1:rQlvB2AYWme2bIB18r/SipGiMEVJYE9U0z+MGoU/LtQ= github.com/blacktear23/go-proxyprotocol v0.0.0-20180807104634-af7a81e8dd0d/go.mod h1:VKt7CNAQxpFpSDz3sXyj9hY/GbVsQCr0sB3w59nE7lU= -github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4= -github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20171208011716-f6d7a1f6fbf3/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/client9/misspell v0.3.4 h1:ta993UF76GwbvJcIo3Y68y/M3WxlpEHPWIGDkJYwzJI= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd h1:qMd81Ts1T2OTKmB4acZcyKaMtRnY5Y44NuXGX2GFJ1w= github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd/go.mod h1:sE/e/2PUdi/liOCUjSTXgM1o87ZssimdTWN964YiIeI= -github.com/coreos/bbolt v1.3.0 h1:HIgH5xUWXT914HCI671AxuTTqjj64UOFr7pHn48LUTI= -github.com/coreos/bbolt v1.3.0/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= -github.com/coreos/etcd v3.3.12+incompatible h1:pAWNwdf7QiT1zfaWyqCtNZQWCLByQyA3JrSQyuYAqnQ= -github.com/coreos/etcd v3.3.12+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= +github.com/coreos/bbolt v1.3.3 h1:n6AiVyVRKQFNb6mJlwESEvvLoDyiTzXX7ORAUlkeBdY= +github.com/coreos/bbolt v1.3.3/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= +github.com/coreos/etcd v3.3.13+incompatible h1:8F3hqu9fGYLBifCmRCJsicFqDx/D68Rt3q1JMazcgBQ= +github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-semver v0.2.0 h1:3Jm3tLmsgAYcjC+4Up7hJrFBPr+n7rAqYeSw/SZazuY= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -38,6 +39,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/dgryski/go-farm v0.0.0-20190104051053-3adb47b1fb0f h1:dDxpBYafY/GYpcl+LS4Bn3ziLPuEdGRkRjYAbSlWxSA= +github.com/dgryski/go-farm v0.0.0-20190104051053-3adb47b1fb0f/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v0.0.0-20180421182945-02af3965c54e/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= @@ -68,14 +71,16 @@ github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfb github.com/golang/protobuf v0.0.0-20180814211427-aa810b61a9c7/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180124185431-e89373fe6b4a/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c h1:964Od4U6p2jUkFxvCydnIczKteheJEzHRToSGK3Bnlw= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/shlex v0.0.0-20181106134648-c34317bd91bf/go.mod h1:RpwtwJQFrIEPstU94h88MWPXP2ektJZ8cZ0YntAmXiE= -github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= @@ -95,9 +100,13 @@ github.com/grpc-ecosystem/grpc-gateway v1.5.1 h1:3scN4iuXkNOyP98jF55Lv8a9j1o/Iwv github.com/grpc-ecosystem/grpc-gateway v1.5.1/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6 h1:UDMh68UUwekSh5iP2OMhRRZJiiBccgV7axzUG8vi56c= +github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/json-iterator/go v1.1.6 h1:MrUvLMLTMxbqFJ9kzlvat/rYZqZnW3u4wkLzWTaFwKs= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/juju/ratelimit v1.0.1 h1:+7AIFJVQ0EQgq/K9+0Krm7m530Du7tIz0METWzN0RgY= github.com/juju/ratelimit v1.0.1/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -111,6 +120,8 @@ github.com/kr/pty v1.0.0/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lonng/pprof v0.0.0-20191012154247-04dfd648ce8d h1:6Ike9EBxOFsCMMih14rQJmb7WPWdgRu4C0OLl6oRHwE= +github.com/lonng/pprof v0.0.0-20191012154247-04dfd648ce8d/go.mod h1:0vjxLpmyJvBwQbIQuxhHxmogQFpJvB9doVHvxFFfyoY= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= @@ -118,23 +129,23 @@ github.com/mattn/go-shellwords v1.0.3/go.mod h1:3xCvwCdWdlDJUrvuMn7Wuy9eWs4pE8vq github.com/matttproud/golang_protobuf_extensions v1.0.0/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/montanaflynn/stats v0.0.0-20151014174947-eeaced052adb/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/montanaflynn/stats v0.0.0-20180911141734-db72e6cae808 h1:pmpDGKLw4n82EtrNiLqB+xSz/JQwFOaZuMALYUHwX5s= github.com/montanaflynn/stats v0.0.0-20180911141734-db72e6cae808/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= -github.com/myesui/uuid v1.0.0 h1:xCBmH4l5KuvLYc5L7AS7SZg9/jKdIFubM7OVoLqaQUI= -github.com/myesui/uuid v1.0.0/go.mod h1:2CDfNgU0LR8mIdO8vdWd8i9gWWxLlcoIGGpSNgafq84= github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7 h1:7KAv7KMGTTqSmYZtNdcNTgsos+vFzULLwyElndwn+5c= github.com/ngaut/pools v0.0.0-20180318154953-b7bc8c42aac7/go.mod h1:iWMfgwqYW+e8n5lC/jjNEhwcjbRDpl5NT7n2h+4UNcI= github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef h1:K0Fn+DoFqNqktdZtdV3bPQ/0cuYh2H4rkg0tytX/07k= github.com/ngaut/sync2 v0.0.0-20141008032647-7a24ed77b2ef/go.mod h1:7WjlapSfwQyo6LNmIvEWzsW1hbBQfpUO4JWnuQRmva8= github.com/nicksnyder/go-i18n v1.10.0/go.mod h1:HrK7VCrbOvQoUAQ7Vpy7i87N7JZZZ7R2xBGjv0j365Q= github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= +github.com/onsi/ginkgo v1.6.0 h1:Ix8l273rp3QzYgXSR+c8d1fTG7UPgYkOSELPhiY/YGw= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v1.4.2 h1:3mYCb7aPxS/RU7TI1y4rkEn1oKmPRjNJLNEXgw7MH2I= github.com/onsi/gomega v1.4.2/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= -github.com/onsi/gomega v1.4.3 h1:RE1xgDvH7imwFD45h+u2SgIfERHlS2yNG4DObb5BSKU= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/opentracing/basictracer-go v1.0.0 h1:YyUAhaEfjoWXclZVJ9sGoNct7j4TVk7lZWlQw5UXuoo= github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg= @@ -144,30 +155,26 @@ github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8 h1:USx2/E1bX46VG32FI github.com/pingcap/check v0.0.0-20190102082844-67f458068fc8/go.mod h1:B1+S9LNcuMyLH/4HMTViQOJevkGiik3wW2AN9zb2fNQ= github.com/pingcap/errcode v0.0.0-20180921232412-a1a7271709d9 h1:KH4f4Si9XK6/IW50HtoaiLIFHGkapOM6w83za47UYik= github.com/pingcap/errcode v0.0.0-20180921232412-a1a7271709d9/go.mod h1:4b2X8xSqxIroj/IZ9MX/VGZhAwc11wB9wRIzHvz6SeM= -github.com/pingcap/errors v0.10.1/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= -github.com/pingcap/failpoint v0.0.0-20190422094118-d8535965f59b h1:gFQVlQbthX4C2WKV/zqGBF3bZFr7oceKK9jGOVNkfws= -github.com/pingcap/failpoint v0.0.0-20190422094118-d8535965f59b/go.mod h1:fdAkVXuIXHAPZ7a280nj9bRORfK9NuSsOguvBH0+W6c= -github.com/pingcap/gofail v0.0.0-20181217135706-6a951c1e42c3 h1:04yuCf5NMvLU8rB2m4Qs3rynH7EYpMno3lHkewIOdMo= -github.com/pingcap/gofail v0.0.0-20181217135706-6a951c1e42c3/go.mod h1:DazNTg0PTldtpsQiT9I5tVJwV1onHMKBBgXzmJUlMns= +github.com/pingcap/failpoint v0.0.0-20190512135322-30cc7431d99c h1:hvQd3aOLKLF7xvRV6DzvPkKY4QXzfVbjU1BhW0d9yL8= +github.com/pingcap/failpoint v0.0.0-20190512135322-30cc7431d99c/go.mod h1:DNS3Qg7bEDhU6EXNHF+XSv/PGznQaMJ5FWvctpm6pQI= github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e h1:P73/4dPCL96rGrobssy1nVy2VaVpNCuLpCbr+FEaTA8= github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e/go.mod h1:O17XtbryoCJhkKGbT62+L2OlrniwqiGLSqrmdHCMzZw= -github.com/pingcap/kvproto v0.0.0-20190327032727-3d8cb3a30d5d/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY= -github.com/pingcap/kvproto v0.0.0-20190517030054-ff2e03f6fdfe h1:hs1Y4RTsPg0DOEGanGxaXG/2iqewWNY6/GVLkFnZMaU= -github.com/pingcap/kvproto v0.0.0-20190517030054-ff2e03f6fdfe/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY= -github.com/pingcap/log v0.0.0-20190214045112-b37da76f67a7/go.mod h1:xsfkWVaFVV5B8e1K9seWfyJWFrIhbtUTAD8NV1Pq3+w= -github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596 h1:t2OQTpPJnrPDGlvA+3FwJptMTt6MEPdzK1Wt99oaefQ= -github.com/pingcap/log v0.0.0-20190307075452-bd41d9273596/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw= -github.com/pingcap/parser v0.0.0-20190522123204-4628ac31ee47 h1:dEGqJ89QHTj1bfSbV9e8A9TtOiwaTEVkrR7VQLa7mTQ= -github.com/pingcap/parser v0.0.0-20190522123204-4628ac31ee47/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= -github.com/pingcap/pd v0.0.0-20190424024702-bd1e2496a669 h1:ZoKjndm/Ig7Ru/wojrQkc/YLUttUdQXoH77gtuWCvL4= -github.com/pingcap/pd v0.0.0-20190424024702-bd1e2496a669/go.mod h1:MUCxRzOkYiWZtlyi4MhxjCIj9PgQQ/j+BLNGm7aUsnM= +github.com/pingcap/kvproto v0.0.0-20190516013202-4cf58ad90b6c/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY= +github.com/pingcap/kvproto v0.0.0-20190918085321-44e3817e1f18 h1:5vQV8S/8B9nE+I+0Me6vZGyASeXl/QymwqtaOL5e5ZA= +github.com/pingcap/kvproto v0.0.0-20190918085321-44e3817e1f18/go.mod h1:QMdbTAXCHzzygQzqcG9uVUgU2fKeSN1GmfMiykdSzzY= +github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd h1:hWDol43WY5PGhsh3+8794bFHY1bPrmu6bTalpssCrGg= +github.com/pingcap/log v0.0.0-20190715063458-479153f07ebd/go.mod h1:WpHUKhNZ18v116SvGrmjkA9CBhYmuUTKL+p8JC9ANEw= +github.com/pingcap/parser v0.0.0-20191018040038-555b97093a2a h1:PMjYrxWKdVUlJ77+9YHbYVciDQCyqZ/noS9nIni76KQ= +github.com/pingcap/parser v0.0.0-20191018040038-555b97093a2a/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= +github.com/pingcap/pd v1.1.0-beta.0.20190912093418-dc03c839debd h1:bKj6hodu/ro78B0oN2yicdGn0t4yd9XjnyoW95qmWic= +github.com/pingcap/pd v1.1.0-beta.0.20190912093418-dc03c839debd/go.mod h1:I7TEby5BHTYIxgHszfsOJSBsk8b2Qt8QrSIgdv5n5QQ= github.com/pingcap/tidb-tools v2.1.3-0.20190321065848-1e8b48f5c168+incompatible h1:MkWCxgZpJBgY2f4HtwWMMFzSBb3+JPzeJgF3VrXE/bU= github.com/pingcap/tidb-tools v2.1.3-0.20190321065848-1e8b48f5c168+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= -github.com/pingcap/tipb v0.0.0-20190428032612-535e1abaa330 h1:rRMLMjIMFulCX9sGKZ1hoov/iROMsKyC8Snc02nSukw= -github.com/pingcap/tipb v0.0.0-20190428032612-535e1abaa330/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= +github.com/pingcap/tipb v0.0.0-20191031111650-d14196d52154 h1:ZA7eV+GvatF/yQwttU++wK/61LEa4YYGBCO6DSfBgEM= +github.com/pingcap/tipb v0.0.0-20191031111650-d14196d52154/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -186,8 +193,10 @@ github.com/prometheus/common v0.0.0-20181020173914-7e9e6cabbd39/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180612222113-7d6f385de8be/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFdaDqxJVlbOQ1DtGmZWs/Qau0hIlk+WQ= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= -github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446 h1:/NRJ5vAYoqz+7sG51ubIDHXeWO8DlTSrToPu6q11ziA= -github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446/go.mod h1:uYEyJGbgTkfkS4+E/PavXkNJcbFIpEtjt2B0KDQ5+9M= +github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7 h1:FUL3b97ZY2EPqg2NbXKuMHs5pXJB9hjj1fDHnF2vl28= +github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44 h1:tB9NOR21++IjLyVx3/PCPhWMwqGNCMQEH96A6dMZ/gc= +github.com/sergi/go-diff v1.0.1-0.20180205163309-da645544ed44/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs= github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shurcooL/httpfs v0.0.0-20171119174359-809beceb2371 h1:SWV2fHctRpRrp49VXJ6UZja7gU9QLHwRpIPBN89SKEo= @@ -218,8 +227,6 @@ github.com/tiancaiamao/appdash v0.0.0-20181126055449-889f96f722a2/go.mod h1:2PfK github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6 h1:lYIiVDtZnyTWlNwiAxLj0bbpTcx1BWCFhXjfsvmPdNc= github.com/tmc/grpc-websocket-proxy v0.0.0-20171017195756-830351dc03c6/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= -github.com/twinj/uuid v1.0.0 h1:fzz7COZnDrXGTAOHGuUGYd6sG+JMq+AoE7+Jlu0przk= -github.com/twinj/uuid v1.0.0/go.mod h1:mMgcE1RHFUFqe5AfiwlINXisXfDGro23fWdPUfOMjRY= github.com/uber-go/atomic v1.3.2 h1:Azu9lPBWRNKzYXSIwRfgRuDuS0YKsK4NFhiQv98gkxo= github.com/uber-go/atomic v1.3.2/go.mod h1:/Ct5t2lcmbJ4OSe/waGBoaVvVqtO0bmtfVNex1PFV8g= github.com/uber/jaeger-client-go v2.15.0+incompatible h1:NP3qsSqNxh8VYr956ur1N/1C1PjvOJnJykCzcD5QHbk= @@ -238,8 +245,9 @@ github.com/urfave/negroni v0.3.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKn github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2 h1:eY9dn8+vbi4tKz5Qo6v2eYzo7kUS51QINcR5jNpbZS8= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/yookoala/realpath v1.0.0/go.mod h1:gJJMA9wuX7AcqLy1+ffPatSCySA1FQ2S8Ya9AIoYBpE= -go.etcd.io/bbolt v1.3.2 h1:Z/90sZLPOeCy2PwprqkFa25PdkusRzaj9P8zm/KNyvk= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= +go.etcd.io/bbolt v1.3.3 h1:MUGmc65QhB3pIlaQ5bB4LwqSj6GIonVJXpZiaKNyaKk= +go.etcd.io/bbolt v1.3.3/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/etcd v0.0.0-20190320044326-77d4b742cdbf h1:rmttwKPEgG/l4UscTDYtaJgeUsedKPKSyFfNQLI6q+I= go.etcd.io/etcd v0.0.0-20190320044326-77d4b742cdbf/go.mod h1:KSGwdbiFchh5KIC9My2+ZVl5/3ANcwohw50dpPwa2cw= go.uber.org/atomic v1.3.2 h1:2Oa65PReHzfn29GpvgsYwloV9AVFHPDk8tYxt2c2tr4= @@ -251,28 +259,35 @@ go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20180608092829-8ac0e0d97ce4/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793 h1:u+LnwYTOOW7Ukr/fppxEb1Nwz0AtPflrblfvUudpo+I= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190909091759-094676da4a83 h1:mgAKeshyNqWKdENOnQsg+8dRTwZFIwFaO3HNl52sweA= +golang.org/x/crypto v0.0.0-20190909091759-094676da4a83/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181005035420-146acd28ed58/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190909003024-a7b16738d86b h1:XfVGCX+0T4WOStkaOsJRllbsiImhB2jgVBGc9L0lPGc= +golang.org/x/net v0.0.0-20190909003024-a7b16738d86b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190109145017-48ac38b7c8cb h1:1w588/yEchbPNpa9sEvOcMZYbWHedwJjg4VOAdDHWHk= -golang.org/x/sys v0.0.0-20190109145017-48ac38b7c8cb/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191010194322-b09406accb47 h1:/XfQ9z7ib8eEJX2hdgFTZJ/ntt0swNk5oYBziWeTCvY= +golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2 h1:+DCIGbF/swA92ohVg0//6X2IVY3KZs6p9mix0ziNYJM= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= -golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190130214255-bb1329dc71a0 h1:iRpjPej1fPzmfoBhMFkp3HdqzF+ytPmAwiQhJGV0zGw= golang.org/x/tools v0.0.0-20190130214255-bb1329dc71a0/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/infoschema/builder.go b/infoschema/builder.go index 7b90e1f23cd12..72d0d9ce29e72 100644 --- a/infoschema/builder.go +++ b/infoschema/builder.go @@ -154,22 +154,22 @@ func (b *Builder) applyDropSchema(schemaID int64) []int64 { delete(b.is.schemaMap, di.Name.L) // Copy the sortedTables that contain the table we are going to drop. + tableIDs := make([]int64, 0, len(di.Tables)) bucketIdxMap := make(map[int]struct{}) for _, tbl := range di.Tables { bucketIdxMap[tableBucketIdx(tbl.ID)] = struct{}{} + // TODO: If the table ID doesn't exist. + tableIDs = append(tableIDs, tbl.ID) } for bucketIdx := range bucketIdxMap { b.copySortedTablesBucket(bucketIdx) } - ids := make([]int64, 0, len(di.Tables)) di = di.Clone() - for _, tbl := range di.Tables { - b.applyDropTable(di, tbl.ID) - // TODO: If the table ID doesn't exist. - ids = append(ids, tbl.ID) + for _, id := range tableIDs { + b.applyDropTable(di, id) } - return ids + return tableIDs } func (b *Builder) copySortedTablesBucket(bucketIdx int) { diff --git a/infoschema/infoschema.go b/infoschema/infoschema.go index ed455b3dacfcb..9da6e2ee72dd0 100644 --- a/infoschema/infoschema.go +++ b/infoschema/infoschema.go @@ -292,6 +292,11 @@ func (h *Handle) Get() InfoSchema { return schema } +// IsValid uses to check whether handle value is valid. +func (h *Handle) IsValid() bool { + return h.value.Load() != nil +} + // EmptyClone creates a new Handle with the same store and memSchema, but the value is not set. func (h *Handle) EmptyClone() *Handle { newHandle := &Handle{ @@ -383,3 +388,13 @@ func IsMemoryDB(dbName string) bool { } return false } + +// HasAutoIncrementColumn checks whether the table has auto_increment columns, if so, return true and the column name. +func HasAutoIncrementColumn(tbInfo *model.TableInfo) (bool, string) { + for _, col := range tbInfo.Columns { + if mysql.HasAutoIncrementFlag(col.Flag) { + return true, col.Name.L + } + } + return false, "" +} diff --git a/infoschema/perfschema/const.go b/infoschema/perfschema/const.go index db8f9af0f5e51..b72226e954d36 100644 --- a/infoschema/perfschema/const.go +++ b/infoschema/perfschema/const.go @@ -36,6 +36,13 @@ var perfSchemaTables = []string{ tableStagesCurrent, tableStagesHistory, tableStagesHistoryLong, + tableEventsStatementsSummaryByDigest, + tableTiDBProfileCPU, + tableTiDBProfileMemory, + tableTiDBProfileMutex, + tableTiDBAllocsProfile, + tableTiDBProfileBlock, + tableTiDBProfileGoroutines, } // tableGlobalStatus contains the column name definitions for table global_status, same as MySQL. @@ -371,3 +378,71 @@ const tableStagesHistoryLong = "CREATE TABLE if not exists performance_schema.ev "WORK_ESTIMATED BIGINT(20) UNSIGNED," + "NESTING_EVENT_ID BIGINT(20) UNSIGNED," + "NESTING_EVENT_TYPE ENUM('TRANSACTION','STATEMENT','STAGE'));" + +// tableEventsStatementsSummaryByDigest contains the column name definitions for table +// events_statements_summary_by_digest, same as MySQL. +const tableEventsStatementsSummaryByDigest = "CREATE TABLE if not exists events_statements_summary_by_digest (" + + "SCHEMA_NAME VARCHAR(64) DEFAULT NULL," + + "DIGEST VARCHAR(64) DEFAULT NULL," + + "DIGEST_TEXT LONGTEXT DEFAULT NULL," + + "EXEC_COUNT BIGINT(20) UNSIGNED NOT NULL," + + "SUM_LATENCY BIGINT(20) UNSIGNED NOT NULL," + + "MAX_LATENCY BIGINT(20) UNSIGNED NOT NULL," + + "MIN_LATENCY BIGINT(20) UNSIGNED NOT NULL," + + "AVG_LATENCY BIGINT(20) UNSIGNED NOT NULL," + + "SUM_ROWS_AFFECTED BIGINT(20) UNSIGNED NOT NULL," + + "FIRST_SEEN TIMESTAMP(6) NOT NULL," + + "LAST_SEEN TIMESTAMP(6) NOT NULL," + + "QUERY_SAMPLE_TEXT LONGTEXT DEFAULT NULL);" + +// tableTiDBProfileCPU contains the columns name definitions for table events_cpu_profile_graph +const tableTiDBProfileCPU = "CREATE TABLE IF NOT EXISTS " + tableNameTiDBProfileCPU + " (" + + "FUNCTION VARCHAR(512) NOT NULL," + + "PERCENT_ABS VARCHAR(8) NOT NULL," + + "PERCENT_REL VARCHAR(8) NOT NULL," + + "ROOT_CHILD INT(8) NOT NULL," + + "DEPTH INT(8) NOT NULL," + + "FILE VARCHAR(512) NOT NULL);" + +// tableTiDBProfileMemory contains the columns name definitions for table events_memory_profile_graph +const tableTiDBProfileMemory = "CREATE TABLE IF NOT EXISTS " + tableNameTiDBProfileMemory + " (" + + "FUNCTION VARCHAR(512) NOT NULL," + + "PERCENT_ABS VARCHAR(8) NOT NULL," + + "PERCENT_REL VARCHAR(8) NOT NULL," + + "ROOT_CHILD INT(8) NOT NULL," + + "DEPTH INT(8) NOT NULL," + + "FILE VARCHAR(512) NOT NULL);" + +// tableTiDBProfileMutex contains the columns name definitions for table events_mutex_profile_graph +const tableTiDBProfileMutex = "CREATE TABLE IF NOT EXISTS " + tableNameTiDBProfileMutex + " (" + + "FUNCTION VARCHAR(512) NOT NULL," + + "PERCENT_ABS VARCHAR(8) NOT NULL," + + "PERCENT_REL VARCHAR(8) NOT NULL," + + "ROOT_CHILD INT(8) NOT NULL," + + "DEPTH INT(8) NOT NULL," + + "FILE VARCHAR(512) NOT NULL);" + +// tableTiDBAllocsProfile contains the columns name definitions for table events_allocs_profile_graph +const tableTiDBAllocsProfile = "CREATE TABLE IF NOT EXISTS " + tableNameTiDBProfileAllocs + " (" + + "FUNCTION VARCHAR(512) NOT NULL," + + "PERCENT_ABS VARCHAR(8) NOT NULL," + + "PERCENT_REL VARCHAR(8) NOT NULL," + + "ROOT_CHILD INT(8) NOT NULL," + + "DEPTH INT(8) NOT NULL," + + "FILE VARCHAR(512) NOT NULL);" + +// tableTiDBProfileBlock contains the columns name definitions for table events_block_profile_graph +const tableTiDBProfileBlock = "CREATE TABLE IF NOT EXISTS " + tableNameTiDBProfileBlock + " (" + + "FUNCTION VARCHAR(512) NOT NULL," + + "PERCENT_ABS VARCHAR(8) NOT NULL," + + "PERCENT_REL VARCHAR(8) NOT NULL," + + "ROOT_CHILD INT(8) NOT NULL," + + "DEPTH INT(8) NOT NULL," + + "FILE VARCHAR(512) NOT NULL);" + +// tableTiDBProfileGoroutines contains the columns name definitions for table events_goroutine +const tableTiDBProfileGoroutines = "CREATE TABLE IF NOT EXISTS " + tableNameTiDBProfileGoroutines + " (" + + "FUNCTION VARCHAR(512) NOT NULL," + + "ID INT(8) NOT NULL," + + "STATE VARCHAR(16) NOT NULL," + + "LOCATION VARCHAR(512));" diff --git a/infoschema/perfschema/tables.go b/infoschema/perfschema/tables.go index dbe3e68155659..aa3229f731365 100644 --- a/infoschema/perfschema/tables.go +++ b/infoschema/perfschema/tables.go @@ -16,8 +16,23 @@ package perfschema import ( "github.com/pingcap/parser/model" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/profile" + "github.com/pingcap/tidb/util/stmtsummary" +) + +const ( + tableNameEventsStatementsSummaryByDigest = "events_statements_summary_by_digest" + tableNameTiDBProfileCPU = "tidb_profile_cpu" + tableNameTiDBProfileMemory = "tidb_profile_memory" + tableNameTiDBProfileMutex = "tidb_profile_mutex" + tableNameTiDBProfileAllocs = "tidb_profile_allocs" + tableNameTiDBProfileBlock = "tidb_profile_block" + tableNameTiDBProfileGoroutines = "tidb_profile_goroutines" ) // perfSchemaTable stands for the fake table all its data is in the memory. @@ -77,3 +92,59 @@ func (vt *perfSchemaTable) GetPhysicalID() int64 { func (vt *perfSchemaTable) Meta() *model.TableInfo { return vt.meta } + +func (vt *perfSchemaTable) getRows(ctx sessionctx.Context, cols []*table.Column) (fullRows [][]types.Datum, err error) { + switch vt.meta.Name.O { + case tableNameEventsStatementsSummaryByDigest: + fullRows = stmtsummary.StmtSummaryByDigestMap.ToDatum() + case tableNameTiDBProfileCPU: + fullRows, err = (&profile.Collector{}).ProfileGraph("cpu") + case tableNameTiDBProfileMemory: + fullRows, err = (&profile.Collector{}).ProfileGraph("heap") + case tableNameTiDBProfileMutex: + fullRows, err = (&profile.Collector{}).ProfileGraph("mutex") + case tableNameTiDBProfileAllocs: + fullRows, err = (&profile.Collector{}).ProfileGraph("allocs") + case tableNameTiDBProfileBlock: + fullRows, err = (&profile.Collector{}).ProfileGraph("block") + case tableNameTiDBProfileGoroutines: + fullRows, err = (&profile.Collector{}).Goroutines() + } + if err != nil { + return + } + if len(cols) == len(vt.cols) { + return + } + rows := make([][]types.Datum, len(fullRows)) + for i, fullRow := range fullRows { + row := make([]types.Datum, len(cols)) + for j, col := range cols { + row[j] = fullRow[col.Offset] + } + rows[i] = row + } + return rows, nil +} + +// IterRecords implements table.Table IterRecords interface. +func (vt *perfSchemaTable) IterRecords(ctx sessionctx.Context, startKey kv.Key, cols []*table.Column, + fn table.RecordIterFunc) error { + if len(startKey) != 0 { + return table.ErrUnsupportedOp + } + rows, err := vt.getRows(ctx, cols) + if err != nil { + return err + } + for i, row := range rows { + more, err := fn(int64(i), row, cols) + if err != nil { + return err + } + if !more { + break + } + } + return nil +} diff --git a/infoschema/perfschema/tables_test.go b/infoschema/perfschema/tables_test.go index 4349e4e77ca2c..b344a1fad9ec3 100644 --- a/infoschema/perfschema/tables_test.go +++ b/infoschema/perfschema/tables_test.go @@ -17,6 +17,8 @@ import ( "testing" . "github.com/pingcap/check" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/util/testkit" @@ -28,22 +30,32 @@ func TestT(t *testing.T) { TestingT(t) } -var _ = Suite(&testSuite{}) +var _ = Suite(&testTableSuite{}) -type testSuite struct { +type testTableSuite struct { + store kv.Storage + dom *domain.Domain } -func (s *testSuite) TestPerfSchemaTables(c *C) { +func (s *testTableSuite) SetUpSuite(c *C) { testleak.BeforeTest() - defer testleak.AfterTest(c)() - store, err := mockstore.NewMockTikvStore() + + var err error + s.store, err = mockstore.NewMockTikvStore() c.Assert(err, IsNil) - defer store.Close() - do, err := session.BootstrapSession(store) + session.DisableStats4Test() + s.dom, err = session.BootstrapSession(s.store) c.Assert(err, IsNil) - defer do.Close() +} - tk := testkit.NewTestKit(c, store) +func (s *testTableSuite) TearDownSuite(c *C) { + defer testleak.AfterTest(c)() + s.dom.Close() + s.store.Close() +} + +func (s *testTableSuite) TestPerfSchemaTables(c *C) { + tk := testkit.NewTestKit(c, s.store) tk.MustExec("use performance_schema") tk.MustQuery("select * from global_status where variable_name = 'Ssl_verify_mode'").Check(testkit.Rows()) @@ -51,3 +63,95 @@ func (s *testSuite) TestPerfSchemaTables(c *C) { tk.MustQuery("select * from setup_actors").Check(testkit.Rows()) tk.MustQuery("select * from events_stages_history_long").Check(testkit.Rows()) } + +// Test events_statements_summary_by_digest +func (s *testTableSuite) TestStmtSummaryTable(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b varchar(10))") + + // Statement summary is disabled by default + tk.MustQuery("select @@global.tidb_enable_stmt_summary").Check(testkit.Rows("0")) + tk.MustExec("insert into t values(1, 'a')") + tk.MustQuery("select * from performance_schema.events_statements_summary_by_digest").Check(testkit.Rows()) + + tk.MustExec("set global tidb_enable_stmt_summary = 1") + tk.MustQuery("select @@global.tidb_enable_stmt_summary").Check(testkit.Rows("1")) + + // Invalidate the cache manually so that tidb_enable_stmt_summary works immediately. + s.dom.GetGlobalVarsCache().Disable() + + // Create a new session to test + tk = testkit.NewTestKitWithInit(c, s.store) + + // Test INSERT + tk.MustExec("insert into t values(1, 'a')") + tk.MustExec("insert into t values(2, 'b')") + tk.MustExec("insert into t VALUES(3, 'c')") + tk.MustExec("/**/insert into t values(4, 'd')") + tk.MustQuery(`select schema_name, exec_count, sum_rows_affected, query_sample_text + from performance_schema.events_statements_summary_by_digest + where digest_text like 'insert into t%'`, + ).Check(testkit.Rows("test 4 4 insert into t values(1, 'a')")) + + // Test SELECT + tk.MustQuery("select * from t where a=2") + tk.MustQuery(`select schema_name, exec_count, sum_rows_affected, query_sample_text + from performance_schema.events_statements_summary_by_digest + where digest_text like 'select * from t%'`, + ).Check(testkit.Rows("test 1 0 select * from t where a=2")) + + // select ... order by + tk.MustQuery(`select schema_name, exec_count, sum_rows_affected, query_sample_text + from performance_schema.events_statements_summary_by_digest + order by exec_count desc limit 1`, + ).Check(testkit.Rows("test 4 4 insert into t values(1, 'a')")) + + // Disable it again + tk.MustExec("set global tidb_enable_stmt_summary = false") + tk.MustQuery("select @@global.tidb_enable_stmt_summary").Check(testkit.Rows("0")) + + // Create a new session to test + tk = testkit.NewTestKitWithInit(c, s.store) + + // This statement shouldn't be summarized + tk.MustQuery("select * from t where a=2") + + // The table should be cleared + tk.MustQuery(`select schema_name, exec_count, sum_rows_affected, query_sample_text + from performance_schema.events_statements_summary_by_digest`, + ).Check(testkit.Rows()) + + // Enable it in session scope + tk.MustExec("set session tidb_enable_stmt_summary = on") + // It should work immediately + tk.MustQuery("select * from t where a=2") + tk.MustQuery(`select schema_name, exec_count, sum_rows_affected, query_sample_text + from performance_schema.events_statements_summary_by_digest + where digest_text like 'select * from t%'`, + ).Check(testkit.Rows("test 1 0 select * from t where a=2")) + + // Disable it in global scope + tk.MustExec("set global tidb_enable_stmt_summary = off") + + // Create a new session to test + tk = testkit.NewTestKitWithInit(c, s.store) + + tk.MustQuery("select * from t where a=2") + + // Statement summary is still enabled + tk.MustQuery(`select schema_name, exec_count, sum_rows_affected, query_sample_text + from performance_schema.events_statements_summary_by_digest + where digest_text like 'select * from t%'`, + ).Check(testkit.Rows("test 2 0 select * from t where a=2")) + + // Unset session variable + tk.MustExec("set session tidb_enable_stmt_summary = ''") + tk.MustQuery("select * from t where a=2") + + // Statement summary is disabled + tk.MustQuery(`select schema_name, exec_count, sum_rows_affected, query_sample_text + from performance_schema.events_statements_summary_by_digest`, + ).Check(testkit.Rows()) +} diff --git a/infoschema/slow_log.go b/infoschema/slow_log.go index 68a0053eb81d4..ada31aaaeb43e 100644 --- a/infoschema/slow_log.go +++ b/infoschema/slow_log.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/plancodec" "go.uber.org/zap" ) @@ -37,8 +38,22 @@ var slowQueryCols = []columnInfo{ {variable.SlowLogTimeStr, mysql.TypeTimestamp, 26, 0, nil, nil}, {variable.SlowLogTxnStartTSStr, mysql.TypeLonglong, 20, mysql.UnsignedFlag, nil, nil}, {variable.SlowLogUserStr, mysql.TypeVarchar, 64, 0, nil, nil}, + {variable.SlowLogHostStr, mysql.TypeVarchar, 64, 0, nil, nil}, {variable.SlowLogConnIDStr, mysql.TypeLonglong, 20, mysql.UnsignedFlag, nil, nil}, {variable.SlowLogQueryTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, + {variable.SlowLogParseTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, + {variable.SlowLogCompileTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, + {execdetails.PreWriteTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, + {execdetails.CommitTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, + {execdetails.GetCommitTSTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, + {execdetails.CommitBackoffTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, + {execdetails.BackoffTypesStr, mysql.TypeVarchar, 64, 0, nil, nil}, + {execdetails.ResolveLockTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, + {execdetails.LocalLatchWaitTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, + {execdetails.WriteKeysStr, mysql.TypeLonglong, 22, 0, nil, nil}, + {execdetails.WriteSizeStr, mysql.TypeLonglong, 22, 0, nil, nil}, + {execdetails.PrewriteRegionStr, mysql.TypeLonglong, 22, 0, nil, nil}, + {execdetails.TxnRetryStr, mysql.TypeLonglong, 22, 0, nil, nil}, {execdetails.ProcessTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, {execdetails.WaitTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, {execdetails.BackoffTimeStr, mysql.TypeDouble, 22, 0, nil, nil}, @@ -46,7 +61,7 @@ var slowQueryCols = []columnInfo{ {execdetails.TotalKeysStr, mysql.TypeLonglong, 20, mysql.UnsignedFlag, nil, nil}, {execdetails.ProcessKeysStr, mysql.TypeLonglong, 20, mysql.UnsignedFlag, nil, nil}, {variable.SlowLogDBStr, mysql.TypeVarchar, 64, 0, nil, nil}, - {variable.SlowLogIndexIDsStr, mysql.TypeVarchar, 100, 0, nil, nil}, + {variable.SlowLogIndexNamesStr, mysql.TypeVarchar, 100, 0, nil, nil}, {variable.SlowLogIsInternalStr, mysql.TypeTiny, 1, 0, nil, nil}, {variable.SlowLogDigestStr, mysql.TypeVarchar, 64, 0, nil, nil}, {variable.SlowLogStatsInfoStr, mysql.TypeVarchar, 512, 0, nil, nil}, @@ -59,7 +74,10 @@ var slowQueryCols = []columnInfo{ {variable.SlowLogCopWaitMax, mysql.TypeDouble, 22, 0, nil, nil}, {variable.SlowLogCopWaitAddr, mysql.TypeVarchar, 64, 0, nil, nil}, {variable.SlowLogMemMax, mysql.TypeLonglong, 20, 0, nil, nil}, - {variable.SlowLogQuerySQLStr, mysql.TypeVarchar, 4096, 0, nil, nil}, + {variable.SlowLogSucc, mysql.TypeTiny, 1, 0, nil, nil}, + {variable.SlowLogPlan, mysql.TypeLongBlob, types.UnspecifiedLength, 0, nil, nil}, + {variable.SlowLogPrevStmt, mysql.TypeLongBlob, types.UnspecifiedLength, 0, nil, nil}, + {variable.SlowLogQuerySQLStr, mysql.TypeLongBlob, types.UnspecifiedLength, 0, nil, nil}, } func dataForSlowLog(ctx sessionctx.Context) ([][]types.Datum, error) { @@ -111,15 +129,19 @@ func ParseSlowLog(tz *time.Location, reader *bufio.Reader) ([][]types.Datum, err // Parse slow log field. if strings.HasPrefix(line, variable.SlowLogRowPrefixStr) { line = line[len(variable.SlowLogRowPrefixStr):] - fieldValues := strings.Split(line, " ") - for i := 0; i < len(fieldValues)-1; i += 2 { - field := fieldValues[i] - if strings.HasSuffix(field, ":") { - field = field[:len(field)-1] - } - err = st.setFieldValue(tz, field, fieldValues[i+1]) - if err != nil { - return rows, err + if strings.HasPrefix(line, variable.SlowLogPrevStmtPrefix) { + st.prevStmt = line[len(variable.SlowLogPrevStmtPrefix):] + } else { + fieldValues := strings.Split(line, " ") + for i := 0; i < len(fieldValues)-1; i += 2 { + field := fieldValues[i] + if strings.HasSuffix(field, ":") { + field = field[:len(field)-1] + } + err = st.setFieldValue(tz, field, fieldValues[i+1]) + if err != nil { + return rows, err + } } } } else if strings.HasSuffix(line, variable.SlowLogSQLSuffixStr) { @@ -138,125 +160,148 @@ func ParseSlowLog(tz *time.Location, reader *bufio.Reader) ([][]types.Datum, err } func getOneLine(reader *bufio.Reader) ([]byte, error) { + var resByte []byte lineByte, isPrefix, err := reader.ReadLine() + if isPrefix { + // Need to read more data. + resByte = make([]byte, len(lineByte), len(lineByte)*2) + } else { + resByte = make([]byte, len(lineByte)) + } + // Use copy here to avoid shallow copy problem. + copy(resByte, lineByte) if err != nil { - return lineByte, err + return resByte, err } + var tempLine []byte for isPrefix { tempLine, isPrefix, err = reader.ReadLine() - lineByte = append(lineByte, tempLine...) + resByte = append(resByte, tempLine...) // Use the max value of max_allowed_packet to check the single line length. - if len(lineByte) > int(variable.MaxOfMaxAllowedPacket) { - return lineByte, errors.Errorf("single line length exceeds limit: %v", variable.MaxOfMaxAllowedPacket) + if len(resByte) > int(variable.MaxOfMaxAllowedPacket) { + return resByte, errors.Errorf("single line length exceeds limit: %v", variable.MaxOfMaxAllowedPacket) } if err != nil { - return lineByte, err + return resByte, err } } - return lineByte, err + return resByte, err } type slowQueryTuple struct { - time time.Time - txnStartTs uint64 - user string - connID uint64 - queryTime float64 - processTime float64 - waitTime float64 - backOffTime float64 - requestCount uint64 - totalKeys uint64 - processKeys uint64 - db string - indexIDs string - isInternal bool - digest string - statsInfo string - avgProcessTime float64 - p90ProcessTime float64 - maxProcessTime float64 - maxProcessAddress string - avgWaitTime float64 - p90WaitTime float64 - maxWaitTime float64 - maxWaitAddress string - memMax int64 - sql string + time time.Time + txnStartTs uint64 + user string + host string + connID uint64 + queryTime float64 + parseTime float64 + compileTime float64 + preWriteTime float64 + commitTime float64 + getCommitTSTime float64 + commitBackoffTime float64 + backoffTypes string + resolveLockTime float64 + localLatchWaitTime float64 + writeKeys uint64 + writeSize uint64 + prewriteRegion uint64 + txnRetry uint64 + processTime float64 + waitTime float64 + backOffTime float64 + requestCount uint64 + totalKeys uint64 + processKeys uint64 + db string + indexIDs string + digest string + statsInfo string + avgProcessTime float64 + p90ProcessTime float64 + maxProcessTime float64 + maxProcessAddress string + avgWaitTime float64 + p90WaitTime float64 + maxWaitTime float64 + maxWaitAddress string + memMax int64 + prevStmt string + sql string + isInternal bool + succ bool + plan string } func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string) error { + var err error switch field { case variable.SlowLogTimeStr: - t, err := ParseTime(value) + st.time, err = ParseTime(value) if err != nil { - return err + break } - if t.Location() != tz { - t = t.In(tz) + if st.time.Location() != tz { + st.time = st.time.In(tz) } - st.time = t case variable.SlowLogTxnStartTSStr: - num, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return errors.AddStack(err) - } - st.txnStartTs = num + st.txnStartTs, err = strconv.ParseUint(value, 10, 64) case variable.SlowLogUserStr: - st.user = value - case variable.SlowLogConnIDStr: - num, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return errors.AddStack(err) + fields := strings.SplitN(value, "@", 2) + if len(field) > 0 { + st.user = fields[0] } - st.connID = num - case variable.SlowLogQueryTimeStr: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) + if len(field) > 1 { + st.host = fields[1] } - st.queryTime = num + case variable.SlowLogConnIDStr: + st.connID, err = strconv.ParseUint(value, 10, 64) + case variable.SlowLogQueryTimeStr: + st.queryTime, err = strconv.ParseFloat(value, 64) + case variable.SlowLogParseTimeStr: + st.parseTime, err = strconv.ParseFloat(value, 64) + case variable.SlowLogCompileTimeStr: + st.compileTime, err = strconv.ParseFloat(value, 64) + case execdetails.PreWriteTimeStr: + st.preWriteTime, err = strconv.ParseFloat(value, 64) + case execdetails.CommitTimeStr: + st.commitTime, err = strconv.ParseFloat(value, 64) + case execdetails.GetCommitTSTimeStr: + st.getCommitTSTime, err = strconv.ParseFloat(value, 64) + case execdetails.CommitBackoffTimeStr: + st.commitBackoffTime, err = strconv.ParseFloat(value, 64) + case execdetails.BackoffTypesStr: + st.backoffTypes = value + case execdetails.ResolveLockTimeStr: + st.resolveLockTime, err = strconv.ParseFloat(value, 64) + case execdetails.LocalLatchWaitTimeStr: + st.localLatchWaitTime, err = strconv.ParseFloat(value, 64) + case execdetails.WriteKeysStr: + st.writeKeys, err = strconv.ParseUint(value, 10, 64) + case execdetails.WriteSizeStr: + st.writeSize, err = strconv.ParseUint(value, 10, 64) + case execdetails.PrewriteRegionStr: + st.prewriteRegion, err = strconv.ParseUint(value, 10, 64) + case execdetails.TxnRetryStr: + st.txnRetry, err = strconv.ParseUint(value, 10, 64) case execdetails.ProcessTimeStr: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) - } - st.processTime = num + st.processTime, err = strconv.ParseFloat(value, 64) case execdetails.WaitTimeStr: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) - } - st.waitTime = num + st.waitTime, err = strconv.ParseFloat(value, 64) case execdetails.BackoffTimeStr: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) - } - st.backOffTime = num + st.backOffTime, err = strconv.ParseFloat(value, 64) case execdetails.RequestCountStr: - num, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return errors.AddStack(err) - } - st.requestCount = num + st.requestCount, err = strconv.ParseUint(value, 10, 64) case execdetails.TotalKeysStr: - num, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return errors.AddStack(err) - } - st.totalKeys = num + st.totalKeys, err = strconv.ParseUint(value, 10, 64) case execdetails.ProcessKeysStr: - num, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return errors.AddStack(err) - } - st.processKeys = num + st.processKeys, err = strconv.ParseUint(value, 10, 64) case variable.SlowLogDBStr: st.db = value - case variable.SlowLogIndexIDsStr: + case variable.SlowLogIndexNamesStr: st.indexIDs = value case variable.SlowLogIsInternalStr: st.isInternal = value == "true" @@ -265,54 +310,33 @@ func (st *slowQueryTuple) setFieldValue(tz *time.Location, field, value string) case variable.SlowLogStatsInfoStr: st.statsInfo = value case variable.SlowLogCopProcAvg: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) - } - st.avgProcessTime = num + st.avgProcessTime, err = strconv.ParseFloat(value, 64) case variable.SlowLogCopProcP90: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) - } - st.p90ProcessTime = num + st.p90ProcessTime, err = strconv.ParseFloat(value, 64) case variable.SlowLogCopProcMax: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) - } - st.maxProcessTime = num + st.maxProcessTime, err = strconv.ParseFloat(value, 64) case variable.SlowLogCopProcAddr: st.maxProcessAddress = value case variable.SlowLogCopWaitAvg: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) - } - st.avgWaitTime = num + st.avgWaitTime, err = strconv.ParseFloat(value, 64) case variable.SlowLogCopWaitP90: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) - } - st.p90WaitTime = num + st.p90WaitTime, err = strconv.ParseFloat(value, 64) case variable.SlowLogCopWaitMax: - num, err := strconv.ParseFloat(value, 64) - if err != nil { - return errors.AddStack(err) - } - st.maxWaitTime = num + st.maxWaitTime, err = strconv.ParseFloat(value, 64) case variable.SlowLogCopWaitAddr: st.maxWaitAddress = value case variable.SlowLogMemMax: - num, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return errors.AddStack(err) - } - st.memMax = num + st.memMax, err = strconv.ParseInt(value, 10, 64) + case variable.SlowLogSucc: + st.succ, err = strconv.ParseBool(value) + case variable.SlowLogPlan: + st.plan = value case variable.SlowLogQuerySQLStr: st.sql = value } + if err != nil { + return errors.Wrap(err, "parse slow log failed `"+field+"` error") + } return nil } @@ -325,8 +349,22 @@ func (st *slowQueryTuple) convertToDatumRow() []types.Datum { })) record = append(record, types.NewUintDatum(st.txnStartTs)) record = append(record, types.NewStringDatum(st.user)) + record = append(record, types.NewStringDatum(st.host)) record = append(record, types.NewUintDatum(st.connID)) record = append(record, types.NewFloat64Datum(st.queryTime)) + record = append(record, types.NewFloat64Datum(st.parseTime)) + record = append(record, types.NewFloat64Datum(st.compileTime)) + record = append(record, types.NewFloat64Datum(st.preWriteTime)) + record = append(record, types.NewFloat64Datum(st.commitTime)) + record = append(record, types.NewFloat64Datum(st.getCommitTSTime)) + record = append(record, types.NewFloat64Datum(st.commitBackoffTime)) + record = append(record, types.NewStringDatum(st.backoffTypes)) + record = append(record, types.NewFloat64Datum(st.resolveLockTime)) + record = append(record, types.NewFloat64Datum(st.localLatchWaitTime)) + record = append(record, types.NewUintDatum(st.writeKeys)) + record = append(record, types.NewUintDatum(st.writeSize)) + record = append(record, types.NewUintDatum(st.prewriteRegion)) + record = append(record, types.NewUintDatum(st.txnRetry)) record = append(record, types.NewFloat64Datum(st.processTime)) record = append(record, types.NewFloat64Datum(st.waitTime)) record = append(record, types.NewFloat64Datum(st.backOffTime)) @@ -347,10 +385,31 @@ func (st *slowQueryTuple) convertToDatumRow() []types.Datum { record = append(record, types.NewFloat64Datum(st.maxWaitTime)) record = append(record, types.NewStringDatum(st.maxWaitAddress)) record = append(record, types.NewIntDatum(st.memMax)) + if st.succ { + record = append(record, types.NewIntDatum(1)) + } else { + record = append(record, types.NewIntDatum(0)) + } + record = append(record, types.NewStringDatum(parsePlan(st.plan))) + record = append(record, types.NewStringDatum(st.prevStmt)) record = append(record, types.NewStringDatum(st.sql)) return record } +func parsePlan(planString string) string { + if len(planString) <= len(variable.SlowLogPlanPrefix)+len(variable.SlowLogPlanSuffix) { + return planString + } + planString = planString[len(variable.SlowLogPlanPrefix) : len(planString)-len(variable.SlowLogPlanSuffix)] + decodePlanString, err := plancodec.DecodePlan(planString) + if err == nil { + planString = decodePlanString + } else { + logutil.Logger(context.Background()).Error("decode plan in slow log failed", zap.String("plan", planString), zap.Error(err)) + } + return planString +} + // ParseTime exports for testing. func ParseTime(s string) (time.Time, error) { t, err := time.Parse(logutil.SlowLogTimeFormat, s) diff --git a/infoschema/slow_log_test.go b/infoschema/slow_log_test.go index 5692f609c8a47..c151bd117521b 100644 --- a/infoschema/slow_log_test.go +++ b/infoschema/slow_log_test.go @@ -37,6 +37,8 @@ func (s *testSuite) TestParseSlowLogFile(c *C) { # Cop_proc_avg: 0.1 Cop_proc_p90: 0.2 Cop_proc_max: 0.03 Cop_proc_addr: 127.0.0.1:20160 # Cop_wait_avg: 0.05 Cop_wait_p90: 0.6 Cop_wait_max: 0.8 Cop_wait_addr: 0.0.0.0:20160 # Mem_max: 70724 +# Succ: false +# Prev_stmt: update t set i = 1; select * from t;`) reader := bufio.NewReader(slowLog) loc, err := time.LoadLocation("Asia/Shanghai") @@ -53,7 +55,7 @@ select * from t;`) } recordString += str } - expectRecordString := "2019-04-28 15:24:04.309074,405888132465033227,,0,0.216905,0.021,0,0,1,637,0,,,1,42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772,t1:1,t2:2,0.1,0.2,0.03,127.0.0.1:20160,0.05,0.6,0.8,0.0.0.0:20160,70724,select * from t;" + expectRecordString := "2019-04-28 15:24:04.309074,405888132465033227,,,0,0.216905,0,0,0,0,0,0,,0,0,0,0,0,0,0.021,0,0,1,637,0,,,1,42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772,t1:1,t2:2,0.1,0.2,0.03,127.0.0.1:20160,0.05,0.6,0.8,0.0.0.0:20160,70724,0,,update t set i = 1;,select * from t;" c.Assert(expectRecordString, Equals, recordString) // fix sql contain '# ' bug @@ -67,6 +69,7 @@ select a# from t; # Is_internal: true # Digest: 42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772 # Stats: t1:1,t2:2 +# Succ: false select * from t; `) reader = bufio.NewReader(slowLog) @@ -110,6 +113,17 @@ select * from t; reader = bufio.NewReader(slowLog) _, err = infoschema.ParseSlowLog(loc, reader) c.Assert(err, IsNil) + + // Add parse error check. + slowLog = bytes.NewBufferString( + `# Time: 2019-04-28T15:24:04.309074+08:00 +# Succ: abc +select * from t; +`) + reader = bufio.NewReader(slowLog) + _, err = infoschema.ParseSlowLog(loc, reader) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, "parse slow log failed `Succ` error: strconv.ParseBool: parsing \"abc\": invalid syntax") } func (s *testSuite) TestSlowLogParseTime(c *C) { diff --git a/infoschema/tables.go b/infoschema/tables.go index 697a68a7d1997..71a2701577eb9 100644 --- a/infoschema/tables.go +++ b/infoschema/tables.go @@ -538,11 +538,12 @@ var tableProcesslistCols = []columnInfo{ {"ID", mysql.TypeLonglong, 21, mysql.NotNullFlag, 0, nil}, {"USER", mysql.TypeVarchar, 16, mysql.NotNullFlag, "", nil}, {"HOST", mysql.TypeVarchar, 64, mysql.NotNullFlag, "", nil}, - {"DB", mysql.TypeVarchar, 64, mysql.NotNullFlag, "", nil}, + {"DB", mysql.TypeVarchar, 64, 0, nil, nil}, {"COMMAND", mysql.TypeVarchar, 16, mysql.NotNullFlag, "", nil}, {"TIME", mysql.TypeLong, 7, mysql.NotNullFlag, 0, nil}, {"STATE", mysql.TypeVarchar, 7, 0, nil, nil}, {"INFO", mysql.TypeString, 512, 0, nil, nil}, + {"MEM", mysql.TypeLonglong, 21, 0, nil, nil}, } var tableTiDBIndexesCols = []columnInfo{ @@ -563,6 +564,7 @@ var tableTiDBHotRegionsCols = []columnInfo{ {"DB_NAME", mysql.TypeVarchar, 64, 0, nil, nil}, {"TABLE_NAME", mysql.TypeVarchar, 64, 0, nil, nil}, {"INDEX_NAME", mysql.TypeVarchar, 64, 0, nil, nil}, + {"REGION_ID", mysql.TypeLonglong, 21, 0, nil, nil}, {"TYPE", mysql.TypeVarchar, 64, 0, nil, nil}, {"MAX_HOT_DEGREE", mysql.TypeLonglong, 21, 0, nil, nil}, {"REGION_COUNT", mysql.TypeLonglong, 21, 0, nil, nil}, @@ -583,8 +585,8 @@ var tableTiKVStoreStatusCols = []columnInfo{ {"LEADER_SCORE", mysql.TypeLonglong, 21, 0, nil, nil}, {"LEADER_SIZE", mysql.TypeLonglong, 21, 0, nil, nil}, {"REGION_COUNT", mysql.TypeLonglong, 21, 0, nil, nil}, - {"REGION_WEIGHT", mysql.TypeLonglong, 21, 0, nil, nil}, - {"REGION_SCORE", mysql.TypeLonglong, 21, 0, nil, nil}, + {"REGION_WEIGHT", mysql.TypeDouble, 22, 0, nil, nil}, + {"REGION_SCORE", mysql.TypeDouble, 22, 0, nil, nil}, {"REGION_SIZE", mysql.TypeLonglong, 21, 0, nil, nil}, {"START_TS", mysql.TypeDatetime, 0, 0, nil, nil}, {"LAST_HEARTBEAT_TS", mysql.TypeDatetime, 0, 0, nil, nil}, @@ -685,18 +687,18 @@ func dataForTikVRegionPeers(ctx sessionctx.Context) (records [][]types.Datum, er row[0].SetInt64(regionStat.ID) row[1].SetInt64(peer.ID) row[2].SetInt64(peer.StoreID) - if peer.ID == regionStat.Leader.ID { + if peer.IsLearner { row[3].SetInt64(1) } else { row[3].SetInt64(0) } - if peer.IsLearner { + if peer.ID == regionStat.Leader.ID { row[4].SetInt64(1) } else { - row[4].SetInt64(0) + row[3].SetInt64(0) } if pendingPeerIDSet.Exist(peer.ID) { - row[5].SetString(pendingPeer) + row[4].SetString(pendingPeer) } else if downSec, ok := downPeerMap[peer.ID]; ok { row[5].SetString(downPeer) row[6].SetInt64(downSec) @@ -745,8 +747,8 @@ func dataForTiKVStoreStatus(ctx sessionctx.Context) (records [][]types.Datum, er row[10].SetInt64(storeStat.Status.LeaderScore) row[11].SetInt64(storeStat.Status.LeaderSize) row[12].SetInt64(storeStat.Status.RegionCount) - row[13].SetInt64(storeStat.Status.RegionWeight) - row[14].SetInt64(storeStat.Status.RegionScore) + row[13].SetFloat64(storeStat.Status.RegionWeight) + row[14].SetFloat64(storeStat.Status.RegionScore) row[15].SetInt64(storeStat.Status.RegionSize) startTs := types.Time{ Time: types.FromGoTime(storeStat.Status.StartTs), @@ -861,7 +863,7 @@ func dataForProcesslist(ctx sessionctx.Context) [][]types.Datum { continue } - rows := pi.ToRow(true) + rows := pi.ToRow() record := types.MakeDatums(rows...) records = append(records, record) } @@ -989,7 +991,7 @@ func getColLengthAllTables(ctx sessionctx.Context) (map[tableHistID]uint64, erro return colLengthMap, nil } -func getDataAndIndexLength(info *model.TableInfo, rowCount uint64, columnLengthMap map[tableHistID]uint64) (uint64, uint64) { +func getDataAndIndexLength(info *model.TableInfo, physicalID int64, rowCount uint64, columnLengthMap map[tableHistID]uint64) (uint64, uint64) { columnLength := make(map[string]uint64) for _, col := range info.Columns { if col.State != model.StatePublic { @@ -999,7 +1001,7 @@ func getDataAndIndexLength(info *model.TableInfo, rowCount uint64, columnLengthM if length != types.VarStorageLen { columnLength[col.Name.L] = rowCount * uint64(length) } else { - length := columnLengthMap[tableHistID{tableID: info.ID, histID: col.ID}] + length := columnLengthMap[tableHistID{tableID: physicalID, histID: col.ID}] columnLength[col.Name.L] = length } } @@ -1072,26 +1074,12 @@ func (c *statsCache) get(ctx sessionctx.Context) (map[int64]uint64, map[tableHis } func getAutoIncrementID(ctx sessionctx.Context, schema *model.DBInfo, tblInfo *model.TableInfo) (int64, error) { - hasAutoIncID := false - for _, col := range tblInfo.Cols() { - if mysql.HasAutoIncrementFlag(col.Flag) { - hasAutoIncID = true - break - } - } - autoIncID := tblInfo.AutoIncID - if hasAutoIncID { - is := ctx.GetSessionVars().TxnCtx.InfoSchema.(InfoSchema) - tbl, err := is.TableByName(schema.Name, tblInfo.Name) - if err != nil { - return 0, err - } - autoIncID, err = tbl.Allocator(ctx).NextGlobalAutoID(tblInfo.ID) - if err != nil { - return 0, err - } + is := ctx.GetSessionVars().TxnCtx.InfoSchema.(InfoSchema) + tbl, err := is.TableByName(schema.Name, tblInfo.Name) + if err != nil { + return 0, err } - return autoIncID, nil + return tbl.Allocator(ctx).Base() + 1, nil } func dataForViews(ctx sessionctx.Context, schemas []*model.DBInfo) ([][]types.Datum, error) { @@ -1162,12 +1150,27 @@ func dataForTables(ctx sessionctx.Context, schemas []*model.DBInfo) ([][]types.D if table.GetPartitionInfo() != nil { createOptions = "partitioned" } - autoIncID, err := getAutoIncrementID(ctx, schema, table) - if err != nil { - return nil, err + var autoIncID interface{} + hasAutoIncID, _ := HasAutoIncrementColumn(table) + if hasAutoIncID { + autoIncID, err = getAutoIncrementID(ctx, schema, table) + if err != nil { + return nil, err + } + } + + var rowCount, dataLength, indexLength uint64 + if table.GetPartitionInfo() == nil { + rowCount = tableRowsMap[table.ID] + dataLength, indexLength = getDataAndIndexLength(table, table.ID, rowCount, colLengthMap) + } else { + for _, pi := range table.GetPartitionInfo().Definitions { + rowCount += tableRowsMap[pi.ID] + parDataLen, parIndexLen := getDataAndIndexLength(table, pi.ID, tableRowsMap[pi.ID], colLengthMap) + dataLength += parDataLen + indexLength += parIndexLen + } } - rowCount := tableRowsMap[table.ID] - dataLength, indexLength := getDataAndIndexLength(table, rowCount, colLengthMap) avgRowLength := uint64(0) if rowCount != 0 { avgRowLength = dataLength / rowCount @@ -1663,9 +1666,9 @@ func dataForTiDBHotRegions(ctx sessionctx.Context) (records [][]types.Datum, err return records, nil } -func dataForHotRegionByMetrics(metrics map[helper.TblIndex]helper.RegionMetric, tp string) [][]types.Datum { +func dataForHotRegionByMetrics(metrics []helper.HotTableIndex, tp string) [][]types.Datum { rows := make([][]types.Datum, 0, len(metrics)) - for tblIndex, regionMetric := range metrics { + for _, tblIndex := range metrics { row := make([]types.Datum, len(tableTiDBHotRegionsCols)) if tblIndex.IndexName != "" { row[1].SetInt64(tblIndex.IndexID) @@ -1677,10 +1680,16 @@ func dataForHotRegionByMetrics(metrics map[helper.TblIndex]helper.RegionMetric, row[0].SetInt64(tblIndex.TableID) row[2].SetString(tblIndex.DbName) row[3].SetString(tblIndex.TableName) - row[5].SetString(tp) - row[6].SetInt64(int64(regionMetric.MaxHotDegree)) - row[7].SetInt64(int64(regionMetric.Count)) - row[8].SetUint64(regionMetric.FlowBytes) + row[5].SetUint64(tblIndex.RegionID) + row[6].SetString(tp) + if tblIndex.RegionMetric == nil { + row[7].SetNull() + row[8].SetNull() + } else { + row[7].SetInt64(int64(tblIndex.RegionMetric.MaxHotDegree)) + row[8].SetInt64(int64(tblIndex.RegionMetric.Count)) + } + row[9].SetUint64(tblIndex.RegionMetric.FlowBytes) rows = append(rows, row) } return rows @@ -1870,6 +1879,7 @@ func (it *infoschemaTable) getRows(ctx sessionctx.Context, cols []*table.Column) return rows, nil } +// IterRecords implements table.Table IterRecords interface. func (it *infoschemaTable) IterRecords(ctx sessionctx.Context, startKey kv.Key, cols []*table.Column, fn table.RecordIterFunc) error { if len(startKey) != 0 { @@ -1891,6 +1901,7 @@ func (it *infoschemaTable) IterRecords(ctx sessionctx.Context, startKey kv.Key, return nil } +// RowWithCols implements table.Table RowWithCols interface. func (it *infoschemaTable) RowWithCols(ctx sessionctx.Context, h int64, cols []*table.Column) ([]types.Datum, error) { return nil, table.ErrUnsupportedOp } @@ -1900,79 +1911,97 @@ func (it *infoschemaTable) Row(ctx sessionctx.Context, h int64) ([]types.Datum, return nil, table.ErrUnsupportedOp } +// Cols implements table.Table Cols interface. func (it *infoschemaTable) Cols() []*table.Column { return it.cols } +// WritableCols implements table.Table WritableCols interface. func (it *infoschemaTable) WritableCols() []*table.Column { return it.cols } +// Indices implements table.Table Indices interface. func (it *infoschemaTable) Indices() []table.Index { return nil } +// WritableIndices implements table.Table WritableIndices interface. func (it *infoschemaTable) WritableIndices() []table.Index { return nil } +// DeletableIndices implements table.Table DeletableIndices interface. func (it *infoschemaTable) DeletableIndices() []table.Index { return nil } +// RecordPrefix implements table.Table RecordPrefix interface. func (it *infoschemaTable) RecordPrefix() kv.Key { return nil } +// IndexPrefix implements table.Table IndexPrefix interface. func (it *infoschemaTable) IndexPrefix() kv.Key { return nil } +// FirstKey implements table.Table FirstKey interface. func (it *infoschemaTable) FirstKey() kv.Key { return nil } +// RecordKey implements table.Table RecordKey interface. func (it *infoschemaTable) RecordKey(h int64) kv.Key { return nil } +// AddRecord implements table.Table AddRecord interface. func (it *infoschemaTable) AddRecord(ctx sessionctx.Context, r []types.Datum, opts ...*table.AddRecordOpt) (recordID int64, err error) { return 0, table.ErrUnsupportedOp } +// RemoveRecord implements table.Table RemoveRecord interface. func (it *infoschemaTable) RemoveRecord(ctx sessionctx.Context, h int64, r []types.Datum) error { return table.ErrUnsupportedOp } +// UpdateRecord implements table.Table UpdateRecord interface. func (it *infoschemaTable) UpdateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datum, touched []bool) error { return table.ErrUnsupportedOp } -func (it *infoschemaTable) AllocAutoID(ctx sessionctx.Context) (int64, error) { +// AllocHandle implements table.Table AllocHandle interface. +func (it *infoschemaTable) AllocHandle(ctx sessionctx.Context) (int64, error) { return 0, table.ErrUnsupportedOp } +// Allocator implements table.Table Allocator interface. func (it *infoschemaTable) Allocator(ctx sessionctx.Context) autoid.Allocator { return nil } +// RebaseAutoID implements table.Table RebaseAutoID interface. func (it *infoschemaTable) RebaseAutoID(ctx sessionctx.Context, newBase int64, isSetStep bool) error { return table.ErrUnsupportedOp } +// Meta implements table.Table Meta interface. func (it *infoschemaTable) Meta() *model.TableInfo { return it.meta } +// GetPhysicalID implements table.Table GetPhysicalID interface. func (it *infoschemaTable) GetPhysicalID() int64 { return it.meta.ID } -// Seek is the first method called for table scan, we lazy initialize it here. +// Seek implements table.Table Seek interface. func (it *infoschemaTable) Seek(ctx sessionctx.Context, h int64) (int64, bool, error) { return 0, false, table.ErrUnsupportedOp } +// Type implements table.Table Type interface. func (it *infoschemaTable) Type() table.Type { return table.VirtualTable } @@ -1980,7 +2009,7 @@ func (it *infoschemaTable) Type() table.Type { // VirtualTable is a dummy table.Table implementation. type VirtualTable struct{} -// IterRecords implements table.Table Type interface. +// IterRecords implements table.Table IterRecords interface. func (vt *VirtualTable) IterRecords(ctx sessionctx.Context, startKey kv.Key, cols []*table.Column, fn table.RecordIterFunc) error { if len(startKey) != 0 { @@ -1989,92 +2018,92 @@ func (vt *VirtualTable) IterRecords(ctx sessionctx.Context, startKey kv.Key, col return nil } -// RowWithCols implements table.Table Type interface. +// RowWithCols implements table.Table RowWithCols interface. func (vt *VirtualTable) RowWithCols(ctx sessionctx.Context, h int64, cols []*table.Column) ([]types.Datum, error) { return nil, table.ErrUnsupportedOp } -// Row implements table.Table Type interface. +// Row implements table.Table Row interface. func (vt *VirtualTable) Row(ctx sessionctx.Context, h int64) ([]types.Datum, error) { return nil, table.ErrUnsupportedOp } -// Cols implements table.Table Type interface. +// Cols implements table.Table Cols interface. func (vt *VirtualTable) Cols() []*table.Column { return nil } -// WritableCols implements table.Table Type interface. +// WritableCols implements table.Table WritableCols interface. func (vt *VirtualTable) WritableCols() []*table.Column { return nil } -// Indices implements table.Table Type interface. +// Indices implements table.Table Indices interface. func (vt *VirtualTable) Indices() []table.Index { return nil } -// WritableIndices implements table.Table Type interface. +// WritableIndices implements table.Table WritableIndices interface. func (vt *VirtualTable) WritableIndices() []table.Index { return nil } -// DeletableIndices implements table.Table Type interface. +// DeletableIndices implements table.Table DeletableIndices interface. func (vt *VirtualTable) DeletableIndices() []table.Index { return nil } -// RecordPrefix implements table.Table Type interface. +// RecordPrefix implements table.Table RecordPrefix interface. func (vt *VirtualTable) RecordPrefix() kv.Key { return nil } -// IndexPrefix implements table.Table Type interface. +// IndexPrefix implements table.Table IndexPrefix interface. func (vt *VirtualTable) IndexPrefix() kv.Key { return nil } -// FirstKey implements table.Table Type interface. +// FirstKey implements table.Table FirstKey interface. func (vt *VirtualTable) FirstKey() kv.Key { return nil } -// RecordKey implements table.Table Type interface. +// RecordKey implements table.Table RecordKey interface. func (vt *VirtualTable) RecordKey(h int64) kv.Key { return nil } -// AddRecord implements table.Table Type interface. +// AddRecord implements table.Table AddRecord interface. func (vt *VirtualTable) AddRecord(ctx sessionctx.Context, r []types.Datum, opts ...*table.AddRecordOpt) (recordID int64, err error) { return 0, table.ErrUnsupportedOp } -// RemoveRecord implements table.Table Type interface. +// RemoveRecord implements table.Table RemoveRecord interface. func (vt *VirtualTable) RemoveRecord(ctx sessionctx.Context, h int64, r []types.Datum) error { return table.ErrUnsupportedOp } -// UpdateRecord implements table.Table Type interface. +// UpdateRecord implements table.Table UpdateRecord interface. func (vt *VirtualTable) UpdateRecord(ctx sessionctx.Context, h int64, oldData, newData []types.Datum, touched []bool) error { return table.ErrUnsupportedOp } -// AllocAutoID implements table.Table Type interface. -func (vt *VirtualTable) AllocAutoID(ctx sessionctx.Context) (int64, error) { +// AllocHandle implements table.Table AllocHandle interface. +func (vt *VirtualTable) AllocHandle(ctx sessionctx.Context) (int64, error) { return 0, table.ErrUnsupportedOp } -// Allocator implements table.Table Type interface. +// Allocator implements table.Table Allocator interface. func (vt *VirtualTable) Allocator(ctx sessionctx.Context) autoid.Allocator { return nil } -// RebaseAutoID implements table.Table Type interface. +// RebaseAutoID implements table.Table RebaseAutoID interface. func (vt *VirtualTable) RebaseAutoID(ctx sessionctx.Context, newBase int64, isSetStep bool) error { return table.ErrUnsupportedOp } -// Meta implements table.Table Type interface. +// Meta implements table.Table Meta interface. func (vt *VirtualTable) Meta() *model.TableInfo { return nil } @@ -2084,7 +2113,7 @@ func (vt *VirtualTable) GetPhysicalID() int64 { return 0 } -// Seek implements table.Table Type interface. +// Seek implements table.Table Seek interface. func (vt *VirtualTable) Seek(ctx sessionctx.Context, h int64) (int64, bool, error) { return 0, false, table.ErrUnsupportedOp } diff --git a/infoschema/tables_test.go b/infoschema/tables_test.go index 31020e99ed846..1b278b4b84bed 100644 --- a/infoschema/tables_test.go +++ b/infoschema/tables_test.go @@ -17,10 +17,13 @@ import ( "fmt" "os" "strconv" + "strings" . "github.com/pingcap/check" "github.com/pingcap/parser/auth" + "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" @@ -47,7 +50,7 @@ func (s *testTableSuite) SetUpSuite(c *C) { var err error s.store, err = mockstore.NewMockTikvStore() c.Assert(err, IsNil) - session.SetStatsLease(0) + session.DisableStats4Test() s.dom, err = session.BootstrapSession(s.store) c.Assert(err, IsNil) } @@ -83,7 +86,21 @@ func (s *testTableSuite) TestInfoschemaFieldValue(c *C) { testkit.Rows("1")) tk.MustExec("insert into t(c, d) values(1, 1)") tk.MustQuery("select auto_increment from information_schema.tables where table_name='t'").Check( - testkit.Rows("30002")) + testkit.Rows("2")) + + tk.MustQuery("show create table t").Check( + testkit.Rows("" + + "t CREATE TABLE `t` (\n" + + " `c` int(11) NOT NULL AUTO_INCREMENT,\n" + + " `d` int(11) DEFAULT NULL,\n" + + " PRIMARY KEY (`c`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin AUTO_INCREMENT=30002")) + + // Test auto_increment for table without auto_increment column + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (d int)") + tk.MustQuery("select auto_increment from information_schema.tables where table_name='t'").Check( + testkit.Rows("")) tk.MustExec("create user xxx") tk.MustExec("flush privileges") @@ -108,12 +125,13 @@ func (s *testTableSuite) TestInfoschemaFieldValue(c *C) { tk1.MustQuery("select distinct(table_schema) from information_schema.tables").Check(testkit.Rows("INFORMATION_SCHEMA")) // Fix issue 9836 - sm := &mockSessionManager{make(map[uint64]util.ProcessInfo, 1)} - sm.processInfoMap[1] = util.ProcessInfo{ + sm := &mockSessionManager{make(map[uint64]*util.ProcessInfo, 1)} + sm.processInfoMap[1] = &util.ProcessInfo{ ID: 1, User: "root", Host: "127.0.0.1", Command: mysql.ComQuery, + StmtCtx: tk.Se.GetSessionVars().StmtCtx, } tk.Se.SetSessionManager(sm) tk.MustQuery("SELECT user,host,command FROM information_schema.processlist;").Check(testkit.Rows("root 127.0.0.1 Query")) @@ -138,25 +156,35 @@ func (s *testTableSuite) TestDataForTableStatsField(c *C) { tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( testkit.Rows("0 0 0 0")) tk.MustExec(`insert into t(c, d, e) values(1, 2, "c"), (2, 3, "d"), (3, 4, "e")`) - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( testkit.Rows("3 17 51 3")) tk.MustExec(`insert into t(c, d, e) values(4, 5, "f")`) - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( testkit.Rows("4 17 68 4")) tk.MustExec("delete from t where c >= 3") - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( testkit.Rows("2 17 34 2")) tk.MustExec("delete from t where c=3") - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( testkit.Rows("2 17 34 2")) + + // Test partition table. + tk.MustExec("drop table if exists t") + tk.MustExec(`CREATE TABLE t (a int, b int, c varchar(5), primary key(a), index idx(c)) PARTITION BY RANGE (a) (PARTITION p0 VALUES LESS THAN (6), PARTITION p1 VALUES LESS THAN (11), PARTITION p2 VALUES LESS THAN (16))`) + h.HandleDDLEvent(<-h.DDLEventCh()) + tk.MustExec(`insert into t(a, b, c) values(1, 2, "c"), (7, 3, "d"), (12, 4, "e")`) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) + tk.MustQuery("select table_rows, avg_row_length, data_length, index_length from information_schema.tables where table_name='t'").Check( + testkit.Rows("3 17 51 3")) } func (s *testTableSuite) TestCharacterSetCollations(c *C) { @@ -240,13 +268,53 @@ func (s *testTableSuite) TestCharacterSetCollations(c *C) { tk.MustExec("DROP DATABASE charset_collate_test") } +func (s *testTableSuite) TestCurrentTimestampAsDefault(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("DROP DATABASE IF EXISTS default_time_test") + tk.MustExec("CREATE DATABASE default_time_test; USE default_time_test") + + tk.MustExec(`CREATE TABLE default_time_table( + c_datetime datetime, + c_datetime_default datetime default current_timestamp, + c_datetime_default_2 datetime(2) default current_timestamp(2), + c_timestamp timestamp, + c_timestamp_default timestamp default current_timestamp, + c_timestamp_default_3 timestamp(3) default current_timestamp(3), + c_varchar_default varchar(20) default "current_timestamp", + c_varchar_default_3 varchar(20) default "current_timestamp(3)", + c_varchar_default_on_update datetime default current_timestamp on update current_timestamp, + c_varchar_default_on_update_fsp datetime(3) default current_timestamp(3) on update current_timestamp(3), + c_varchar_default_with_case varchar(20) default "cUrrent_tImestamp" + );`) + + tk.MustQuery(`SELECT column_name, column_default, extra + FROM information_schema.COLUMNS + WHERE table_schema = "default_time_test" AND table_name = "default_time_table" + ORDER BY column_name`, + ).Check(testkit.Rows( + "c_datetime ", + "c_datetime_default CURRENT_TIMESTAMP ", + "c_datetime_default_2 CURRENT_TIMESTAMP(2) ", + "c_timestamp ", + "c_timestamp_default CURRENT_TIMESTAMP ", + "c_timestamp_default_3 CURRENT_TIMESTAMP(3) ", + "c_varchar_default current_timestamp ", + "c_varchar_default_3 current_timestamp(3) ", + "c_varchar_default_on_update CURRENT_TIMESTAMP DEFAULT_GENERATED on update CURRENT_TIMESTAMP", + "c_varchar_default_on_update_fsp CURRENT_TIMESTAMP(3) DEFAULT_GENERATED on update CURRENT_TIMESTAMP(3)", + "c_varchar_default_with_case cUrrent_tImestamp ", + )) + tk.MustExec("DROP DATABASE default_time_test") +} + type mockSessionManager struct { - processInfoMap map[uint64]util.ProcessInfo + processInfoMap map[uint64]*util.ProcessInfo } -func (sm *mockSessionManager) ShowProcessList() map[uint64]util.ProcessInfo { return sm.processInfoMap } +func (sm *mockSessionManager) ShowProcessList() map[uint64]*util.ProcessInfo { return sm.processInfoMap } -func (sm *mockSessionManager) GetProcessInfo(id uint64) (util.ProcessInfo, bool) { +func (sm *mockSessionManager) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { rs, ok := sm.processInfoMap[id] return rs, ok } @@ -267,27 +335,89 @@ func (s *testTableSuite) TestSomeTables(c *C) { testkit.Rows("def mysql columns_priv 0 mysql PRIMARY 1 Host A BTREE ")) tk.MustQuery("select * from information_schema.USER_PRIVILEGES where PRIVILEGE_TYPE='Select';").Check(testkit.Rows("'root'@'%' def Select YES")) - sm := &mockSessionManager{make(map[uint64]util.ProcessInfo, 2)} - sm.processInfoMap[1] = util.ProcessInfo{ + sm := &mockSessionManager{make(map[uint64]*util.ProcessInfo, 2)} + sm.processInfoMap[1] = &util.ProcessInfo{ ID: 1, User: "user-1", Host: "localhost", DB: "information_schema", Command: byte(1), State: 1, - Info: "do something"} - sm.processInfoMap[2] = util.ProcessInfo{ + Info: "do something", + StmtCtx: tk.Se.GetSessionVars().StmtCtx, + } + sm.processInfoMap[2] = &util.ProcessInfo{ ID: 2, User: "user-2", Host: "localhost", DB: "test", Command: byte(2), State: 2, - Info: "do something"} + Info: strings.Repeat("x", 101), + StmtCtx: tk.Se.GetSessionVars().StmtCtx, + } + tk.Se.SetSessionManager(sm) + tk.MustQuery("select * from information_schema.PROCESSLIST order by ID;").Sort().Check( + testkit.Rows( + fmt.Sprintf("1 user-1 localhost information_schema Quit 9223372036 1 %s 0", "do something"), + fmt.Sprintf("2 user-2 localhost test Init DB 9223372036 2 %s 0", strings.Repeat("x", 101)), + )) + tk.MustQuery("SHOW PROCESSLIST;").Sort().Check( + testkit.Rows( + fmt.Sprintf("1 user-1 localhost information_schema Quit 9223372036 1 %s", "do something"), + fmt.Sprintf("2 user-2 localhost test Init DB 9223372036 2 %s", strings.Repeat("x", 100)), + )) + tk.MustQuery("SHOW FULL PROCESSLIST;").Sort().Check( + testkit.Rows( + fmt.Sprintf("1 user-1 localhost information_schema Quit 9223372036 1 %s", "do something"), + fmt.Sprintf("2 user-2 localhost test Init DB 9223372036 2 %s", strings.Repeat("x", 101)), + )) + + sm = &mockSessionManager{make(map[uint64]*util.ProcessInfo, 2)} + sm.processInfoMap[1] = &util.ProcessInfo{ + ID: 1, + User: "user-1", + Host: "localhost", + DB: "information_schema", + Command: byte(1), + State: 1, + Info: nil, + StmtCtx: tk.Se.GetSessionVars().StmtCtx, + } + sm.processInfoMap[2] = &util.ProcessInfo{ + ID: 2, + User: "user-2", + Host: "localhost", + DB: nil, + Command: byte(2), + State: 2, + Info: strings.Repeat("x", 101), + StmtCtx: tk.Se.GetSessionVars().StmtCtx, + } tk.Se.SetSessionManager(sm) tk.MustQuery("select * from information_schema.PROCESSLIST order by ID;").Check( - testkit.Rows("1 user-1 localhost information_schema Quit 9223372036 1 do something", - "2 user-2 localhost test Init DB 9223372036 2 do something")) + testkit.Rows( + fmt.Sprintf("1 user-1 localhost information_schema Quit 9223372036 1 %s 0", ""), + fmt.Sprintf("2 user-2 localhost Init DB 9223372036 2 %s 0", strings.Repeat("x", 101)), + )) + tk.MustQuery("SHOW PROCESSLIST;").Sort().Check( + testkit.Rows( + fmt.Sprintf("1 user-1 localhost information_schema Quit 9223372036 1 %s", ""), + fmt.Sprintf("2 user-2 localhost Init DB 9223372036 2 %s", strings.Repeat("x", 100)), + )) + tk.MustQuery("SHOW FULL PROCESSLIST;").Sort().Check( + testkit.Rows( + fmt.Sprintf("1 user-1 localhost information_schema Quit 9223372036 1 %s", ""), + fmt.Sprintf("2 user-2 localhost Init DB 9223372036 2 %s", strings.Repeat("x", 101)), + )) + tk.MustQuery("select * from information_schema.PROCESSLIST where db is null;").Check( + testkit.Rows( + fmt.Sprintf("2 user-2 localhost Init DB 9223372036 2 %s 0", strings.Repeat("x", 101)), + )) + tk.MustQuery("select * from information_schema.PROCESSLIST where Info is null;").Check( + testkit.Rows( + fmt.Sprintf("1 user-1 localhost information_schema Quit 9223372036 1 %s 0", ""), + )) } func (s *testTableSuite) TestSchemataCharacterSet(c *C) { @@ -334,6 +464,9 @@ func (s *testTableSuite) TestSlowQuery(c *C) { # User: root@127.0.0.1 # Conn_ID: 6 # Query_time: 4.895492 +# Parse_time: 0.4 +# Compile_time: 0.2 +# Request_count: 1 Prewrite_time: 0.19 Commit_time: 0.01 Commit_backoff_time: 0.18 Backoff_types: [txnLock] Resolve_lock_time: 0.03 Write_keys: 15 Write_size: 480 Prewrite_region: 1 Txn_retry: 8 # Process_time: 0.161 Request_count: 1 Total_keys: 100001 Process_keys: 100000 # Wait_time: 0.101 # Backoff_time: 0.092 @@ -344,18 +477,38 @@ func (s *testTableSuite) TestSlowQuery(c *C) { # Cop_proc_avg: 0.1 Cop_proc_p90: 0.2 Cop_proc_max: 0.03 Cop_proc_addr: 127.0.0.1:20160 # Cop_wait_avg: 0.05 Cop_wait_p90: 0.6 Cop_wait_max: 0.8 Cop_wait_addr: 0.0.0.0:20160 # Mem_max: 70724 +# Succ: true +# Plan: abcd +# Prev_stmt: update t set i = 2; select * from t_slim;`)) - c.Assert(f.Close(), IsNil) + c.Assert(f.Sync(), IsNil) c.Assert(err, IsNil) tk.MustExec(fmt.Sprintf("set @@tidb_slow_query_file='%v'", slowLogFileName)) tk.MustExec("set time_zone = '+08:00';") re := tk.MustQuery("select * from information_schema.slow_query") re.Check(testutil.RowsWithSep("|", - "2019-02-12 19:33:56.571953|406315658548871171|root@127.0.0.1|6|4.895492|0.161|0.101|0.092|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|select * from t_slim;")) + "2019-02-12 19:33:56.571953|406315658548871171|root|127.0.0.1|6|4.895492|0.4|0.2|0.19|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.161|0.101|0.092|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|1|abcd|update t set i = 2;|select * from t_slim;")) tk.MustExec("set time_zone = '+00:00';") re = tk.MustQuery("select * from information_schema.slow_query") - re.Check(testutil.RowsWithSep("|", "2019-02-12 11:33:56.571953|406315658548871171|root@127.0.0.1|6|4.895492|0.161|0.101|0.092|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|select * from t_slim;")) + re.Check(testutil.RowsWithSep("|", "2019-02-12 11:33:56.571953|406315658548871171|root|127.0.0.1|6|4.895492|0.4|0.2|0.19|0.01|0|0.18|[txnLock]|0.03|0|15|480|1|8|0.161|0.101|0.092|1|100001|100000|test||0|42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772|t1:1,t2:2|0.1|0.2|0.03|127.0.0.1:20160|0.05|0.6|0.8|0.0.0.0:20160|70724|1|abcd|update t set i = 2;|select * from t_slim;")) + + // Test for long query. + _, err = f.Write([]byte(` +# Time: 2019-02-13T19:33:56.571953+08:00 +`)) + c.Assert(err, IsNil) + sql := "select * from " + for len(sql) < 5000 { + sql += "abcdefghijklmnopqrstuvwxyz_1234567890_qwertyuiopasdfghjklzxcvbnm" + } + sql += ";" + _, err = f.Write([]byte(sql)) + c.Assert(err, IsNil) + c.Assert(f.Close(), IsNil) + re = tk.MustQuery("select query from information_schema.slow_query order by time desc limit 1") + rows := re.Rows() + c.Assert(rows[0][0], Equals, sql) } func (s *testTableSuite) TestForAnalyzeStatus(c *C) { @@ -387,3 +540,21 @@ func (s *testTableSuite) TestForAnalyzeStatus(c *C) { c.Assert(result.Rows()[1][5], NotNil) c.Assert(result.Rows()[1][6], Equals, "finished") } + +func (s *testTableSuite) TestReloadDropDatabase(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("create database test_dbs") + tk.MustExec("use test_dbs") + tk.MustExec("create table t1 (a int)") + tk.MustExec("create table t2 (a int)") + tk.MustExec("create table t3 (a int)") + is := domain.GetDomain(tk.Se).InfoSchema() + t2, err := is.TableByName(model.NewCIStr("test_dbs"), model.NewCIStr("t2")) + c.Assert(err, IsNil) + tk.MustExec("drop database test_dbs") + is = domain.GetDomain(tk.Se).InfoSchema() + _, err = is.TableByName(model.NewCIStr("test_dbs"), model.NewCIStr("t2")) + c.Assert(terror.ErrorEqual(infoschema.ErrTableNotExists, err), IsTrue) + _, ok := is.TableByID(t2.Meta().ID) + c.Assert(ok, IsFalse) +} diff --git a/kv/kv.go b/kv/kv.go index 7e98844c2082f..7ca90087377e2 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -47,6 +47,8 @@ const ( KeyOnly // Pessimistic is defined for pessimistic lock Pessimistic + // SnapshotTS is defined to set snapshot ts. + SnapshotTS ) // Priority value for transaction priority. @@ -134,7 +136,7 @@ type Transaction interface { // String implements fmt.Stringer interface. String() string // LockKeys tries to lock the entries with the keys in KV store. - LockKeys(ctx context.Context, forUpdateTS uint64, keys ...Key) error + LockKeys(ctx context.Context, killed *uint32, forUpdateTS uint64, keys ...Key) error // SetOption sets an option with a value, when val is nil, uses the default // value of this option. SetOption(opt Option, val interface{}) @@ -154,6 +156,8 @@ type Transaction interface { // SetAssertion sets an assertion for an operation on the key. SetAssertion(key Key, assertion AssertionType) // BatchGet gets kv from the memory buffer of statement and transaction, and the kv storage. + // Do not use len(value) == 0 or value == nil to represent non-exist. + // If a key doesn't exist, there shouldn't be any corresponding entry in the result map. BatchGet(keys []Key) (map[string][]byte, error) IsPessimistic() bool } @@ -293,3 +297,10 @@ type Iterator interface { Next() error Close() } + +// SplitableStore is the kv store which supports split regions. +type SplitableStore interface { + SplitRegions(ctx context.Context, splitKey [][]byte, scatter bool) (regionID []uint64, err error) + WaitScatterRegionFinish(regionID uint64, backOff int) error + CheckRegionInScattering(regionID uint64) (bool, error) +} diff --git a/kv/mock.go b/kv/mock.go index 8d007e64e6893..877ab09ff9344 100644 --- a/kv/mock.go +++ b/kv/mock.go @@ -39,7 +39,7 @@ func (t *mockTxn) String() string { return "" } -func (t *mockTxn) LockKeys(_ context.Context, _ uint64, _ ...Key) error { +func (t *mockTxn) LockKeys(_ context.Context, _ *uint32, _ uint64, _ ...Key) error { return nil } diff --git a/kv/mock_test.go b/kv/mock_test.go index 67b4193f7ea9f..4cbe5631e0610 100644 --- a/kv/mock_test.go +++ b/kv/mock_test.go @@ -37,7 +37,7 @@ func (s testMockSuite) TestInterface(c *C) { transaction, err := storage.Begin() c.Check(err, IsNil) - err = transaction.LockKeys(context.Background(), 0, Key("lock")) + err = transaction.LockKeys(context.Background(), nil, 0, Key("lock")) c.Check(err, IsNil) transaction.SetOption(Option(23), struct{}{}) if mock, ok := transaction.(*mockTxn); ok { diff --git a/kv/union_store.go b/kv/union_store.go index ccf0a72a91cd5..27a34923d5aed 100644 --- a/kv/union_store.go +++ b/kv/union_store.go @@ -19,6 +19,8 @@ type UnionStore interface { MemBuffer // Returns related condition pair LookupConditionPair(k Key) *conditionPair + // DeleteConditionPair deletes a condition pair. + DeleteConditionPair(k Key) // WalkBuffer iterates all buffered kv pairs. WalkBuffer(f func(k Key, v []byte) error) error // SetOption sets an option with a value, when val is nil, uses the default @@ -217,6 +219,10 @@ func (us *unionStore) LookupConditionPair(k Key) *conditionPair { return nil } +func (us *unionStore) DeleteConditionPair(k Key) { + delete(us.lazyConditionPairs, string(k)) +} + // SetOption implements the UnionStore SetOption interface. func (us *unionStore) SetOption(opt Option, val interface{}) { us.opts[opt] = val diff --git a/meta/autoid/autoid.go b/meta/autoid/autoid.go old mode 100644 new mode 100755 index 8f52cb1e4adbb..0b65c84595e13 --- a/meta/autoid/autoid.go +++ b/meta/autoid/autoid.go @@ -22,6 +22,7 @@ import ( "github.com/cznic/mathutil" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" @@ -30,6 +31,12 @@ import ( "go.uber.org/zap" ) +const ( + minStep = 30000 + maxStep = 2000000 + defaultConsumeTime = 10 * time.Second +) + // Test needs to change it, so it's a variable. var step = int64(30000) @@ -38,9 +45,10 @@ var errInvalidTableID = terror.ClassAutoid.New(codeInvalidTableID, "invalid Tabl // Allocator is an auto increment id generator. // Just keep id unique actually. type Allocator interface { - // Alloc allocs the next autoID for table with tableID. + // Alloc allocs N consecutive autoID for table with tableID, returning (min, max] of the allocated autoID batch. // It gets a batch of autoIDs at a time. So it does not need to access storage for each call. - Alloc(tableID int64) (int64, error) + // The consecutive feature is used to insert multiple rows in a statement. + Alloc(tableID int64, n uint64) (int64, int64, error) // Rebase rebases the autoID base for table with tableID and the new base value. // If allocIDs is true, it will allocate some IDs and save to the cache. // If allocIDs is false, it will not allocate IDs. @@ -59,8 +67,10 @@ type allocator struct { end int64 store kv.Storage // dbID is current database's ID. - dbID int64 - isUnsigned bool + dbID int64 + isUnsigned bool + lastAllocTime time.Time + step int64 } // GetStep is only used by tests @@ -124,7 +134,7 @@ func (alloc *allocator) rebase4Unsigned(tableID int64, requiredBase uint64, allo uCurrentEnd := uint64(currentEnd) if allocIDs { newBase = mathutil.MaxUint64(uCurrentEnd, requiredBase) - newEnd = mathutil.MinUint64(math.MaxUint64-uint64(step), newBase) + uint64(step) + newEnd = mathutil.MinUint64(math.MaxUint64-uint64(alloc.step), newBase) + uint64(alloc.step) } else { if uCurrentEnd >= requiredBase { newBase = uCurrentEnd @@ -169,7 +179,7 @@ func (alloc *allocator) rebase4Signed(tableID, requiredBase int64, allocIDs bool } if allocIDs { newBase = mathutil.MaxInt64(currentEnd, requiredBase) - newEnd = mathutil.MinInt64(math.MaxInt64-step, newBase) + step + newEnd = mathutil.MinInt64(math.MaxInt64-alloc.step, newBase) + alloc.step } else { if currentEnd >= requiredBase { newBase = currentEnd @@ -211,10 +221,80 @@ func (alloc *allocator) Rebase(tableID, requiredBase int64, allocIDs bool) error return alloc.rebase4Signed(tableID, requiredBase, allocIDs) } -func (alloc *allocator) alloc4Unsigned(tableID int64) (int64, error) { - if alloc.base == alloc.end { // step +// NextStep return new auto id step according to previous step and consuming time. +func NextStep(curStep int64, consumeDur time.Duration) int64 { + failpoint.Inject("mockAutoIDChange", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(step) + } + }) + + consumeRate := defaultConsumeTime.Seconds() / consumeDur.Seconds() + res := int64(float64(curStep) * consumeRate) + if res < minStep { + return minStep + } else if res > maxStep { + return maxStep + } + return res +} + +// NewAllocator returns a new auto increment id generator on the store. +func NewAllocator(store kv.Storage, dbID int64, isUnsigned bool) Allocator { + return &allocator{ + store: store, + dbID: dbID, + isUnsigned: isUnsigned, + step: step, + lastAllocTime: time.Now(), + } +} + +//codeInvalidTableID is the code of autoid error. +const codeInvalidTableID terror.ErrCode = 1 + +var localSchemaID = int64(math.MaxInt64) + +// GenLocalSchemaID generates a local schema ID. +func GenLocalSchemaID() int64 { + return atomic.AddInt64(&localSchemaID, -1) +} + +// Alloc implements autoid.Allocator Alloc interface. +func (alloc *allocator) Alloc(tableID int64, n uint64) (int64, int64, error) { + if tableID == 0 { + return 0, 0, errInvalidTableID.GenWithStackByArgs("Invalid tableID") + } + if n == 0 { + return 0, 0, nil + } + alloc.mu.Lock() + defer alloc.mu.Unlock() + if alloc.isUnsigned { + return alloc.alloc4Unsigned(tableID, n) + } + return alloc.alloc4Signed(tableID, n) +} + +func (alloc *allocator) alloc4Signed(tableID int64, n uint64) (int64, int64, error) { + n1 := int64(n) + // Condition alloc.base+N1 > alloc.end will overflow when alloc.base + N1 > MaxInt64. So need this. + if math.MaxInt64-alloc.base <= n1 { + return 0, 0, ErrAutoincReadFailed + } + // The local rest is not enough for allocN, skip it. + if alloc.base+n1 > alloc.end { var newBase, newEnd int64 startTime := time.Now() + // Although it may skip a segment here, we still think it is consumed. + consumeDur := startTime.Sub(alloc.lastAllocTime) + nextStep := NextStep(alloc.step, consumeDur) + // Make sure nextStep is big enough. + if nextStep <= n1 { + alloc.step = mathutil.MinInt64(n1*2, maxStep) + } else { + alloc.step = nextStep + } err := kv.RunInNewTxn(alloc.store, true, func(txn kv.Transaction) error { m := meta.NewMeta(txn) var err1 error @@ -222,32 +302,53 @@ func (alloc *allocator) alloc4Unsigned(tableID int64) (int64, error) { if err1 != nil { return err1 } - tmpStep := int64(mathutil.MinUint64(math.MaxUint64-uint64(newBase), uint64(step))) + tmpStep := mathutil.MinInt64(math.MaxInt64-newBase, alloc.step) + // The global rest is not enough for alloc. + if tmpStep < n1 { + return ErrAutoincReadFailed + } newEnd, err1 = m.GenAutoTableID(alloc.dbID, tableID, tmpStep) return err1 }) metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDAlloc, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) if err != nil { - return 0, err + return 0, 0, err } - if uint64(newBase) == math.MaxUint64 { - return 0, ErrAutoincReadFailed + alloc.lastAllocTime = time.Now() + if newBase == math.MaxInt64 { + return 0, 0, ErrAutoincReadFailed } alloc.base, alloc.end = newBase, newEnd } - - alloc.base = int64(uint64(alloc.base) + 1) - logutil.Logger(context.Background()).Debug("alloc unsigned ID", - zap.Uint64("ID", uint64(alloc.base)), + logutil.Logger(context.TODO()).Debug("alloc N signed ID", + zap.Uint64("from ID", uint64(alloc.base)), + zap.Uint64("to ID", uint64(alloc.base+n1)), zap.Int64("table ID", tableID), zap.Int64("database ID", alloc.dbID)) - return alloc.base, nil + min := alloc.base + alloc.base += n1 + return min, alloc.base, nil } -func (alloc *allocator) alloc4Signed(tableID int64) (int64, error) { - if alloc.base == alloc.end { // step +func (alloc *allocator) alloc4Unsigned(tableID int64, n uint64) (int64, int64, error) { + n1 := int64(n) + // Condition alloc.base+n1 > alloc.end will overflow when alloc.base + n1 > MaxInt64. So need this. + if math.MaxUint64-uint64(alloc.base) <= n { + return 0, 0, ErrAutoincReadFailed + } + // The local rest is not enough for alloc, skip it. + if uint64(alloc.base)+n > uint64(alloc.end) { var newBase, newEnd int64 startTime := time.Now() + // Although it may skip a segment here, we still treat it as consumed. + consumeDur := startTime.Sub(alloc.lastAllocTime) + nextStep := NextStep(alloc.step, consumeDur) + // Make sure nextStep is big enough. + if nextStep <= n1 { + alloc.step = mathutil.MinInt64(n1*2, maxStep) + } else { + alloc.step = nextStep + } err := kv.RunInNewTxn(alloc.store, true, func(txn kv.Transaction) error { m := meta.NewMeta(txn) var err1 error @@ -255,56 +356,31 @@ func (alloc *allocator) alloc4Signed(tableID int64) (int64, error) { if err1 != nil { return err1 } - tmpStep := mathutil.MinInt64(math.MaxInt64-newBase, step) + tmpStep := int64(mathutil.MinUint64(math.MaxUint64-uint64(newBase), uint64(alloc.step))) + // The global rest is not enough for alloc. + if tmpStep < n1 { + return ErrAutoincReadFailed + } newEnd, err1 = m.GenAutoTableID(alloc.dbID, tableID, tmpStep) return err1 }) metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDAlloc, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) if err != nil { - return 0, err + return 0, 0, err } - if newBase == math.MaxInt64 { - return 0, ErrAutoincReadFailed + alloc.lastAllocTime = time.Now() + if uint64(newBase) == math.MaxUint64 { + return 0, 0, ErrAutoincReadFailed } alloc.base, alloc.end = newBase, newEnd } - - alloc.base++ - logutil.Logger(context.Background()).Debug("alloc signed ID", - zap.Uint64("ID", uint64(alloc.base)), + logutil.Logger(context.TODO()).Debug("alloc unsigned ID", + zap.Uint64(" from ID", uint64(alloc.base)), + zap.Uint64("to ID", uint64(alloc.base+n1)), zap.Int64("table ID", tableID), zap.Int64("database ID", alloc.dbID)) - return alloc.base, nil -} - -// Alloc implements autoid.Allocator Alloc interface. -func (alloc *allocator) Alloc(tableID int64) (int64, error) { - if tableID == 0 { - return 0, errInvalidTableID.GenWithStack("Invalid tableID") - } - alloc.mu.Lock() - defer alloc.mu.Unlock() - if alloc.isUnsigned { - return alloc.alloc4Unsigned(tableID) - } - return alloc.alloc4Signed(tableID) -} - -// NewAllocator returns a new auto increment id generator on the store. -func NewAllocator(store kv.Storage, dbID int64, isUnsigned bool) Allocator { - return &allocator{ - store: store, - dbID: dbID, - isUnsigned: isUnsigned, - } -} - -//autoid error codes. -const codeInvalidTableID terror.ErrCode = 1 - -var localSchemaID = int64(math.MaxInt64) - -// GenLocalSchemaID generates a local schema ID. -func GenLocalSchemaID() int64 { - return atomic.AddInt64(&localSchemaID, -1) + min := alloc.base + // Use uint64 n directly. + alloc.base = int64(uint64(alloc.base) + n) + return min, alloc.base, nil } diff --git a/meta/autoid/autoid_test.go b/meta/autoid/autoid_test.go index 569578b7c4bbd..32ba1c49dbb98 100644 --- a/meta/autoid/autoid_test.go +++ b/meta/autoid/autoid_test.go @@ -15,12 +15,15 @@ package autoid_test import ( "fmt" + "math" + "math/rand" "sync" "testing" "time" . "github.com/pingcap/check" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" @@ -39,6 +42,11 @@ type testSuite struct { } func (*testSuite) TestT(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange"), IsNil) + }() + store, err := mockstore.NewMockTikvStore() c.Assert(err, IsNil) defer store.Close() @@ -53,6 +61,8 @@ func (*testSuite) TestT(c *C) { c.Assert(err, IsNil) err = m.CreateTableOrView(1, &model.TableInfo{ID: 3, Name: model.NewCIStr("t1")}) c.Assert(err, IsNil) + err = m.CreateTableOrView(1, &model.TableInfo{ID: 4, Name: model.NewCIStr("t2")}) + c.Assert(err, IsNil) return nil }) c.Assert(err, IsNil) @@ -63,13 +73,13 @@ func (*testSuite) TestT(c *C) { globalAutoID, err := alloc.NextGlobalAutoID(1) c.Assert(err, IsNil) c.Assert(globalAutoID, Equals, int64(1)) - id, err := alloc.Alloc(1) + _, id, err := alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(1)) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(2)) - _, err = alloc.Alloc(0) + _, _, err = alloc.Alloc(0, 1) c.Assert(err, NotNil) globalAutoID, err = alloc.NextGlobalAutoID(1) c.Assert(err, IsNil) @@ -78,28 +88,28 @@ func (*testSuite) TestT(c *C) { // rebase err = alloc.Rebase(1, int64(1), true) c.Assert(err, IsNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(3)) err = alloc.Rebase(1, int64(3), true) c.Assert(err, IsNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(4)) err = alloc.Rebase(1, int64(10), true) c.Assert(err, IsNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(11)) err = alloc.Rebase(1, int64(3010), true) c.Assert(err, IsNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(3011)) alloc = autoid.NewAllocator(store, 1, false) c.Assert(alloc, NotNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(autoid.GetStep()+1)) @@ -107,7 +117,7 @@ func (*testSuite) TestT(c *C) { c.Assert(alloc, NotNil) err = alloc.Rebase(2, int64(1), false) c.Assert(err, IsNil) - id, err = alloc.Alloc(2) + _, id, err = alloc.Alloc(2, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(2)) @@ -119,17 +129,73 @@ func (*testSuite) TestT(c *C) { c.Assert(alloc, NotNil) err = alloc.Rebase(3, int64(3000), false) c.Assert(err, IsNil) - id, err = alloc.Alloc(3) + _, id, err = alloc.Alloc(3, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(3211)) err = alloc.Rebase(3, int64(6543), false) c.Assert(err, IsNil) - id, err = alloc.Alloc(3) + _, id, err = alloc.Alloc(3, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(6544)) + + // Test the MaxInt64 is the upper bound of `alloc` function but not `rebase`. + err = alloc.Rebase(3, int64(math.MaxInt64-1), true) + c.Assert(err, IsNil) + _, _, err = alloc.Alloc(3, 1) + c.Assert(alloc, NotNil) + err = alloc.Rebase(3, int64(math.MaxInt64), true) + c.Assert(err, IsNil) + + // alloc N for signed + alloc = autoid.NewAllocator(store, 1, false) + c.Assert(alloc, NotNil) + globalAutoID, err = alloc.NextGlobalAutoID(4) + c.Assert(err, IsNil) + c.Assert(globalAutoID, Equals, int64(1)) + min, max, err := alloc.Alloc(4, 1) + c.Assert(err, IsNil) + c.Assert(max-min, Equals, int64(1)) + c.Assert(min+1, Equals, int64(1)) + + min, max, err = alloc.Alloc(4, 2) + c.Assert(err, IsNil) + c.Assert(max-min, Equals, int64(2)) + c.Assert(min+1, Equals, int64(2)) + c.Assert(max, Equals, int64(3)) + + min, max, err = alloc.Alloc(4, 100) + c.Assert(err, IsNil) + c.Assert(max-min, Equals, int64(100)) + expected := int64(4) + for i := min + 1; i <= max; i++ { + c.Assert(i, Equals, expected) + expected++ + } + + err = alloc.Rebase(4, int64(1000), false) + c.Assert(err, IsNil) + min, max, err = alloc.Alloc(4, 3) + c.Assert(err, IsNil) + c.Assert(max-min, Equals, int64(3)) + c.Assert(min+1, Equals, int64(1001)) + c.Assert(min+2, Equals, int64(1002)) + c.Assert(max, Equals, int64(1003)) + + lastRemainOne := alloc.End() + err = alloc.Rebase(4, alloc.End()-2, false) + c.Assert(err, IsNil) + min, max, err = alloc.Alloc(4, 5) + c.Assert(err, IsNil) + c.Assert(max-min, Equals, int64(5)) + c.Assert(min+1, Greater, lastRemainOne) } func (*testSuite) TestUnsignedAutoid(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange", `return(true)`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/meta/autoid/mockAutoIDChange"), IsNil) + }() + store, err := mockstore.NewMockTikvStore() c.Assert(err, IsNil) defer store.Close() @@ -144,6 +210,8 @@ func (*testSuite) TestUnsignedAutoid(c *C) { c.Assert(err, IsNil) err = m.CreateTableOrView(1, &model.TableInfo{ID: 3, Name: model.NewCIStr("t1")}) c.Assert(err, IsNil) + err = m.CreateTableOrView(1, &model.TableInfo{ID: 4, Name: model.NewCIStr("t2")}) + c.Assert(err, IsNil) return nil }) c.Assert(err, IsNil) @@ -154,13 +222,13 @@ func (*testSuite) TestUnsignedAutoid(c *C) { globalAutoID, err := alloc.NextGlobalAutoID(1) c.Assert(err, IsNil) c.Assert(globalAutoID, Equals, int64(1)) - id, err := alloc.Alloc(1) + _, id, err := alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(1)) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(2)) - _, err = alloc.Alloc(0) + _, _, err = alloc.Alloc(0, 1) c.Assert(err, NotNil) globalAutoID, err = alloc.NextGlobalAutoID(1) c.Assert(err, IsNil) @@ -169,28 +237,28 @@ func (*testSuite) TestUnsignedAutoid(c *C) { // rebase err = alloc.Rebase(1, int64(1), true) c.Assert(err, IsNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(3)) err = alloc.Rebase(1, int64(3), true) c.Assert(err, IsNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(4)) err = alloc.Rebase(1, int64(10), true) c.Assert(err, IsNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(11)) err = alloc.Rebase(1, int64(3010), true) c.Assert(err, IsNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(3011)) alloc = autoid.NewAllocator(store, 1, true) c.Assert(alloc, NotNil) - id, err = alloc.Alloc(1) + _, id, err = alloc.Alloc(1, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(autoid.GetStep()+1)) @@ -198,7 +266,7 @@ func (*testSuite) TestUnsignedAutoid(c *C) { c.Assert(alloc, NotNil) err = alloc.Rebase(2, int64(1), false) c.Assert(err, IsNil) - id, err = alloc.Alloc(2) + _, id, err = alloc.Alloc(2, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(2)) @@ -210,14 +278,54 @@ func (*testSuite) TestUnsignedAutoid(c *C) { c.Assert(alloc, NotNil) err = alloc.Rebase(3, int64(3000), false) c.Assert(err, IsNil) - id, err = alloc.Alloc(3) + _, id, err = alloc.Alloc(3, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(3211)) err = alloc.Rebase(3, int64(6543), false) c.Assert(err, IsNil) - id, err = alloc.Alloc(3) + _, id, err = alloc.Alloc(3, 1) c.Assert(err, IsNil) c.Assert(id, Equals, int64(6544)) + + // Test the MaxUint64 is the upper bound of `alloc` func but not `rebase`. + var n uint64 = math.MaxUint64 - 1 + un := int64(n) + err = alloc.Rebase(3, un, true) + c.Assert(err, IsNil) + _, _, err = alloc.Alloc(3, 1) + c.Assert(err, NotNil) + un = int64(n + 1) + err = alloc.Rebase(3, un, true) + c.Assert(err, IsNil) + + // alloc N for unsigned + alloc = autoid.NewAllocator(store, 1, true) + c.Assert(alloc, NotNil) + globalAutoID, err = alloc.NextGlobalAutoID(4) + c.Assert(err, IsNil) + c.Assert(globalAutoID, Equals, int64(1)) + + min, max, err := alloc.Alloc(4, 2) + c.Assert(err, IsNil) + c.Assert(max-min, Equals, int64(2)) + c.Assert(min+1, Equals, int64(1)) + c.Assert(max, Equals, int64(2)) + + err = alloc.Rebase(4, int64(500), true) + c.Assert(err, IsNil) + min, max, err = alloc.Alloc(4, 2) + c.Assert(err, IsNil) + c.Assert(max-min, Equals, int64(2)) + c.Assert(min+1, Equals, int64(501)) + c.Assert(max, Equals, int64(502)) + + lastRemainOne := alloc.End() + err = alloc.Rebase(4, alloc.End()-2, false) + c.Assert(err, IsNil) + min, max, err = alloc.Alloc(4, 5) + c.Assert(err, IsNil) + c.Assert(max-min, Equals, int64(5)) + c.Assert(min+1, Greater, lastRemainOne) } // TestConcurrentAlloc is used for the test that @@ -252,7 +360,7 @@ func (*testSuite) TestConcurrentAlloc(c *C) { allocIDs := func() { alloc := autoid.NewAllocator(store, dbID, false) for j := 0; j < int(autoid.GetStep())+5; j++ { - id, err1 := alloc.Alloc(tblID) + _, id, err1 := alloc.Alloc(tblID, 1) if err1 != nil { errCh <- err1 break @@ -266,6 +374,30 @@ func (*testSuite) TestConcurrentAlloc(c *C) { } m[id] = struct{}{} mu.Unlock() + + //test Alloc N + N := rand.Uint64() % 100 + min, max, err1 := alloc.Alloc(tblID, N) + if err1 != nil { + errCh <- err1 + break + } + + errFlag := false + mu.Lock() + for i := min + 1; i <= max; i++ { + if _, ok := m[i]; ok { + errCh <- fmt.Errorf("duplicate id:%v", i) + errFlag = true + mu.Unlock() + break + } + m[i] = struct{}{} + } + if errFlag { + break + } + mu.Unlock() } } for i := 0; i < count; i++ { @@ -305,7 +437,7 @@ func (*testSuite) TestRollbackAlloc(c *C) { injectConf.SetCommitError(errors.New("injected")) injectedStore := kv.NewInjectedStore(store, injectConf) alloc := autoid.NewAllocator(injectedStore, 1, false) - _, err = alloc.Alloc(2) + _, _, err = alloc.Alloc(2, 1) c.Assert(err, NotNil) c.Assert(alloc.Base(), Equals, int64(0)) c.Assert(alloc.End(), Equals, int64(0)) @@ -315,3 +447,44 @@ func (*testSuite) TestRollbackAlloc(c *C) { c.Assert(alloc.Base(), Equals, int64(0)) c.Assert(alloc.End(), Equals, int64(0)) } + +// TestNextStep tests generate next auto id step. +func (*testSuite) TestNextStep(c *C) { + nextStep := autoid.NextStep(2000000, 1*time.Nanosecond) + c.Assert(nextStep, Equals, int64(2000000)) + nextStep = autoid.NextStep(678910, 10*time.Second) + c.Assert(nextStep, Equals, int64(678910)) + nextStep = autoid.NextStep(50000, 10*time.Minute) + c.Assert(nextStep, Equals, int64(30000)) +} + +func BenchmarkAllocator_Alloc(b *testing.B) { + b.StopTimer() + store, err := mockstore.NewMockTikvStore() + if err != nil { + return + } + defer store.Close() + dbID := int64(1) + tblID := int64(2) + err = kv.RunInNewTxn(store, false, func(txn kv.Transaction) error { + m := meta.NewMeta(txn) + err = m.CreateDatabase(&model.DBInfo{ID: dbID, Name: model.NewCIStr("a")}) + if err != nil { + return err + } + err = m.CreateTableOrView(dbID, &model.TableInfo{ID: tblID, Name: model.NewCIStr("t")}) + if err != nil { + return err + } + return nil + }) + if err != nil { + return + } + alloc := autoid.NewAllocator(store, 1, false) + b.StartTimer() + for i := 0; i < b.N; i++ { + alloc.Alloc(2, 1) + } +} diff --git a/meta/autoid/errors.go b/meta/autoid/errors.go index 44ef83650a202..e9b0b0fa6fae4 100644 --- a/meta/autoid/errors.go +++ b/meta/autoid/errors.go @@ -21,12 +21,14 @@ import ( // Error instances. var ( ErrAutoincReadFailed = terror.ClassAutoid.New(mysql.ErrAutoincReadFailed, mysql.MySQLErrName[mysql.ErrAutoincReadFailed]) + ErrWrongAutoKey = terror.ClassAutoid.New(mysql.ErrWrongAutoKey, mysql.MySQLErrName[mysql.ErrWrongAutoKey]) ) func init() { // Map error codes to mysql error codes. tableMySQLErrCodes := map[terror.ErrCode]uint16{ mysql.ErrAutoincReadFailed: mysql.ErrAutoincReadFailed, + mysql.ErrWrongAutoKey: mysql.ErrWrongAutoKey, } terror.ErrClassToMySQLCodes[terror.ClassAutoid] = tableMySQLErrCodes } diff --git a/meta/meta.go b/meta/meta.go index 09cc9c3a89a92..4104d1a11929f 100644 --- a/meta/meta.go +++ b/meta/meta.go @@ -135,6 +135,11 @@ func (m *Meta) tableKey(tableID int64) []byte { return []byte(fmt.Sprintf("%s:%d", mTablePrefix, tableID)) } +// DDLJobHistoryKey is only used for testing. +func DDLJobHistoryKey(m *Meta, jobID int64) []byte { + return m.txn.EncodeHashDataKey(mDDLJobHistoryKey, m.jobIDKey(jobID)) +} + // GenAutoTableIDKeyValue generates meta key by dbID, tableID and corresponding value by autoID. func (m *Meta) GenAutoTableIDKeyValue(dbID, tableID, autoID int64) (key, value []byte) { dbKey := m.dbKey(dbID) @@ -452,8 +457,13 @@ func (m *Meta) enQueueDDLJob(key []byte, job *model.Job) error { } // EnQueueDDLJob adds a DDL job to the list. -func (m *Meta) EnQueueDDLJob(job *model.Job) error { - return m.enQueueDDLJob(m.jobListKey, job) +func (m *Meta) EnQueueDDLJob(job *model.Job, jobListKeys ...JobListKeyType) error { + listKey := m.jobListKey + if len(jobListKeys) != 0 { + listKey = jobListKeys[0] + } + + return m.enQueueDDLJob(listKey, job) } func (m *Meta) deQueueDDLJob(key []byte) (*model.Job, error) { @@ -637,10 +647,23 @@ func (m *Meta) GetAllHistoryDDLJobs() ([]*model.Job, error) { if err != nil { return nil, errors.Trace(err) } - jobs := make([]*model.Job, 0, len(pairs)) - for _, pair := range pairs { + return decodeAndSortJob(pairs) +} + +// GetLastNHistoryDDLJobs gets latest N history ddl jobs. +func (m *Meta) GetLastNHistoryDDLJobs(num int) ([]*model.Job, error) { + pairs, err := m.txn.HGetLastN(mDDLJobHistoryKey, num) + if err != nil { + return nil, errors.Trace(err) + } + return decodeAndSortJob(pairs) +} + +func decodeAndSortJob(jobPairs []structure.HashPair) ([]*model.Job, error) { + jobs := make([]*model.Job, 0, len(jobPairs)) + for _, pair := range jobPairs { job := &model.Job{} - err = job.Decode(pair.Value) + err := job.Decode(pair.Value) if err != nil { return nil, errors.Trace(err) } diff --git a/meta/meta_test.go b/meta/meta_test.go index 5e719680f5e7a..2d0ef05f6605b 100644 --- a/meta/meta_test.go +++ b/meta/meta_test.go @@ -237,6 +237,10 @@ func (s *testSuite) TestMeta(c *C) { err = txn.Commit(context.Background()) c.Assert(err, IsNil) + + // Test for DDLJobHistoryKey. + key = meta.DDLJobHistoryKey(t, 888) + c.Assert(key, DeepEquals, []byte{0x6d, 0x44, 0x44, 0x4c, 0x4a, 0x6f, 0x62, 0x48, 0x69, 0xff, 0x73, 0x74, 0x6f, 0x72, 0x79, 0x0, 0x0, 0x0, 0xfc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x68, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x78, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf7}) } func (s *testSuite) TestSnapshot(c *C) { @@ -356,6 +360,13 @@ func (s *testSuite) TestDDL(c *C) { lastID = job.ID } + // Test for get last N history ddl jobs. + historyJobs, err := t.GetLastNHistoryDDLJobs(2) + c.Assert(err, IsNil) + c.Assert(len(historyJobs), Equals, 2) + c.Assert(historyJobs[0].ID == 123, IsTrue) + c.Assert(historyJobs[1].ID == 1234, IsTrue) + // Test GetAllDDLJobsInQueue. err = t.EnQueueDDLJob(job) c.Assert(err, IsNil) diff --git a/metrics/bindinfo.go b/metrics/bindinfo.go new file mode 100644 index 0000000000000..958bd110c2b23 --- /dev/null +++ b/metrics/bindinfo.go @@ -0,0 +1,43 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package metrics + +import "github.com/prometheus/client_golang/prometheus" + +// bindinfo metrics. +var ( + BindUsageCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "bindinfo", + Name: "bind_usage_counter", + Help: "Counter of query using sql bind", + }, []string{LableScope}) + + BindTotalGauge = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "tidb", + Subsystem: "bindinfo", + Name: "bind_total_gauge", + Help: "Total number of sql bind", + }, []string{LableScope, LblType}) + + BindMemoryUsage = prometheus.NewGaugeVec( + prometheus.GaugeOpts{ + Namespace: "tidb", + Subsystem: "bindinfo", + Name: "bind_memory_usage", + Help: "Memory usage of sql bind", + }, []string{LableScope, LblType}) +) diff --git a/metrics/ddl.go b/metrics/ddl.go index 8a6de6347cd38..e36f58bd7b1f3 100644 --- a/metrics/ddl.go +++ b/metrics/ddl.go @@ -65,10 +65,13 @@ var ( Buckets: prometheus.ExponentialBuckets(0.001, 2, 20), // 1ms ~ 1024s }, []string{LblResult}) - OwnerUpdateGlobalVersion = "update_global_version" - OwnerGetGlobalVersion = "get_global_version" - OwnerCheckAllVersions = "check_all_versions" - OwnerHandleSyncerHistogram = prometheus.NewHistogramVec( + OwnerUpdateGlobalVersion = "update_global_version" + OwnerGetGlobalVersion = "get_global_version" + OwnerCheckAllVersions = "check_all_versions" + OwnerNotifyCleanExpirePaths = "notify_clean_expire_paths" + OwnerCleanExpirePaths = "clean_expire_paths" + OwnerCleanOneExpirePath = "clean_an_expire_path" + OwnerHandleSyncerHistogram = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "tidb", Subsystem: "ddl", @@ -93,6 +96,7 @@ var ( CreateDDLInstance = "create_ddl_instance" CreateDDL = "create_ddl" + StartCleanWork = "start_clean_work" DDLOwner = "owner" DDLCounter = prometheus.NewCounterVec( prometheus.CounterOpts{ @@ -101,6 +105,14 @@ var ( Name: "worker_operation_total", Help: "Counter of creating ddl/worker and isowner.", }, []string{LblType}) + + AddIndexTotalCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "ddl", + Name: "add_index_total", + Help: "Speed of add index", + }, []string{LblType}) ) // Label constants. diff --git a/metrics/domain.go b/metrics/domain.go index 017e007e4fb98..a8ea4e3d4cbbf 100644 --- a/metrics/domain.go +++ b/metrics/domain.go @@ -17,6 +17,7 @@ import ( "github.com/prometheus/client_golang/prometheus" ) +// Metrics for the domain package. var ( // LoadSchemaCounter records the counter of load schema. LoadSchemaCounter = prometheus.NewCounterVec( @@ -45,4 +46,18 @@ var ( Name: "load_privilege_total", Help: "Counter of load privilege", }, []string{LblType}) + + SchemaValidatorStop = "stop" + SchemaValidatorRestart = "restart" + SchemaValidatorReset = "reset" + SchemaValidatorCacheEmpty = "cache_empty" + SchemaValidatorCacheMiss = "cache_miss" + // HandleSchemaValidate records the counter of handling schema validate. + HandleSchemaValidate = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "domain", + Name: "handle_schema_validate", + Help: "Counter of handle schema validate", + }, []string{LblType}) ) diff --git a/metrics/gprc.go b/metrics/gprc.go new file mode 100644 index 0000000000000..33875054b64a1 --- /dev/null +++ b/metrics/gprc.go @@ -0,0 +1,27 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package metrics + +import "github.com/prometheus/client_golang/prometheus" + +// Metrics to monitor gRPC service +var ( + GRPCConnTransientFailureCounter = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "tidb", + Subsystem: "grpc", + Name: "connection_transient_failure_count", + Help: "Counter of gRPC connection transient failure", + }, []string{LblAddress, LblStore}) +) diff --git a/metrics/metrics.go b/metrics/metrics.go index 5711d411efc12..74805734e9679 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -28,18 +28,23 @@ var ( // metrics labels. const ( - LabelSession = "session" - LabelDomain = "domain" - LabelDDLOwner = "ddl-owner" - LabelDDL = "ddl" - LabelGCWorker = "gcworker" - LabelAnalyze = "analyze" + LabelSession = "session" + LabelDomain = "domain" + LabelDDLOwner = "ddl-owner" + LabelDDL = "ddl" + LabelDDLSyncer = "ddl-syncer" + LabelGCWorker = "gcworker" + LabelAnalyze = "analyze" LabelBatchRecvLoop = "batch-recv-loop" LabelBatchSendLoop = "batch-send-loop" opSucc = "ok" opFailed = "err" + + LableScope = "scope" + ScopeGlobal = "global" + ScopeSession = "session" ) // RetLabel returns "ok" when err == nil and "err" when err != nil. @@ -57,11 +62,15 @@ func RegisterMetrics() { prometheus.MustRegister(AutoAnalyzeHistogram) prometheus.MustRegister(AutoIDHistogram) prometheus.MustRegister(BatchAddIdxHistogram) + prometheus.MustRegister(BindUsageCounter) + prometheus.MustRegister(BindTotalGauge) + prometheus.MustRegister(BindMemoryUsage) prometheus.MustRegister(CampaignOwnerCounter) prometheus.MustRegister(ConnGauge) prometheus.MustRegister(PreparedStmtGauge) prometheus.MustRegister(CriticalErrorCounter) prometheus.MustRegister(DDLCounter) + prometheus.MustRegister(AddIndexTotalCounter) prometheus.MustRegister(DDLWorkerHistogram) prometheus.MustRegister(DeploySyncerHistogram) prometheus.MustRegister(DistSQLPartialCountHistogram) @@ -113,6 +122,7 @@ func RegisterMetrics() { prometheus.MustRegister(TiKVSecondaryLockCleanupFailureCounter) prometheus.MustRegister(TiKVSendReqHistogram) prometheus.MustRegister(TiKVSnapshotCounter) + prometheus.MustRegister(TiKVTxnCmdCounter) prometheus.MustRegister(TiKVTxnCmdHistogram) prometheus.MustRegister(TiKVTxnCounter) prometheus.MustRegister(TiKVTxnRegionsNumHistogram) @@ -141,4 +151,7 @@ func RegisterMetrics() { prometheus.MustRegister(TiKVBatchClientUnavailable) prometheus.MustRegister(TiKVRangeTaskStats) prometheus.MustRegister(TiKVRangeTaskPushDuration) + prometheus.MustRegister(HandleSchemaValidate) + prometheus.MustRegister(TiKVTxnHeartBeatHistogram) + prometheus.MustRegister(GRPCConnTransientFailureCounter) } diff --git a/metrics/session.go b/metrics/session.go index 0ea3548547986..b06f70984c3aa 100644 --- a/metrics/session.go +++ b/metrics/session.go @@ -106,10 +106,13 @@ const ( LblOK = "ok" LblError = "error" LblRollback = "rollback" + LblComRol = "com_rol" LblType = "type" LblDb = "db" LblResult = "result" LblSQLType = "sql_type" LblGeneral = "general" LblInternal = "internal" + LblStore = "store" + LblAddress = "address" ) diff --git a/metrics/tikvclient.go b/metrics/tikvclient.go index 1f66c1986a15b..f1e52ddde43ab 100644 --- a/metrics/tikvclient.go +++ b/metrics/tikvclient.go @@ -58,14 +58,14 @@ var ( Help: "Counter of backoff.", }, []string{LblType}) - TiKVBackoffHistogram = prometheus.NewHistogram( + TiKVBackoffHistogram = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "tidb", Subsystem: "tikvclient", Name: "backoff_seconds", Help: "total backoff seconds of a single backoffer.", Buckets: prometheus.ExponentialBuckets(0.0005, 2, 20), // 0.5ms ~ 524s - }) + }, []string{LblType}) TiKVSendReqHistogram = prometheus.NewHistogramVec( prometheus.HistogramOpts{ @@ -74,7 +74,7 @@ var ( Name: "request_seconds", Help: "Bucketed histogram of sending request duration.", Buckets: prometheus.ExponentialBuckets(0.0005, 2, 20), // 0.5ms ~ 524s - }, []string{LblType, "store"}) + }, []string{LblType, LblStore}) TiKVCoprocessorHistogram = prometheus.NewHistogram( prometheus.HistogramOpts{ @@ -180,13 +180,13 @@ var ( }) // TiKVPendingBatchRequests indicates the number of requests pending in the batch channel. - TiKVPendingBatchRequests = prometheus.NewGauge( + TiKVPendingBatchRequests = prometheus.NewGaugeVec( prometheus.GaugeOpts{ Namespace: "tidb", Subsystem: "tikvclient", Name: "pending_batch_requests", Help: "Pending batch requests", - }) + }, []string{"store"}) TiKVBatchWaitDuration = prometheus.NewHistogram( prometheus.HistogramOpts{ @@ -223,4 +223,13 @@ var ( Buckets: prometheus.ExponentialBuckets(0.001, 2, 20), Help: "duration to push sub tasks to range task workers", }, []string{LblType}) + + TiKVTxnHeartBeatHistogram = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: "tidb", + Subsystem: "tikvclient", + Name: "txn_heart_beat", + Help: "Bucketed histogram of the txn_heartbeat request duration.", + Buckets: prometheus.ExponentialBuckets(0.001, 2, 18), // 1ms ~ 292s + }, []string{LblType}) ) diff --git a/owner/fail_test.go b/owner/fail_test.go index 3b0ed1c1f2149..f224359049f90 100644 --- a/owner/fail_test.go +++ b/owner/fail_test.go @@ -75,6 +75,7 @@ func (s *testSuite) TestFailNewSession(c *C) { if cli != nil { cli.Close() } + c.Assert(failpoint.Disable("github.com/pingcap/tidb/owner/closeClient"), IsNil) }() c.Assert(failpoint.Enable("github.com/pingcap/tidb/owner/closeClient", `return(true)`), IsNil) _, err = NewSession(context.Background(), "fail_new_serssion", cli, retryCnt, ManagerSessionTTL) @@ -92,6 +93,7 @@ func (s *testSuite) TestFailNewSession(c *C) { if cli != nil { cli.Close() } + c.Assert(failpoint.Disable("github.com/pingcap/tidb/owner/closeGrpc"), IsNil) }() c.Assert(failpoint.Enable("github.com/pingcap/tidb/owner/closeGrpc", `return(true)`), IsNil) _, err = NewSession(context.Background(), "fail_new_serssion", cli, retryCnt, ManagerSessionTTL) diff --git a/owner/manager.go b/owner/manager.go index 5f738d11a9328..bd21159c526f1 100644 --- a/owner/manager.go +++ b/owner/manager.go @@ -245,6 +245,7 @@ func (m *ownerManager) campaignLoop(ctx context.Context, etcdSession *concurrenc return } case <-ctx.Done(): + logutil.Logger(logCtx).Info("break campaign loop, context is done") m.revokeSession(logPrefix, etcdSession.Lease()) return default: @@ -288,7 +289,7 @@ func (m *ownerManager) revokeSession(logPrefix string, leaseID clientv3.LeaseID) time.Duration(ManagerSessionTTL)*time.Second) _, err := m.etcdCli.Revoke(cancelCtx, leaseID) cancel() - logutil.Logger(m.logCtx).Info("break campaign loop, revoke err", zap.Error(err)) + logutil.Logger(m.logCtx).Info("revoke session", zap.Error(err)) } // GetOwnerID implements Manager.GetOwnerID interface. diff --git a/owner/manager_test.go b/owner/manager_test.go new file mode 100644 index 0000000000000..a83e4dc699352 --- /dev/null +++ b/owner/manager_test.go @@ -0,0 +1,175 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package owner_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + "github.com/coreos/etcd/integration" + "github.com/pingcap/errors" + "github.com/pingcap/parser/terror" + . "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/owner" + "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/util/logutil" + goctx "golang.org/x/net/context" +) + +const testLease = 5 * time.Millisecond + +func checkOwner(d DDL, fbVal bool) (isOwner bool) { + manager := d.OwnerManager() + // The longest to wait for 3 seconds to + // make sure that campaigning owners is completed. + for i := 0; i < 600; i++ { + time.Sleep(5 * time.Millisecond) + isOwner = manager.IsOwner() + if isOwner == fbVal { + break + } + } + return +} + +func TestSingle(t *testing.T) { + store, err := mockstore.NewMockTikvStore() + if err != nil { + t.Fatal(err) + } + defer store.Close() + + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 1}) + defer clus.Terminate(t) + cli := clus.RandClient() + ctx := goctx.Background() + d := NewDDL(ctx, cli, store, nil, nil, testLease, nil) + defer d.Stop() + + isOwner := checkOwner(d, true) + if !isOwner { + t.Fatalf("expect true, got isOwner:%v", isOwner) + } + + // test for newSession failed + ctx, cancel := goctx.WithCancel(ctx) + cancel() + manager := owner.NewOwnerManager(cli, "ddl", "ddl_id", DDLOwnerKey, nil) + err = manager.CampaignOwner(ctx) + if !terror.ErrorEqual(err, goctx.Canceled) && + !terror.ErrorEqual(err, goctx.DeadlineExceeded) { + t.Fatalf("campaigned result don't match, err %v", err) + } + isOwner = checkOwner(d, true) + if !isOwner { + t.Fatalf("expect true, got isOwner:%v", isOwner) + } + // The test is used to exit campaign loop. + d.OwnerManager().Cancel() + isOwner = checkOwner(d, false) + if isOwner { + t.Fatalf("expect false, got isOwner:%v", isOwner) + } + time.Sleep(10 * time.Millisecond) + ownerID, _ := manager.GetOwnerID(goctx.Background()) + // The error is ok to be not nil since we canceled the manager. + if ownerID != "" { + t.Fatalf("owner %s is not empty", ownerID) + } +} + +func TestCluster(t *testing.T) { + tmpTTL := 3 + orignalTTL := owner.ManagerSessionTTL + owner.ManagerSessionTTL = tmpTTL + defer func() { + owner.ManagerSessionTTL = orignalTTL + }() + store, err := mockstore.NewMockTikvStore() + if err != nil { + t.Fatal(err) + } + defer store.Close() + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 4}) + defer clus.Terminate(t) + + cli := clus.Client(0) + d := NewDDL(goctx.Background(), cli, store, nil, nil, testLease, nil) + isOwner := checkOwner(d, true) + if !isOwner { + t.Fatalf("expect true, got isOwner:%v", isOwner) + } + cli1 := clus.Client(1) + d1 := NewDDL(goctx.Background(), cli1, store, nil, nil, testLease, nil) + isOwner = checkOwner(d1, false) + if isOwner { + t.Fatalf("expect false, got isOwner:%v", isOwner) + } + + // Delete the leader key, the d1 become the owner. + cliRW := clus.Client(2) + err = deleteLeader(cliRW, DDLOwnerKey) + if err != nil { + t.Fatal(err) + } + isOwner = checkOwner(d, false) + if isOwner { + t.Fatalf("expect false, got isOwner:%v", isOwner) + } + d.Stop() + + // d3 (not owner) stop + cli3 := clus.Client(3) + d3 := NewDDL(goctx.Background(), cli3, store, nil, nil, testLease, nil) + defer d3.Stop() + isOwner = checkOwner(d3, false) + if isOwner { + t.Fatalf("expect false, got isOwner:%v", isOwner) + } + d3.Stop() + + // Cancel the owner context, there is no owner. + d1.Stop() + time.Sleep(time.Duration(tmpTTL+1) * time.Second) + session, err := concurrency.NewSession(cliRW) + if err != nil { + t.Fatalf("new session failed %v", err) + } + elec := concurrency.NewElection(session, DDLOwnerKey) + logPrefix := fmt.Sprintf("[ddl] %s ownerManager %s", DDLOwnerKey, "useless id") + logCtx := logutil.WithKeyValue(context.Background(), "owner info", logPrefix) + _, err = owner.GetOwnerInfo(goctx.Background(), logCtx, elec, "useless id") + if !terror.ErrorEqual(err, concurrency.ErrElectionNoLeader) { + t.Fatalf("get owner info result don't match, err %v", err) + } +} + +func deleteLeader(cli *clientv3.Client, prefixKey string) error { + session, err := concurrency.NewSession(cli) + if err != nil { + return errors.Trace(err) + } + defer session.Close() + elec := concurrency.NewElection(session, prefixKey) + resp, err := elec.Leader(goctx.Background()) + if err != nil { + return errors.Trace(err) + } + _, err = cli.Delete(goctx.Background(), string(resp.Kvs[0].Key)) + return errors.Trace(err) +} diff --git a/planner/cascades/optimize_test.go b/planner/cascades/optimize_test.go index e0f9660c78238..3ea1cec66e580 100644 --- a/planner/cascades/optimize_test.go +++ b/planner/cascades/optimize_test.go @@ -14,6 +14,7 @@ package cascades import ( + "context" "math" "testing" @@ -42,7 +43,7 @@ type testCascadesSuite struct { func (s *testCascadesSuite) SetUpSuite(c *C) { testleak.BeforeTest() - s.is = infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockTable()}) + s.is = infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockSignedTable()}) s.sctx = plannercore.MockContext() s.Parser = parser.New() } @@ -54,7 +55,7 @@ func (s *testCascadesSuite) TearDownSuite(c *C) { func (s *testCascadesSuite) TestImplGroupZeroCost(c *C) { stmt, err := s.ParseOneStmt("select t1.a, t2.a from t as t1 left join t as t2 on t1.a = t2.a where t1.a < 1.0", "", "") c.Assert(err, IsNil) - p, err := plannercore.BuildLogicalPlan(s.sctx, stmt, s.is) + p, err := plannercore.BuildLogicalPlan(context.Background(), s.sctx, stmt, s.is) c.Assert(err, IsNil) logic, ok := p.(plannercore.LogicalPlan) c.Assert(ok, IsTrue) @@ -70,7 +71,7 @@ func (s *testCascadesSuite) TestImplGroupZeroCost(c *C) { func (s *testCascadesSuite) TestInitGroupSchema(c *C) { stmt, err := s.ParseOneStmt("select a from t", "", "") c.Assert(err, IsNil) - p, err := plannercore.BuildLogicalPlan(s.sctx, stmt, s.is) + p, err := plannercore.BuildLogicalPlan(context.Background(), s.sctx, stmt, s.is) c.Assert(err, IsNil) logic, ok := p.(plannercore.LogicalPlan) c.Assert(ok, IsTrue) @@ -84,7 +85,7 @@ func (s *testCascadesSuite) TestInitGroupSchema(c *C) { func (s *testCascadesSuite) TestFillGroupStats(c *C) { stmt, err := s.ParseOneStmt("select * from t t1 join t t2 on t1.a = t2.a", "", "") c.Assert(err, IsNil) - p, err := plannercore.BuildLogicalPlan(s.sctx, stmt, s.is) + p, err := plannercore.BuildLogicalPlan(context.Background(), s.sctx, stmt, s.is) c.Assert(err, IsNil) logic, ok := p.(plannercore.LogicalPlan) c.Assert(ok, IsTrue) diff --git a/planner/core/cacheable_checker.go b/planner/core/cacheable_checker.go index 49e08eb8227b1..f4ef4f9c22b6d 100644 --- a/planner/core/cacheable_checker.go +++ b/planner/core/cacheable_checker.go @@ -82,6 +82,11 @@ func (checker *cacheableChecker) Enter(in ast.Node) (out ast.Node, skipChildren return in, true } } + case *ast.FrameBound: + if _, ok := node.Expr.(*driver.ParamMarkerExpr); ok { + checker.cacheable = false + return in, true + } } return in, false } diff --git a/planner/core/cacheable_checker_test.go b/planner/core/cacheable_checker_test.go index 8f3d287701533..6d195c25f4b27 100644 --- a/planner/core/cacheable_checker_test.go +++ b/planner/core/cacheable_checker_test.go @@ -191,4 +191,7 @@ func (s *testCacheableSuite) TestCacheable(c *C) { OrderBy: orderByClause, } c.Assert(Cacheable(stmt), IsTrue) + + boundExpr := &ast.FrameBound{Expr: &driver.ParamMarkerExpr{}} + c.Assert(Cacheable(boundExpr), IsFalse) } diff --git a/planner/core/cbo_test.go b/planner/core/cbo_test.go index 8dca13b7215d2..fff0ef7c17037 100644 --- a/planner/core/cbo_test.go +++ b/planner/core/cbo_test.go @@ -14,6 +14,7 @@ package core_test import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -35,11 +36,23 @@ import ( "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" + "github.com/pingcap/tidb/util/testutil" ) var _ = Suite(&testAnalyzeSuite{}) type testAnalyzeSuite struct { + testData testutil.TestData +} + +func (s *testAnalyzeSuite) SetUpSuite(c *C) { + var err error + s.testData, err = testutil.LoadTestSuiteData("testdata", "analyze_suite") + c.Assert(err, IsNil) +} + +func (s *testAnalyzeSuite) TearDownSuite(c *C) { + c.Assert(s.testData.GenerateOutputIfNeeded(), IsNil) } func (s *testAnalyzeSuite) loadTableStats(fileName string, dom *domain.Domain) error { @@ -80,7 +93,7 @@ func (s *testAnalyzeSuite) TestExplainAnalyze(c *C) { rs := tk.MustQuery("explain analyze select t1.a, t1.b, sum(t1.c) from t1 join t2 on t1.a = t2.b where t1.a > 1") c.Assert(len(rs.Rows()), Equals, 10) for _, row := range rs.Rows() { - c.Assert(len(row), Equals, 5) + c.Assert(len(row), Equals, 6) execInfo := row[4].(string) c.Assert(strings.Contains(execInfo, "time"), Equals, true) c.Assert(strings.Contains(execInfo, "loops"), Equals, true) @@ -106,7 +119,7 @@ func (s *testAnalyzeSuite) TestCBOWithoutAnalyze(c *C) { c.Assert(h.HandleDDLEvent(<-h.DDLEventCh()), IsNil) testKit.MustExec("insert into t1 values (1), (2), (3), (4), (5), (6)") testKit.MustExec("insert into t2 values (1), (2), (3), (4), (5), (6)") - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(dom.InfoSchema()), IsNil) testKit.MustQuery("explain select * from t1, t2 where t1.a = t2.a").Check(testkit.Rows( "HashLeftJoin_8 7.49 root inner join, inner:TableReader_15, equal:[eq(test.t1.a, test.t2.a)]", @@ -136,9 +149,9 @@ func (s *testAnalyzeSuite) TestStraightJoin(c *C) { } testKit.MustQuery("explain select straight_join * from t1, t2, t3, t4").Check(testkit.Rows( - "HashLeftJoin_10 10000000000000000.00 root inner join, inner:TableReader_23", - "├─HashLeftJoin_12 1000000000000.00 root inner join, inner:TableReader_21", - "│ ├─HashLeftJoin_14 100000000.00 root inner join, inner:TableReader_19", + "HashLeftJoin_10 10000000000000000.00 root CARTESIAN inner join, inner:TableReader_23", + "├─HashLeftJoin_12 1000000000000.00 root CARTESIAN inner join, inner:TableReader_21", + "│ ├─HashLeftJoin_14 100000000.00 root CARTESIAN inner join, inner:TableReader_19", "│ │ ├─TableReader_17 10000.00 root data:TableScan_16", "│ │ │ └─TableScan_16 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "│ │ └─TableReader_19 10000.00 root data:TableScan_18", @@ -150,9 +163,9 @@ func (s *testAnalyzeSuite) TestStraightJoin(c *C) { )) testKit.MustQuery("explain select * from t1 straight_join t2 straight_join t3 straight_join t4").Check(testkit.Rows( - "HashLeftJoin_10 10000000000000000.00 root inner join, inner:TableReader_23", - "├─HashLeftJoin_12 1000000000000.00 root inner join, inner:TableReader_21", - "│ ├─HashLeftJoin_14 100000000.00 root inner join, inner:TableReader_19", + "HashLeftJoin_10 10000000000000000.00 root CARTESIAN inner join, inner:TableReader_23", + "├─HashLeftJoin_12 1000000000000.00 root CARTESIAN inner join, inner:TableReader_21", + "│ ├─HashLeftJoin_14 100000000.00 root CARTESIAN inner join, inner:TableReader_19", "│ │ ├─TableReader_17 10000.00 root data:TableScan_16", "│ │ │ └─TableScan_16 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", "│ │ └─TableReader_19 10000.00 root data:TableScan_18", @@ -165,8 +178,8 @@ func (s *testAnalyzeSuite) TestStraightJoin(c *C) { testKit.MustQuery("explain select straight_join * from t1, t2, t3, t4 where t1.a=t4.a;").Check(testkit.Rows( "HashLeftJoin_11 1248750000000.00 root inner join, inner:TableReader_26, equal:[eq(test.t1.a, test.t4.a)]", - "├─HashLeftJoin_13 999000000000.00 root inner join, inner:TableReader_23", - "│ ├─HashRightJoin_16 99900000.00 root inner join, inner:TableReader_19", + "├─HashLeftJoin_13 999000000000.00 root CARTESIAN inner join, inner:TableReader_23", + "│ ├─HashRightJoin_16 99900000.00 root CARTESIAN inner join, inner:TableReader_19", "│ │ ├─TableReader_19 9990.00 root data:Selection_18", "│ │ │ └─Selection_18 9990.00 cop not(isnull(test.t1.a))", "│ │ │ └─TableScan_17 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", @@ -196,7 +209,7 @@ func (s *testAnalyzeSuite) TestTableDual(c *C) { testKit.MustExec("insert into t values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)") c.Assert(h.HandleDDLEvent(<-h.DDLEventCh()), IsNil) - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(dom.InfoSchema()), IsNil) testKit.MustQuery(`explain select * from t where 1 = 0`).Check(testkit.Rows( @@ -226,12 +239,12 @@ func (s *testAnalyzeSuite) TestEstimation(c *C) { testKit.MustExec("insert into t select * from t") h := dom.StatsHandle() h.HandleDDLEvent(<-h.DDLEventCh()) - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) testKit.MustExec("analyze table t") for i := 1; i <= 8; i++ { testKit.MustExec("delete from t where a = ?", i) } - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(dom.InfoSchema()), IsNil) testKit.MustQuery("explain select count(*) from t group by a").Check(testkit.Rows( "HashAgg_9 2.00 root group by:col_1, funcs:count(col_0)", @@ -380,7 +393,7 @@ func (s *testAnalyzeSuite) TestIndexRead(c *C) { is := domain.GetDomain(ctx).InfoSchema() err = core.Preprocess(ctx, stmt, is) c.Assert(err, IsNil) - p, err := planner.Optimize(ctx, stmt, is) + p, err := planner.Optimize(context.TODO(), ctx, stmt, is) c.Assert(err, IsNil) c.Assert(core.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) } @@ -430,7 +443,7 @@ func (s *testAnalyzeSuite) TestEmptyTable(c *C) { is := domain.GetDomain(ctx).InfoSchema() err = core.Preprocess(ctx, stmt, is) c.Assert(err, IsNil) - p, err := planner.Optimize(ctx, stmt, is) + p, err := planner.Optimize(context.TODO(), ctx, stmt, is) c.Assert(err, IsNil) c.Assert(core.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) } @@ -546,7 +559,7 @@ func (s *testAnalyzeSuite) TestAnalyze(c *C) { is := domain.GetDomain(ctx).InfoSchema() err = core.Preprocess(ctx, stmt, is) c.Assert(err, IsNil) - p, err := planner.Optimize(ctx, stmt, is) + p, err := planner.Optimize(context.TODO(), ctx, stmt, is) c.Assert(err, IsNil) c.Assert(core.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) } @@ -568,12 +581,12 @@ func (s *testAnalyzeSuite) TestOutdatedAnalyze(c *C) { } h := dom.StatsHandle() h.HandleDDLEvent(<-h.DDLEventCh()) - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) testKit.MustExec("analyze table t") testKit.MustExec("insert into t select * from t") testKit.MustExec("insert into t select * from t") testKit.MustExec("insert into t select * from t") - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(dom.InfoSchema()), IsNil) statistics.RatioOfPseudoEstimate.Store(10.0) testKit.MustQuery("explain select * from t where a <= 5 and b <= 5").Check(testkit.Rows( @@ -623,7 +636,7 @@ func (s *testAnalyzeSuite) TestPreparedNullParam(c *C) { is := domain.GetDomain(ctx).InfoSchema() err = core.Preprocess(ctx, stmt, is, core.InPrepare) c.Assert(err, IsNil) - p, err := planner.Optimize(ctx, stmt, is) + p, err := planner.Optimize(context.TODO(), ctx, stmt, is) c.Assert(err, IsNil) c.Assert(core.ToString(p), Equals, best, Commentf("for %s", sql)) @@ -658,8 +671,6 @@ func (s *testAnalyzeSuite) TestNullCount(c *C) { )) h := dom.StatsHandle() h.Clear() - h.Lease = 1 - defer func() { h.Lease = 0 }() c.Assert(h.Update(dom.InfoSchema()), IsNil) testKit.MustQuery("explain select * from t where b = 1").Check(testkit.Rows( "TableReader_7 0.00 root data:Selection_6", @@ -690,7 +701,7 @@ func (s *testAnalyzeSuite) TestCorrelatedEstimation(c *C) { tk.MustQuery("explain select t.c in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t;"). Check(testkit.Rows( "Projection_11 10.00 root 9_aux_0", - "└─Apply_13 10.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0)", + "└─Apply_13 10.00 root CARTESIAN left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.c, 7_col_0)", " ├─TableReader_15 10.00 root data:TableScan_14", " │ └─TableScan_14 10.00 cop table:t, range:[-inf,+inf], keep order:false", " └─StreamAgg_20 1.00 root funcs:count(1)", @@ -705,7 +716,7 @@ func (s *testAnalyzeSuite) TestCorrelatedEstimation(c *C) { tk.MustQuery("explain select (select concat(t1.a, \",\", t1.b) from t t1 where t1.a=t.a and t1.c=t.c) from t"). Check(testkit.Rows( "Projection_8 10.00 root concat(t1.a, \",\", t1.b)", - "└─Apply_10 10.00 root left outer join, inner:MaxOneRow_13", + "└─Apply_10 10.00 root CARTESIAN left outer join, inner:MaxOneRow_13", " ├─TableReader_12 10.00 root data:TableScan_11", " │ └─TableScan_11 10.00 cop table:t, range:[-inf,+inf], keep order:false", " └─MaxOneRow_13 1.00 root ", @@ -755,7 +766,7 @@ func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) { } session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() dom, err := session.BootstrapSession(store) if err != nil { @@ -881,7 +892,7 @@ func BenchmarkOptimize(b *testing.B) { b.Run(tt.sql, func(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := planner.Optimize(ctx, stmt, is) + _, err := planner.Optimize(context.TODO(), ctx, stmt, is) c.Assert(err, IsNil) } b.ReportAllocs() @@ -908,21 +919,21 @@ func (s *testAnalyzeSuite) TestIssue9562(c *C) { "├─TableReader_12 9980.01 root data:Selection_11", "│ └─Selection_11 9980.01 cop not(isnull(test.t1.a)), not(isnull(test.t1.c))", "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", - "└─IndexReader_8 0.00 root index:Selection_7", - " └─Selection_7 0.00 cop not(isnull(test.t2.a)), not(isnull(test.t2.c))", + "└─IndexReader_8 9.98 root index:Selection_7", + " └─Selection_7 9.98 cop not(isnull(test.t2.a)), not(isnull(test.t2.c))", " └─IndexScan_6 10.00 cop table:t2, index:a, b, c, range: decided by [eq(test.t2.a, test.t1.a) gt(test.t2.b, minus(test.t1.b, 1)) lt(test.t2.b, plus(test.t1.b, 1))], keep order:false, stats:pseudo", )) tk.MustExec("create table t(a int, b int, index idx_ab(a, b))") tk.MustQuery("explain select * from t t1 join t t2 where t1.b = t2.b and t2.b is null").Check(testkit.Rows( "Projection_7 0.00 root test.t1.a, test.t1.b, test.t2.a, test.t2.b", - "└─HashRightJoin_9 0.00 root inner join, inner:TableReader_12, equal:[eq(test.t2.b, test.t1.b)]", - " ├─TableReader_12 0.00 root data:Selection_11", + "└─HashRightJoin_9 0.00 root inner join, inner:IndexReader_12, equal:[eq(test.t2.b, test.t1.b)]", + " ├─IndexReader_12 0.00 root index:Selection_11", " │ └─Selection_11 0.00 cop isnull(test.t2.b), not(isnull(test.t2.b))", - " │ └─TableScan_10 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo", - " └─TableReader_15 9990.00 root data:Selection_14", + " │ └─IndexScan_10 10000.00 cop table:t2, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo", + " └─IndexReader_15 9990.00 root index:Selection_14", " └─Selection_14 9990.00 cop not(isnull(test.t1.b))", - " └─TableScan_13 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + " └─IndexScan_13 10000.00 cop table:t1, index:a, b, range:[NULL,+inf], keep order:false, stats:pseudo", )) } @@ -979,7 +990,7 @@ func (s *testAnalyzeSuite) TestIssue9805(c *C) { c.Assert(rs.Rows(), HasLen, 10) hasIndexLookUp12 := false for _, row := range rs.Rows() { - c.Assert(row, HasLen, 5) + c.Assert(row, HasLen, 6) if strings.HasSuffix(row[0].(string), "IndexLookUp_12") { hasIndexLookUp12 = true c.Assert(row[4], Equals, "time:0ns, loops:0, rows:0") @@ -1000,102 +1011,42 @@ func (s *testAnalyzeSuite) TestLimitCrossEstimation(c *C) { tk.MustExec("use test") tk.MustExec("drop table if exists t") - tk.MustExec("create table t(a int primary key, b int not null, index idx_b(b))") - // Pseudo stats. - tk.MustQuery("EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1;").Check(testkit.Rows( - "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", - "└─IndexReader_16 1.00 root index:TopN_15", - " └─TopN_15 1.00 cop test.t.a:asc, offset:0, count:1", - " └─IndexScan_14 10.00 cop table:t, index:b, range:[2,2], keep order:false, stats:pseudo", - )) - // Positive correlation. - tk.MustExec("insert into t values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 1),(8, 1),(9, 1),(10, 1),(11, 1),(12, 1),(13, 1),(14, 1),(15, 1),(16, 1),(17, 1),(18, 1),(19, 1),(20, 2),(21, 2),(22, 2),(23, 2),(24, 2),(25, 2)") - tk.MustExec("analyze table t") - tk.MustQuery("EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1;").Check(testkit.Rows( - "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", - "└─IndexReader_16 1.00 root index:TopN_15", - " └─TopN_15 1.00 cop test.t.a:asc, offset:0, count:1", - " └─IndexScan_14 6.00 cop table:t, index:b, range:[2,2], keep order:false", - )) - // Negative correlation. - tk.MustExec("truncate table t") - tk.MustExec("insert into t values (1, 25),(2, 24),(3, 23),(4, 23),(5, 21),(6, 20),(7, 19),(8, 18),(9, 17),(10, 16),(11, 15),(12, 14),(13, 13),(14, 12),(15, 11),(16, 10),(17, 9),(18, 8),(19, 7),(20, 6),(21, 5),(22, 4),(23, 3),(24, 2),(25, 1)") - tk.MustExec("analyze table t") - tk.MustQuery("EXPLAIN SELECT * FROM t WHERE b <= 6 ORDER BY a limit 1").Check(testkit.Rows( - "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", - "└─IndexReader_16 1.00 root index:TopN_15", - " └─TopN_15 1.00 cop test.t.a:asc, offset:0, count:1", - " └─IndexScan_14 6.00 cop table:t, index:b, range:[-inf,6], keep order:false", - )) - // Outer plan of index join (to test that correct column ID is used). - tk.MustQuery("EXPLAIN SELECT *, t1.a IN (SELECT t2.b FROM t t2) FROM t t1 WHERE t1.b <= 6 ORDER BY t1.a limit 1").Check(testkit.Rows( - "Limit_17 1.00 root offset:0, count:1", - "└─IndexJoin_58 1.00 root left outer semi join, inner:IndexReader_57, outer key:test.t1.a, inner key:test.t2.b", - " ├─TopN_23 1.00 root test.t1.a:asc, offset:0, count:1", - " │ └─IndexReader_31 1.00 root index:TopN_30", - " │ └─TopN_30 1.00 cop test.t1.a:asc, offset:0, count:1", - " │ └─IndexScan_29 6.00 cop table:t1, index:b, range:[-inf,6], keep order:false", - " └─IndexReader_57 1.04 root index:IndexScan_56", - " └─IndexScan_56 1.04 cop table:t2, index:b, range: decided by [eq(test.t2.b, test.t1.a)], keep order:false", - )) - // Desc TableScan. - tk.MustExec("truncate table t") - tk.MustExec("insert into t values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 2),(8, 2),(9, 2),(10, 2),(11, 2),(12, 2),(13, 2),(14, 2),(15, 2),(16, 2),(17, 2),(18, 2),(19, 2),(20, 2),(21, 2),(22, 2),(23, 2),(24, 2),(25, 2)") - tk.MustExec("analyze table t") - tk.MustQuery("EXPLAIN SELECT * FROM t WHERE b = 1 ORDER BY a desc limit 1").Check(testkit.Rows( - "TopN_8 1.00 root test.t.a:desc, offset:0, count:1", - "└─IndexReader_16 1.00 root index:TopN_15", - " └─TopN_15 1.00 cop test.t.a:desc, offset:0, count:1", - " └─IndexScan_14 6.00 cop table:t, index:b, range:[1,1], keep order:false", - )) - // Correlation threshold not met. - tk.MustExec("truncate table t") - tk.MustExec("insert into t values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 1),(8, 1),(9, 2),(10, 1),(11, 1),(12, 1),(13, 1),(14, 2),(15, 2),(16, 1),(17, 2),(18, 1),(19, 2),(20, 1),(21, 2),(22, 1),(23, 1),(24, 1),(25, 1)") - tk.MustExec("analyze table t") - tk.MustQuery("EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1").Check(testkit.Rows( - "Limit_11 1.00 root offset:0, count:1", - "└─TableReader_22 1.00 root data:Limit_21", - " └─Limit_21 1.00 cop offset:0, count:1", - " └─Selection_20 1.00 cop eq(test.t.b, 2)", - " └─TableScan_19 4.17 cop table:t, range:[-inf,+inf], keep order:true", - )) - tk.MustExec("set @@tidb_opt_correlation_exp_factor = 1") - tk.MustQuery("EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1").Check(testkit.Rows( - "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", - "└─IndexReader_16 1.00 root index:TopN_15", - " └─TopN_15 1.00 cop test.t.a:asc, offset:0, count:1", - " └─IndexScan_14 6.00 cop table:t, index:b, range:[2,2], keep order:false", - )) - tk.MustExec("set @@tidb_opt_correlation_exp_factor = 0") - // TableScan has access conditions, but correlation is 1. - tk.MustExec("truncate table t") - tk.MustExec("insert into t values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 1),(8, 1),(9, 1),(10, 1),(11, 1),(12, 1),(13, 1),(14, 1),(15, 1),(16, 1),(17, 1),(18, 1),(19, 1),(20, 2),(21, 2),(22, 2),(23, 2),(24, 2),(25, 2)") - tk.MustExec("analyze table t") - tk.MustQuery("EXPLAIN SELECT * FROM t WHERE b = 2 and a > 0 ORDER BY a limit 1").Check(testkit.Rows( - "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", - "└─IndexReader_19 1.00 root index:TopN_18", - " └─TopN_18 1.00 cop test.t.a:asc, offset:0, count:1", - " └─Selection_17 6.00 cop gt(test.t.a, 0)", - " └─IndexScan_16 6.00 cop table:t, index:b, range:[2,2], keep order:false", - )) - // Multi-column filter. - tk.MustExec("drop table t") - tk.MustExec("create table t(a int primary key, b int, c int, index idx_b(b))") - tk.MustExec("insert into t values (1, 1, 1),(2, 1, 2),(3, 1, 1),(4, 1, 2),(5, 1, 1),(6, 1, 2),(7, 1, 1),(8, 1, 2),(9, 1, 1),(10, 1, 2),(11, 1, 1),(12, 1, 2),(13, 1, 1),(14, 1, 2),(15, 1, 1),(16, 1, 2),(17, 1, 1),(18, 1, 2),(19, 1, 1),(20, 2, 2),(21, 2, 1),(22, 2, 2),(23, 2, 1),(24, 2, 2),(25, 2, 1)") - tk.MustExec("analyze table t") - tk.MustQuery("EXPLAIN SELECT * FROM t WHERE b = 2 and c > 0 ORDER BY a limit 1").Check(testkit.Rows( - "TopN_9 1.00 root test.t.a:asc, offset:0, count:1", - "└─IndexLookUp_22 1.00 root ", - " ├─IndexScan_18 6.00 cop table:t, index:b, range:[2,2], keep order:false", - " └─TopN_21 1.00 cop test.t.a:asc, offset:0, count:1", - " └─Selection_20 6.00 cop gt(test.t.c, 0)", - " └─TableScan_19 6.00 cop table:t, keep order:false", - )) - tk.MustQuery("EXPLAIN SELECT * FROM t WHERE b = 2 or c > 0 ORDER BY a limit 1").Check(testkit.Rows( - "Limit_11 1.00 root offset:0, count:1", - "└─TableReader_24 1.00 root data:Limit_23", - " └─Limit_23 1.00 cop offset:0, count:1", - " └─Selection_22 1.00 cop or(eq(test.t.b, 2), gt(test.t.c, 0))", - " └─TableScan_21 1.25 cop table:t, range:[-inf,+inf], keep order:true", - )) + tk.MustExec("create table t(a int primary key, b int not null, c int not null default 0, index idx_bc(b, c))") + var input [][]string + var output []struct { + SQL []string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, ts := range input { + for j, tt := range ts { + if j != len(ts)-1 { + tk.MustExec(tt) + } + s.testData.OnRecord(func() { + output[i].SQL = ts + if j == len(ts)-1 { + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + } + }) + if j == len(ts)-1 { + tk.MustQuery(tt).Check(testkit.Rows(output[i].Plan...)) + } + } + } +} + +func (s *testAnalyzeSuite) TestUpdateProjEliminate(c *C) { + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("explain update t t1, (select distinct b from t) t2 set t1.b = t2.b") } diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index ed0ad362ac43e..1117e6e0232f9 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -15,6 +15,7 @@ package core import ( "bytes" + "context" "fmt" "strconv" "strings" @@ -34,6 +35,7 @@ import ( "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/kvcache" "github.com/pingcap/tidb/util/ranger" + "github.com/pingcap/tidb/util/texttree" ) var planCacheCounter = metrics.PlanCacheCounter.WithLabelValues("prepare") @@ -74,9 +76,10 @@ type ShowNextRowID struct { type CheckTable struct { baseSchemaProducer - Tables []*ast.TableName - - GenExprs map[model.TableColumnID]expression.Expression + DBName string + Table table.Table + IndexInfos []*model.IndexInfo + IndexLookUpReaders []*PhysicalIndexLookUpReader } // RecoverIndex is used for backfilling corrupted index data. @@ -128,6 +131,33 @@ type CancelDDLJobs struct { JobIDs []int64 } +// ReloadExprPushdownBlacklist reloads the data from expr_pushdown_blacklist table. +type ReloadExprPushdownBlacklist struct { + baseSchemaProducer +} + +// ReloadOptRuleBlacklist reloads the data from opt_rule_blacklist table. +type ReloadOptRuleBlacklist struct { + baseSchemaProducer +} + +// AdminPluginsAction indicate action will be taken on plugins. +type AdminPluginsAction int + +const ( + // Enable indicates enable plugins. + Enable AdminPluginsAction = iota + 1 + // Disable indicates disable plugins. + Disable +) + +// AdminPlugins administrates tidb plugins. +type AdminPlugins struct { + baseSchemaProducer + Action AdminPluginsAction + Plugins []string +} + // Change represents a change plan. type Change struct { baseSchemaProducer @@ -150,12 +180,13 @@ type Execute struct { UsingVars []expression.Expression ExecID uint32 Stmt ast.StmtNode + StmtType string Plan Plan } // OptimizePreparedPlan optimizes the prepared statement. -func (e *Execute) OptimizePreparedPlan(ctx sessionctx.Context, is infoschema.InfoSchema) error { - vars := ctx.GetSessionVars() +func (e *Execute) OptimizePreparedPlan(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema) error { + vars := sctx.GetSessionVars() if e.Name != "" { e.ExecID = vars.PreparedStmtNameToID[e.Name] } @@ -163,6 +194,7 @@ func (e *Execute) OptimizePreparedPlan(ctx sessionctx.Context, is infoschema.Inf if !ok { return errors.Trace(ErrStmtNotFound) } + vars.StmtCtx.StmtType = prepared.StmtType if len(prepared.Params) != len(e.UsingVars) { return errors.Trace(ErrWrongParamCount) @@ -181,13 +213,13 @@ func (e *Execute) OptimizePreparedPlan(ctx sessionctx.Context, is infoschema.Inf if prepared.SchemaVersion != is.SchemaMetaVersion() { // If the schema version has changed we need to preprocess it again, // if this time it failed, the real reason for the error is schema changed. - err := Preprocess(ctx, prepared.Stmt, is, InPrepare) + err := Preprocess(sctx, prepared.Stmt, is, InPrepare) if err != nil { return ErrSchemaChanged.GenWithStack("Schema change caused error: %s", err.Error()) } prepared.SchemaVersion = is.SchemaMetaVersion() } - p, err := e.getPhysicalPlan(ctx, is, prepared) + p, err := e.getPhysicalPlan(ctx, sctx, is, prepared) if err != nil { return err } @@ -196,13 +228,13 @@ func (e *Execute) OptimizePreparedPlan(ctx sessionctx.Context, is infoschema.Inf return nil } -func (e *Execute) getPhysicalPlan(ctx sessionctx.Context, is infoschema.InfoSchema, prepared *ast.Prepared) (Plan, error) { +func (e *Execute) getPhysicalPlan(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema, prepared *ast.Prepared) (Plan, error) { var cacheKey kvcache.Key - sessionVars := ctx.GetSessionVars() + sessionVars := sctx.GetSessionVars() sessionVars.StmtCtx.UseCache = prepared.UseCache if prepared.UseCache { cacheKey = NewPSTMTPlanCacheKey(sessionVars, e.ExecID, prepared.SchemaVersion) - if cacheValue, exists := ctx.PreparedPlanCache().Get(cacheKey); exists { + if cacheValue, exists := sctx.PreparedPlanCache().Get(cacheKey); exists { if metrics.ResettablePlanCacheCounterFortTest { metrics.PlanCacheCounter.WithLabelValues("prepare").Inc() } else { @@ -216,12 +248,13 @@ func (e *Execute) getPhysicalPlan(ctx sessionctx.Context, is infoschema.InfoSche return plan, nil } } - p, err := OptimizeAstNode(ctx, prepared.Stmt, is) + p, err := OptimizeAstNode(ctx, sctx, prepared.Stmt, is) if err != nil { return nil, err } - if prepared.UseCache { - ctx.PreparedPlanCache().Put(cacheKey, NewPSTMTPlanCacheValue(p)) + _, isTableDual := p.(*PhysicalTableDual) + if !isTableDual && prepared.UseCache { + sctx.PreparedPlanCache().Put(cacheKey, NewPSTMTPlanCacheValue(p)) } return p, err } @@ -317,20 +350,19 @@ type Deallocate struct { // Show represents a show plan. type Show struct { - baseSchemaProducer + physicalSchemaProducer Tp ast.ShowStmtType // Databases/Tables/Columns/.... DBName string Table *ast.TableName // Used for showing columns. Column *ast.ColumnName // Used for `desc table column`. - Flag int // Some flag parsed from sql, such as FULL. + IndexName model.CIStr + Flag int // Some flag parsed from sql, such as FULL. Full bool User *auth.UserIdentity // Used for show grants. Roles []*auth.RoleIdentity // Used for show grants. IfNotExists bool // Used for `show create database if not exists` - Conditions []expression.Expression - GlobalScope bool // Used by show variables } @@ -478,15 +510,26 @@ type LoadStats struct { Path string } -// SplitIndexRegion represents a split index regions plan. -type SplitIndexRegion struct { +// SplitRegion represents a split regions plan. +type SplitRegion struct { baseSchemaProducer - Table table.Table + TableInfo *model.TableInfo IndexInfo *model.IndexInfo + Lower []types.Datum + Upper []types.Datum + Num int ValueLists [][]types.Datum } +// SplitRegionStatus represents a split regions status plan. +type SplitRegionStatus struct { + baseSchemaProducer + + Table table.Table + IndexInfo *model.IndexInfo +} + // DDL represents a DDL statement plan. type DDL struct { baseSchemaProducer @@ -513,7 +556,7 @@ func (e *Explain) prepareSchema() error { case ast.ExplainFormatROW: retFields := []string{"id", "count", "task", "operator info"} if e.Analyze { - retFields = append(retFields, "execution info") + retFields = append(retFields, "execution info", "memory") } schema := expression.NewSchema(make([]*expression.Column, 0, len(retFields))...) for _, fieldName := range retFields { @@ -556,7 +599,7 @@ func (e *Explain) explainPlanInRowFormat(p PhysicalPlan, taskType, indent string e.explainedPlans[p.ID()] = true // For every child we create a new sub-tree rooted by it. - childIndent := e.getIndent4Child(indent, isLastChild) + childIndent := texttree.Indent4Child(indent, isLastChild) for i, child := range p.Children() { if e.explainedPlans[child.ID()] { continue @@ -581,7 +624,7 @@ func (e *Explain) prepareOperatorInfo(p PhysicalPlan, taskType string, indent st operatorInfo := p.ExplainInfo() count := string(strconv.AppendFloat([]byte{}, p.statsInfo().RowCount, 'f', 2, 64)) explainID := p.ExplainID().String() - row := []string{e.prettyIdentifier(explainID, indent, isLastChild), count, taskType, operatorInfo} + row := []string{texttree.PrettyIdentifier(explainID, indent, isLastChild), count, taskType, operatorInfo} if e.Analyze { runtimeStatsColl := e.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl // There maybe some mock information for cop task to let runtimeStatsColl.Exists(p.ExplainID()) is true. @@ -593,72 +636,15 @@ func (e *Explain) prepareOperatorInfo(p PhysicalPlan, taskType string, indent st } else { row = append(row, "time:0ns, loops:0, rows:0") } - } - e.Rows = append(e.Rows, row) -} -const ( - // treeBody indicates the current operator sub-tree is not finished, still - // has child operators to be attached on. - treeBody = '│' - // treeMiddleNode indicates this operator is not the last child of the - // current sub-tree rooted by its parent. - treeMiddleNode = '├' - // treeLastNode indicates this operator is the last child of the current - // sub-tree rooted by its parent. - treeLastNode = '└' - // treeGap is used to represent the gap between the branches of the tree. - treeGap = ' ' - // treeNodeIdentifier is used to replace the treeGap once we need to attach - // a node to a sub-tree. - treeNodeIdentifier = '─' -) - -func (e *Explain) prettyIdentifier(id, indent string, isLastChild bool) string { - if len(indent) == 0 { - return id - } - - indentBytes := []rune(indent) - for i := len(indentBytes) - 1; i >= 0; i-- { - if indentBytes[i] != treeBody { - continue - } - - // Here we attach a new node to the current sub-tree by changing - // the closest treeBody to a: - // 1. treeLastNode, if this operator is the last child. - // 2. treeMiddleNode, if this operator is not the last child.. - if isLastChild { - indentBytes[i] = treeLastNode + tracker := e.ctx.GetSessionVars().StmtCtx.MemTracker.SearchTracker(p.ExplainID().String()) + if tracker != nil { + row = append(row, tracker.BytesToString(tracker.MaxConsumed())) } else { - indentBytes[i] = treeMiddleNode - } - break - } - - // Replace the treeGap between the treeBody and the node to a - // treeNodeIdentifier. - indentBytes[len(indentBytes)-1] = treeNodeIdentifier - return string(indentBytes) + id -} - -func (e *Explain) getIndent4Child(indent string, isLastChild bool) string { - if !isLastChild { - return string(append([]rune(indent), treeBody, treeGap)) - } - - // If the current node is the last node of the current operator tree, we - // need to end this sub-tree by changing the closest treeBody to a treeGap. - indentBytes := []rune(indent) - for i := len(indentBytes) - 1; i >= 0; i-- { - if indentBytes[i] == treeBody { - indentBytes[i] = treeGap - break + row = append(row, "N/A") } } - - return string(append(indentBytes, treeBody, treeGap)) + e.Rows = append(e.Rows, row) } func (e *Explain) prepareDotInfo(p PhysicalPlan) { diff --git a/planner/core/encode.go b/planner/core/encode.go new file mode 100644 index 0000000000000..4abb3ed9b7cd9 --- /dev/null +++ b/planner/core/encode.go @@ -0,0 +1,68 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "bytes" + "sync" + + "github.com/pingcap/tidb/util/plancodec" +) + +var encoderPool = sync.Pool{ + New: func() interface{} { + return &planEncoder{} + }, +} + +type planEncoder struct { + buf bytes.Buffer + encodedPlans map[int]bool +} + +// EncodePlan is used to encodePlan the plan to the plan tree with compressing. +func EncodePlan(p PhysicalPlan) string { + pn := encoderPool.Get().(*planEncoder) + defer encoderPool.Put(pn) + return pn.encodePlanTree(p) +} + +func (pn *planEncoder) encodePlanTree(p PhysicalPlan) string { + pn.encodedPlans = make(map[int]bool) + pn.buf.Reset() + pn.encodePlan(p, true, 0) + return plancodec.Compress(pn.buf.Bytes()) +} + +func (pn *planEncoder) encodePlan(p PhysicalPlan, isRoot bool, depth int) { + plancodec.EncodePlanNode(depth, p.ID(), p.TP(), isRoot, p.statsInfo().RowCount, p.ExplainInfo(), &pn.buf) + pn.encodedPlans[p.ID()] = true + + depth++ + for _, child := range p.Children() { + if pn.encodedPlans[child.ID()] { + continue + } + pn.encodePlan(child.(PhysicalPlan), isRoot, depth) + } + switch copPlan := p.(type) { + case *PhysicalTableReader: + pn.encodePlan(copPlan.tablePlan, false, depth) + case *PhysicalIndexReader: + pn.encodePlan(copPlan.indexPlan, false, depth) + case *PhysicalIndexLookUpReader: + pn.encodePlan(copPlan.indexPlan, false, depth) + pn.encodePlan(copPlan.tablePlan, false, depth) + } +} diff --git a/planner/core/errors.go b/planner/core/errors.go index 9facfedc6d240..72d75a9a5f47a 100644 --- a/planner/core/errors.go +++ b/planner/core/errors.go @@ -66,6 +66,7 @@ const ( codeDBaccessDenied = mysql.ErrDBaccessDenied codeTableaccessDenied = mysql.ErrTableaccessDenied codeSpecificAccessDenied = mysql.ErrSpecificAccessDenied + codeViewNoExplain = mysql.ErrViewNoExplain codeWindowFrameStartIllegal = mysql.ErrWindowFrameStartIllegal codeWindowFrameEndIllegal = mysql.ErrWindowFrameEndIllegal codeWindowFrameIllegal = mysql.ErrWindowFrameIllegal @@ -128,6 +129,7 @@ var ( ErrDBaccessDenied = terror.ClassOptimizer.New(mysql.ErrDBaccessDenied, mysql.MySQLErrName[mysql.ErrDBaccessDenied]) ErrTableaccessDenied = terror.ClassOptimizer.New(mysql.ErrTableaccessDenied, mysql.MySQLErrName[mysql.ErrTableaccessDenied]) ErrSpecificAccessDenied = terror.ClassOptimizer.New(mysql.ErrSpecificAccessDenied, mysql.MySQLErrName[mysql.ErrSpecificAccessDenied]) + ErrViewNoExplain = terror.ClassOptimizer.New(mysql.ErrViewNoExplain, mysql.MySQLErrName[mysql.ErrViewNoExplain]) ErrWindowFrameStartIllegal = terror.ClassOptimizer.New(codeWindowFrameStartIllegal, mysql.MySQLErrName[mysql.ErrWindowFrameStartIllegal]) ErrWindowFrameEndIllegal = terror.ClassOptimizer.New(codeWindowFrameEndIllegal, mysql.MySQLErrName[mysql.ErrWindowFrameEndIllegal]) ErrWindowFrameIllegal = terror.ClassOptimizer.New(codeWindowFrameIllegal, mysql.MySQLErrName[mysql.ErrWindowFrameIllegal]) @@ -183,6 +185,7 @@ func init() { codeDBaccessDenied: mysql.ErrDBaccessDenied, codeTableaccessDenied: mysql.ErrTableaccessDenied, codeSpecificAccessDenied: mysql.ErrSpecificAccessDenied, + codeViewNoExplain: mysql.ErrViewNoExplain, codeWindowFrameStartIllegal: mysql.ErrWindowFrameStartIllegal, codeWindowFrameEndIllegal: mysql.ErrWindowFrameEndIllegal, codeWindowFrameIllegal: mysql.ErrWindowFrameIllegal, diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index e1eecf0766540..25e6de5922a76 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -292,38 +292,17 @@ func (p *LogicalJoin) getHashJoin(prop *property.PhysicalProperty, innerIdx int) return hashJoin } -// joinKeysMatchIndex checks whether the join key is in the index. -// It returns a slice a[] what a[i] means keys[i] is related with indexCols[a[i]], -1 for no matching column. -// It will return nil if there's no column that matches index. -func joinKeysMatchIndex(keys, indexCols []*expression.Column, colLengths []int) []int { - keyOff2IdxOff := make([]int, len(keys)) - for i := range keyOff2IdxOff { - keyOff2IdxOff[i] = -1 - } - // There should be at least one column in join keys which can match the index's column. - matched := false - tmpSchema := expression.NewSchema(keys...) - for i, idxCol := range indexCols { - if colLengths[i] != types.UnspecifiedLength { - continue - } - keyOff := tmpSchema.ColumnIndex(idxCol) - if keyOff == -1 { - continue - } - matched = true - keyOff2IdxOff[keyOff] = i - } - if !matched { - return nil - } - return keyOff2IdxOff -} - // When inner plan is TableReader, the parameter `ranges` will be nil. Because pk only have one column. So all of its range // is generated during execution time. -func (p *LogicalJoin) constructIndexJoin(prop *property.PhysicalProperty, outerIdx int, innerPlan PhysicalPlan, - ranges []*ranger.Range, keyOff2IdxOff []int, compareFilters *ColWithCmpFuncManager) []PhysicalPlan { +func (p *LogicalJoin) constructIndexJoin( + prop *property.PhysicalProperty, + outerIdx int, + innerPlan PhysicalPlan, + ranges []*ranger.Range, + keyOff2IdxOff []int, + lens []int, + compareFilters *ColWithCmpFuncManager, +) []PhysicalPlan { joinType := p.JoinType outerSchema := p.children[outerIdx].Schema() var ( @@ -373,6 +352,7 @@ func (p *LogicalJoin) constructIndexJoin(prop *property.PhysicalProperty, outerI DefaultValues: p.DefaultValues, innerPlan: innerPlan, KeyOff2IdxOff: newKeyOff, + IdxColLens: lens, Ranges: ranges, CompareFilters: compareFilters, }.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), chReqProps...) @@ -431,7 +411,7 @@ func (p *LogicalJoin) getIndexJoinByOuterIdx(prop *property.PhysicalProperty, ou innerPlan := p.constructInnerTableScan(ds, pkCol, outerJoinKeys, us) // Since the primary key means one value corresponding to exact one row, this will always be a no worse one // comparing to other index. - return p.constructIndexJoin(prop, outerIdx, innerPlan, nil, keyOff2IdxOff, nil) + return p.constructIndexJoin(prop, outerIdx, innerPlan, nil, keyOff2IdxOff, nil, nil) } } helper := &indexJoinBuildHelper{join: p} @@ -440,7 +420,10 @@ func (p *LogicalJoin) getIndexJoinByOuterIdx(prop *property.PhysicalProperty, ou continue } indexInfo := path.index - err := helper.analyzeLookUpFilters(indexInfo, ds, innerJoinKeys) + emptyRange, err := helper.analyzeLookUpFilters(indexInfo, ds, innerJoinKeys) + if emptyRange { + return nil + } if err != nil { logutil.Logger(context.Background()).Warn("build index join failed", zap.Error(err)) } @@ -455,10 +438,10 @@ func (p *LogicalJoin) getIndexJoinByOuterIdx(prop *property.PhysicalProperty, ou keyOff2IdxOff[keyOff] = idxOff } } - idxCols, _ := expression.IndexInfo2Cols(ds.schema.Columns, helper.chosenIndexInfo) + idxCols, lens := expression.IndexInfo2Cols(ds.schema.Columns, helper.chosenIndexInfo) rangeInfo := helper.buildRangeDecidedByInformation(idxCols, outerJoinKeys) innerPlan := p.constructInnerIndexScan(ds, helper.chosenIndexInfo, helper.chosenRemained, outerJoinKeys, us, rangeInfo) - return p.constructIndexJoin(prop, outerIdx, innerPlan, helper.chosenRanges, keyOff2IdxOff, helper.lastColManager) + return p.constructIndexJoin(prop, outerIdx, innerPlan, helper.chosenRanges, keyOff2IdxOff, lens, helper.lastColManager) } return nil } @@ -561,6 +544,8 @@ func (p *LogicalJoin) constructInnerIndexScan(ds *DataSource, idx *model.IndexIn KeepOrder: false, Ranges: ranger.FullRange(), rangeInfo: rangeInfo, + isPartition: ds.isPartition, + physicalTableID: ds.physicalTableID, }.Init(ds.ctx) var rowCount float64 @@ -581,15 +566,35 @@ func (p *LogicalJoin) constructInnerIndexScan(ds *DataSource, idx *model.IndexIn } if !isCoveringIndex(ds.schema.Columns, is.Index.Columns, is.Table.PKIsHandle) { // On this way, it's double read case. - ts := PhysicalTableScan{Columns: ds.Columns, Table: is.Table}.Init(ds.ctx) + ts := PhysicalTableScan{ + Columns: ds.Columns, + Table: is.Table, + isPartition: ds.isPartition, + physicalTableID: ds.physicalTableID, + }.Init(ds.ctx) ts.SetSchema(is.dataSourceSchema) cop.tablePlan = ts } - is.initSchema(ds.id, idx, cop.tablePlan != nil) + is.initSchema(idx, cop.tablePlan != nil) indexConds, tblConds := splitIndexFilterConditions(filterConds, idx.Columns, ds.tableInfo) - path := &accessPath{indexFilters: indexConds, tableFilters: tblConds, countAfterIndex: math.MaxFloat64} - is.addPushedDownSelection(cop, ds, math.MaxFloat64, path) + path := &accessPath{ + indexFilters: indexConds, + tableFilters: tblConds, + countAfterAccess: rowCount, + } + // Assume equal conditions used by index join and other conditions are independent. + if len(indexConds) > 0 { + selectivity, _, err := ds.tableStats.HistColl.Selectivity(ds.ctx, indexConds) + if err != nil { + logutil.Logger(context.Background()).Debug("calculate selectivity failed, use selection factor", zap.Error(err)) + selectivity = selectionFactor + } + path.countAfterIndex = rowCount * selectivity + } + selectivity := ds.stats.RowCount / ds.tableStats.RowCount + finalStats := ds.stats.ScaleByExpectCnt(selectivity * rowCount) + is.addPushedDownSelection(cop, ds, path, finalStats) t := finishCopTask(ds.ctx, cop) reader := t.plan() return p.constructInnerUnionScan(us, reader) @@ -786,10 +791,10 @@ func (ijHelper *indexJoinBuildHelper) removeUselessEqAndInFunc( return notKeyEqAndIn, nil } -func (ijHelper *indexJoinBuildHelper) analyzeLookUpFilters(indexInfo *model.IndexInfo, innerPlan *DataSource, innerJoinKeys []*expression.Column) error { +func (ijHelper *indexJoinBuildHelper) analyzeLookUpFilters(indexInfo *model.IndexInfo, innerPlan *DataSource, innerJoinKeys []*expression.Column) (emptyRange bool, err error) { idxCols, colLengths := expression.IndexInfo2Cols(innerPlan.schema.Columns, indexInfo) if len(idxCols) == 0 { - return nil + return false, nil } accesses := make([]expression.Expression, 0, len(idxCols)) ijHelper.resetContextForIndex(innerJoinKeys, idxCols, colLengths) @@ -799,7 +804,7 @@ func (ijHelper *indexJoinBuildHelper) analyzeLookUpFilters(indexInfo *model.Inde matchedKeyCnt := len(ijHelper.curPossibleUsedKeys) // If no join key is matched while join keys actually are not empty. We don't choose index join for now. if matchedKeyCnt <= 0 && len(innerJoinKeys) > 0 { - return nil + return false, nil } accesses = append(accesses, notKeyEqAndIn...) remained = append(remained, remainedEqAndIn...) @@ -807,7 +812,7 @@ func (ijHelper *indexJoinBuildHelper) analyzeLookUpFilters(indexInfo *model.Inde // There should be some equal conditions. But we don't need that there must be some join key in accesses here. // A more strict check is applied later. if lastColPos <= 0 { - return nil + return false, nil } // If all the index columns are covered by eq/in conditions, we don't need to consider other conditions anymore. if lastColPos == len(idxCols) { @@ -815,15 +820,18 @@ func (ijHelper *indexJoinBuildHelper) analyzeLookUpFilters(indexInfo *model.Inde // e.g. select * from t1, t2 where t2.a=1 and t2.b=1. And t2 has index(a, b). // If we don't have the following check, TiDB will build index join for this case. if matchedKeyCnt <= 0 { - return nil + return false, nil } remained = append(remained, rangeFilterCandidates...) - ranges, err := ijHelper.buildTemplateRange(matchedKeyCnt, notKeyEqAndIn, nil, false) + ranges, emptyRange, err := ijHelper.buildTemplateRange(matchedKeyCnt, notKeyEqAndIn, nil, false) if err != nil { - return err + return false, err + } + if emptyRange { + return true, nil } ijHelper.updateBestChoice(ranges, indexInfo, accesses, remained, nil) - return nil + return false, nil } lastPossibleCol := idxCols[lastColPos] lastColManager := &ColWithCmpFuncManager{ @@ -838,7 +846,7 @@ func (ijHelper *indexJoinBuildHelper) analyzeLookUpFilters(indexInfo *model.Inde // e.g. select * from t1, t2 where t2.a=1 and t2.b=1 and t2.c > 10 and t2.c < 20. And t2 has index(a, b, c). // If we don't have the following check, TiDB will build index join for this case. if matchedKeyCnt <= 0 { - return nil + return false, nil } colAccesses, colRemained := ranger.DetachCondsForColumn(ijHelper.join.ctx, rangeFilterCandidates, lastPossibleCol) var ranges, nextColRange []*ranger.Range @@ -846,12 +854,15 @@ func (ijHelper *indexJoinBuildHelper) analyzeLookUpFilters(indexInfo *model.Inde if len(colAccesses) > 0 { nextColRange, err = ranger.BuildColumnRange(colAccesses, ijHelper.join.ctx.GetSessionVars().StmtCtx, lastPossibleCol.RetType, colLengths[lastColPos]) if err != nil { - return err + return false, err } } - ranges, err = ijHelper.buildTemplateRange(matchedKeyCnt, notKeyEqAndIn, nextColRange, false) + ranges, emptyRange, err = ijHelper.buildTemplateRange(matchedKeyCnt, notKeyEqAndIn, nextColRange, false) if err != nil { - return err + return false, err + } + if emptyRange { + return true, nil } remained = append(remained, colRemained...) if colLengths[lastColPos] != types.UnspecifiedLength { @@ -859,16 +870,19 @@ func (ijHelper *indexJoinBuildHelper) analyzeLookUpFilters(indexInfo *model.Inde } accesses = append(accesses, colAccesses...) ijHelper.updateBestChoice(ranges, indexInfo, accesses, remained, nil) - return nil + return false, nil } accesses = append(accesses, lastColAccess...) remained = append(remained, rangeFilterCandidates...) - ranges, err := ijHelper.buildTemplateRange(matchedKeyCnt, notKeyEqAndIn, nil, true) + ranges, emptyRange, err := ijHelper.buildTemplateRange(matchedKeyCnt, notKeyEqAndIn, nil, true) if err != nil { - return err + return false, err + } + if emptyRange { + return true, nil } ijHelper.updateBestChoice(ranges, indexInfo, accesses, remained, lastColManager) - return nil + return false, nil } func (ijHelper *indexJoinBuildHelper) updateBestChoice(ranges []*ranger.Range, idxInfo *model.IndexInfo, accesses, @@ -887,7 +901,7 @@ func (ijHelper *indexJoinBuildHelper) updateBestChoice(ranges []*ranger.Range, i } } -func (ijHelper *indexJoinBuildHelper) buildTemplateRange(matchedKeyCnt int, eqAndInFuncs []expression.Expression, nextColRange []*ranger.Range, haveExtraCol bool) (ranges []*ranger.Range, err error) { +func (ijHelper *indexJoinBuildHelper) buildTemplateRange(matchedKeyCnt int, eqAndInFuncs []expression.Expression, nextColRange []*ranger.Range, haveExtraCol bool) (ranges []*ranger.Range, emptyRange bool, err error) { pointLength := matchedKeyCnt + len(eqAndInFuncs) if nextColRange != nil { for _, colRan := range nextColRange { @@ -914,49 +928,37 @@ func (ijHelper *indexJoinBuildHelper) buildTemplateRange(matchedKeyCnt int, eqAn HighVal: make([]types.Datum, pointLength, pointLength), }) } - emptyRow := chunk.Row{} + sc := ijHelper.join.ctx.GetSessionVars().StmtCtx for i, j := 0, 0; j < len(eqAndInFuncs); i++ { // This position is occupied by join key. if ijHelper.curIdxOff2KeyOff[i] != -1 { continue } - sf := eqAndInFuncs[j].(*expression.ScalarFunction) - // Deal with the first two args. - if _, ok := sf.GetArgs()[0].(*expression.Column); ok { - for _, ran := range ranges { - ran.LowVal[i], err = sf.GetArgs()[1].Eval(emptyRow) - if err != nil { - return nil, err - } - ran.HighVal[i] = ran.LowVal[i] - } - } else { - for _, ran := range ranges { - ran.LowVal[i], err = sf.GetArgs()[0].Eval(emptyRow) - if err != nil { - return nil, err - } - ran.HighVal[i] = ran.LowVal[i] - } + oneColumnRan, err := ranger.BuildColumnRange([]expression.Expression{eqAndInFuncs[j]}, sc, ijHelper.curNotUsedIndexCols[j].RetType, ijHelper.curNotUsedColLens[j]) + if err != nil { + return nil, false, err + } + if len(oneColumnRan) == 0 { + return nil, true, nil + } + for _, ran := range ranges { + ran.LowVal[i] = oneColumnRan[0].LowVal[0] + ran.HighVal[i] = oneColumnRan[0].HighVal[0] } - // If the length of in function's constant list is more than one, we will expand ranges. curRangeLen := len(ranges) - for argIdx := 2; argIdx < len(sf.GetArgs()); argIdx++ { + for ranIdx := 1; ranIdx < len(oneColumnRan); ranIdx++ { newRanges := make([]*ranger.Range, 0, curRangeLen) for oldRangeIdx := 0; oldRangeIdx < curRangeLen; oldRangeIdx++ { newRange := ranges[oldRangeIdx].Clone() - newRange.LowVal[i], err = sf.GetArgs()[argIdx].Eval(emptyRow) - if err != nil { - return nil, err - } - newRange.HighVal[i] = newRange.LowVal[i] + newRange.LowVal[i] = oneColumnRan[ranIdx].LowVal[0] + newRange.HighVal[i] = oneColumnRan[ranIdx].HighVal[0] newRanges = append(newRanges, newRange) } ranges = append(ranges, newRanges...) } j++ } - return ranges, nil + return ranges, false, nil } // tryToGetIndexJoin will get index join by hints. If we can generate a valid index join by hint, the second return value @@ -1007,15 +1009,19 @@ func (p *LogicalJoin) tryToGetIndexJoin(prop *property.PhysicalProperty) (indexJ } if leftJoins != nil && lhsCardinality < rhsCardinality { - return leftJoins, hasIndexJoinHint + return leftJoins, leftOuter } if rightJoins != nil && rhsCardinality < lhsCardinality { - return rightJoins, hasIndexJoinHint + return rightJoins, rightOuter } + canForceLeft := leftJoins != nil && leftOuter + canForceRight := rightJoins != nil && rightOuter + forced = canForceLeft || canForceRight + joins := append(leftJoins, rightJoins...) - return joins, hasIndexJoinHint && len(joins) != 0 + return joins, forced } return nil, false diff --git a/planner/core/exhaust_physical_plans_test.go b/planner/core/exhaust_physical_plans_test.go index a7cf2a21a9a1f..644c9607ec795 100644 --- a/planner/core/exhaust_physical_plans_test.go +++ b/planner/core/exhaust_physical_plans_test.go @@ -227,7 +227,7 @@ func (s *testUnitTestSuit) TestIndexJoinAnalyzeLookUpFilters(c *C) { c.Assert(err, IsNil) joinNode.OtherConditions = others helper := &indexJoinBuildHelper{join: joinNode, lastColManager: nil} - err = helper.analyzeLookUpFilters(idxInfo, dataSourceNode, tt.innerKeys) + _, err = helper.analyzeLookUpFilters(idxInfo, dataSourceNode, tt.innerKeys) c.Assert(err, IsNil) c.Assert(fmt.Sprintf("%v", helper.chosenRanges), Equals, tt.ranges, Commentf("test case: #%v", i)) c.Assert(fmt.Sprintf("%v", helper.idxOff2KeyOff), Equals, tt.idxOff2KeyOff) diff --git a/planner/core/explain.go b/planner/core/explain.go index 3b3a027fbe606..1f865d663ae11 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -141,6 +141,9 @@ func (p *PhysicalIndexReader) ExplainInfo() string { // ExplainInfo implements PhysicalPlan interface. func (p *PhysicalIndexLookUpReader) ExplainInfo() string { // The children can be inferred by the relation symbol. + if p.PushedLimit != nil { + return fmt.Sprintf("limit embedded(offset:%v, count:%v)", p.PushedLimit.Offset, p.PushedLimit.Count) + } return "" } @@ -233,7 +236,13 @@ func (p *PhysicalIndexJoin) ExplainInfo() string { // ExplainInfo implements PhysicalPlan interface. func (p *PhysicalHashJoin) ExplainInfo() string { - buffer := bytes.NewBufferString(p.JoinType.String()) + buffer := new(bytes.Buffer) + + if len(p.EqualConditions) == 0 { + buffer.WriteString("CARTESIAN ") + } + + buffer.WriteString(p.JoinType.String()) fmt.Fprintf(buffer, ", inner:%s", p.Children()[p.InnerChildIdx].ExplainID()) if len(p.EqualConditions) > 0 { fmt.Fprintf(buffer, ", equal:%v", p.EqualConditions) diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 91ea45fca7a26..2f7cbea61fa24 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -14,6 +14,7 @@ package core import ( + "context" "strconv" "strings" @@ -29,39 +30,39 @@ import ( "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/types/parser_driver" + driver "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/stringutil" ) // EvalSubquery evaluates incorrelated subqueries once. -var EvalSubquery func(p PhysicalPlan, is infoschema.InfoSchema, ctx sessionctx.Context) ([][]types.Datum, error) +var EvalSubquery func(ctx context.Context, p PhysicalPlan, is infoschema.InfoSchema, sctx sessionctx.Context) ([][]types.Datum, error) // evalAstExpr evaluates ast expression directly. -func evalAstExpr(ctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error) { +func evalAstExpr(sctx sessionctx.Context, expr ast.ExprNode) (types.Datum, error) { if val, ok := expr.(*driver.ValueExpr); ok { return val.Datum, nil } b := &PlanBuilder{ - ctx: ctx, + ctx: sctx, colMapper: make(map[*ast.ColumnNameExpr]int), } - if ctx.GetSessionVars().TxnCtx.InfoSchema != nil { - b.is = ctx.GetSessionVars().TxnCtx.InfoSchema.(infoschema.InfoSchema) + if sctx.GetSessionVars().TxnCtx.InfoSchema != nil { + b.is = sctx.GetSessionVars().TxnCtx.InfoSchema.(infoschema.InfoSchema) } - fakePlan := LogicalTableDual{}.Init(ctx) - newExpr, _, err := b.rewrite(expr, fakePlan, nil, true) + fakePlan := LogicalTableDual{}.Init(sctx) + newExpr, _, err := b.rewrite(context.TODO(), expr, fakePlan, nil, true) if err != nil { return types.Datum{}, err } return newExpr.Eval(chunk.Row{}) } -func (b *PlanBuilder) rewriteInsertOnDuplicateUpdate(exprNode ast.ExprNode, mockPlan LogicalPlan, insertPlan *Insert) (expression.Expression, error) { +func (b *PlanBuilder) rewriteInsertOnDuplicateUpdate(ctx context.Context, exprNode ast.ExprNode, mockPlan LogicalPlan, insertPlan *Insert) (expression.Expression, error) { b.rewriterCounter++ defer func() { b.rewriterCounter-- }() - rewriter := b.getExpressionRewriter(mockPlan) + rewriter := b.getExpressionRewriter(ctx, mockPlan) // The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is // not nil means certain previous procedure has not handled this error. // Here we give us one more chance to make a correct behavior by handling @@ -81,19 +82,19 @@ func (b *PlanBuilder) rewriteInsertOnDuplicateUpdate(exprNode ast.ExprNode, mock // aggMapper maps ast.AggregateFuncExpr to the columns offset in p's output schema. // asScalar means whether this expression must be treated as a scalar expression. // And this function returns a result expression, a new plan that may have apply or semi-join. -func (b *PlanBuilder) rewrite(exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool) (expression.Expression, LogicalPlan, error) { - expr, resultPlan, err := b.rewriteWithPreprocess(exprNode, p, aggMapper, nil, asScalar, nil) +func (b *PlanBuilder) rewrite(ctx context.Context, exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, asScalar bool) (expression.Expression, LogicalPlan, error) { + expr, resultPlan, err := b.rewriteWithPreprocess(ctx, exprNode, p, aggMapper, nil, asScalar, nil) return expr, resultPlan, err } // rewriteWithPreprocess is for handling the situation that we need to adjust the input ast tree // before really using its node in `expressionRewriter.Leave`. In that case, we first call // er.preprocess(expr), which returns a new expr. Then we use the new expr in `Leave`. -func (b *PlanBuilder) rewriteWithPreprocess(exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, asScalar bool, preprocess func(ast.Node) ast.Node) (expression.Expression, LogicalPlan, error) { +func (b *PlanBuilder) rewriteWithPreprocess(ctx context.Context, exprNode ast.ExprNode, p LogicalPlan, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, asScalar bool, preprocess func(ast.Node) ast.Node) (expression.Expression, LogicalPlan, error) { b.rewriterCounter++ defer func() { b.rewriterCounter-- }() - rewriter := b.getExpressionRewriter(p) + rewriter := b.getExpressionRewriter(ctx, p) // The rewriter maybe is obtained from "b.rewriterPool", "rewriter.err" is // not nil means certain previous procedure has not handled this error. // Here we give us one more chance to make a correct behavior by handling @@ -111,7 +112,7 @@ func (b *PlanBuilder) rewriteWithPreprocess(exprNode ast.ExprNode, p LogicalPlan return expr, resultPlan, err } -func (b *PlanBuilder) getExpressionRewriter(p LogicalPlan) (rewriter *expressionRewriter) { +func (b *PlanBuilder) getExpressionRewriter(ctx context.Context, p LogicalPlan) (rewriter *expressionRewriter) { defer func() { if p != nil { rewriter.schema = p.Schema() @@ -119,7 +120,7 @@ func (b *PlanBuilder) getExpressionRewriter(p LogicalPlan) (rewriter *expression }() if len(b.rewriterPool) < b.rewriterCounter { - rewriter = &expressionRewriter{p: p, b: b, ctx: b.ctx} + rewriter = &expressionRewriter{p: p, b: b, sctx: b.ctx, ctx: ctx} b.rewriterPool = append(b.rewriterPool, rewriter) return } @@ -132,6 +133,7 @@ func (b *PlanBuilder) getExpressionRewriter(p LogicalPlan) (rewriter *expression rewriter.insertPlan = nil rewriter.disableFoldCounter = 0 rewriter.ctxStack = rewriter.ctxStack[:0] + rewriter.ctx = ctx return } @@ -161,7 +163,8 @@ type expressionRewriter struct { aggrMap map[*ast.AggregateFuncExpr]int windowMap map[*ast.WindowFuncExpr]int b *PlanBuilder - ctx sessionctx.Context + sctx sessionctx.Context + ctx context.Context // asScalar indicates the return value must be a scalar value. // NOTE: This value can be changed during expression rewritten. @@ -206,21 +209,21 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, } } if op == ast.NE { - return expression.ComposeDNFCondition(er.ctx, funcs...), nil + return expression.ComposeDNFCondition(er.sctx, funcs...), nil } - return expression.ComposeCNFCondition(er.ctx, funcs...), nil + return expression.ComposeCNFCondition(er.sctx, funcs...), nil default: larg0, rarg0 := expression.GetFuncArg(l, 0), expression.GetFuncArg(r, 0) var expr1, expr2, expr3, expr4, expr5 expression.Expression - expr1 = expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) - expr2 = expression.NewFunctionInternal(er.ctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) - expr3 = expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr1) + expr1 = expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) + expr2 = expression.NewFunctionInternal(er.sctx, op, types.NewFieldType(mysql.TypeTiny), larg0, rarg0) + expr3 = expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr1) var err error - l, err = expression.PopRowFirstArg(er.ctx, l) + l, err = expression.PopRowFirstArg(er.sctx, l) if err != nil { return nil, err } - r, err = expression.PopRowFirstArg(er.ctx, r) + r, err = expression.PopRowFirstArg(er.sctx, r) if err != nil { return nil, err } @@ -236,14 +239,14 @@ func (er *expressionRewriter) constructBinaryOpFunction(l expression.Expression, } } -func (er *expressionRewriter) buildSubquery(subq *ast.SubqueryExpr) (LogicalPlan, error) { +func (er *expressionRewriter) buildSubquery(ctx context.Context, subq *ast.SubqueryExpr) (LogicalPlan, error) { if er.schema != nil { outerSchema := er.schema.Clone() er.b.outerSchemas = append(er.b.outerSchemas, outerSchema) defer func() { er.b.outerSchemas = er.b.outerSchemas[0 : len(er.b.outerSchemas)-1] }() } - np, err := er.b.buildResultSetNode(subq.Query) + np, err := er.b.buildResultSetNode(ctx, subq.Query) if err != nil { return nil, err } @@ -270,12 +273,12 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { return inNode, true } case *ast.CompareSubqueryExpr: - return er.handleCompareSubquery(v) + return er.handleCompareSubquery(er.ctx, v) case *ast.ExistsSubqueryExpr: - return er.handleExistSubquery(v) + return er.handleExistSubquery(er.ctx, v) case *ast.PatternInExpr: if v.Sel != nil { - return er.handleInSubquery(v) + return er.handleInSubquery(er.ctx, v) } if len(v.List) != 1 { break @@ -287,7 +290,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { switch y := x.(type) { case *ast.SubqueryExpr: v.Sel = y - return er.handleInSubquery(v) + return er.handleInSubquery(er.ctx, v) case *ast.ParenthesesExpr: x = y.Expr default: @@ -295,7 +298,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { } } case *ast.SubqueryExpr: - return er.handleScalarSubquery(v) + return er.handleScalarSubquery(er.ctx, v) case *ast.ParenthesesExpr: case *ast.ValuesExpr: schema := er.schema @@ -314,7 +317,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { er.err = ErrUnknownColumn.GenWithStackByArgs(v.Column.Name.OrigColName(), "field list") return inNode, false } - er.ctxStack = append(er.ctxStack, expression.NewValuesFunc(er.ctx, col.Index, col.RetType)) + er.ctxStack = append(er.ctxStack, expression.NewValuesFunc(er.sctx, col.Index, col.RetType)) return inNode, true case *ast.WindowFuncExpr: index, ok := -1, false @@ -322,7 +325,7 @@ func (er *expressionRewriter) Enter(inNode ast.Node) (ast.Node, bool) { index, ok = er.windowMap[v] } if !ok { - er.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(v.F) + er.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.F)) return inNode, true } er.ctxStack = append(er.ctxStack, er.schema.Columns[index]) @@ -354,7 +357,7 @@ func (er *expressionRewriter) buildSemiApplyFromEqualSubq(np LogicalPlan, l, r e er.p, er.err = er.b.buildSemiApply(er.p, np, []expression.Expression{condition}, er.asScalar, not) } -func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) (ast.Node, bool) { +func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, v *ast.CompareSubqueryExpr) (ast.Node, bool) { v.L.Accept(er) if er.err != nil { return v, true @@ -365,7 +368,7 @@ func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) er.err = errors.Errorf("Unknown compare type %T.", v.R) return v, true } - np, err := er.buildSubquery(subq) + np, err := er.buildSubquery(ctx, subq) if err != nil { er.err = err return v, true @@ -437,7 +440,7 @@ func (er *expressionRewriter) handleCompareSubquery(v *ast.CompareSubqueryExpr) // handleOtherComparableSubq handles the queries like < any, < max, etc. For example, if the query is t.id < any (select s.id from s), // it will be rewrote to t.id < (select max(s.id) from s). func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.Expression, np LogicalPlan, useMin bool, cmpFunc string, all bool) { - plan4Agg := LogicalAggregation{}.Init(er.ctx) + plan4Agg := LogicalAggregation{}.Init(er.sctx) plan4Agg.SetChildren(np) // Create a "max" or "min" aggregation. @@ -445,12 +448,16 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression. if useMin { funcName = ast.AggFuncMin } - funcMaxOrMin := aggregation.NewAggFuncDesc(er.ctx, funcName, []expression.Expression{rexpr}, false) + funcMaxOrMin, err := aggregation.NewAggFuncDesc(er.sctx, funcName, []expression.Expression{rexpr}, false) + if err != nil { + er.err = err + return + } // Create a column and append it to the schema of that aggregation. colMaxOrMin := &expression.Column{ ColName: model.NewCIStr("agg_Col_0"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: funcMaxOrMin.RetTp, } schema := expression.NewSchema(colMaxOrMin) @@ -458,30 +465,38 @@ func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression. plan4Agg.SetSchema(schema) plan4Agg.AggFuncs = []*aggregation.AggFuncDesc{funcMaxOrMin} - cond := expression.NewFunctionInternal(er.ctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin) + cond := expression.NewFunctionInternal(er.sctx, cmpFunc, types.NewFieldType(mysql.TypeTiny), lexpr, colMaxOrMin) er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, all) } // buildQuantifierPlan adds extra condition for any / all subquery. func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, cond, lexpr, rexpr expression.Expression, all bool) { - innerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) - outerIsNull := expression.NewFunctionInternal(er.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr) + innerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), rexpr) + outerIsNull := expression.NewFunctionInternal(er.sctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), lexpr) - funcSum := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) + funcSum, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncSum, []expression.Expression{innerIsNull}, false) + if err != nil { + er.err = err + return + } colSum := &expression.Column{ ColName: model.NewCIStr("agg_col_sum"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: funcSum.RetTp, } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcSum) plan4Agg.schema.Append(colSum) - innerHasNull := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) + innerHasNull := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colSum, expression.Zero) // Build `count(1)` aggregation to check if subquery is empty. - funcCount := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) + funcCount, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []expression.Expression{expression.One}, false) + if err != nil { + er.err = err + return + } colCount := &expression.Column{ ColName: model.NewCIStr("agg_col_cnt"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: funcCount.RetTp, } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, funcCount) @@ -490,23 +505,23 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, if all { // All of the inner record set should not contain null value. So for t.id < all(select s.id from s), it // should be rewrote to t.id < min(s.id) and if(sum(s.id is null) != 0, null, true). - innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.One) - cond = expression.ComposeCNFCondition(er.ctx, cond, innerNullChecker) + innerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.sctx, cond, innerNullChecker) // If the subquery is empty, it should always return true. - emptyChecker := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + emptyChecker := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) // If outer key is null, and subquery is not empty, it should always return null, even when it is `null = all (1, 2)`. - outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) + outerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.sctx, cond, emptyChecker, outerNullChecker) } else { // For "any" expression, if the subquery has null and the cond returns false, the result should be NULL. // Specifically, `t.id < any (select s.id from s)` would be rewrote to `t.id < max(s.id) or if(sum(s.id is null) != 0, null, false)` - innerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.Zero) - cond = expression.ComposeDNFCondition(er.ctx, cond, innerNullChecker) + innerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), innerHasNull, expression.Null, expression.Zero) + cond = expression.ComposeDNFCondition(er.sctx, cond, innerNullChecker) // If the subquery is empty, it should always return false. - emptyChecker := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) + emptyChecker := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), colCount, expression.Zero) // If outer key is null, and subquery is not empty, it should return null. - outerNullChecker := expression.NewFunctionInternal(er.ctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.One) - cond = expression.ComposeCNFCondition(er.ctx, cond, emptyChecker, outerNullChecker) + outerNullChecker := expression.NewFunctionInternal(er.sctx, ast.If, types.NewFieldType(mysql.TypeTiny), outerIsNull, expression.Null, expression.One) + cond = expression.ComposeCNFCondition(er.sctx, cond, emptyChecker, outerNullChecker) } // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. @@ -522,12 +537,12 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, joinSchema := er.p.Schema() proj := LogicalProjection{ Exprs: expression.Column2Exprs(joinSchema.Clone().Columns[:outerSchemaLen]), - }.Init(er.ctx) + }.Init(er.sctx) proj.SetSchema(expression.NewSchema(joinSchema.Clone().Columns[:outerSchemaLen]...)) proj.Exprs = append(proj.Exprs, cond) proj.schema.Append(&expression.Column{ ColName: model.NewCIStr("aux_col"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), IsReferenced: true, RetType: cond.GetType(), }) @@ -539,62 +554,78 @@ func (er *expressionRewriter) buildQuantifierPlan(plan4Agg *LogicalAggregation, // t.id != s.id or count(distinct s.id) > 1 or [any checker]. If there are two different values in s.id , // there must exist a s.id that doesn't equal to t.id. func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np LogicalPlan) { - firstRowFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) - countFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + firstRowFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) + if err != nil { + er.err = err + return + } + countFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + if err != nil { + er.err = err + return + } plan4Agg := LogicalAggregation{ AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, - }.Init(er.ctx) + }.Init(er.sctx) plan4Agg.SetChildren(np) firstRowResultCol := &expression.Column{ ColName: model.NewCIStr("col_firstRow"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: firstRowFunc.RetTp, } count := &expression.Column{ ColName: model.NewCIStr("col_count"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: countFunc.RetTp, } plan4Agg.SetSchema(expression.NewSchema(firstRowResultCol, count)) - gtFunc := expression.NewFunctionInternal(er.ctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count, expression.One) - neCond := expression.NewFunctionInternal(er.ctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) - cond := expression.ComposeDNFCondition(er.ctx, gtFunc, neCond) + gtFunc := expression.NewFunctionInternal(er.sctx, ast.GT, types.NewFieldType(mysql.TypeTiny), count, expression.One) + neCond := expression.NewFunctionInternal(er.sctx, ast.NE, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) + cond := expression.ComposeDNFCondition(er.sctx, gtFunc, neCond) er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, false) } // handleEQAll handles the case of = all. For example, if the query is t.id = all (select s.id from s), it will be rewrote to // t.id = (select s.id from s having count(distinct s.id) <= 1 and [all checker]). func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np LogicalPlan) { - firstRowFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) - countFunc := aggregation.NewAggFuncDesc(er.ctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + firstRowFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncFirstRow, []expression.Expression{rexpr}, false) + if err != nil { + er.err = err + return + } + countFunc, err := aggregation.NewAggFuncDesc(er.sctx, ast.AggFuncCount, []expression.Expression{rexpr}, true) + if err != nil { + er.err = err + return + } plan4Agg := LogicalAggregation{ AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc}, - }.Init(er.ctx) + }.Init(er.sctx) plan4Agg.SetChildren(np) firstRowResultCol := &expression.Column{ ColName: model.NewCIStr("col_firstRow"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: firstRowFunc.RetTp, } count := &expression.Column{ ColName: model.NewCIStr("col_count"), - UniqueID: er.ctx.GetSessionVars().AllocPlanColumnID(), + UniqueID: er.sctx.GetSessionVars().AllocPlanColumnID(), RetType: countFunc.RetTp, } plan4Agg.SetSchema(expression.NewSchema(firstRowResultCol, count)) - leFunc := expression.NewFunctionInternal(er.ctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.One) - eqCond := expression.NewFunctionInternal(er.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) - cond := expression.ComposeCNFCondition(er.ctx, leFunc, eqCond) + leFunc := expression.NewFunctionInternal(er.sctx, ast.LE, types.NewFieldType(mysql.TypeTiny), count, expression.One) + eqCond := expression.NewFunctionInternal(er.sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), lexpr, firstRowResultCol) + cond := expression.ComposeCNFCondition(er.sctx, leFunc, eqCond) er.buildQuantifierPlan(plan4Agg, cond, lexpr, rexpr, true) } -func (er *expressionRewriter) handleExistSubquery(v *ast.ExistsSubqueryExpr) (ast.Node, bool) { +func (er *expressionRewriter) handleExistSubquery(ctx context.Context, v *ast.ExistsSubqueryExpr) (ast.Node, bool) { subq, ok := v.Sel.(*ast.SubqueryExpr) if !ok { er.err = errors.Errorf("Unknown exists type %T.", v.Sel) return v, true } - np, err := er.buildSubquery(subq) + np, err := er.buildSubquery(ctx, subq) if err != nil { er.err = err return v, true @@ -607,12 +638,12 @@ func (er *expressionRewriter) handleExistSubquery(v *ast.ExistsSubqueryExpr) (as } er.ctxStack = append(er.ctxStack, er.p.Schema().Columns[er.p.Schema().Len()-1]) } else { - physicalPlan, err := DoOptimize(er.b.optFlag, np) + physicalPlan, err := DoOptimize(ctx, er.b.optFlag, np) if err != nil { er.err = err return v, true } - rows, err := EvalSubquery(physicalPlan, er.b.is, er.b.ctx) + rows, err := EvalSubquery(ctx, physicalPlan, er.b.is, er.b.ctx) if err != nil { er.err = err return v, true @@ -638,7 +669,7 @@ out: p = p.Children()[0] case *LogicalAggregation: if len(plan.GroupByItems) == 0 { - p = LogicalTableDual{RowCount: 1}.Init(er.ctx) + p = LogicalTableDual{RowCount: 1}.Init(er.sctx) break out } p = p.Children()[0] @@ -649,7 +680,7 @@ out: return p } -func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, bool) { +func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.PatternInExpr) (ast.Node, bool) { asScalar := er.asScalar er.asScalar = true v.Expr.Accept(er) @@ -662,7 +693,7 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, er.err = errors.Errorf("Unknown compare type %T.", v.Sel) return v, true } - np, err := er.buildSubquery(subq) + np, err := er.buildSubquery(ctx, subq) if err != nil { er.err = err return v, true @@ -706,18 +737,22 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, // and has no correlated column from the current level plan(if the correlated column is from upper level, // we can treat it as constant, because the upper LogicalApply cannot be eliminated since current node is a join node), // and don't need to append a scalar value, we can rewrite it to inner join. - if er.ctx.GetSessionVars().AllowInSubqToJoinAndAgg && !v.Not && !asScalar && len(extractCorColumnsBySchema(np, er.p.Schema())) == 0 { + if er.sctx.GetSessionVars().AllowInSubqToJoinAndAgg && !v.Not && !asScalar && len(extractCorColumnsBySchema(np, er.p.Schema())) == 0 { // We need to try to eliminate the agg and the projection produced by this operation. er.b.optFlag |= flagEliminateAgg er.b.optFlag |= flagEliminateProjection er.b.optFlag |= flagJoinReOrder // Build distinct for the inner query. - agg := er.b.buildDistinct(np, np.Schema().Len()) + agg, err := er.b.buildDistinct(np, np.Schema().Len()) + if err != nil { + er.err = err + return v, true + } for _, col := range agg.schema.Columns { col.IsReferenced = true } // Build inner join above the aggregation. - join := LogicalJoin{JoinType: InnerJoin}.Init(er.ctx) + join := LogicalJoin{JoinType: InnerJoin}.Init(er.sctx) join.SetChildren(er.p, agg) join.SetSchema(expression.MergeSchema(er.p.Schema(), agg.schema)) join.attachOnConds(expression.SplitCNFItems(checkCondition)) @@ -745,8 +780,8 @@ func (er *expressionRewriter) handleInSubquery(v *ast.PatternInExpr) (ast.Node, return v, true } -func (er *expressionRewriter) handleScalarSubquery(v *ast.SubqueryExpr) (ast.Node, bool) { - np, err := er.buildSubquery(v) +func (er *expressionRewriter) handleScalarSubquery(ctx context.Context, v *ast.SubqueryExpr) (ast.Node, bool) { + np, err := er.buildSubquery(ctx, v) if err != nil { er.err = err return v, true @@ -770,12 +805,12 @@ func (er *expressionRewriter) handleScalarSubquery(v *ast.SubqueryExpr) (ast.Nod } return v, true } - physicalPlan, err := DoOptimize(er.b.optFlag, np) + physicalPlan, err := DoOptimize(ctx, er.b.optFlag, np) if err != nil { er.err = err return v, true } - rows, err := EvalSubquery(physicalPlan, er.b.is, er.b.ctx) + rows, err := EvalSubquery(ctx, physicalPlan, er.b.is, er.b.ctx) if err != nil { er.err = err return v, true @@ -819,7 +854,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok er.ctxStack = append(er.ctxStack, value) case *driver.ParamMarkerExpr: var value expression.Expression - value, er.err = expression.GetParamExpression(er.ctx, v) + value, er.err = expression.GetParamExpression(er.sctx, v) if er.err != nil { return retNode, false } @@ -854,7 +889,7 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok return retNode, false } - er.ctxStack[len(er.ctxStack)-1] = expression.BuildCastFunction(er.ctx, arg, v.Tp) + er.ctxStack[len(er.ctxStack)-1] = expression.BuildCastFunction(er.sctx, arg, v.Tp) case *ast.PatternLikeExpr: er.patternLikeToExpression(v) case *ast.PatternRegexpExpr: @@ -887,9 +922,9 @@ func (er *expressionRewriter) Leave(originInNode ast.Node) (retNode ast.Node, ok // newFunction chooses which expression.NewFunctionImpl() will be used. func (er *expressionRewriter) newFunction(funcName string, retType *types.FieldType, args ...expression.Expression) (expression.Expression, error) { if er.disableFoldCounter > 0 { - return expression.NewFunctionBase(er.ctx, funcName, retType, args...) + return expression.NewFunctionBase(er.sctx, funcName, retType, args...) } - return expression.NewFunction(er.ctx, funcName, retType, args...) + return expression.NewFunction(er.sctx, funcName, retType, args...) } func (er *expressionRewriter) checkTimePrecision(ft *types.FieldType) error { @@ -900,7 +935,7 @@ func (er *expressionRewriter) checkTimePrecision(ft *types.FieldType) error { } func (er *expressionRewriter) useCache() bool { - return er.ctx.GetSessionVars().StmtCtx.UseCache + return er.sctx.GetSessionVars().StmtCtx.UseCache } func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { @@ -951,8 +986,8 @@ func (er *expressionRewriter) rewriteVariable(v *ast.VariableExpr) { return } e := expression.DatumToConstant(types.NewStringDatum(val), mysql.TypeVarString) - e.RetType.Charset, _ = er.ctx.GetSessionVars().GetSystemVar(variable.CharacterSetConnection) - e.RetType.Collate, _ = er.ctx.GetSessionVars().GetSystemVar(variable.CollationConnection) + e.RetType.Charset, _ = er.sctx.GetSessionVars().GetSystemVar(variable.CharacterSetConnection) + e.RetType.Collate, _ = er.sctx.GetSessionVars().GetSystemVar(variable.CollationConnection) er.ctxStack = append(er.ctxStack, e) } @@ -1039,7 +1074,7 @@ func (er *expressionRewriter) positionToScalarFunc(v *ast.PositionExpr) { if v.P != nil { stkLen := len(er.ctxStack) val := er.ctxStack[stkLen-1] - intNum, isNull, err := expression.GetIntFromConstant(er.ctx, val) + intNum, isNull, err := expression.GetIntFromConstant(er.sctx, val) str = "?" if err == nil { if isNull { @@ -1095,7 +1130,11 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field if leftEt == types.ETInt { for i := 1; i < len(args); i++ { if c, ok := args[i].(*expression.Constant); ok { - args[i], _ = expression.RefineComparedConstant(er.ctx, mysql.HasUnsignedFlag(leftFt.Flag), c, opcode.EQ) + var isExceptional bool + args[i], isExceptional = expression.RefineComparedConstant(er.sctx, *leftFt, c, opcode.EQ) + if isExceptional { + args[i] = c + } } } } @@ -1119,7 +1158,7 @@ func (er *expressionRewriter) inToExpression(lLen int, not bool, tp *types.Field } eqFunctions = append(eqFunctions, expr) } - function = expression.ComposeDNFCondition(er.ctx, eqFunctions...) + function = expression.ComposeDNFCondition(er.sctx, eqFunctions...) if not { var err error function, err = er.newFunction(ast.UnaryNot, tp, function) @@ -1203,7 +1242,7 @@ func (er *expressionRewriter) patternLikeToExpression(v *ast.PatternLikeExpr) { } if !isNull { patValue, patTypes := stringutil.CompilePattern(patString, v.Escape) - if stringutil.IsExactMatch(patTypes) { + if stringutil.IsExactMatch(patTypes) && er.ctxStack[l-2].GetType().EvalType() == types.ETString { op := ast.EQ if v.Not { op = ast.NE @@ -1263,9 +1302,9 @@ func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) { expr, lexp, rexp := er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1] if expression.GetCmpTp4MinMax([]expression.Expression{expr, lexp, rexp}) == types.ETDatetime { - expr = expression.WrapWithCastAsTime(er.ctx, expr, types.NewFieldType(mysql.TypeDatetime)) - lexp = expression.WrapWithCastAsTime(er.ctx, lexp, types.NewFieldType(mysql.TypeDatetime)) - rexp = expression.WrapWithCastAsTime(er.ctx, rexp, types.NewFieldType(mysql.TypeDatetime)) + expr = expression.WrapWithCastAsTime(er.sctx, expr, types.NewFieldType(mysql.TypeDatetime)) + lexp = expression.WrapWithCastAsTime(er.sctx, lexp, types.NewFieldType(mysql.TypeDatetime)) + rexp = expression.WrapWithCastAsTime(er.sctx, rexp, types.NewFieldType(mysql.TypeDatetime)) } var op string @@ -1369,7 +1408,7 @@ func (er *expressionRewriter) funcCallToExpression(v *ast.FuncCallExpr) { var function expression.Expression er.ctxStack = er.ctxStack[:stackLen-len(v.Args)] if _, ok := expression.DeferredFunctions[v.FnName.L]; er.useCache() && ok { - function, er.err = expression.NewFunctionBase(er.ctx, v.FnName.L, &v.Type, args...) + function, er.err = expression.NewFunctionBase(er.sctx, v.FnName.L, &v.Type, args...) c := &expression.Constant{Value: types.NewDatum(nil), RetType: function.GetType().Clone(), DeferredExpr: function} er.ctxStack = append(er.ctxStack, c) } else { @@ -1442,7 +1481,7 @@ func (er *expressionRewriter) evalDefaultExpr(v *ast.DefaultExpr) { dbName := colExpr.DBName if dbName.O == "" { // if database name is not specified, use current database name - dbName = model.NewCIStr(er.ctx.GetSessionVars().CurrentDB) + dbName = model.NewCIStr(er.sctx.GetSessionVars().CurrentDB) } if colExpr.OrigTblName.O == "" { // column is evaluated by some expressions, for example: diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index 034634685e237..c180b6aea50cc 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -242,3 +242,21 @@ func (s *testExpressionRewriterSuite) TestCheckFullGroupBy(c *C) { err = tk.ExecToErr("select t1.a, (select t2.a, max(t2.b) from t t2) from t t1") c.Assert(terror.ErrorEqual(err, core.ErrMixOfGroupFuncAndFields), IsTrue, Commentf("err %v", err)) } + +func (s *testExpressionRewriterSuite) TestPatternLikeToExpression(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustQuery("select 0 like 'a string';").Check(testkit.Rows("0")) + tk.MustQuery("select 0.0 like 'a string';").Check(testkit.Rows("0")) + tk.MustQuery("select 0 like '0.00';").Check(testkit.Rows("0")) + tk.MustQuery("select cast(\"2011-5-3\" as datetime) like \"2011-05-03\";").Check(testkit.Rows("0")) + tk.MustQuery("select 1 like '1';").Check(testkit.Rows("1")) + tk.MustQuery("select 0 like '0';").Check(testkit.Rows("1")) + tk.MustQuery("select 0.00 like '0.00';").Check(testkit.Rows("1")) +} diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index b71546145175b..0030510e1e31d 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/ranger" + "github.com/pingcap/tidb/util/set" "golang.org/x/tools/container/intsets" ) @@ -68,7 +69,10 @@ func (p *LogicalTableDual) findBestTask(prop *property.PhysicalProperty) (task, if !prop.IsEmpty() { return invalidTask, nil } - dual := PhysicalTableDual{RowCount: p.RowCount}.Init(p.ctx, p.stats) + dual := PhysicalTableDual{ + RowCount: p.RowCount, + placeHolder: p.placeHolder, + }.Init(p.ctx, p.stats) dual.SetSchema(p.schema) return &rootTask{p: dual}, nil } @@ -207,6 +211,18 @@ type candidatePath struct { isMatchProp bool } +// getScanType converts the scan type to int, the higher the better. +// DoubleScan -> 0, TableScan -> 1, IndexScan -> 2. +func getScanTypeScore(p *candidatePath) int { + if !p.isSingleScan { + return 0 + } + if p.path.isTablePath { + return 1 + } + return 2 +} + // compareColumnSet will compares the two set. The last return value is used to indicate // if they are comparable, it is false when both two sets have columns that do not occur in the other. // When the second return value is true, the value of first: @@ -237,10 +253,20 @@ func compareBool(l, r bool) int { return 1 } +func compareInt(l, r int) int { + if l == r { + return 0 + } + if l < r { + return -1 + } + return 1 +} + // compareCandidates is the core of skyline pruning. It compares the two candidate paths on three dimensions: // (1): the set of columns that occurred in the access condition, // (2): whether or not it matches the physical property -// (3): does it require a double scan. +// (3): whether the candidate is a IndexScan or TableScan or DoubleScan. (IndexScan > TableScan > DoubleScan) // If `x` is not worse than `y` at all factors, // and there exists one factor that `x` is better than `y`, then `x` is better than `y`. func compareCandidates(lhs, rhs *candidatePath) int { @@ -248,7 +274,7 @@ func compareCandidates(lhs, rhs *candidatePath) int { if !comparable { return 0 } - scanResult := compareBool(lhs.isSingleScan, rhs.isSingleScan) + scanResult := compareInt(getScanTypeScore(lhs), getScanTypeScore(rhs)) matchResult := compareBool(lhs.isMatchProp, rhs.isMatchProp) sum := setsResult + scanResult + matchResult if setsResult >= 0 && scanResult >= 0 && matchResult >= 0 && sum > 0 { @@ -301,11 +327,12 @@ func (ds *DataSource) skylinePruning(prop *property.PhysicalProperty) []*candida var currentCandidate *candidatePath if path.isTablePath { currentCandidate = ds.getTableCandidate(path, prop) - } else if len(path.accessConds) > 0 || !prop.IsEmpty() || path.forced { - // We will use index to generate physical plan if: - // this path's access cond is not nil or - // we have prop to match or - // this index is forced to choose. + } else if len(path.accessConds) > 0 || !prop.IsEmpty() || path.forced || isCoveringIndex(ds.schema.Columns, path.index.Columns, ds.tableInfo.PKIsHandle) { + // We will use index to generate physical plan if any of the following conditions is satisfied: + // 1. This path's access cond is not nil. + // 2. We have a non-empty prop to match. + // 3. This index is forced to choose. + // 4. The needed columns are all covered by index columns(and handleCol). currentCandidate = ds.getIndexCandidate(path, prop) } else { continue @@ -498,7 +525,7 @@ func (ds *DataSource) convertToIndexScan(prop *property.PhysicalProperty, candid ts.SetSchema(ds.schema.Clone()) cop.tablePlan = ts } - is.initSchema(ds.id, idx, cop.tablePlan != nil) + is.initSchema(idx, cop.tablePlan != nil) // Only use expectedCnt when it's smaller than the count we calculated. // e.g. IndexScan(count1)->After Filter(count2). The `ds.stats.RowCount` is count2. count1 is the one we need to calculate // If expectedCnt and count2 are both zero and we go into the below `if` block, the count1 will be set to zero though it's shouldn't be. @@ -528,7 +555,8 @@ func (ds *DataSource) convertToIndexScan(prop *property.PhysicalProperty, candid } // prop.IsEmpty() would always return true when coming to here, // so we can just use prop.ExpectedCnt as parameter of addPushedDownSelection. - is.addPushedDownSelection(cop, ds, prop.ExpectedCnt, path) + finalStats := ds.stats.ScaleByExpectCnt(prop.ExpectedCnt) + is.addPushedDownSelection(cop, ds, path, finalStats) if prop.TaskTp == property.RootTaskType { task = finishCopTask(ds.ctx, task) } else if _, ok := task.(*rootTask); ok { @@ -538,7 +566,7 @@ func (ds *DataSource) convertToIndexScan(prop *property.PhysicalProperty, candid } // TODO: refactor this part, we should not call Clone in fact. -func (is *PhysicalIndexScan) initSchema(id int, idx *model.IndexInfo, isDoubleRead bool) { +func (is *PhysicalIndexScan) initSchema(idx *model.IndexInfo, isDoubleRead bool) { indexCols := make([]*expression.Column, 0, len(idx.Columns)) for _, col := range idx.Columns { colFound := is.dataSourceSchema.FindColumnByName(col.Name.L) @@ -569,16 +597,16 @@ func (is *PhysicalIndexScan) initSchema(id int, idx *model.IndexInfo, isDoubleRe is.SetSchema(expression.NewSchema(indexCols...)) } -func (is *PhysicalIndexScan) addPushedDownSelection(copTask *copTask, p *DataSource, expectedCnt float64, path *accessPath) { +func (is *PhysicalIndexScan) addPushedDownSelection(copTask *copTask, p *DataSource, path *accessPath, finalStats *property.StatsInfo) { // Add filter condition to table plan now. indexConds, tableConds := path.indexFilters, path.tableFilters if indexConds != nil { copTask.cst += copTask.count() * cpuFactor - count := path.countAfterAccess - if count >= 1.0 { - selectivity := path.countAfterIndex / path.countAfterAccess - count = is.stats.RowCount * selectivity + var selectivity float64 + if path.countAfterAccess > 0 { + selectivity = path.countAfterIndex / path.countAfterAccess } + count := is.stats.RowCount * selectivity stats := &property.StatsInfo{RowCount: count} indexSel := PhysicalSelection{Conditions: indexConds}.Init(is.ctx, stats) indexSel.SetChildren(is) @@ -587,7 +615,7 @@ func (is *PhysicalIndexScan) addPushedDownSelection(copTask *copTask, p *DataSou if tableConds != nil { copTask.finishIndexPlan() copTask.cst += copTask.count() * cpuFactor - tableSel := PhysicalSelection{Conditions: tableConds}.Init(is.ctx, p.stats.ScaleByExpectCnt(expectedCnt)) + tableSel := PhysicalSelection{Conditions: tableConds}.Init(is.ctx, finalStats) tableSel.SetChildren(copTask.tablePlan) copTask.tablePlan = tableSel } @@ -619,33 +647,35 @@ func splitIndexFilterConditions(conditions []expression.Expression, indexColumns } // getMostCorrColFromExprs checks if column in the condition is correlated enough with handle. If the condition -// contains multiple columns, choose the most correlated one, and compute an overall correlation factor by multiplying -// single factors. +// contains multiple columns, return nil and get the max correlation, which would be used in the heuristic estimation. func getMostCorrColFromExprs(exprs []expression.Expression, histColl *statistics.Table, threshold float64) (*expression.Column, float64) { var cols []*expression.Column cols = expression.ExtractColumnsFromExpressions(cols, exprs, nil) if len(cols) == 0 { return nil, 0 } - compCorr := 1.0 + colSet := set.NewInt64Set() var corr float64 var corrCol *expression.Column for _, col := range cols { + if colSet.Exist(col.UniqueID) { + continue + } + colSet.Insert(col.UniqueID) hist, ok := histColl.Columns[col.ID] if !ok { - return nil, 0 - } - curCorr := math.Abs(hist.Correlation) - compCorr *= curCorr - if curCorr < threshold { continue } + curCorr := math.Abs(hist.Correlation) if corrCol == nil || corr < curCorr { corrCol = col corr = curCorr } } - return corrCol, compCorr + if len(colSet) == 1 && corr >= threshold { + return corrCol, corr + } + return nil, corr } // getColumnRangeCounts estimates row count for each range respectively. @@ -804,6 +834,13 @@ func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candid if prop.ExpectedCnt < ds.stats.RowCount { count, ok, corr := ds.crossEstimateRowCount(path, prop.ExpectedCnt, candidate.isMatchProp && prop.Items[0].Desc) if ok { + // TODO: actually, before using this count as the estimated row count of table scan, we need additionally + // check if count < row_count(first_region | last_region), and use the larger one since we build one copTask + // for one region now, so even if it is `limit 1`, we have to scan at least one region in table scan. + // Currently, we can use `tikvrpc.CmdDebugGetRegionProperties` interface as `getSampRegionsRowCount()` does + // to get the row count in a region, but that result contains MVCC old version rows, so it is not that accurate. + // Considering that when this scenario happens, the execution time is close between IndexScan and TableScan, + // we do not add this check temporarily. rowCount = count } else if corr < 1 { correlationFactor := math.Pow(1-corr, float64(ds.ctx.GetSessionVars().CorrelationExpFactor)) diff --git a/planner/core/initialize.go b/planner/core/initialize.go index ff2a3dbc7b059..95ae4e3f426ba 100644 --- a/planner/core/initialize.go +++ b/planner/core/initialize.go @@ -16,108 +16,42 @@ package core import ( "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/sessionctx" -) - -const ( - // TypeSel is the type of Selection. - TypeSel = "Selection" - // TypeSet is the type of Set. - TypeSet = "Set" - // TypeProj is the type of Projection. - TypeProj = "Projection" - // TypeAgg is the type of Aggregation. - TypeAgg = "Aggregation" - // TypeStreamAgg is the type of StreamAgg. - TypeStreamAgg = "StreamAgg" - // TypeHashAgg is the type of HashAgg. - TypeHashAgg = "HashAgg" - // TypeShow is the type of show. - TypeShow = "Show" - // TypeJoin is the type of Join. - TypeJoin = "Join" - // TypeUnion is the type of Union. - TypeUnion = "Union" - // TypeTableScan is the type of TableScan. - TypeTableScan = "TableScan" - // TypeMemTableScan is the type of TableScan. - TypeMemTableScan = "MemTableScan" - // TypeUnionScan is the type of UnionScan. - TypeUnionScan = "UnionScan" - // TypeIdxScan is the type of IndexScan. - TypeIdxScan = "IndexScan" - // TypeSort is the type of Sort. - TypeSort = "Sort" - // TypeTopN is the type of TopN. - TypeTopN = "TopN" - // TypeLimit is the type of Limit. - TypeLimit = "Limit" - // TypeHashLeftJoin is the type of left hash join. - TypeHashLeftJoin = "HashLeftJoin" - // TypeHashRightJoin is the type of right hash join. - TypeHashRightJoin = "HashRightJoin" - // TypeMergeJoin is the type of merge join. - TypeMergeJoin = "MergeJoin" - // TypeIndexJoin is the type of index look up join. - TypeIndexJoin = "IndexJoin" - // TypeApply is the type of Apply. - TypeApply = "Apply" - // TypeMaxOneRow is the type of MaxOneRow. - TypeMaxOneRow = "MaxOneRow" - // TypeExists is the type of Exists. - TypeExists = "Exists" - // TypeDual is the type of TableDual. - TypeDual = "TableDual" - // TypeLock is the type of SelectLock. - TypeLock = "SelectLock" - // TypeInsert is the type of Insert - TypeInsert = "Insert" - // TypeUpdate is the type of Update. - TypeUpdate = "Update" - // TypeDelete is the type of Delete. - TypeDelete = "Delete" - // TypeIndexLookUp is the type of IndexLookUp. - TypeIndexLookUp = "IndexLookUp" - // TypeTableReader is the type of TableReader. - TypeTableReader = "TableReader" - // TypeIndexReader is the type of IndexReader. - TypeIndexReader = "IndexReader" - // TypeWindow is the type of Window. - TypeWindow = "Window" + "github.com/pingcap/tidb/util/plancodec" ) // Init initializes LogicalAggregation. func (la LogicalAggregation) Init(ctx sessionctx.Context) *LogicalAggregation { - la.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeAgg, &la) + la.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeAgg, &la) return &la } // Init initializes LogicalJoin. func (p LogicalJoin) Init(ctx sessionctx.Context) *LogicalJoin { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeJoin, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeJoin, &p) return &p } // Init initializes DataSource. func (ds DataSource) Init(ctx sessionctx.Context) *DataSource { - ds.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeTableScan, &ds) + ds.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeTableScan, &ds) return &ds } // Init initializes LogicalApply. func (la LogicalApply) Init(ctx sessionctx.Context) *LogicalApply { - la.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeApply, &la) + la.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeApply, &la) return &la } // Init initializes LogicalSelection. func (p LogicalSelection) Init(ctx sessionctx.Context) *LogicalSelection { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeSel, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeSel, &p) return &p } // Init initializes PhysicalSelection. func (p PhysicalSelection) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalSelection { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeSel, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeSel, &p) p.childrenReqProps = props p.stats = stats return &p @@ -125,19 +59,19 @@ func (p PhysicalSelection) Init(ctx sessionctx.Context, stats *property.StatsInf // Init initializes LogicalUnionScan. func (p LogicalUnionScan) Init(ctx sessionctx.Context) *LogicalUnionScan { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeUnionScan, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeUnionScan, &p) return &p } // Init initializes LogicalProjection. func (p LogicalProjection) Init(ctx sessionctx.Context) *LogicalProjection { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeProj, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeProj, &p) return &p } // Init initializes PhysicalProjection. func (p PhysicalProjection) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalProjection { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeProj, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeProj, &p) p.childrenReqProps = props p.stats = stats return &p @@ -145,13 +79,13 @@ func (p PhysicalProjection) Init(ctx sessionctx.Context, stats *property.StatsIn // Init initializes LogicalUnionAll. func (p LogicalUnionAll) Init(ctx sessionctx.Context) *LogicalUnionAll { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeUnion, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeUnion, &p) return &p } // Init initializes PhysicalUnionAll. func (p PhysicalUnionAll) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalUnionAll { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeUnion, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeUnion, &p) p.childrenReqProps = props p.stats = stats return &p @@ -159,13 +93,13 @@ func (p PhysicalUnionAll) Init(ctx sessionctx.Context, stats *property.StatsInfo // Init initializes LogicalSort. func (ls LogicalSort) Init(ctx sessionctx.Context) *LogicalSort { - ls.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeSort, &ls) + ls.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeSort, &ls) return &ls } // Init initializes PhysicalSort. func (p PhysicalSort) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalSort { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeSort, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeSort, &p) p.childrenReqProps = props p.stats = stats return &p @@ -173,20 +107,20 @@ func (p PhysicalSort) Init(ctx sessionctx.Context, stats *property.StatsInfo, pr // Init initializes NominalSort. func (p NominalSort) Init(ctx sessionctx.Context, props ...*property.PhysicalProperty) *NominalSort { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeSort, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeSort, &p) p.childrenReqProps = props return &p } // Init initializes LogicalTopN. func (lt LogicalTopN) Init(ctx sessionctx.Context) *LogicalTopN { - lt.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeTopN, <) + lt.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeTopN, <) return < } // Init initializes PhysicalTopN. func (p PhysicalTopN) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalTopN { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeTopN, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeTopN, &p) p.childrenReqProps = props p.stats = stats return &p @@ -194,13 +128,13 @@ func (p PhysicalTopN) Init(ctx sessionctx.Context, stats *property.StatsInfo, pr // Init initializes LogicalLimit. func (p LogicalLimit) Init(ctx sessionctx.Context) *LogicalLimit { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeLimit, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeLimit, &p) return &p } // Init initializes PhysicalLimit. func (p PhysicalLimit) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalLimit { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeLimit, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeLimit, &p) p.childrenReqProps = props p.stats = stats return &p @@ -208,26 +142,26 @@ func (p PhysicalLimit) Init(ctx sessionctx.Context, stats *property.StatsInfo, p // Init initializes LogicalTableDual. func (p LogicalTableDual) Init(ctx sessionctx.Context) *LogicalTableDual { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeDual, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeDual, &p) return &p } // Init initializes PhysicalTableDual. func (p PhysicalTableDual) Init(ctx sessionctx.Context, stats *property.StatsInfo) *PhysicalTableDual { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeDual, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeDual, &p) p.stats = stats return &p } // Init initializes LogicalMaxOneRow. func (p LogicalMaxOneRow) Init(ctx sessionctx.Context) *LogicalMaxOneRow { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeMaxOneRow, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeMaxOneRow, &p) return &p } // Init initializes PhysicalMaxOneRow. func (p PhysicalMaxOneRow) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalMaxOneRow { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeMaxOneRow, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeMaxOneRow, &p) p.childrenReqProps = props p.stats = stats return &p @@ -235,13 +169,13 @@ func (p PhysicalMaxOneRow) Init(ctx sessionctx.Context, stats *property.StatsInf // Init initializes LogicalWindow. func (p LogicalWindow) Init(ctx sessionctx.Context) *LogicalWindow { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeWindow, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeWindow, &p) return &p } // Init initializes PhysicalWindow. func (p PhysicalWindow) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalWindow { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeWindow, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeWindow, &p) p.childrenReqProps = props p.stats = stats return &p @@ -249,37 +183,39 @@ func (p PhysicalWindow) Init(ctx sessionctx.Context, stats *property.StatsInfo, // Init initializes Update. func (p Update) Init(ctx sessionctx.Context) *Update { - p.basePlan = newBasePlan(ctx, TypeUpdate) + p.basePlan = newBasePlan(ctx, plancodec.TypeUpdate) return &p } // Init initializes Delete. func (p Delete) Init(ctx sessionctx.Context) *Delete { - p.basePlan = newBasePlan(ctx, TypeDelete) + p.basePlan = newBasePlan(ctx, plancodec.TypeDelete) return &p } // Init initializes Insert. func (p Insert) Init(ctx sessionctx.Context) *Insert { - p.basePlan = newBasePlan(ctx, TypeInsert) + p.basePlan = newBasePlan(ctx, plancodec.TypeInsert) return &p } // Init initializes Show. func (p Show) Init(ctx sessionctx.Context) *Show { - p.basePlan = newBasePlan(ctx, TypeShow) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeShow, &p) + // Just use pseudo stats to avoid panic. + p.stats = &property.StatsInfo{RowCount: 1} return &p } // Init initializes LogicalLock. func (p LogicalLock) Init(ctx sessionctx.Context) *LogicalLock { - p.baseLogicalPlan = newBaseLogicalPlan(ctx, TypeLock, &p) + p.baseLogicalPlan = newBaseLogicalPlan(ctx, plancodec.TypeLock, &p) return &p } // Init initializes PhysicalLock. func (p PhysicalLock) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalLock { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeLock, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeLock, &p) p.childrenReqProps = props p.stats = stats return &p @@ -287,28 +223,28 @@ func (p PhysicalLock) Init(ctx sessionctx.Context, stats *property.StatsInfo, pr // Init initializes PhysicalTableScan. func (p PhysicalTableScan) Init(ctx sessionctx.Context) *PhysicalTableScan { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeTableScan, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeTableScan, &p) return &p } // Init initializes PhysicalIndexScan. func (p PhysicalIndexScan) Init(ctx sessionctx.Context) *PhysicalIndexScan { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeIdxScan, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeIdxScan, &p) return &p } // Init initializes PhysicalMemTable. func (p PhysicalMemTable) Init(ctx sessionctx.Context, stats *property.StatsInfo) *PhysicalMemTable { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeMemTableScan, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeMemTableScan, &p) p.stats = stats return &p } // Init initializes PhysicalHashJoin. func (p PhysicalHashJoin) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalHashJoin { - tp := TypeHashRightJoin + tp := plancodec.TypeHashRightJoin if p.InnerChildIdx == 1 { - tp = TypeHashLeftJoin + tp = plancodec.TypeHashLeftJoin } p.basePhysicalPlan = newBasePhysicalPlan(ctx, tp, &p) p.childrenReqProps = props @@ -318,21 +254,21 @@ func (p PhysicalHashJoin) Init(ctx sessionctx.Context, stats *property.StatsInfo // Init initializes PhysicalMergeJoin. func (p PhysicalMergeJoin) Init(ctx sessionctx.Context, stats *property.StatsInfo) *PhysicalMergeJoin { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeMergeJoin, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeMergeJoin, &p) p.stats = stats return &p } // Init initializes basePhysicalAgg. func (base basePhysicalAgg) Init(ctx sessionctx.Context, stats *property.StatsInfo) *basePhysicalAgg { - base.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeHashAgg, &base) + base.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeHashAgg, &base) base.stats = stats return &base } func (base basePhysicalAgg) initForHash(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalHashAgg { p := &PhysicalHashAgg{base} - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeHashAgg, p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeHashAgg, p) p.childrenReqProps = props p.stats = stats return p @@ -340,7 +276,7 @@ func (base basePhysicalAgg) initForHash(ctx sessionctx.Context, stats *property. func (base basePhysicalAgg) initForStream(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalStreamAgg { p := &PhysicalStreamAgg{base} - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeStreamAgg, p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeStreamAgg, p) p.childrenReqProps = props p.stats = stats return p @@ -348,7 +284,7 @@ func (base basePhysicalAgg) initForStream(ctx sessionctx.Context, stats *propert // Init initializes PhysicalApply. func (p PhysicalApply) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalApply { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeApply, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeApply, &p) p.childrenReqProps = props p.stats = stats return &p @@ -356,7 +292,7 @@ func (p PhysicalApply) Init(ctx sessionctx.Context, stats *property.StatsInfo, p // Init initializes PhysicalUnionScan. func (p PhysicalUnionScan) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalUnionScan { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeUnionScan, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeUnionScan, &p) p.childrenReqProps = props p.stats = stats return &p @@ -364,7 +300,7 @@ func (p PhysicalUnionScan) Init(ctx sessionctx.Context, stats *property.StatsInf // Init initializes PhysicalIndexLookUpReader. func (p PhysicalIndexLookUpReader) Init(ctx sessionctx.Context) *PhysicalIndexLookUpReader { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeIndexLookUp, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeIndexLookUp, &p) p.TablePlans = flattenPushDownPlan(p.tablePlan) p.IndexPlans = flattenPushDownPlan(p.indexPlan) p.schema = p.tablePlan.Schema() @@ -373,7 +309,7 @@ func (p PhysicalIndexLookUpReader) Init(ctx sessionctx.Context) *PhysicalIndexLo // Init initializes PhysicalTableReader. func (p PhysicalTableReader) Init(ctx sessionctx.Context) *PhysicalTableReader { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeTableReader, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeTableReader, &p) p.TablePlans = flattenPushDownPlan(p.tablePlan) p.schema = p.tablePlan.Schema() return &p @@ -381,7 +317,7 @@ func (p PhysicalTableReader) Init(ctx sessionctx.Context) *PhysicalTableReader { // Init initializes PhysicalIndexReader. func (p PhysicalIndexReader) Init(ctx sessionctx.Context) *PhysicalIndexReader { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeIndexReader, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeIndexReader, &p) p.IndexPlans = flattenPushDownPlan(p.indexPlan) switch p.indexPlan.(type) { case *PhysicalHashAgg, *PhysicalStreamAgg: @@ -396,7 +332,7 @@ func (p PhysicalIndexReader) Init(ctx sessionctx.Context) *PhysicalIndexReader { // Init initializes PhysicalIndexJoin. func (p PhysicalIndexJoin) Init(ctx sessionctx.Context, stats *property.StatsInfo, props ...*property.PhysicalProperty) *PhysicalIndexJoin { - p.basePhysicalPlan = newBasePhysicalPlan(ctx, TypeIndexJoin, &p) + p.basePhysicalPlan = newBasePhysicalPlan(ctx, plancodec.TypeIndexJoin, &p) p.childrenReqProps = props p.stats = stats return &p diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go new file mode 100644 index 0000000000000..b2c5e3d8b84ac --- /dev/null +++ b/planner/core/integration_test.go @@ -0,0 +1,191 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/util/testkit" + "github.com/pingcap/tidb/util/testutil" +) + +var _ = Suite(&testIntegrationSuite{}) + +type testIntegrationSuite struct { + testData testutil.TestData + store kv.Storage + dom *domain.Domain +} + +func (s *testIntegrationSuite) SetUpSuite(c *C) { + var err error + s.testData, err = testutil.LoadTestSuiteData("testdata", "integration_suite") + c.Assert(err, IsNil) +} + +func (s *testIntegrationSuite) TearDownSuite(c *C) { + c.Assert(s.testData.GenerateOutputIfNeeded(), IsNil) +} + +func (s *testIntegrationSuite) SetUpTest(c *C) { + var err error + s.store, s.dom, err = newStoreWithBootstrap() + c.Assert(err, IsNil) +} + +func (s *testIntegrationSuite) TearDownTest(c *C) { + s.dom.Close() + err := s.store.Close() + c.Assert(err, IsNil) +} + +func (s *testIntegrationSuite) TestShowSubquery(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a varchar(10), b int, c int)") + tk.MustQuery("show columns from t where true").Check(testkit.Rows( + "a varchar(10) YES ", + "b int(11) YES ", + "c int(11) YES ", + )) + tk.MustQuery("show columns from t where field = 'b'").Check(testkit.Rows( + "b int(11) YES ", + )) + tk.MustQuery("show columns from t where field in (select 'b')").Check(testkit.Rows( + "b int(11) YES ", + )) + tk.MustQuery("show columns from t where field in (select 'b') and true").Check(testkit.Rows( + "b int(11) YES ", + )) + tk.MustQuery("show columns from t where field in (select 'b') and false").Check(testkit.Rows()) + tk.MustExec("insert into t values('c', 0, 0)") + tk.MustQuery("show columns from t where field < all (select a from t)").Check(testkit.Rows( + "a varchar(10) YES ", + "b int(11) YES ", + )) + tk.MustExec("insert into t values('b', 0, 0)") + tk.MustQuery("show columns from t where field < all (select a from t)").Check(testkit.Rows( + "a varchar(10) YES ", + )) +} + +func (s *testIntegrationSuite) TestIsFromUnixtimeNullRejective(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`drop table if exists t;`) + tk.MustExec(`create table t(a bigint, b bigint);`) + s.runTestsWithTestData("TestIsFromUnixtimeNullRejective", tk, c) +} + +func (s *testIntegrationSuite) runTestsWithTestData(caseName string, tk *testkit.TestKit, c *C) { + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCasesByName(caseName, c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + tk.MustQuery(tt).Check(testkit.Rows(output[i].Plan...)) + } +} + +func (s *testIntegrationSuite) TestApplyNotNullFlag(c *C) { + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1(x int not null)") + tk.MustExec("create table t2(x int)") + tk.MustExec("insert into t2 values (1)") + + tk.MustQuery("select IFNULL((select t1.x from t1 where t1.x = t2.x), 'xxx') as col1 from t2").Check(testkit.Rows("xxx")) +} + +func (s *testIntegrationSuite) TestSimplifyOuterJoinWithCast(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int not null, b datetime default null)") + + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + tk.MustQuery(tt).Check(testkit.Rows(output[i].Plan...)) + } +} + +func (s *testIntegrationSuite) TestAntiJoinConstProp(c *C) { + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1(a int not null, b int not null)") + tk.MustExec("insert into t1 values (1,1)") + tk.MustExec("create table t2(a int not null, b int not null)") + tk.MustExec("insert into t2 values (2,2)") + + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t2.a = t1.a and t2.a > 1)").Check(testkit.Rows( + "1 1", + )) + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t2.b = t1.b and t2.a > 1)").Check(testkit.Rows( + "1 1", + )) + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t2.b = t1.b and t2.b > 1)").Check(testkit.Rows( + "1 1", + )) + tk.MustQuery("select q.a in (select count(*) from t1 s where not exists (select 1 from t1 p where q.a > 1 and p.a = s.a)) from t1 q").Check(testkit.Rows( + "1", + )) + tk.MustQuery("select q.a in (select not exists (select 1 from t1 p where q.a > 1 and p.a = s.a) from t1 s) from t1 q").Check(testkit.Rows( + "1", + )) + + tk.MustExec("drop table t1, t2") + tk.MustExec("create table t1(a int not null, b int)") + tk.MustExec("insert into t1 values (1,null)") + tk.MustExec("create table t2(a int not null, b int)") + tk.MustExec("insert into t2 values (2,2)") + + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t2.b > t1.b)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t1.a = 2)").Check(testkit.Rows( + "1 ", + )) +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index c38cde5ff6851..e2434021df74b 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -14,10 +14,12 @@ package core import ( + "context" "fmt" "math" "math/bits" "reflect" + "sort" "strings" "unicode" @@ -25,6 +27,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/format" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" @@ -42,6 +45,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/plancodec" ) const ( @@ -69,7 +73,7 @@ func (la *LogicalAggregation) collectGroupByColumns() { } } -func (b *PlanBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int, error) { +func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int, error) { b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagPushDownAgg // We may apply aggregation eliminate optimization. @@ -89,14 +93,17 @@ func (b *PlanBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega for i, aggFunc := range aggFuncList { newArgList := make([]expression.Expression, 0, len(aggFunc.Args)) for _, arg := range aggFunc.Args { - newArg, np, err := b.rewrite(arg, p, nil, true) + newArg, np, err := b.rewrite(ctx, arg, p, nil, true) if err != nil { return nil, nil, err } p = np newArgList = append(newArgList, newArg) } - newFunc := aggregation.NewAggFuncDesc(b.ctx, aggFunc.F, newArgList, aggFunc.Distinct) + newFunc, err := aggregation.NewAggFuncDesc(b.ctx, aggFunc.F, newArgList, aggFunc.Distinct) + if err != nil { + return nil, nil, err + } combined := false for j, oldFunc := range plan4Agg.AggFuncs { if oldFunc.Equal(b.ctx, newFunc) { @@ -118,7 +125,10 @@ func (b *PlanBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega } } for _, col := range p.Schema().Columns { - newFunc := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + newFunc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, nil, err + } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) newCol, _ := col.Clone().(*expression.Column) newCol.RetType = newFunc.RetTp @@ -131,18 +141,18 @@ func (b *PlanBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.Aggrega return plan4Agg, aggIndexMap, nil } -func (b *PlanBuilder) buildResultSetNode(node ast.ResultSetNode) (p LogicalPlan, err error) { +func (b *PlanBuilder) buildResultSetNode(ctx context.Context, node ast.ResultSetNode) (p LogicalPlan, err error) { switch x := node.(type) { case *ast.Join: - return b.buildJoin(x) + return b.buildJoin(ctx, x) case *ast.TableSource: switch v := x.Source.(type) { case *ast.SelectStmt: - p, err = b.buildSelect(v) + p, err = b.buildSelect(ctx, v) case *ast.UnionStmt: - p, err = b.buildUnion(v) + p, err = b.buildUnion(ctx, v) case *ast.TableName: - p, err = b.buildDataSource(v) + p, err = b.buildDataSource(ctx, v) default: err = ErrUnsupportedType.GenWithStackByArgs(v) } @@ -171,9 +181,9 @@ func (b *PlanBuilder) buildResultSetNode(node ast.ResultSetNode) (p LogicalPlan, } return p, nil case *ast.SelectStmt: - return b.buildSelect(x) + return b.buildSelect(ctx, x) case *ast.UnionStmt: - return b.buildUnion(x) + return b.buildUnion(ctx, x) default: return nil, ErrUnsupportedType.GenWithStack("Unsupported ast.ResultSetNode(%T) for buildResultSetNode()", x) } @@ -201,9 +211,14 @@ func (p *LogicalJoin) pushDownConstExpr(expr expression.Expression, leftCond []e } else { leftCond = append(leftCond, expr) } - case SemiJoin, AntiSemiJoin, InnerJoin: + case SemiJoin, InnerJoin: leftCond = append(leftCond, expr) rightCond = append(rightCond, expr) + case AntiSemiJoin: + if filterCond { + leftCond = append(leftCond, expr) + } + rightCond = append(rightCond, expr) } return leftCond, rightCond } @@ -222,41 +237,36 @@ func (p *LogicalJoin) extractOnCondition(conditions []expression.Expression, der arg0, lOK := binop.GetArgs()[0].(*expression.Column) arg1, rOK := binop.GetArgs()[1].(*expression.Column) if lOK && rOK { - var leftCol, rightCol *expression.Column - if left.Schema().Contains(arg0) && right.Schema().Contains(arg1) { - leftCol, rightCol = arg0, arg1 - } - if leftCol == nil && left.Schema().Contains(arg1) && right.Schema().Contains(arg0) { - leftCol, rightCol = arg1, arg0 + leftCol := left.Schema().RetrieveColumn(arg0) + rightCol := right.Schema().RetrieveColumn(arg1) + if leftCol == nil || rightCol == nil { + leftCol = left.Schema().RetrieveColumn(arg1) + rightCol = right.Schema().RetrieveColumn(arg0) + arg0, arg1 = arg1, arg0 } - if leftCol != nil { - // Do not derive `is not null` for anti join, since it may cause wrong results. - // For example: - // `select * from t t1 where t1.a not in (select b from t t2)` does not imply `t2.b is not null`, - // `select * from t t1 where t1.a not in (select a from t t2 where t1.b = t2.b` does not imply `t1.b is not null`, - // `select * from t t1 where not exists (select * from t t2 where t2.a = t1.a)` does not imply `t1.a is not null`, - if deriveLeft && p.JoinType != AntiSemiJoin { + if leftCol != nil && rightCol != nil { + if deriveLeft { if isNullRejected(ctx, left.Schema(), expr) && !mysql.HasNotNullFlag(leftCol.RetType.Flag) { notNullExpr := expression.BuildNotNullExpr(ctx, leftCol) leftCond = append(leftCond, notNullExpr) } } - if deriveRight && p.JoinType != AntiSemiJoin { + if deriveRight { if isNullRejected(ctx, right.Schema(), expr) && !mysql.HasNotNullFlag(rightCol.RetType.Flag) { notNullExpr := expression.BuildNotNullExpr(ctx, rightCol) rightCond = append(rightCond, notNullExpr) } } - } - // For quries like `select a in (select a from s where s.b = t.b) from t`, - // if subquery is empty caused by `s.b = t.b`, the result should always be - // false even if t.a is null or s.a is null. To make this join "empty aware", - // we should differentiate `t.a = s.a` from other column equal conditions, so - // we put it into OtherConditions instead of EqualConditions of join. - if leftCol != nil && binop.FuncName.L == ast.EQ && !leftCol.InOperand && !rightCol.InOperand { - cond := expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), leftCol, rightCol) - eqCond = append(eqCond, cond.(*expression.ScalarFunction)) - continue + // For queries like `select a in (select a from s where s.b = t.b) from t`, + // if subquery is empty caused by `s.b = t.b`, the result should always be + // false even if t.a is null or s.a is null. To make this join "empty aware", + // we should differentiate `t.a = s.a` from other column equal conditions, so + // we put it into OtherConditions instead of EqualConditions of join. + if binop.FuncName.L == ast.EQ && !arg0.InOperand && !arg1.InOperand { + cond := expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), arg0, arg1) + eqCond = append(eqCond, cond.(*expression.ScalarFunction)) + continue + } } } } @@ -302,8 +312,17 @@ func (p *LogicalJoin) extractOnCondition(conditions []expression.Expression, der return } +// extractTableAlias returns table alias of the LogicalPlan's columns. +// It will return nil when there are multiple table alias, because the alias is only used to check if +// the logicalPlan match some optimizer hints, and hints are not expected to take effect in this case. func extractTableAlias(p LogicalPlan) *model.CIStr { if p.Schema().Len() > 0 && p.Schema().Columns[0].TblName.L != "" { + tblName := p.Schema().Columns[0].TblName.L + for _, column := range p.Schema().Columns { + if column.TblName.L != tblName { + return nil + } + } return &(p.Schema().Columns[0].TblName) } return nil @@ -352,22 +371,22 @@ func resetNotNullFlag(schema *expression.Schema, start, end int) { } } -func (b *PlanBuilder) buildJoin(joinNode *ast.Join) (LogicalPlan, error) { +func (b *PlanBuilder) buildJoin(ctx context.Context, joinNode *ast.Join) (LogicalPlan, error) { // We will construct a "Join" node for some statements like "INSERT", // "DELETE", "UPDATE", "REPLACE". For this scenario "joinNode.Right" is nil // and we only build the left "ResultSetNode". if joinNode.Right == nil { - return b.buildResultSetNode(joinNode.Left) + return b.buildResultSetNode(ctx, joinNode.Left) } b.optFlag = b.optFlag | flagPredicatePushDown - leftPlan, err := b.buildResultSetNode(joinNode.Left) + leftPlan, err := b.buildResultSetNode(ctx, joinNode.Left) if err != nil { return nil, err } - rightPlan, err := b.buildResultSetNode(joinNode.Right) + rightPlan, err := b.buildResultSetNode(ctx, joinNode.Right) if err != nil { return nil, err } @@ -430,7 +449,7 @@ func (b *PlanBuilder) buildJoin(joinNode *ast.Join) (LogicalPlan, error) { } } else if joinNode.On != nil { b.curClause = onClause - onExpr, newPlan, err := b.rewrite(joinNode.On.Expr, joinPlan, nil, false) + onExpr, newPlan, err := b.rewrite(ctx, joinNode.On.Expr, joinPlan, nil, false) if err != nil { return nil, err } @@ -541,7 +560,7 @@ func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan return nil } -func (b *PlanBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMapper map[*ast.AggregateFuncExpr]int) (LogicalPlan, error) { +func (b *PlanBuilder) buildSelection(ctx context.Context, p LogicalPlan, where ast.ExprNode, AggMapper map[*ast.AggregateFuncExpr]int) (LogicalPlan, error) { b.optFlag = b.optFlag | flagPredicatePushDown if b.curClause != havingClause { b.curClause = whereClause @@ -551,7 +570,7 @@ func (b *PlanBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMappe expressions := make([]expression.Expression, 0, len(conditions)) selection := LogicalSelection{}.Init(b.ctx) for _, cond := range conditions { - expr, np, err := b.rewrite(cond, p, AggMapper, false) + expr, np, err := b.rewrite(ctx, cond, p, AggMapper, false) if err != nil { return nil, err } @@ -583,12 +602,10 @@ func (b *PlanBuilder) buildSelection(p LogicalPlan, where ast.ExprNode, AggMappe } // buildProjectionFieldNameFromColumns builds the field name, table name and database name when field expression is a column reference. -func (b *PlanBuilder) buildProjectionFieldNameFromColumns(field *ast.SelectField, c *expression.Column) (colName, origColName, tblName, origTblName, dbName model.CIStr) { - if astCol, ok := getInnerFromParenthesesAndUnaryPlus(field.Expr).(*ast.ColumnNameExpr); ok { - origColName, tblName, dbName = astCol.Name.Name, astCol.Name.Table, astCol.Name.Schema - } - if field.AsName.L != "" { - colName = field.AsName +func (b *PlanBuilder) buildProjectionFieldNameFromColumns(origField *ast.SelectField, colNameField *ast.ColumnNameExpr, c *expression.Column) (colName, origColName, tblName, origTblName, dbName model.CIStr) { + origColName, tblName, dbName = colNameField.Name.Name, colNameField.Name.Table, colNameField.Name.Schema + if origField.AsName.L != "" { + colName = origField.AsName } else { colName = origColName } @@ -602,7 +619,7 @@ func (b *PlanBuilder) buildProjectionFieldNameFromColumns(field *ast.SelectField } // buildProjectionFieldNameFromExpressions builds the field name when field expression is a normal expression. -func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectField) (model.CIStr, error) { +func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(ctx context.Context, field *ast.SelectField) (model.CIStr, error) { if agg, ok := field.Expr.(*ast.AggregateFuncExpr); ok && agg.F == ast.AggFuncFirstRow { // When the query is select t.a from t group by a; The Column Name should be a but not t.a; return agg.Args[0].(*ast.ColumnNameExpr).Name.Name, nil @@ -667,18 +684,22 @@ func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(field *ast.SelectF } // buildProjectionField builds the field object according to SelectField in projection. -func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectField, expr expression.Expression) (*expression.Column, error) { +func (b *PlanBuilder) buildProjectionField(ctx context.Context, id, position int, field *ast.SelectField, expr expression.Expression) (*expression.Column, error) { var origTblName, tblName, origColName, colName, dbName model.CIStr - if c, ok := expr.(*expression.Column); ok && !c.IsReferenced { + innerNode := getInnerFromParenthesesAndUnaryPlus(field.Expr) + col, isCol := expr.(*expression.Column) + // Correlated column won't affect the final output names. So we can put it in any of the three logic block. + // Don't put it into the first block just for simplifying the codes. + if colNameField, ok := innerNode.(*ast.ColumnNameExpr); ok && isCol { // Field is a column reference. - colName, origColName, tblName, origTblName, dbName = b.buildProjectionFieldNameFromColumns(field, c) + colName, origColName, tblName, origTblName, dbName = b.buildProjectionFieldNameFromColumns(field, colNameField, col) } else if field.AsName.L != "" { // Field has alias. colName = field.AsName } else { // Other: field is an expression. var err error - if colName, err = b.buildProjectionFieldNameFromExpressions(field); err != nil { + if colName, err = b.buildProjectionFieldNameFromExpressions(ctx, field); err != nil { return nil, err } } @@ -694,7 +715,7 @@ func (b *PlanBuilder) buildProjectionField(id, position int, field *ast.SelectFi } // buildProjection returns a Projection plan and non-aux columns length. -func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, considerWindow bool) (LogicalPlan, int, error) { +func (b *PlanBuilder) buildProjection(ctx context.Context, p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, considerWindow bool) (LogicalPlan, int, error) { b.optFlag |= flagEliminateProjection b.curClause = fieldList proj := LogicalProjection{Exprs: make([]expression.Expression, 0, len(fields))}.Init(b.ctx) @@ -719,14 +740,14 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, } else if !considerWindow && isWindowFuncField { expr := expression.Zero proj.Exprs = append(proj.Exprs, expr) - col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, expr) + col, err := b.buildProjectionField(ctx, proj.id, schema.Len()+1, field, expr) if err != nil { return nil, 0, err } schema.Append(col) continue } - newExpr, np, err := b.rewriteWithPreprocess(field.Expr, p, mapper, windowMapper, true, nil) + newExpr, np, err := b.rewriteWithPreprocess(ctx, field.Expr, p, mapper, windowMapper, true, nil) if err != nil { return nil, 0, err } @@ -742,7 +763,7 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, p = np proj.Exprs = append(proj.Exprs, newExpr) - col, err := b.buildProjectionField(proj.id, schema.Len()+1, field, newExpr) + col, err := b.buildProjectionField(ctx, proj.id, schema.Len()+1, field, newExpr) if err != nil { return nil, 0, err } @@ -753,7 +774,7 @@ func (b *PlanBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, return proj, oldLen, nil } -func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggregation { +func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) (*LogicalAggregation, error) { b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagPushDownAgg plan4Agg := LogicalAggregation{ @@ -762,7 +783,10 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggre }.Init(b.ctx) plan4Agg.collectGroupByColumns() for _, col := range child.Schema().Columns { - aggDesc := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, err + } plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, aggDesc) } plan4Agg.SetChildren(child) @@ -772,7 +796,7 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) *LogicalAggre for i, col := range plan4Agg.schema.Columns { col.RetType = plan4Agg.AggFuncs[i].RetTp } - return plan4Agg + return plan4Agg, nil } // unionJoinFieldType finds the type which can carry the given types in Union. @@ -798,7 +822,7 @@ func unionJoinFieldType(a, b *types.FieldType) *types.FieldType { return resultTp } -func (b *PlanBuilder) buildProjection4Union(u *LogicalUnionAll) { +func (b *PlanBuilder) buildProjection4Union(ctx context.Context, u *LogicalUnionAll) { unionCols := make([]*expression.Column, 0, u.children[0].Schema().Len()) // Infer union result types by its children's schema. @@ -836,22 +860,25 @@ func (b *PlanBuilder) buildProjection4Union(u *LogicalUnionAll) { } } -func (b *PlanBuilder) buildUnion(union *ast.UnionStmt) (LogicalPlan, error) { - distinctSelectPlans, allSelectPlans, err := b.divideUnionSelectPlans(union.SelectList.Selects) +func (b *PlanBuilder) buildUnion(ctx context.Context, union *ast.UnionStmt) (LogicalPlan, error) { + distinctSelectPlans, allSelectPlans, err := b.divideUnionSelectPlans(ctx, union.SelectList.Selects) if err != nil { return nil, err } - unionDistinctPlan := b.buildUnionAll(distinctSelectPlans) + unionDistinctPlan := b.buildUnionAll(ctx, distinctSelectPlans) if unionDistinctPlan != nil { - unionDistinctPlan = b.buildDistinct(unionDistinctPlan, unionDistinctPlan.Schema().Len()) + unionDistinctPlan, err = b.buildDistinct(unionDistinctPlan, unionDistinctPlan.Schema().Len()) + if err != nil { + return nil, err + } if len(allSelectPlans) > 0 { // Can't change the statements order in order to get the correct column info. allSelectPlans = append([]LogicalPlan{unionDistinctPlan}, allSelectPlans...) } } - unionAllPlan := b.buildUnionAll(allSelectPlans) + unionAllPlan := b.buildUnionAll(ctx, allSelectPlans) unionPlan := unionDistinctPlan if unionAllPlan != nil { unionPlan = unionAllPlan @@ -860,7 +887,7 @@ func (b *PlanBuilder) buildUnion(union *ast.UnionStmt) (LogicalPlan, error) { oldLen := unionPlan.Schema().Len() if union.OrderBy != nil { - unionPlan, err = b.buildSort(unionPlan, union.OrderBy.Items, nil, nil) + unionPlan, err = b.buildSort(ctx, unionPlan, union.OrderBy.Items, nil, nil) if err != nil { return nil, err } @@ -893,7 +920,7 @@ func (b *PlanBuilder) buildUnion(union *ast.UnionStmt) (LogicalPlan, error) { // and divide result plans into "union-distinct" and "union-all" parts. // divide rule ref: https://dev.mysql.com/doc/refman/5.7/en/union.html // "Mixed UNION types are treated such that a DISTINCT union overrides any ALL union to its left." -func (b *PlanBuilder) divideUnionSelectPlans(selects []*ast.SelectStmt) (distinctSelects []LogicalPlan, allSelects []LogicalPlan, err error) { +func (b *PlanBuilder) divideUnionSelectPlans(ctx context.Context, selects []*ast.SelectStmt) (distinctSelects []LogicalPlan, allSelects []LogicalPlan, err error) { firstUnionAllIdx, columnNums := 0, -1 // The last slot is reserved for appending distinct union outside this function. children := make([]LogicalPlan, len(selects), len(selects)+1) @@ -903,7 +930,7 @@ func (b *PlanBuilder) divideUnionSelectPlans(selects []*ast.SelectStmt) (distinc firstUnionAllIdx = i + 1 } - selectPlan, err := b.buildSelect(stmt) + selectPlan, err := b.buildSelect(ctx, stmt) if err != nil { return nil, nil, err } @@ -919,13 +946,13 @@ func (b *PlanBuilder) divideUnionSelectPlans(selects []*ast.SelectStmt) (distinc return children[:firstUnionAllIdx], children[firstUnionAllIdx:], nil } -func (b *PlanBuilder) buildUnionAll(subPlan []LogicalPlan) LogicalPlan { +func (b *PlanBuilder) buildUnionAll(ctx context.Context, subPlan []LogicalPlan) LogicalPlan { if len(subPlan) == 0 { return nil } u := LogicalUnionAll{}.Init(b.ctx) u.children = subPlan - b.buildProjection4Union(u) + b.buildProjection4Union(ctx, u) return u } @@ -965,7 +992,7 @@ func (t *itemTransformer) Leave(inNode ast.Node) (ast.Node, bool) { return inNode, false } -func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int) (*LogicalSort, error) { +func (b *PlanBuilder) buildSort(ctx context.Context, p LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int) (*LogicalSort, error) { if _, isUnion := p.(*LogicalUnionAll); isUnion { b.curClause = globalOrderByClause } else { @@ -977,7 +1004,7 @@ func (b *PlanBuilder) buildSort(p LogicalPlan, byItems []*ast.ByItem, aggMapper for _, item := range byItems { newExpr, _ := item.Expr.Accept(transformer) item.Expr = newExpr.(ast.ExprNode) - it, np, err := b.rewriteWithPreprocess(item.Expr, p, aggMapper, windowMapper, true, nil) + it, np, err := b.rewriteWithPreprocess(ctx, item.Expr, p, aggMapper, windowMapper, true, nil) if err != nil { return nil, err } @@ -1215,7 +1242,7 @@ func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, o case *ast.WindowFuncExpr: a.inWindowFunc = false if a.curClause == havingClause { - a.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(v.F) + a.err = ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.F)) return node, false } if a.curClause == orderByClause { @@ -1229,7 +1256,7 @@ func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, o a.inWindowSpec = false case *ast.ColumnNameExpr: resolveFieldsFirst := true - if a.inAggFunc || a.inWindowFunc || a.inWindowSpec || (a.orderBy && a.inExpr) { + if a.inAggFunc || a.inWindowFunc || a.inWindowSpec || (a.orderBy && a.inExpr) || a.curClause == fieldList { resolveFieldsFirst = false } if !a.inAggFunc && !a.orderBy { @@ -1264,7 +1291,7 @@ func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, o var err error index, err = a.resolveFromSchema(v, a.p.Schema()) _ = err - if index == -1 && a.curClause != windowClause { + if index == -1 && a.curClause != fieldList { index, a.err = resolveFromSelectFields(v, a.selectFields, false) if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { a.err = ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) @@ -1369,7 +1396,7 @@ func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p LogicalPlan) colMapper: b.colMapper, outerSchemas: b.outerSchemas, } - extractor.curClause = windowClause + extractor.curClause = fieldList for _, field := range sel.Fields.Fields { if !ast.HasWindowFlag(field.Expr) { continue @@ -1411,16 +1438,25 @@ type gbyResolver struct { err error inExpr bool isParam bool + + exprDepth int // exprDepth is the depth of current expression in expression tree. } func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { + g.exprDepth++ switch n := inNode.(type) { case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: return inNode, true case *driver.ParamMarkerExpr: - newNode := expression.ConstructPositionExpr(n) g.isParam = true - return newNode, true + if g.exprDepth == 1 { + _, isNull, isExpectedType := getUintFromNode(g.ctx, n) + // For constant uint expression in top level, it should be treated as position expression. + if !isNull && isExpectedType { + return expression.ConstructPositionExpr(n), true + } + } + return n, true case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: default: g.inExpr = true @@ -1743,7 +1779,7 @@ func (b *PlanBuilder) checkOnlyFullGroupByWithGroupClause(p LogicalPlan, sel *as } switch errExprLoc.Loc { case ErrExprInSelect: - return ErrFieldNotInGroupBy.GenWithStackByArgs(errExprLoc.Offset+1, errExprLoc.Loc, sel.Fields.Fields[errExprLoc.Offset].Text()) + return ErrFieldNotInGroupBy.GenWithStackByArgs(errExprLoc.Offset+1, errExprLoc.Loc, col.DBName.O+"."+col.TblName.O+"."+col.OrigColName.O) case ErrExprInOrderBy: return ErrFieldNotInGroupBy.GenWithStackByArgs(errExprLoc.Offset+1, errExprLoc.Loc, sel.OrderBy.Items[errExprLoc.Offset].Expr.Text()) } @@ -1839,7 +1875,7 @@ func allColFromExprNode(p LogicalPlan, n ast.Node, cols map[*expression.Column]s n.Accept(extractor) } -func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fields []*ast.SelectField) (LogicalPlan, []expression.Expression, error) { +func (b *PlanBuilder) resolveGbyExprs(ctx context.Context, p LogicalPlan, gby *ast.GroupByClause, fields []*ast.SelectField) (LogicalPlan, []expression.Expression, error) { b.curClause = groupByClause exprs := make([]expression.Expression, 0, len(gby.Items)) resolver := &gbyResolver{ @@ -1858,7 +1894,7 @@ func (b *PlanBuilder) resolveGbyExprs(p LogicalPlan, gby *ast.GroupByClause, fie } itemExpr := retExpr.(ast.ExprNode) - expr, np, err := b.rewrite(itemExpr, p, nil, true) + expr, np, err := b.rewrite(ctx, itemExpr, p, nil, true) if err != nil { return nil, nil, err } @@ -1956,7 +1992,7 @@ func (b *PlanBuilder) TableHints() *tableHintInfo { return &(b.tableHintInfo[len(b.tableHintInfo)-1]) } -func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error) { +func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p LogicalPlan, err error) { if b.pushTableHints(sel.TableHints) { // table hints are only visible in the current SELECT statement. defer b.popTableHints() @@ -1975,7 +2011,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error ) if sel.From != nil { - p, err = b.buildResultSetNode(sel.From.TableRefs) + p, err = b.buildResultSetNode(ctx, sel.From.TableRefs) if err != nil { return nil, err } @@ -1988,9 +2024,12 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error if err != nil { return nil, err } + if b.capFlag&canExpandAST != 0 { + originalFields = sel.Fields.Fields + } if sel.GroupBy != nil { - p, gbyCols, err = b.resolveGbyExprs(p, sel.GroupBy, sel.Fields.Fields) + p, gbyCols, err = b.resolveGbyExprs(ctx, p, sel.GroupBy, sel.Fields.Fields) if err != nil { return nil, err } @@ -2019,7 +2058,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error } if sel.Where != nil { - p, err = b.buildSelection(p, sel.Where, nil) + p, err = b.buildSelection(ctx, p, sel.Where, nil) if err != nil { return nil, err } @@ -2033,7 +2072,7 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error if hasAgg { aggFuncs, totalMap = b.extractAggFuncs(sel.Fields.Fields) var aggIndexMap map[int]int - p, aggIndexMap, err = b.buildAggregation(p, aggFuncs, gbyCols) + p, aggIndexMap, err = b.buildAggregation(ctx, p, aggFuncs, gbyCols) if err != nil { return nil, err } @@ -2045,14 +2084,14 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error var oldLen int // According to https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html, // we can only process window functions after having clause, so `considerWindow` is false now. - p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, totalMap, nil, false) + p, oldLen, err = b.buildProjection(ctx, p, sel.Fields.Fields, totalMap, nil, false) if err != nil { return nil, err } if sel.Having != nil { b.curClause = havingClause - p, err = b.buildSelection(p, sel.Having.Expr, havingMap) + p, err = b.buildSelection(ctx, p, sel.Having.Expr, havingMap) if err != nil { return nil, err } @@ -2066,27 +2105,35 @@ func (b *PlanBuilder) buildSelect(sel *ast.SelectStmt) (p LogicalPlan, err error var windowMapper map[*ast.WindowFuncExpr]int if hasWindowFuncField { windowFuncs := extractWindowFuncs(sel.Fields.Fields) + // we need to check the func args first before we check the window spec + err := b.checkWindowFuncArgs(ctx, p, windowFuncs, windowAggMap) + if err != nil { + return nil, err + } groupedFuncs, err := b.groupWindowFuncs(windowFuncs) if err != nil { return nil, err } - p, windowMapper, err = b.buildWindowFunctions(p, groupedFuncs, windowAggMap) + p, windowMapper, err = b.buildWindowFunctions(ctx, p, groupedFuncs, windowAggMap) if err != nil { return nil, err } // Now we build the window function fields. - p, oldLen, err = b.buildProjection(p, sel.Fields.Fields, windowAggMap, windowMapper, true) + p, oldLen, err = b.buildProjection(ctx, p, sel.Fields.Fields, windowAggMap, windowMapper, true) if err != nil { return nil, err } } if sel.Distinct { - p = b.buildDistinct(p, oldLen) + p, err = b.buildDistinct(p, oldLen) + if err != nil { + return nil, err + } } if sel.OrderBy != nil { - p, err = b.buildSort(p, sel.OrderBy.Items, orderMap, windowMapper) + p, err = b.buildSort(ctx, p, sel.OrderBy.Items, orderMap, windowMapper) if err != nil { return nil, err } @@ -2164,7 +2211,7 @@ func getStatsTable(ctx sessionctx.Context, tblInfo *model.TableInfo, pid int64) return statsTbl } -func (b *PlanBuilder) buildDataSource(tn *ast.TableName) (LogicalPlan, error) { +func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName) (LogicalPlan, error) { dbName := tn.Schema if dbName.L == "" { dbName = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB) @@ -2178,12 +2225,12 @@ func (b *PlanBuilder) buildDataSource(tn *ast.TableName) (LogicalPlan, error) { tableInfo := tbl.Meta() var authErr error if b.ctx.GetSessionVars().User != nil { - authErr = ErrTableaccessDenied.GenWithStackByArgs("SELECT", b.ctx.GetSessionVars().User.Hostname, b.ctx.GetSessionVars().User.Username, tableInfo.Name.L) + authErr = ErrTableaccessDenied.GenWithStackByArgs("SELECT", b.ctx.GetSessionVars().User.Username, b.ctx.GetSessionVars().User.Hostname, tableInfo.Name.L) } b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName.L, tableInfo.Name.L, "", authErr) if tableInfo.IsView() { - return b.BuildDataSourceFromView(dbName, tableInfo) + return b.BuildDataSourceFromView(ctx, dbName, tableInfo) } if tableInfo.GetPartitionInfo() != nil { @@ -2284,7 +2331,7 @@ func (b *PlanBuilder) buildDataSource(tn *ast.TableName) (LogicalPlan, error) { // If this table contains any virtual generated columns, we need a // "Projection" to calculate these columns. - proj, err := b.projectVirtualColumns(ds, columns) + proj, err := b.projectVirtualColumns(ctx, ds, columns) if err != nil { return nil, err } @@ -2297,7 +2344,7 @@ func (b *PlanBuilder) buildDataSource(tn *ast.TableName) (LogicalPlan, error) { } // BuildDataSourceFromView is used to build LogicalPlan from view -func (b *PlanBuilder) BuildDataSourceFromView(dbName model.CIStr, tableInfo *model.TableInfo) (LogicalPlan, error) { +func (b *PlanBuilder) BuildDataSourceFromView(ctx context.Context, dbName model.CIStr, tableInfo *model.TableInfo) (LogicalPlan, error) { charset, collation := b.ctx.GetSessionVars().GetCharsetInfo() viewParser := parser.New() viewParser.EnableWindowFunc(b.ctx.GetSessionVars().EnableWindowFunction) @@ -2307,8 +2354,9 @@ func (b *PlanBuilder) BuildDataSourceFromView(dbName model.CIStr, tableInfo *mod } originalVisitInfo := b.visitInfo b.visitInfo = make([]visitInfo, 0) - selectLogicalPlan, err := b.Build(selectNode) + selectLogicalPlan, err := b.Build(ctx, selectNode) if err != nil { + err = ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) return nil, err } @@ -2324,25 +2372,52 @@ func (b *PlanBuilder) BuildDataSourceFromView(dbName model.CIStr, tableInfo *mod } b.visitInfo = append(originalVisitInfo, b.visitInfo...) - projSchema := expression.NewSchema(make([]*expression.Column, 0, len(tableInfo.View.Cols))...) - projExprs := make([]expression.Expression, 0, len(tableInfo.View.Cols)) - for i := range tableInfo.View.Cols { - col := selectLogicalPlan.Schema().FindColumnByName(tableInfo.View.Cols[i].L) - if col == nil { - return nil, ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) + if b.ctx.GetSessionVars().StmtCtx.InExplainStmt { + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.ShowViewPriv, dbName.L, tableInfo.Name.L, "", ErrViewNoExplain) + } + + if len(tableInfo.Columns) != selectLogicalPlan.Schema().Len() { + return nil, ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) + } + + return b.buildProjUponView(ctx, dbName, tableInfo, selectLogicalPlan) +} + +func (b *PlanBuilder) buildProjUponView(ctx context.Context, dbName model.CIStr, tableInfo *model.TableInfo, selectLogicalPlan Plan) (LogicalPlan, error) { + columnInfo := tableInfo.Cols() + cols := selectLogicalPlan.Schema().Columns + // In the old version of VIEW implementation, tableInfo.View.Cols is used to + // store the origin columns' names of the underlying SelectStmt used when + // creating the view. + if tableInfo.View.Cols != nil { + cols = cols[:0] + for _, info := range columnInfo { + col := selectLogicalPlan.Schema().FindColumnByName(info.Name.L) + if col == nil { + return nil, ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) + } + cols = append(cols, col) + } + } + + projSchema := expression.NewSchema(make([]*expression.Column, 0, len(tableInfo.Columns))...) + projExprs := make([]expression.Expression, 0, len(tableInfo.Columns)) + for i, col := range cols { + origColName := col.ColName + if tableInfo.View.Cols != nil { + origColName = tableInfo.View.Cols[i] } projSchema.Append(&expression.Column{ UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), TblName: col.TblName, OrigTblName: col.OrigTblName, - ColName: tableInfo.Cols()[i].Name, - OrigColName: tableInfo.View.Cols[i], + ColName: columnInfo[i].Name, + OrigColName: origColName, DBName: col.DBName, RetType: col.GetType(), }) projExprs = append(projExprs, col) } - projUponView := LogicalProjection{Exprs: projExprs}.Init(b.ctx) projUponView.SetChildren(selectLogicalPlan.(LogicalPlan)) projUponView.SetSchema(projSchema) @@ -2352,7 +2427,7 @@ func (b *PlanBuilder) BuildDataSourceFromView(dbName model.CIStr, tableInfo *mod // projectVirtualColumns is only for DataSource. If some table has virtual generated columns, // we add a projection on the original DataSource, and calculate those columns in the projection // so that plans above it can reference generated columns by their name. -func (b *PlanBuilder) projectVirtualColumns(ds *DataSource, columns []*table.Column) (*LogicalProjection, error) { +func (b *PlanBuilder) projectVirtualColumns(ctx context.Context, ds *DataSource, columns []*table.Column) (*LogicalProjection, error) { var hasVirtualGeneratedColumn = false for _, column := range columns { if column.IsGenerated() && !column.GeneratedStored { @@ -2374,7 +2449,7 @@ func (b *PlanBuilder) projectVirtualColumns(ds *DataSource, columns []*table.Col if i < len(columns) { if columns[i].IsGenerated() && !columns[i].GeneratedStored { var err error - expr, _, err = b.rewrite(columns[i].GeneratedExpr, ds, nil, true) + expr, _, err = b.rewrite(ctx, columns[i].GeneratedExpr, ds, nil, true) if err != nil { return nil, err } @@ -2414,6 +2489,11 @@ func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan LogicalPlan, t ap := LogicalApply{LogicalJoin: LogicalJoin{JoinType: tp}}.Init(b.ctx) ap.SetChildren(outerPlan, innerPlan) ap.SetSchema(expression.MergeSchema(outerPlan.Schema(), innerPlan.Schema())) + // Note that, tp can only be LeftOuterJoin or InnerJoin, so we don't consider other outer joins. + if tp == LeftOuterJoin { + b.optFlag = b.optFlag | flagEliminateOuterJoin + resetNotNullFlag(ap.schema, outerPlan.Schema().Len(), ap.schema.Len()) + } for i := outerPlan.Schema().Len(); i < ap.Schema().Len(); i++ { ap.schema.Columns[i].IsReferenced = true } @@ -2432,7 +2512,7 @@ func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan LogicalPlan, condition } ap := &LogicalApply{LogicalJoin: *join} - ap.tp = TypeApply + ap.tp = plancodec.TypeApply ap.self = ap return ap, nil } @@ -2493,7 +2573,7 @@ func (b *PlanBuilder) buildSemiJoin(outerPlan, innerPlan LogicalPlan, onConditio return joinPlan, nil } -func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { +func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (Plan, error) { if b.pushTableHints(update.TableHints) { // table hints are only visible in the current UPDATE statement. defer b.popTableHints() @@ -2519,7 +2599,7 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { Limit: update.Limit, } - p, err := b.buildResultSetNode(sel.From.TableRefs) + p, err := b.buildResultSetNode(ctx, sel.From.TableRefs) if err != nil { return nil, err } @@ -2537,14 +2617,22 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "", nil) } + oldSchema := p.Schema().Clone() if sel.Where != nil { - p, err = b.buildSelection(p, sel.Where, nil) + p, err = b.buildSelection(ctx, p, update.Where, nil) if err != nil { return nil, err } } + // TODO: expression rewriter should not change the output columns. We should cut the columns here. + if p.Schema().Len() != oldSchema.Len() { + proj := LogicalProjection{Exprs: expression.Column2Exprs(oldSchema.Columns)}.Init(b.ctx) + proj.SetSchema(oldSchema) + proj.SetChildren(p) + p = proj + } if sel.OrderBy != nil { - p, err = b.buildSort(p, sel.OrderBy.Items, nil, nil) + p, err = b.buildSort(ctx, p, sel.OrderBy.Items, nil, nil) if err != nil { return nil, err } @@ -2558,7 +2646,7 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { var updateTableList []*ast.TableName updateTableList = extractTableList(sel.From.TableRefs, updateTableList, true) - orderedList, np, err := b.buildUpdateLists(updateTableList, update.List, p) + orderedList, np, err := b.buildUpdateLists(ctx, updateTableList, update.List, p) if err != nil { return nil, err } @@ -2566,7 +2654,9 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { updt := Update{OrderedList: orderedList}.Init(b.ctx) updt.SetSchema(p.Schema()) - updt.SelectPlan, err = DoOptimize(b.optFlag, p) + // We cannot apply projection elimination when building the subplan, because + // columns in orderedList cannot be resolved. + updt.SelectPlan, err = DoOptimize(ctx, b.optFlag&^flagEliminateProjection, p) if err != nil { return nil, err } @@ -2574,7 +2664,7 @@ func (b *PlanBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { return updt, err } -func (b *PlanBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan) ([]*expression.Assignment, LogicalPlan, error) { +func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan) ([]*expression.Assignment, LogicalPlan, error) { b.curClause = fieldList modifyColumns := make(map[string]struct{}, p.Schema().Len()) // Which columns are in set list. for _, assign := range list { @@ -2625,7 +2715,7 @@ func (b *PlanBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A var newExpr expression.Expression var np LogicalPlan if i < len(list) { - newExpr, np, err = b.rewrite(assign.Expr, p, nil, false) + newExpr, np, err = b.rewrite(ctx, assign.Expr, p, nil, false) } else { // rewrite with generation expression rewritePreprocess := func(expr ast.Node) ast.Node { @@ -2640,7 +2730,7 @@ func (b *PlanBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A return expr } } - newExpr, np, err = b.rewriteWithPreprocess(assign.Expr, p, nil, nil, false, rewritePreprocess) + newExpr, np, err = b.rewriteWithPreprocess(ctx, assign.Expr, p, nil, nil, false, rewritePreprocess) } if err != nil { return nil, nil, err @@ -2711,7 +2801,7 @@ func extractTableAsNameForUpdate(p LogicalPlan, asNames map[*model.TableInfo][]* } } -func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { +func (b *PlanBuilder) buildDelete(ctx context.Context, delete *ast.DeleteStmt) (Plan, error) { if b.pushTableHints(delete.TableHints) { // table hints are only visible in the current DELETE statement. defer b.popTableHints() @@ -2724,7 +2814,7 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { OrderBy: delete.Order, Limit: delete.Limit, } - p, err := b.buildResultSetNode(sel.From.TableRefs) + p, err := b.buildResultSetNode(ctx, sel.From.TableRefs) if err != nil { return nil, err } @@ -2732,14 +2822,14 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { oldLen := oldSchema.Len() if sel.Where != nil { - p, err = b.buildSelection(p, sel.Where, nil) + p, err = b.buildSelection(ctx, p, sel.Where, nil) if err != nil { return nil, err } } if sel.OrderBy != nil { - p, err = b.buildSort(p, sel.OrderBy.Items, nil, nil) + p, err = b.buildSort(ctx, p, sel.OrderBy.Items, nil, nil) if err != nil { return nil, err } @@ -2771,7 +2861,7 @@ func (b *PlanBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { IsMultiTable: delete.IsMultiTable, }.Init(b.ctx) - del.SelectPlan, err = DoOptimize(b.optFlag, p) + del.SelectPlan, err = DoOptimize(ctx, b.optFlag, p) if err != nil { return nil, err } @@ -2846,7 +2936,7 @@ func getWindowName(name string) string { // buildProjectionForWindow builds the projection for expressions in the window specification that is not an column, // so after the projection, window functions only needs to deal with columns. -func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, spec *ast.WindowSpec, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, []property.Item, []property.Item, []expression.Expression, error) { +func (b *PlanBuilder) buildProjectionForWindow(ctx context.Context, p LogicalPlan, spec *ast.WindowSpec, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, []property.Item, []property.Item, []expression.Expression, error) { b.optFlag |= flagEliminateProjection var partitionItems, orderItems []*ast.ByItem @@ -2867,19 +2957,19 @@ func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, spec *ast.WindowSp propertyItems := make([]property.Item, 0, len(partitionItems)+len(orderItems)) var err error - p, propertyItems, err = b.buildByItemsForWindow(p, proj, partitionItems, propertyItems, aggMap) + p, propertyItems, err = b.buildByItemsForWindow(ctx, p, proj, partitionItems, propertyItems, aggMap) if err != nil { return nil, nil, nil, nil, err } lenPartition := len(propertyItems) - p, propertyItems, err = b.buildByItemsForWindow(p, proj, orderItems, propertyItems, aggMap) + p, propertyItems, err = b.buildByItemsForWindow(ctx, p, proj, orderItems, propertyItems, aggMap) if err != nil { return nil, nil, nil, nil, err } newArgList := make([]expression.Expression, 0, len(args)) for _, arg := range args { - newArg, np, err := b.rewrite(arg, p, aggMap, true) + newArg, np, err := b.rewrite(ctx, arg, p, aggMap, true) if err != nil { return nil, nil, nil, nil, err } @@ -2903,7 +2993,37 @@ func (b *PlanBuilder) buildProjectionForWindow(p LogicalPlan, spec *ast.WindowSp return proj, propertyItems[:lenPartition], propertyItems[lenPartition:], newArgList, nil } +func (b *PlanBuilder) buildArgs4WindowFunc(ctx context.Context, p LogicalPlan, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) ([]expression.Expression, error) { + b.optFlag |= flagEliminateProjection + + newArgList := make([]expression.Expression, 0, len(args)) + // use below index for created a new col definition + // it's okay here because we only want to return the args used in window function + newColIndex := 0 + for _, arg := range args { + newArg, np, err := b.rewrite(ctx, arg, p, aggMap, true) + if err != nil { + return nil, err + } + p = np + switch newArg.(type) { + case *expression.Column, *expression.Constant: + newArgList = append(newArgList, newArg) + continue + } + col := &expression.Column{ + ColName: model.NewCIStr(fmt.Sprintf("%d_proj_window_%d", p.ID(), newColIndex)), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: newArg.GetType(), + } + newColIndex += 1 + newArgList = append(newArgList, col) + } + return newArgList, nil +} + func (b *PlanBuilder) buildByItemsForWindow( + ctx context.Context, p LogicalPlan, proj *LogicalProjection, items []*ast.ByItem, @@ -2914,7 +3034,7 @@ func (b *PlanBuilder) buildByItemsForWindow( for _, item := range items { newExpr, _ := item.Expr.Accept(transformer) item.Expr = newExpr.(ast.ExprNode) - it, np, err := b.rewrite(item.Expr, p, aggMap, true) + it, np, err := b.rewrite(ctx, item.Expr, p, aggMap, true) if err != nil { return nil, nil, err } @@ -2941,7 +3061,7 @@ func (b *PlanBuilder) buildByItemsForWindow( // buildWindowFunctionFrameBound builds the bounds of window function frames. // For type `Rows`, the bound expr must be an unsigned integer. // For type `Range`, the bound expr must be temporal or numeric types. -func (b *PlanBuilder) buildWindowFunctionFrameBound(spec *ast.WindowSpec, orderByItems []property.Item, boundClause *ast.FrameBound) (*FrameBound, error) { +func (b *PlanBuilder) buildWindowFunctionFrameBound(ctx context.Context, spec *ast.WindowSpec, orderByItems []property.Item, boundClause *ast.FrameBound) (*FrameBound, error) { frameType := spec.Frame.Type bound := &FrameBound{Type: boundClause.Type, UnBounded: boundClause.UnBounded} if bound.UnBounded { @@ -2952,14 +3072,7 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(spec *ast.WindowSpec, orderB if bound.Type == ast.CurrentRow { return bound, nil } - // Rows type does not support interval range. - if boundClause.Unit != nil { - return nil, ErrWindowRowsIntervalUse.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - numRows, isNull, isExpectedType := getUintFromNode(b.ctx, boundClause.Expr) - if isNull || !isExpectedType { - return nil, ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) - } + numRows, _, _ := getUintFromNode(b.ctx, boundClause.Expr) bound.Num = numRows return bound, nil } @@ -2975,23 +3088,7 @@ func (b *PlanBuilder) buildWindowFunctionFrameBound(spec *ast.WindowSpec, orderB return bound, nil } - if len(orderByItems) != 1 { - return nil, ErrWindowRangeFrameOrderType.GenWithStackByArgs(getWindowName(spec.Name.O)) - } col := orderByItems[0].Col - isNumeric, isTemporal := types.IsTypeNumeric(col.RetType.Tp), types.IsTypeTemporal(col.RetType.Tp) - if !isNumeric && !isTemporal { - return nil, ErrWindowRangeFrameOrderType.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - // Interval bounds only support order by temporal types. - if boundClause.Unit != nil && isNumeric { - return nil, ErrWindowRangeFrameNumericType.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - // Non-interval bound only support order by numeric types. - if boundClause.Unit == nil && !isNumeric { - return nil, ErrWindowRangeFrameTemporalType.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - // TODO: We also need to raise error for non-deterministic expressions, like rand(). val, err := evalAstExpr(b.ctx, boundClause.Expr) if err != nil { @@ -3071,46 +3168,114 @@ func (pc *paramMarkerInPrepareChecker) Leave(in ast.Node) (out ast.Node, ok bool // buildWindowFunctionFrame builds the window function frames. // See https://dev.mysql.com/doc/refman/8.0/en/window-functions-frames.html -func (b *PlanBuilder) buildWindowFunctionFrame(spec *ast.WindowSpec, orderByItems []property.Item) (*WindowFrame, error) { +func (b *PlanBuilder) buildWindowFunctionFrame(ctx context.Context, spec *ast.WindowSpec, orderByItems []property.Item) (*WindowFrame, error) { frameClause := spec.Frame if frameClause == nil { return nil, nil } - if frameClause.Type == ast.Groups { - return nil, ErrNotSupportedYet.GenWithStackByArgs("GROUPS") - } frame := &WindowFrame{Type: frameClause.Type} - start := frameClause.Extent.Start - if start.Type == ast.Following && start.UnBounded { - return nil, ErrWindowFrameStartIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) - } var err error - frame.Start, err = b.buildWindowFunctionFrameBound(spec, orderByItems, &start) + frame.Start, err = b.buildWindowFunctionFrameBound(ctx, spec, orderByItems, &frameClause.Extent.Start) if err != nil { return nil, err } + frame.End, err = b.buildWindowFunctionFrameBound(ctx, spec, orderByItems, &frameClause.Extent.End) + return frame, err +} - end := frameClause.Extent.End - if end.Type == ast.Preceding && end.UnBounded { - return nil, ErrWindowFrameEndIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) +func (b *PlanBuilder) checkWindowFuncArgs(ctx context.Context, p LogicalPlan, windowFuncExprs []*ast.WindowFuncExpr, windowAggMap map[*ast.AggregateFuncExpr]int) error { + for _, windowFuncExpr := range windowFuncExprs { + args, err := b.buildArgs4WindowFunc(ctx, p, windowFuncExpr.Args, windowAggMap) + if err != nil { + return err + } + desc, err := aggregation.NewWindowFuncDesc(b.ctx, windowFuncExpr.F, args) + if err != nil { + return err + } + if desc == nil { + return ErrWrongArguments.GenWithStackByArgs(strings.ToLower(windowFuncExpr.F)) + } } - frame.End, err = b.buildWindowFunctionFrameBound(spec, orderByItems, &end) - return frame, err + return nil +} + +func getAllByItems(itemsBuf []*ast.ByItem, spec *ast.WindowSpec) []*ast.ByItem { + itemsBuf = itemsBuf[:0] + if spec.PartitionBy != nil { + itemsBuf = append(itemsBuf, spec.PartitionBy.Items...) + } + if spec.OrderBy != nil { + itemsBuf = append(itemsBuf, spec.OrderBy.Items...) + } + return itemsBuf +} + +func restoreByItemText(item *ast.ByItem) string { + var sb strings.Builder + ctx := format.NewRestoreCtx(0, &sb) + err := item.Expr.Restore(ctx) + if err != nil { + return "" + } + return sb.String() } -func (b *PlanBuilder) buildWindowFunctions(p LogicalPlan, groupedFuncs map[*ast.WindowSpec][]*ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, map[*ast.WindowFuncExpr]int, error) { +func compareItems(lItems []*ast.ByItem, rItems []*ast.ByItem) bool { + minLen := mathutil.Min(len(lItems), len(rItems)) + for i := 0; i < minLen; i++ { + res := strings.Compare(restoreByItemText(lItems[i]), restoreByItemText(rItems[i])) + if res != 0 { + return res < 0 + } + res = compareBool(lItems[i].Desc, rItems[i].Desc) + if res != 0 { + return res < 0 + } + } + return len(lItems) < len(rItems) +} + +type windowFuncs struct { + spec *ast.WindowSpec + funcs []*ast.WindowFuncExpr +} + +// sortWindowSpecs sorts the window specifications by reversed alphabetical order, then we could add less `Sort` operator +// in physical plan because the window functions with the same partition by and order by clause will be at near places. +func sortWindowSpecs(groupedFuncs map[*ast.WindowSpec][]*ast.WindowFuncExpr) []windowFuncs { + windows := make([]windowFuncs, 0, len(groupedFuncs)) + for spec, funcs := range groupedFuncs { + windows = append(windows, windowFuncs{spec, funcs}) + } + lItemsBuf := make([]*ast.ByItem, 0, 4) + rItemsBuf := make([]*ast.ByItem, 0, 4) + sort.SliceStable(windows, func(i, j int) bool { + lItemsBuf = getAllByItems(lItemsBuf, windows[i].spec) + rItemsBuf = getAllByItems(rItemsBuf, windows[j].spec) + return !compareItems(lItemsBuf, rItemsBuf) + }) + return windows +} + +func (b *PlanBuilder) buildWindowFunctions(ctx context.Context, p LogicalPlan, groupedFuncs map[*ast.WindowSpec][]*ast.WindowFuncExpr, aggMap map[*ast.AggregateFuncExpr]int) (LogicalPlan, map[*ast.WindowFuncExpr]int, error) { args := make([]ast.ExprNode, 0, 4) windowMap := make(map[*ast.WindowFuncExpr]int) - for spec, funcs := range groupedFuncs { + for _, window := range sortWindowSpecs(groupedFuncs) { args = args[:0] + spec, funcs := window.spec, window.funcs for _, windowFunc := range funcs { args = append(args, windowFunc.Args...) } - np, partitionBy, orderBy, args, err := b.buildProjectionForWindow(p, spec, args, aggMap) + np, partitionBy, orderBy, args, err := b.buildProjectionForWindow(ctx, p, spec, args, aggMap) + if err != nil { + return nil, nil, err + } + err = b.checkOriginWindowSpecs(funcs, orderBy) if err != nil { return nil, nil, err } - frame, err := b.buildWindowFunctionFrame(spec, orderBy) + frame, err := b.buildWindowFunctionFrame(ctx, spec, orderBy) if err != nil { return nil, nil, err } @@ -3124,9 +3289,12 @@ func (b *PlanBuilder) buildWindowFunctions(p LogicalPlan, groupedFuncs map[*ast. descs := make([]*aggregation.WindowFuncDesc, 0, len(funcs)) preArgs := 0 for _, windowFunc := range funcs { - desc := aggregation.NewWindowFuncDesc(b.ctx, windowFunc.F, args[preArgs:preArgs+len(windowFunc.Args)]) + desc, err := aggregation.NewWindowFuncDesc(b.ctx, windowFunc.F, args[preArgs:preArgs+len(windowFunc.Args)]) + if err != nil { + return nil, nil, err + } if desc == nil { - return nil, nil, ErrWrongArguments.GenWithStackByArgs(windowFunc.F) + return nil, nil, ErrWrongArguments.GenWithStackByArgs(strings.ToLower(windowFunc.F)) } preArgs += len(windowFunc.Args) desc.WrapCastForAggArgs(b.ctx) @@ -3147,6 +3315,89 @@ func (b *PlanBuilder) buildWindowFunctions(p LogicalPlan, groupedFuncs map[*ast. return p, windowMap, nil } +// checkOriginWindowSpecs checks the validation for origin window specifications for a group of functions. +// Because of the grouped specification is different from it, we should especially check them before build window frame. +func (b *PlanBuilder) checkOriginWindowSpecs(funcs []*ast.WindowFuncExpr, orderByItems []property.Item) error { + for _, f := range funcs { + if f.IgnoreNull { + return ErrNotSupportedYet.GenWithStackByArgs("IGNORE NULLS") + } + if f.Distinct { + return ErrNotSupportedYet.GenWithStackByArgs("(DISTINCT ..)") + } + if f.FromLast { + return ErrNotSupportedYet.GenWithStackByArgs("FROM LAST") + } + spec := &f.Spec + if f.Spec.Name.L != "" { + spec = b.windowSpecs[f.Spec.Name.L] + } + if spec.Frame == nil { + continue + } + if spec.Frame.Type == ast.Groups { + return ErrNotSupportedYet.GenWithStackByArgs("GROUPS") + } + start, end := spec.Frame.Extent.Start, spec.Frame.Extent.End + if start.Type == ast.Following && start.UnBounded { + return ErrWindowFrameStartIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if end.Type == ast.Preceding && end.UnBounded { + return ErrWindowFrameEndIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if start.Type == ast.Following && (end.Type == ast.Preceding || end.Type == ast.CurrentRow) { + return ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if (start.Type == ast.Following || start.Type == ast.CurrentRow) && end.Type == ast.Preceding { + return ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + + err := b.checkOriginWindowFrameBound(&start, spec, orderByItems) + if err != nil { + return err + } + err = b.checkOriginWindowFrameBound(&end, spec, orderByItems) + if err != nil { + return err + } + } + return nil +} + +func (b *PlanBuilder) checkOriginWindowFrameBound(bound *ast.FrameBound, spec *ast.WindowSpec, orderByItems []property.Item) error { + if bound.Type == ast.CurrentRow || bound.UnBounded { + return nil + } + + frameType := spec.Frame.Type + if frameType == ast.Rows { + if bound.Unit != nil { + return ErrWindowRowsIntervalUse.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + _, isNull, isExpectedType := getUintFromNode(b.ctx, bound.Expr) + if isNull || !isExpectedType { + return ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + return nil + } + + if len(orderByItems) != 1 { + return ErrWindowRangeFrameOrderType.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + orderItemType := orderByItems[0].Col.RetType.Tp + isNumeric, isTemporal := types.IsTypeNumeric(orderItemType), types.IsTypeTemporal(orderItemType) + if !isNumeric && !isTemporal { + return ErrWindowRangeFrameOrderType.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if bound.Unit != nil && !isTemporal { + return ErrWindowRangeFrameNumericType.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if bound.Unit == nil && !isNumeric { + return ErrWindowRangeFrameTemporalType.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + return nil +} + func extractWindowFuncs(fields []*ast.SelectField) []*ast.WindowFuncExpr { extractor := &WindowFuncExtractor{} for _, f := range fields { @@ -3156,8 +3407,8 @@ func extractWindowFuncs(fields []*ast.SelectField) []*ast.WindowFuncExpr { return extractor.windowFuncs } -func (b *PlanBuilder) handleDefaultFrame(spec *ast.WindowSpec, name string) (*ast.WindowSpec, bool) { - needFrame := aggregation.NeedFrame(name) +func (b *PlanBuilder) handleDefaultFrame(spec *ast.WindowSpec, windowFuncName string) (*ast.WindowSpec, bool) { + needFrame := aggregation.NeedFrame(windowFuncName) // According to MySQL, In the absence of a frame clause, the default frame depends on whether an ORDER BY clause is present: // (1) With order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"; // (2) Without order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", @@ -3176,7 +3427,7 @@ func (b *PlanBuilder) handleDefaultFrame(spec *ast.WindowSpec, name string) (*as // For functions that operate on the entire partition, the frame clause will be ignored. if !needFrame && spec.Frame != nil { specName := spec.Name.O - b.ctx.GetSessionVars().StmtCtx.AppendNote(ErrWindowFunctionIgnoresFrame.GenWithStackByArgs(name, specName)) + b.ctx.GetSessionVars().StmtCtx.AppendNote(ErrWindowFunctionIgnoresFrame.GenWithStackByArgs(windowFuncName, getWindowName(specName))) newSpec := *spec newSpec.Frame = nil return &newSpec, true diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 71598d87d2203..897a104097c8f 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -14,6 +14,7 @@ package core import ( + "context" "fmt" "sort" "strings" @@ -22,6 +23,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/format" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -47,7 +49,7 @@ type testPlanSuite struct { } func (s *testPlanSuite) SetUpSuite(c *C) { - s.is = infoschema.MockInfoSchema([]*model.TableInfo{MockTable(), MockView()}) + s.is = infoschema.MockInfoSchema([]*model.TableInfo{MockSignedTable(), MockView()}) s.ctx = MockContext() s.Parser = parser.New() } @@ -84,7 +86,7 @@ func (s *testPlanSuite) TestPredicatePushDown(c *C) { }, { sql: "select * from t t1, t t2 where t1.a = t2.b and t2.b > 0 and t1.a = t1.c and t1.d like 'abc' and t2.d = t1.d", - best: "Join{DataScan(t1)->Sel([eq(cast(test.t1.d), cast(abc))])->DataScan(t2)->Sel([eq(cast(test.t2.d), cast(abc))])}(test.t1.a,test.t2.b)(test.t1.d,test.t2.d)->Projection", + best: "Join{DataScan(t1)->Sel([like(cast(test.t1.d), abc, 92)])->DataScan(t2)->Sel([like(cast(test.t2.d), abc, 92)])}(test.t1.a,test.t2.b)(test.t1.d,test.t2.d)->Projection", }, { sql: "select * from t ta join t tb on ta.d = tb.d and ta.d > 1 where tb.a = 0", @@ -200,13 +202,15 @@ func (s *testPlanSuite) TestPredicatePushDown(c *C) { best: "Dual->Projection", }, } + + ctx := context.Background() for ith, ca := range tests { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) - p, err = logicalOptimize(flagPredicatePushDown|flagDecorrelate|flagPrunColumns, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagDecorrelate|flagPrunColumns, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, ca.best, Commentf("for %s %d", ca.sql, ith)) } @@ -293,13 +297,15 @@ func (s *testPlanSuite) TestJoinPredicatePushDown(c *C) { right: "[]", }, } + + ctx := context.Background() for _, ca := range tests { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) - p, err = logicalOptimize(flagPredicatePushDown|flagDecorrelate|flagPrunColumns, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagDecorrelate|flagPrunColumns, p.(LogicalPlan)) c.Assert(err, IsNil, comment) proj, ok := p.(*LogicalProjection) c.Assert(ok, IsTrue, comment) @@ -344,13 +350,15 @@ func (s *testPlanSuite) TestOuterWherePredicatePushDown(c *C) { right: "[]", }, } + + ctx := context.Background() for _, ca := range tests { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) - p, err = logicalOptimize(flagPredicatePushDown|flagDecorrelate|flagPrunColumns, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagDecorrelate|flagPrunColumns, p.(LogicalPlan)) c.Assert(err, IsNil, comment) proj, ok := p.(*LogicalProjection) c.Assert(ok, IsTrue, comment) @@ -409,13 +417,15 @@ func (s *testPlanSuite) TestSimplifyOuterJoin(c *C) { joinType: "left outer join", }, } + + ctx := context.Background() for _, ca := range tests { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) - p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns, p.(LogicalPlan)) c.Assert(err, IsNil, comment) c.Assert(ToString(p), Equals, ca.best, comment) join, ok := p.(LogicalPlan).Children()[0].(*LogicalJoin) @@ -436,17 +446,19 @@ func (s *testPlanSuite) TestAntiSemiJoinConstFalse(c *C) { }{ { sql: "select a from t t1 where not exists (select a from t t2 where t1.a = t2.a and t2.b = 1 and t2.b = 2)", - best: "Join{DataScan(t1)->DataScan(t2)}->Projection", + best: "Join{DataScan(t1)->DataScan(t2)}(test.t1.a,test.t2.a)->Projection", joinType: "anti semi join", }, } + + ctx := context.Background() for _, ca := range tests { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) - p, err = logicalOptimize(flagDecorrelate|flagPredicatePushDown|flagPrunColumns, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagDecorrelate|flagPredicatePushDown|flagPrunColumns, p.(LogicalPlan)) c.Assert(err, IsNil, comment) c.Assert(ToString(p), Equals, ca.best, comment) join, _ := p.(LogicalPlan).Children()[0].(*LogicalJoin) @@ -543,13 +555,15 @@ func (s *testPlanSuite) TestDeriveNotNullConds(c *C) { right: "[]", }, } + + ctx := context.Background() for _, ca := range tests { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) - p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagDecorrelate, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagDecorrelate, p.(LogicalPlan)) c.Assert(err, IsNil, comment) c.Assert(ToString(p), Equals, ca.plan, comment) join := p.(LogicalPlan).Children()[0].(*LogicalJoin) @@ -562,14 +576,79 @@ func (s *testPlanSuite) TestDeriveNotNullConds(c *C) { } } +func buildLogicPlan4GroupBy(s *testPlanSuite, c *C, sql string) (Plan, error) { + sqlMode := s.ctx.GetSessionVars().SQLMode + mockedTableInfo := MockSignedTable() + // mock the table info here for later use + // enable only full group by + s.ctx.GetSessionVars().SQLMode = sqlMode | mysql.ModeOnlyFullGroupBy + defer func() { s.ctx.GetSessionVars().SQLMode = sqlMode }() // restore it + comment := Commentf("for %s", sql) + stmt, err := s.ParseOneStmt(sql, "", "") + c.Assert(err, IsNil, comment) + + stmt.(*ast.SelectStmt).From.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName).TableInfo = mockedTableInfo + + return BuildLogicalPlan(context.Background(), s.ctx, stmt, s.is) +} + +func (s *testPlanSuite) TestGroupByWhenNotExistCols(c *C) { + sqlTests := []struct { + sql string + expectedErrMatch string + }{ + { + sql: "select a from t group by b", + expectedErrMatch: ".*contains nonaggregated column 'test\\.t\\.a'.*", + }, + { + // has an as column alias + sql: "select a as tempField from t group by b", + expectedErrMatch: ".*contains nonaggregated column 'test\\.t\\.a'.*", + }, + { + // has as table alias + sql: "select tempTable.a from t as tempTable group by b", + expectedErrMatch: ".*contains nonaggregated column 'test\\.tempTable\\.a'.*", + }, + { + // has a func call + sql: "select length(a) from t group by b", + expectedErrMatch: ".*contains nonaggregated column 'test\\.t\\.a'.*", + }, + { + // has a func call with two cols + sql: "select length(b + a) from t group by b", + expectedErrMatch: ".*contains nonaggregated column 'test\\.t\\.a'.*", + }, + { + // has a func call with two cols + sql: "select length(a + b) from t group by b", + expectedErrMatch: ".*contains nonaggregated column 'test\\.t\\.a'.*", + }, + { + // has a func call with two cols + sql: "select length(a + b) as tempField from t group by b", + expectedErrMatch: ".*contains nonaggregated column 'test\\.t\\.a'.*", + }, + } + for _, test := range sqlTests { + sql := test.sql + p, err := buildLogicPlan4GroupBy(s, c, sql) + c.Assert(err, NotNil) + c.Assert(p, IsNil) + c.Assert(err, ErrorMatches, test.expectedErrMatch) + } +} + func (s *testPlanSuite) TestDupRandJoinCondsPushDown(c *C) { sql := "select * from t as t1 join t t2 on t1.a > rand() and t1.a > rand()" comment := Commentf("for %s", sql) stmt, err := s.ParseOneStmt(sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(context.Background(), s.ctx, stmt, s.is) c.Assert(err, IsNil, comment) - p, err = logicalOptimize(flagPredicatePushDown, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagPredicatePushDown, p.(LogicalPlan)) c.Assert(err, IsNil, comment) proj, ok := p.(*LogicalProjection) c.Assert(ok, IsTrue, comment) @@ -670,13 +749,15 @@ func (s *testPlanSuite) TestTablePartition(c *C) { is: is1, }, } + + ctx := context.Background() for _, ca := range tests { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, ca.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, ca.is) c.Assert(err, IsNil) - p, err = logicalOptimize(flagDecorrelate|flagPrunColumns|flagPredicatePushDown|flagPartitionProcessor, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagDecorrelate|flagPrunColumns|flagPredicatePushDown|flagPartitionProcessor, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, ca.best, Commentf("for %s", ca.sql)) } @@ -756,16 +837,17 @@ func (s *testPlanSuite) TestSubquery(c *C) { }, } + ctx := context.Background() for ith, ca := range tests { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) if lp, ok := p.(LogicalPlan); ok { - p, err = logicalOptimize(flagBuildKeyInfo|flagDecorrelate|flagPrunColumns, lp) + p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagDecorrelate|flagPrunColumns, lp) c.Assert(err, IsNil) } c.Assert(ToString(p), Equals, ca.best, Commentf("for %s %d", ca.sql, ith)) @@ -829,7 +911,7 @@ func (s *testPlanSuite) TestPlanBuilder(c *C) { }, { sql: "show columns from t where `Key` = 'pri' like 't*'", - plan: "Show([eq(cast(key), 0)])", + plan: "Show->Sel([eq(cast(key), 0)])", }, { sql: "do sleep(5)", @@ -859,7 +941,13 @@ func (s *testPlanSuite) TestPlanBuilder(c *C) { // binlog columns, because the schema and data are not consistent. plan: "LeftHashJoin{LeftHashJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[666,666]], Table(t))}(test.t.a,test.t.b)->IndexReader(Index(t.c_d_e)[[42,42]])}(test.t.b,test.t.a)->Sel([or(6_aux_0, 10_aux_0)])->Projection->Delete", }, + { + sql: "update t set a = 2 where b in (select c from t)", + plan: "LeftHashJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])->StreamAgg}(test.t.b,test.t.c)->Projection->Update", + }, } + + ctx := context.Background() for _, ca := range tests { comment := Commentf("for %s", ca.sql) stmt, err := s.ParseOneStmt(ca.sql, "", "") @@ -867,10 +955,10 @@ func (s *testPlanSuite) TestPlanBuilder(c *C) { s.ctx.GetSessionVars().HashJoinConcurrency = 1 Preprocess(s.ctx, stmt, s.is) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) if lp, ok := p.(LogicalPlan); ok { - p, err = logicalOptimize(flagPrunColumns, lp) + p, err = logicalOptimize(context.TODO(), flagPrunColumns, lp) c.Assert(err, IsNil) } c.Assert(ToString(p), Equals, ca.plan, Commentf("for %s", ca.sql)) @@ -908,14 +996,16 @@ func (s *testPlanSuite) TestJoinReOrder(c *C) { best: "Apply{DataScan(o)->Join{Join{DataScan(t1)->DataScan(t2)}->DataScan(t3)}->Projection}->Projection", }, } + + ctx := context.Background() for _, tt := range tests { comment := Commentf("for %s", tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) - p, err = logicalOptimize(flagPredicatePushDown|flagJoinReOrder, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagJoinReOrder, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) } @@ -1008,15 +1098,17 @@ func (s *testPlanSuite) TestEagerAggregation(c *C) { best: "Join{DataScan(t1)->DataScan(t2)}(test.t1.a,test.t2.a)->Projection->Projection", }, } + + ctx := context.Background() s.ctx.GetSessionVars().AllowAggPushDown = true for ith, tt := range tests { comment := Commentf("for %s", tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) - p, err = logicalOptimize(flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagPushDownAgg, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagPushDownAgg, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, Commentf("for %s %d", tt.sql, ith)) } @@ -1148,15 +1240,35 @@ func (s *testPlanSuite) TestColumnPruning(c *C) { 12: {"test.t4.a"}, }, }, + { + sql: "select 1 from (select count(b) as cnt from t) t1;", + ans: map[int][]string{ + 1: {"test.t.a"}, + }, + }, + { + sql: "select count(1) from (select count(b) as cnt from t) t1;", + ans: map[int][]string{ + 1: {"test.t.a"}, + }, + }, + { + sql: "select count(1) from (select count(b) as cnt from t group by c) t1;", + ans: map[int][]string{ + 1: {"test.t.c"}, + }, + }, } + + ctx := context.Background() for _, tt := range tests { comment := Commentf("for %s", tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) - lp, err := logicalOptimize(flagPredicatePushDown|flagPrunColumns, p.(LogicalPlan)) + lp, err := logicalOptimize(ctx, flagPredicatePushDown|flagPrunColumns, p.(LogicalPlan)) c.Assert(err, IsNil) checkDataSourceCols(lp, c, tt.ans, comment) } @@ -1342,13 +1454,15 @@ func (s *testPlanSuite) TestValidate(c *C) { err: ErrUnknownColumn, }, } + + ctx := context.Background() for _, tt := range tests { sql := tt.sql comment := Commentf("for %s", sql) stmt, err := s.ParseOneStmt(sql, "", "") c.Assert(err, IsNil, comment) Preprocess(s.ctx, stmt, s.is) - _, err = BuildLogicalPlan(s.ctx, stmt, s.is) + _, err = BuildLogicalPlan(ctx, s.ctx, stmt, s.is) if tt.err == nil { c.Assert(err, IsNil, comment) } else { @@ -1439,14 +1553,16 @@ func (s *testPlanSuite) TestUniqueKeyInfo(c *C) { }, }, } + + ctx := context.Background() for ith, tt := range tests { comment := Commentf("for %s %d", tt.sql, ith) stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) - lp, err := logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo, p.(LogicalPlan)) + lp, err := logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo, p.(LogicalPlan)) c.Assert(err, IsNil) checkUniqueKeys(lp, c, tt.ans, tt.sql) } @@ -1483,15 +1599,17 @@ func (s *testPlanSuite) TestAggPrune(c *C) { best: "DataScan(t)->Projection->Projection", }, } + + ctx := context.Background() for _, tt := range tests { comment := Commentf("for %s", tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) - p, err := BuildLogicalPlan(s.ctx, stmt, s.is) + p, err := BuildLogicalPlan(ctx, s.ctx, stmt, s.is) c.Assert(err, IsNil) - p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagEliminateAgg|flagEliminateProjection, p.(LogicalPlan)) + p, err = logicalOptimize(context.TODO(), flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagEliminateAgg|flagEliminateProjection, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, comment) } @@ -1599,18 +1717,6 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { {mysql.IndexPriv, "test", "t", "", nil}, }, }, - { - sql: `create user 'test'@'%' identified by '123456'`, - ans: []visitInfo{ - {mysql.CreateUserPriv, "", "", "", ErrSpecificAccessDenied}, - }, - }, - { - sql: `drop user 'test'@'%'`, - ans: []visitInfo{ - {mysql.CreateUserPriv, "", "", "", ErrSpecificAccessDenied}, - }, - }, { sql: `grant all privileges on test.* to 'test'@'%'`, ans: []visitInfo{ @@ -1702,7 +1808,7 @@ func (s *testPlanSuite) TestVisitInfo(c *C) { is: s.is, } builder.ctx.GetSessionVars().HashJoinConcurrency = 1 - _, err = builder.Build(stmt) + _, err = builder.Build(context.TODO(), stmt) c.Assert(err, IsNil, comment) checkVisitInfo(c, builder.visitInfo, tt.ans, comment) @@ -1809,6 +1915,7 @@ func (s *testPlanSuite) TestUnion(c *C) { err: false, }, } + ctx := context.TODO() for i, tt := range tests { comment := Commentf("case:%v sql:%s", i, tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") @@ -1819,14 +1926,14 @@ func (s *testPlanSuite) TestUnion(c *C) { is: s.is, colMapper: make(map[*ast.ColumnNameExpr]int), } - plan, err := builder.Build(stmt) + plan, err := builder.Build(ctx, stmt) if tt.err { c.Assert(err, NotNil) continue } c.Assert(err, IsNil) p := plan.(LogicalPlan) - p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + p, err = logicalOptimize(ctx, builder.optFlag, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, comment) } @@ -1941,6 +2048,7 @@ func (s *testPlanSuite) TestTopNPushDown(c *C) { best: "Join{DataScan(t1)->DataScan(t2)}(test.t1.e,test.t2.e)->TopN([ifnull(test.t1.h, test.t2.b)],0,5)->Projection->Projection", }, } + ctx := context.TODO() for i, tt := range tests { comment := Commentf("case:%v sql:%s", i, tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") @@ -1951,9 +2059,9 @@ func (s *testPlanSuite) TestTopNPushDown(c *C) { is: s.is, colMapper: make(map[*ast.ColumnNameExpr]int), } - p, err := builder.Build(stmt) + p, err := builder.Build(ctx, stmt) c.Assert(err, IsNil) - p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + p, err = logicalOptimize(ctx, builder.optFlag, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, comment) } @@ -1991,13 +2099,14 @@ func (s *testPlanSuite) TestNameResolver(c *C) { {"update t, (select * from t) as b set b.a = t.a", "[planner:1288]The target table b of the UPDATE is not updatable"}, } + ctx := context.Background() for _, t := range tests { comment := Commentf("for %s", t.sql) stmt, err := s.ParseOneStmt(t.sql, "", "") c.Assert(err, IsNil, comment) s.ctx.GetSessionVars().HashJoinConcurrency = 1 - _, err = BuildLogicalPlan(s.ctx, stmt, s.is) + _, err = BuildLogicalPlan(ctx, s.ctx, stmt, s.is) if t.err == "" { c.Check(err, IsNil) } else { @@ -2048,10 +2157,25 @@ func (s *testPlanSuite) TestOuterJoinEliminator(c *C) { // For complex join query { sql: "select max(t3.b) from (t t1 left join t t2 on t1.a = t2.a) right join t t3 on t1.b = t3.b", - best: "DataScan(t3)->TopN([test.t3.b true],0,1)->Aggr(max(test.t3.b))->Projection", + best: "Join{Join{DataScan(t1)->DataScan(t2)}(test.t1.a,test.t2.a)->DataScan(t3)->TopN([test.t3.b true],0,1)}(test.t1.b,test.t3.b)->TopN([test.t3.b true],0,1)->Aggr(max(test.t3.b))->Projection", + }, + { + sql: "select t1.a ta, t1.b tb from t t1 left join t t2 on t1.a = t2.a", + best: "DataScan(t1)->Projection", + }, + { + // Because the `order by` uses t2.a, the `join` can't be eliminated. + sql: "select t1.a, t1.b from t t1 left join t t2 on t1.a = t2.a order by t2.a", + best: "Join{DataScan(t1)->DataScan(t2)}(test.t1.a,test.t2.a)->Sort->Projection", + }, + // For issue 11167 + { + sql: "select a.a from t a natural left join t b natural left join t c", + best: "DataScan(a)->Projection", }, } + ctx := context.TODO() for i, tt := range tests { comment := Commentf("case:%v sql:%s", i, tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") @@ -2062,9 +2186,9 @@ func (s *testPlanSuite) TestOuterJoinEliminator(c *C) { is: s.is, colMapper: make(map[*ast.ColumnNameExpr]int), } - p, err := builder.Build(stmt) + p, err := builder.Build(ctx, stmt) c.Assert(err, IsNil) - p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + p, err = logicalOptimize(ctx, builder.optFlag, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, comment) } @@ -2083,6 +2207,7 @@ func (s *testPlanSuite) TestSelectView(c *C) { best: "DataScan(t)->Projection", }, } + ctx := context.TODO() for i, tt := range tests { comment := Commentf("case:%v sql:%s", i, tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") @@ -2093,9 +2218,9 @@ func (s *testPlanSuite) TestSelectView(c *C) { is: s.is, colMapper: make(map[*ast.ColumnNameExpr]int), } - p, err := builder.Build(stmt) + p, err := builder.Build(ctx, stmt) c.Assert(err, IsNil) - p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + p, err = logicalOptimize(ctx, builder.optFlag, p.(LogicalPlan)) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, tt.best, comment) } @@ -2117,7 +2242,7 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { }, { sql: "select a, avg(a+1) over(partition by (a+1)) from t", - result: "TableReader(Table(t))->Projection->Sort->Window(avg(cast(2_proj_window_3)) over(partition by 2_proj_window_2))->Projection", + result: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Projection->Sort->Window(avg(cast(2_proj_window_3)) over(partition by 2_proj_window_2))->Projection", }, { sql: "select a, avg(a) over(order by a asc, b desc) from t order by a asc, b desc", @@ -2137,7 +2262,7 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { }, { sql: "select sum(avg(a)) over() from t", - result: "TableReader(Table(t)->StreamAgg)->StreamAgg->Window(sum(sel_agg_2) over())->Projection", + result: "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->StreamAgg)->StreamAgg->Window(sum(sel_agg_2) over())->Projection", }, { sql: "select b from t order by(sum(a) over())", @@ -2153,7 +2278,7 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { }, { sql: "select a from t having (select sum(a) over() as w from t tt where a > t.a)", - result: "Apply{TableReader(Table(t))->TableReader(Table(t)->Sel([gt(test.tt.a, test.t.a)]))->Window(sum(cast(test.tt.a)) over())->MaxOneRow->Sel([w])}->Projection", + result: "Apply{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Sel([gt(test.tt.a, test.t.a)]))->Window(sum(cast(test.tt.a)) over())->MaxOneRow->Sel([w])}->Projection", }, { sql: "select avg(a) over() as w from t having w > 1", @@ -2185,7 +2310,7 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { }, { sql: "select sum(a) over w from t window w as (rows between 1 preceding AND 1 following)", - result: "TableReader(Table(t))->Window(sum(cast(test.t.a)) over(rows between 1 preceding and 1 following))->Projection", + result: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Window(sum(cast(test.t.a)) over(rows between 1 preceding and 1 following))->Projection", }, { sql: "select sum(a) over(w order by b) from t window w as (order by a)", @@ -2253,7 +2378,7 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { }, { sql: "select row_number() over(rows between 1 preceding and 1 following) from t", - result: "TableReader(Table(t))->Window(row_number() over())->Projection", + result: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Window(row_number() over())->Projection", }, { sql: "select avg(b), max(avg(b)) over(rows between 1 preceding and 1 following) max from t group by c", @@ -2263,6 +2388,10 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { sql: "select nth_value(a, 1.0) over() from t", result: "[planner:1210]Incorrect arguments to nth_value", }, + { + sql: "SELECT NTH_VALUE(a, 1.0) OVER() FROM t", + result: "[planner:1210]Incorrect arguments to nth_value", + }, { sql: "select nth_value(a, 0) over() from t", result: "[planner:1210]Incorrect arguments to nth_value", @@ -2273,7 +2402,7 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { }, { sql: "select ntile(null) over() from t", - result: "TableReader(Table(t))->Window(ntile() over())->Projection", + result: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Window(ntile() over())->Projection", }, { sql: "select avg(a) over w from t window w as(partition by b)", @@ -2291,38 +2420,140 @@ func (s *testPlanSuite) TestWindowFunction(c *C) { sql: "delete from t order by (sum(a) over())", result: "[planner:3593]You cannot use the window function 'sum' in this context.'", }, + { + sql: "delete from t order by (SUM(a) over())", + result: "[planner:3593]You cannot use the window function 'sum' in this context.'", + }, + { + sql: "SELECT * from t having ROW_NUMBER() over()", + result: "[planner:3593]You cannot use the window function 'row_number' in this context.'", + }, + { + // The best execution order should be (a,c), (a, b, c), (a, b), (), it requires only 2 sort operations. + sql: "select sum(a) over (partition by a order by b), sum(b) over (order by a, b, c), sum(c) over(partition by a order by c), sum(d) over() from t", + result: "TableReader(Table(t))->Sort->Window(sum(cast(test.t.c)) over(partition by test.t.a order by test.t.c asc range between unbounded preceding and current row))->Sort->Window(sum(cast(test.t.b)) over(order by test.t.a asc, test.t.b asc, test.t.c asc range between unbounded preceding and current row))->Window(sum(cast(test.t.a)) over(partition by test.t.a order by test.t.b asc range between unbounded preceding and current row))->Window(sum(cast(test.t.d)) over())->Projection", + }, + // Test issue 11010. + { + sql: "select dense_rank() over w1, a, b from t window w1 as (partition by t.b order by t.a desc, t.b desc range between current row and 1 following)", + result: "[planner:3587]Window 'w1' with RANGE N PRECEDING/FOLLOWING frame requires exactly one ORDER BY expression, of numeric or temporal type", + }, + { + sql: "select dense_rank() over w1, a, b from t window w1 as (partition by t.b order by t.a desc, t.b desc range between current row and unbounded following)", + result: "TableReader(Table(t))->Sort->Window(dense_rank() over(partition by test.t.b order by test.t.a desc, test.t.b desc))->Projection", + }, + { + sql: "select dense_rank() over w1, a, b from t window w1 as (partition by t.b order by t.a desc, t.b desc range between 1 preceding and 1 following)", + result: "[planner:3587]Window 'w1' with RANGE N PRECEDING/FOLLOWING frame requires exactly one ORDER BY expression, of numeric or temporal type", + }, + // Test issue 11001. + { + sql: "SELECT PERCENT_RANK() OVER w1 AS 'percent_rank', fieldA, fieldB FROM ( SELECT a AS fieldA, b AS fieldB FROM t ) t1 WINDOW w1 AS ( ROWS BETWEEN 0 FOLLOWING AND UNBOUNDED PRECEDING)", + result: "[planner:3585]Window 'w1': frame end cannot be UNBOUNDED PRECEDING.", + }, + // Test issue 11002. + { + sql: "SELECT PERCENT_RANK() OVER w1 AS 'percent_rank', fieldA, fieldB FROM ( SELECT a AS fieldA, b AS fieldB FROM t ) as t1 WINDOW w1 AS ( ROWS BETWEEN UNBOUNDED FOLLOWING AND UNBOUNDED FOLLOWING)", + result: "[planner:3584]Window 'w1': frame start cannot be UNBOUNDED FOLLOWING.", + }, + // Test issue 11011. + { + sql: "select dense_rank() over w1, a, b from t window w1 as (partition by t.b order by t.a asc range between 1250951168 following AND 1250951168 preceding)", + result: "[planner:3586]Window 'w1': frame start or end is negative, NULL or of non-integral type", + }, + // Test issue 10556. + { + sql: "SELECT FIRST_VALUE(a) IGNORE NULLS OVER () FROM t", + result: "[planner:1235]This version of TiDB doesn't yet support 'IGNORE NULLS'", + }, + { + sql: "SELECT SUM(DISTINCT a) OVER () FROM t", + result: "[planner:1235]This version of TiDB doesn't yet support '(DISTINCT ..)'", + }, + { + sql: "SELECT NTH_VALUE(a, 1) FROM LAST over (partition by b order by b), a FROM t", + result: "[planner:1235]This version of TiDB doesn't yet support 'FROM LAST'", + }, + { + sql: "SELECT NTH_VALUE(a, 1) FROM LAST IGNORE NULLS over (partition by b order by b), a FROM t", + result: "[planner:1235]This version of TiDB doesn't yet support 'IGNORE NULLS'", + }, + { + sql: "SELECT NTH_VALUE(fieldA, ATAN(-1)) OVER (w1) AS 'ntile', fieldA, fieldB FROM ( SELECT a AS fieldA, b AS fieldB FROM t ) as te WINDOW w1 AS ( ORDER BY fieldB ASC, fieldA DESC )", + result: "[planner:1210]Incorrect arguments to nth_value", + }, + { + sql: "SELECT NTH_VALUE(fieldA, -1) OVER (w1 PARTITION BY fieldB ORDER BY fieldB , fieldA ) AS 'ntile', fieldA, fieldB FROM ( SELECT a AS fieldA, b AS fieldB FROM t ) as temp WINDOW w1 AS ( ORDER BY fieldB ASC, fieldA DESC )", + result: "[planner:1210]Incorrect arguments to nth_value", + }, + { + sql: "SELECT SUM(a) OVER w AS 'sum' FROM t WINDOW w AS (ROWS BETWEEN 1 FOLLOWING AND CURRENT ROW )", + result: "[planner:3586]Window 'w': frame start or end is negative, NULL or of non-integral type", + }, + { + sql: "SELECT SUM(a) OVER w AS 'sum' FROM t WINDOW w AS (ROWS BETWEEN CURRENT ROW AND 1 PRECEDING )", + result: "[planner:3586]Window 'w': frame start or end is negative, NULL or of non-integral type", + }, + { + sql: "SELECT SUM(a) OVER w AS 'sum' FROM t WINDOW w AS (ROWS BETWEEN 1 FOLLOWING AND 1 PRECEDING )", + result: "[planner:3586]Window 'w': frame start or end is negative, NULL or of non-integral type", + }, + // Test issue 11943 + { + sql: "SELECT ROW_NUMBER() OVER (partition by b) + a FROM t", + result: "TableReader(Table(t))->Sort->Window(row_number() over(partition by test.t.b))->Projection->Projection", + }, } s.Parser.EnableWindowFunc(true) defer func() { s.Parser.EnableWindowFunc(false) }() + ctx := context.TODO() for i, tt := range tests { comment := Commentf("case:%v sql:%s", i, tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") - c.Assert(err, IsNil, comment) - Preprocess(s.ctx, stmt, s.is) - builder := &PlanBuilder{ - ctx: MockContext(), - is: s.is, - colMapper: make(map[*ast.ColumnNameExpr]int), - } - p, err := builder.Build(stmt) + p, stmt, err := s.optimize(ctx, tt.sql) if err != nil { c.Assert(err.Error(), Equals, tt.result, comment) continue } + c.Assert(ToString(p), Equals, tt.result, comment) + + var sb strings.Builder + // After restore, the result should be the same. + err = stmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) c.Assert(err, IsNil) - p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) - c.Assert(err, IsNil) - lp, ok := p.(LogicalPlan) - c.Assert(ok, IsTrue) - p, err = physicalOptimize(lp) - c.Assert(err, IsNil) + p, _, err = s.optimize(ctx, sb.String()) + if err != nil { + c.Assert(err.Error(), Equals, tt.result, comment) + continue + } c.Assert(ToString(p), Equals, tt.result, comment) } } +func (s *testPlanSuite) optimize(ctx context.Context, sql string) (PhysicalPlan, ast.Node, error) { + stmt, err := s.ParseOneStmt(sql, "", "") + if err != nil { + return nil, nil, err + } + err = Preprocess(s.ctx, stmt, s.is) + if err != nil { + return nil, nil, err + } + builder := NewPlanBuilder(MockContext(), s.is) + p, err := builder.Build(ctx, stmt) + if err != nil { + return nil, nil, err + } + p, err = logicalOptimize(ctx, builder.optFlag, p.(LogicalPlan)) + if err != nil { + return nil, nil, err + } + p, err = physicalOptimize(p.(LogicalPlan)) + return p.(PhysicalPlan), stmt, err +} + func byItemsToProperty(byItems []*ByItems) *property.PhysicalProperty { pp := &property.PhysicalProperty{} for _, item := range byItems { @@ -2381,7 +2612,12 @@ func (s *testPlanSuite) TestSkylinePruning(c *C) { sql: "select * from t where f > 1 and g > 1", result: "PRIMARY_KEY,f,g,f_g", }, + { + sql: "select count(1) from t", + result: "c_d_e,f,g,f_g,c_d_e_str,e_d_c_str_prefix", + }, } + ctx := context.TODO() for i, tt := range tests { comment := Commentf("case:%v sql:%s", i, tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") @@ -2392,13 +2628,13 @@ func (s *testPlanSuite) TestSkylinePruning(c *C) { is: s.is, colMapper: make(map[*ast.ColumnNameExpr]int), } - p, err := builder.Build(stmt) + p, err := builder.Build(ctx, stmt) if err != nil { c.Assert(err.Error(), Equals, tt.result, comment) continue } c.Assert(err, IsNil) - p, err = logicalOptimize(builder.optFlag, p.(LogicalPlan)) + p, err = logicalOptimize(ctx, builder.optFlag, p.(LogicalPlan)) c.Assert(err, IsNil) lp := p.(LogicalPlan) _, err = lp.recursiveDeriveStats() @@ -2420,3 +2656,45 @@ func (s *testPlanSuite) TestSkylinePruning(c *C) { c.Assert(pathsName(paths), Equals, tt.result) } } + +func (s *testPlanSuite) TestFastPlanContextTables(c *C) { + defer testleak.AfterTest(c)() + tests := []struct { + sql string + fastPlan bool + }{ + { + "select * from t where a=1", + true, + }, + { + + "update t set f=0 where a=43215", + true, + }, + { + "delete from t where a =43215", + true, + }, + { + "select * from t where a>1", + false, + }, + } + for _, tt := range tests { + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil) + Preprocess(s.ctx, stmt, s.is) + s.ctx.GetSessionVars().StmtCtx.Tables = nil + p := TryFastPlan(s.ctx, stmt) + if tt.fastPlan { + c.Assert(p, NotNil) + c.Assert(len(s.ctx.GetSessionVars().StmtCtx.Tables), Equals, 1) + c.Assert(s.ctx.GetSessionVars().StmtCtx.Tables[0].Table, Equals, "t") + c.Assert(s.ctx.GetSessionVars().StmtCtx.Tables[0].DB, Equals, "test") + } else { + c.Assert(p, IsNil) + c.Assert(len(s.ctx.GetSessionVars().StmtCtx.Tables), Equals, 0) + } + } +} diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index d6c2bfadbf922..bc92c0241dcf6 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -309,6 +309,10 @@ type LogicalTableDual struct { logicalSchemaProducer RowCount int + // placeHolder indicates if this dual plan is a place holder in query optimization + // for data sources like `Show`, if true, the dual plan would be substituted by + // `Show` in the final plan. + placeHolder bool } // LogicalUnionScan is only used in non read-only txn. @@ -318,7 +322,7 @@ type LogicalUnionScan struct { conditions []expression.Expression } -// DataSource represents a tablescan without condition push down. +// DataSource represents a tableScan without condition push down. type DataSource struct { logicalSchemaProducer @@ -337,6 +341,7 @@ type DataSource struct { allConds []expression.Expression statisticTable *statistics.Table + tableStats *property.StatsInfo // possibleAccessPaths stores all the possible access path for physical plan, including table scan. possibleAccessPaths []*accessPath @@ -367,6 +372,16 @@ type accessPath struct { forced bool } +// getTablePath finds the TablePath from a group of accessPaths. +func getTablePath(paths []*accessPath) *accessPath { + for _, path := range paths { + if path.isTablePath { + return path + } + } + return nil +} + // deriveTablePathStats will fulfill the information that the accessPath need. // And it will check whether the primary key is covered only by point query. func (ds *DataSource) deriveTablePathStats(path *accessPath) (bool, error) { @@ -470,7 +485,7 @@ func (ds *DataSource) deriveIndexPathStats(path *accessPath) (bool, error) { path.tableFilters = res.RemainedConds path.eqCondCount = res.EqCondCount eqOrInCount = res.EqOrInCount - path.countAfterAccess, err = ds.stats.HistColl.GetRowCountByIndexRanges(sc, path.index.ID, path.ranges) + path.countAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, path.index.ID, path.ranges) if err != nil { return false, err } @@ -503,9 +518,9 @@ func (ds *DataSource) deriveIndexPathStats(path *accessPath) (bool, error) { path.countAfterAccess = math.Min(ds.stats.RowCount/selectionFactor, float64(ds.statisticTable.Count)) } if path.indexFilters != nil { - selectivity, _, err := ds.stats.HistColl.Selectivity(ds.ctx, path.indexFilters) + selectivity, _, err := ds.tableStats.HistColl.Selectivity(ds.ctx, path.indexFilters) if err != nil { - logutil.Logger(context.Background()).Warn("calculate selectivity faild, use selection factor", zap.Error(err)) + logutil.Logger(context.Background()).Debug("calculate selectivity failed, use selection factor", zap.Error(err)) selectivity = selectionFactor } path.countAfterIndex = math.Max(path.countAfterAccess*selectivity, ds.stats.RowCount) diff --git a/planner/core/mock.go b/planner/core/mock.go index e91146078a389..3a2cfe1fad51e 100644 --- a/planner/core/mock.go +++ b/planner/core/mock.go @@ -39,8 +39,8 @@ func newDateType() types.FieldType { return *ft } -// MockTable is only used for plan related tests. -func MockTable() *model.TableInfo { +// MockSignedTable is only used for plan related tests. +func MockSignedTable() *model.TableInfo { // column: a, b, c, d, e, c_str, d_str, e_str, f, g // PK: a // indeices: c_d_e, e, f, g, f_g, c_d_e_str, c_d_e_str_prefix @@ -263,6 +263,75 @@ func MockTable() *model.TableInfo { return table } +// MockUnsignedTable is only used for plan related tests. +func MockUnsignedTable() *model.TableInfo { + // column: a, b + // PK: a + // indeices: b + indices := []*model.IndexInfo{ + { + Name: model.NewCIStr("b"), + Columns: []*model.IndexColumn{ + { + Name: model.NewCIStr("b"), + Length: types.UnspecifiedLength, + Offset: 1, + }, + }, + State: model.StatePublic, + Unique: true, + }, + { + Name: model.NewCIStr("b_c"), + Columns: []*model.IndexColumn{ + { + Name: model.NewCIStr("b"), + Length: types.UnspecifiedLength, + Offset: 1, + }, + { + Name: model.NewCIStr("c"), + Length: types.UnspecifiedLength, + Offset: 2, + }, + }, + State: model.StatePublic, + }, + } + pkColumn := &model.ColumnInfo{ + State: model.StatePublic, + Offset: 0, + Name: model.NewCIStr("a"), + FieldType: newLongType(), + ID: 1, + } + col0 := &model.ColumnInfo{ + State: model.StatePublic, + Offset: 1, + Name: model.NewCIStr("b"), + FieldType: newLongType(), + ID: 2, + } + col1 := &model.ColumnInfo{ + State: model.StatePublic, + Offset: 2, + Name: model.NewCIStr("c"), + FieldType: newLongType(), + ID: 3, + } + pkColumn.Flag = mysql.PriKeyFlag | mysql.NotNullFlag | mysql.UnsignedFlag + // Column 'b', 'c', 'd', 'f', 'g' is not null. + col0.Flag = mysql.NotNullFlag + col1.Flag = mysql.UnsignedFlag + table := &model.TableInfo{ + Columns: []*model.ColumnInfo{pkColumn, col0, col1}, + Indices: indices, + Name: model.NewCIStr("t2"), + PKIsHandle: true, + } + return table +} + // MockView is only used for plan related tests. func MockView() *model.TableInfo { selectStmt := "select b,c,d from t" @@ -308,7 +377,7 @@ func MockContext() sessionctx.Context { // MockPartitionInfoSchema mocks an info schema for partition table. func MockPartitionInfoSchema(definitions []model.PartitionDefinition) infoschema.InfoSchema { - tableInfo := MockTable() + tableInfo := MockSignedTable() cols := make([]*model.ColumnInfo, 0, len(tableInfo.Columns)) cols = append(cols, tableInfo.Columns...) last := tableInfo.Columns[len(tableInfo.Columns)-1] diff --git a/planner/core/optimizer.go b/planner/core/optimizer.go index e0fd6c507c304..04ac09e2be074 100644 --- a/planner/core/optimizer.go +++ b/planner/core/optimizer.go @@ -14,6 +14,7 @@ package core import ( + "context" "math" "github.com/pingcap/errors" @@ -24,11 +25,12 @@ import ( "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/set" "go.uber.org/atomic" ) // OptimizeAstNode optimizes the query to a physical plan directly. -var OptimizeAstNode func(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (Plan, error) +var OptimizeAstNode func(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (Plan, error) // AllowCartesianProduct means whether tidb allows cartesian join without equal conditions. var AllowCartesianProduct = atomic.NewBool(true) @@ -65,19 +67,20 @@ var optRuleList = []logicalOptRule{ // logicalOptRule means a logical optimizing rule, which contains decorrelate, ppd, column pruning, etc. type logicalOptRule interface { - optimize(LogicalPlan) (LogicalPlan, error) + optimize(context.Context, LogicalPlan) (LogicalPlan, error) + name() string } // BuildLogicalPlan used to build logical plan from ast.Node. -func BuildLogicalPlan(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (Plan, error) { - ctx.GetSessionVars().PlanID = 0 - ctx.GetSessionVars().PlanColumnID = 0 +func BuildLogicalPlan(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (Plan, error) { + sctx.GetSessionVars().PlanID = 0 + sctx.GetSessionVars().PlanColumnID = 0 builder := &PlanBuilder{ - ctx: ctx, + ctx: sctx, is: is, colMapper: make(map[*ast.ColumnNameExpr]int), } - p, err := builder.Build(node) + p, err := builder.Build(ctx, node) if err != nil { return nil, err } @@ -98,8 +101,8 @@ func CheckPrivilege(activeRoles []*auth.RoleIdentity, pm privilege.Manager, vs [ } // DoOptimize optimizes a logical plan to a physical plan. -func DoOptimize(flag uint64, logic LogicalPlan) (PhysicalPlan, error) { - logic, err := logicalOptimize(flag, logic) +func DoOptimize(ctx context.Context, flag uint64, logic LogicalPlan) (PhysicalPlan, error) { + logic, err := logicalOptimize(ctx, flag, logic) if err != nil { return nil, err } @@ -120,16 +123,16 @@ func postOptimize(plan PhysicalPlan) PhysicalPlan { return plan } -func logicalOptimize(flag uint64, logic LogicalPlan) (LogicalPlan, error) { +func logicalOptimize(ctx context.Context, flag uint64, logic LogicalPlan) (LogicalPlan, error) { var err error for i, rule := range optRuleList { // The order of flags is same as the order of optRule in the list. // We use a bitmask to record which opt rules should be used. If the i-th bit is 1, it means we should // apply i-th optimizing rule. - if flag&(1<= unix_timestamp('2008-01-01')", schema) + c.Assert(err, IsNil) + queryExpr, err = expression.ParseSimpleExprsWithSchema(ctx, "report_updated > '2008-05-01 00:00:00'", schema) + c.Assert(err, IsNil) + succ, err = s.canBePruned(ctx, nil, partitionExpr[0], queryExpr) + c.Assert(err, IsNil) + c.Assert(succ, IsTrue) + + queryExpr, err = expression.ParseSimpleExprsWithSchema(ctx, "report_updated > unix_timestamp('2008-05-01 00:00:00')", schema) + c.Assert(err, IsNil) + succ, err = s.canBePruned(ctx, nil, partitionExpr[0], queryExpr) + c.Assert(err, IsNil) + _ = succ + // c.Assert(succ, IsTrue) + // TODO: Uncomment the check after fixing issue https://github.com/pingcap/tidb/issues/12028 + // report_updated > unix_timestamp('2008-05-01 00:00:00') is converted to gt(t.t.report_updated, ) + // Because unix_timestamp('2008-05-01 00:00:00') is fold to constant int 1564761600, and compare it with timestamp (report_updated) + // need to convert 1564761600 to a timestamp, during that step, an error happen and the result is set to } diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index 6772d1fbbb0b0..fc508d11935fb 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -28,20 +28,30 @@ import ( "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/util/testleak" + "github.com/pingcap/tidb/util/testutil" ) var _ = Suite(&testPlanSuite{}) type testPlanSuite struct { *parser.Parser - is infoschema.InfoSchema + + testData testutil.TestData } func (s *testPlanSuite) SetUpSuite(c *C) { - s.is = infoschema.MockInfoSchema([]*model.TableInfo{core.MockTable()}) + s.is = infoschema.MockInfoSchema([]*model.TableInfo{core.MockSignedTable(), core.MockUnsignedTable()}) s.Parser = parser.New() s.Parser.EnableWindowFunc(true) + + var err error + s.testData, err = testutil.LoadTestSuiteData("testdata", "plan_suite") + c.Assert(err, IsNil) +} + +func (s *testPlanSuite) TearDownSuite(c *C) { + c.Assert(s.testData.GenerateOutputIfNeeded(), IsNil) } func (s *testPlanSuite) TestDAGPlanBuilderSimpleCase(c *C) { @@ -56,171 +66,26 @@ func (s *testPlanSuite) TestDAGPlanBuilderSimpleCase(c *C) { c.Assert(err, IsNil) _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) - tests := []struct { - sql string - best string - }{ - // Test index hint. - { - sql: "select * from t t1 use index(c_d_e)", - best: "IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))", - }, - // Test ts + Sort vs. DoubleRead + filter. - { - sql: "select a from t where a between 1 and 2 order by c", - best: "TableReader(Table(t))->Sort->Projection", - }, - // Test DNF condition + Double Read. - { - sql: "select * from t where (t.c > 0 and t.c < 2) or (t.c > 4 and t.c < 6) or (t.c > 8 and t.c < 10) or (t.c > 12 and t.c < 14) or (t.c > 16 and t.c < 18)", - best: "IndexLookUp(Index(t.c_d_e)[(0,2) (4,6) (8,10) (12,14) (16,18)], Table(t))", - }, - { - sql: "select * from t where (t.c > 0 and t.c < 1) or (t.c > 2 and t.c < 3) or (t.c > 4 and t.c < 5) or (t.c > 6 and t.c < 7) or (t.c > 9 and t.c < 10)", - best: "Dual", - }, - // Test TopN to table branch in double read. - { - sql: "select * from t where t.c = 1 and t.e = 1 order by t.b limit 1", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t))->TopN([test.t.b],0,1)", - }, - // Test Null Range - { - sql: "select * from t where t.e_str is null", - best: "IndexLookUp(Index(t.e_d_c_str_prefix)[[NULL,NULL]], Table(t))", - }, - // Test Null Range but the column has not null flag. - { - sql: "select * from t where t.c is null", - best: "Dual", - }, - // Test TopN to index branch in double read. - { - sql: "select * from t where t.c = 1 and t.e = 1 order by t.e limit 1", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t))->TopN([test.t.e],0,1)", - }, - // Test TopN to Limit in double read. - { - sql: "select * from t where t.c = 1 and t.e = 1 order by t.d limit 1", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)])->Limit, Table(t))->Limit", - }, - // Test TopN to Limit in index single read. - { - sql: "select c from t where t.c = 1 and t.e = 1 order by t.d limit 1", - best: "IndexReader(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)])->Limit)->Limit->Projection", - }, - // Test TopN to Limit in table single read. - { - sql: "select c from t order by t.a limit 1", - best: "TableReader(Table(t)->Limit)->Limit->Projection", - }, - // Test TopN push down in table single read. - { - sql: "select c from t order by t.a + t.b limit 1", - best: "TableReader(Table(t)->TopN([plus(test.t.a, test.t.b)],0,1))->Projection->TopN([col_3],0,1)->Projection->Projection", - }, - // Test Limit push down in table single read. - { - sql: "select c from t limit 1", - best: "TableReader(Table(t)->Limit)->Limit", - }, - // Test Limit push down in index single read. - { - sql: "select c from t where c = 1 limit 1", - best: "IndexReader(Index(t.c_d_e)[[1,1]]->Limit)->Limit", - }, - // Test index single read and Selection. - { - sql: "select c from t where c = 1", - best: "IndexReader(Index(t.c_d_e)[[1,1]])", - }, - // Test index single read and Sort. - { - sql: "select c from t order by c", - best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])", - }, - // Test index single read and Sort. - { - sql: "select c from t where c = 1 order by e", - best: "IndexReader(Index(t.c_d_e)[[1,1]])->Sort->Projection", - }, - // Test Limit push down in double single read. - { - sql: "select c, b from t where c = 1 limit 1", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Limit, Table(t))->Limit->Projection", - }, - // Test Selection + Limit push down in double single read. - { - sql: "select c, b from t where c = 1 and e = 1 and b = 1 limit 1", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)])->Limit)->Limit->Projection", - }, - // Test Order by multi columns. - { - sql: "select c from t where c = 1 order by d, c", - best: "IndexReader(Index(t.c_d_e)[[1,1]])->Sort->Projection", - }, - // Test for index with length. - { - sql: "select c_str from t where e_str = '1' order by d_str, c_str", - best: `IndexLookUp(Index(t.e_d_c_str_prefix)[["1","1"]], Table(t))->Sort->Projection`, - }, - // Test PK in index single read. - { - sql: "select c from t where t.c = 1 and t.a > 1 order by t.d limit 1", - best: "IndexReader(Index(t.c_d_e)[[1,1]]->Sel([gt(test.t.a, 1)])->Limit)->Limit->Projection", - }, - // Test composed index. - // FIXME: The TopN didn't be pushed. - { - sql: "select c from t where t.c = 1 and t.d = 1 order by t.a limit 1", - best: "IndexReader(Index(t.c_d_e)[[1 1,1 1]])->TopN([test.t.a],0,1)->Projection", - }, - // Test PK in index double read. - { - sql: "select * from t where t.c = 1 and t.a > 1 order by t.d limit 1", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([gt(test.t.a, 1)])->Limit, Table(t))->Limit", - }, - // Test index filter condition push down. - { - sql: "select * from t use index(e_d_c_str_prefix) where t.c_str = 'abcdefghijk' and t.d_str = 'd' and t.e_str = 'e'", - best: "IndexLookUp(Index(t.e_d_c_str_prefix)[[\"e\" \"d\" \"abcdefghij\",\"e\" \"d\" \"abcdefghij\"]], Table(t)->Sel([eq(test.t.c_str, abcdefghijk)]))", - }, - { - sql: "select * from t use index(e_d_c_str_prefix) where t.e_str = b'1110000'", - best: "IndexLookUp(Index(t.e_d_c_str_prefix)[[\"p\",\"p\"]], Table(t))", - }, - { - sql: "select * from (select * from t use index() order by b) t left join t t1 on t.a=t1.a limit 10", - best: "IndexJoin{TableReader(Table(t)->TopN([test.t.b],0,10))->TopN([test.t.b],0,10)->TableReader(Table(t))}(test.t.a,test.t1.a)->Limit", - }, - // Test embedded ORDER BY which imposes on different number of columns than outer query. - { - sql: "select * from ((SELECT 1 a,3 b) UNION (SELECT 2,1) ORDER BY (SELECT 2)) t order by a,b", - best: "UnionAll{Dual->Projection->Dual->Projection}->HashAgg->Sort", - }, - { - sql: "select * from ((SELECT 1 a,6 b) UNION (SELECT 2,5) UNION (SELECT 2, 4) ORDER BY 1) t order by 1, 2", - best: "UnionAll{Dual->Projection->Dual->Projection->Dual->Projection}->HashAgg->Sort->Sort", - }, - { - sql: "select * from (select *, NULL as xxx from t) t order by xxx", - best: "TableReader(Table(t))->Projection", - }, - { - sql: "select lead(a, 1) over (partition by null) as c from t", - best: "TableReader(Table(t))->Window(lead(test.t.a, 1) over())->Projection", - }, + var input []string + var output []struct { + SQL string + Best string } - for i, tt := range tests { - comment := Commentf("case:%v sql:%s", i, tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("case:%v sql:%s", i, tt) + stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) err = se.NewTxn(context.Background()) c.Assert(err, IsNil) - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, comment) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, comment) } } @@ -237,231 +102,24 @@ func (s *testPlanSuite) TestDAGPlanBuilderJoin(c *C) { _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) - tests := []struct { - sql string - best string - }{ - { - sql: "select * from t t1 join t t2 on t1.a = t2.c_str", - best: "LeftHashJoin{TableReader(Table(t))->Projection->TableReader(Table(t))->Projection}(cast(test.t1.a),cast(test.t2.c_str))->Projection", - }, - { - sql: "select * from t t1 join t t2 on t1.b = t2.a", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)", - }, - { - sql: "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.a = t3.a", - best: "MergeInnerJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t1.a,test.t3.a)", - }, - { - sql: "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.b = t3.a", - best: "LeftHashJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t1.b,test.t3.a)", - }, - { - sql: "select * from t t1 join t t2 on t1.b = t2.a order by t1.a", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)->Sort", - }, - { - sql: "select * from t t1 join t t2 on t1.b = t2.a order by t1.a limit 1", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)->Limit", - }, - // Test hash join's hint. - { - sql: "select /*+ TIDB_HJ(t1, t2) */ * from t t1 join t t2 on t1.b = t2.a order by t1.a limit 1", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)->TopN([test.t1.a],0,1)", - }, - { - sql: "select * from t t1 left join t t2 on t1.b = t2.a where 1 = 1 limit 1", - best: "IndexJoin{TableReader(Table(t)->Limit)->Limit->TableReader(Table(t))}(test.t1.b,test.t2.a)->Limit", - }, - { - sql: "select * from t t1 join t t2 on t1.b = t2.a and t1.c = 1 and t1.d = 1 and t1.e = 1 order by t1.a limit 1", - best: "IndexJoin{IndexLookUp(Index(t.c_d_e)[[1 1 1,1 1 1]], Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)->TopN([test.t1.a],0,1)", - }, - { - sql: "select * from t t1 join t t2 on t1.b = t2.b join t t3 on t1.b = t3.b", - best: "LeftHashJoin{LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.b)->TableReader(Table(t))}(test.t1.b,test.t3.b)", - }, - { - sql: "select * from t t1 join t t2 on t1.a = t2.a order by t1.a", - best: "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)", - }, - { - sql: "select * from t t1 left outer join t t2 on t1.a = t2.a right outer join t t3 on t1.a = t3.a", - best: "MergeRightOuterJoin{MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t1.a,test.t3.a)", - }, - { - sql: "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.a = t3.a and t1.b = 1 and t3.c = 1", - best: "IndexJoin{IndexJoin{TableReader(Table(t)->Sel([eq(test.t1.b, 1)]))->IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t))}(test.t3.a,test.t1.a)->TableReader(Table(t))}(test.t1.a,test.t2.a)->Projection", - }, - { - sql: "select * from t where t.c in (select b from t s where s.a = t.a)", - best: "MergeSemiJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t.a,test.s.a)", - }, - { - sql: "select t.c in (select b from t s where s.a = t.a) from t", - best: "MergeLeftOuterSemiJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t.a,test.s.a)->Projection", - }, - // Test Single Merge Join. - // Merge Join now enforce a sort. - { - sql: "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.b", - best: "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))->Sort}(test.t1.a,test.t2.b)", - }, - { - sql: "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.a", - best: "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)", - }, - // Test Single Merge Join + Sort. - { - sql: "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.a order by t2.a", - best: "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)", - }, - { - sql: "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.b = t2.b order by t2.a", - best: "MergeInnerJoin{TableReader(Table(t))->Sort->TableReader(Table(t))->Sort}(test.t1.b,test.t2.b)->Sort", - }, - // Test Single Merge Join + Sort + desc. - { - sql: "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.a order by t2.a desc", - best: "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)", - }, - { - sql: "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.b = t2.b order by t2.b desc", - best: "MergeInnerJoin{TableReader(Table(t))->Sort->TableReader(Table(t))->Sort}(test.t1.b,test.t2.b)", - }, - // Test Multi Merge Join. - { - sql: "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.a = t2.a and t2.a = t3.a", - best: "MergeInnerJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t2.a,test.t3.a)", - }, - { - sql: "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.a = t2.b and t2.a = t3.b", - best: "MergeInnerJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))->Sort}(test.t1.a,test.t2.b)->Sort->TableReader(Table(t))->Sort}(test.t2.a,test.t3.b)", - }, - // Test Multi Merge Join with multi keys. - // TODO: More tests should be added. - { - sql: "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.c = t2.c and t1.d = t2.d and t3.c = t1.c and t3.d = t1.d", - best: "MergeInnerJoin{MergeInnerJoin{IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)(test.t1.d,test.t2.d)->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t3.c)(test.t1.d,test.t3.d)", - }, - { - sql: "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.c = t2.c and t1.d = t2.d and t3.c = t1.c and t3.d = t1.d order by t1.c", - best: "MergeInnerJoin{MergeInnerJoin{IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)(test.t1.d,test.t2.d)->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t3.c)(test.t1.d,test.t3.d)", - }, - // Test Multi Merge Join + Outer Join. - { - sql: "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1 left outer join t t2 on t1.a = t2.a left outer join t t3 on t2.a = t3.a", - best: "MergeLeftOuterJoin{MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t2.a,test.t3.a)", - }, - { - sql: "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1 left outer join t t2 on t1.a = t2.a left outer join t t3 on t1.a = t3.a", - best: "MergeLeftOuterJoin{MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t1.a,test.t3.a)", - }, - // Test Index Join + TableScan. - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.a", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)", - }, - // Test Index Join + DoubleRead. - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1, t t2 where t1.a = t2.c", - best: "IndexJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.a,test.t2.c)", - }, - // Test Index Join + SingleRead. - { - sql: "select /*+ TIDB_INLJ(t2) */ t1.a , t2.a from t t1, t t2 where t1.a = t2.c", - best: "IndexJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t1.a,test.t2.c)->Projection", - }, - // Test Index Join + Order by. - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ t1.a, t2.a from t t1, t t2 where t1.a = t2.a order by t1.c", - best: "IndexJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->TableReader(Table(t))}(test.t1.a,test.t2.a)->Projection", - }, - // Test Index Join + Order by. - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ t1.a, t2.a from t t1, t t2 where t1.a = t2.a order by t2.c", - best: "IndexJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.a,test.t1.a)->Projection", - }, - // Test Index Join + TableScan + Rotate. - { - sql: "select /*+ TIDB_INLJ(t1) */ t1.a , t2.a from t t1, t t2 where t1.a = t2.c", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t2.c,test.t1.a)->Projection", - }, - // Test Index Join + OuterJoin + TableScan. - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 left outer join t t2 on t1.a = t2.a and t2.b < 1", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t)->Sel([lt(test.t2.b, 1)]))}(test.t1.a,test.t2.a)", - }, - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 join t t2 on t1.d=t2.d and t2.c = 1", - best: "IndexJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.d,test.t2.d)", - }, - // Test Index Join failed. - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 left outer join t t2 on t1.a = t2.b", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)", - }, - // Test Index Join failed. - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 right outer join t t2 on t1.a = t2.b", - best: "RightHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)", - }, - // Test Semi Join hint success. - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 where t1.a in (select a from t t2)", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->Projection", - }, - // Test Semi Join hint fail. - { - sql: "select /*+ TIDB_INLJ(t1) */ * from t t1 where t1.a in (select a from t t2)", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t2.a,test.t1.a)->Projection", - }, - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.c=t2.c and t1.f=t2.f", - best: "IndexJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)", - }, - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.a = t2.a and t1.f=t2.f", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)", - }, - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.f=t2.f and t1.a=t2.a", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)", - }, - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.a=t2.a and t2.a in (1, 2)", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t)->Sel([in(test.t2.a, 1, 2)]))}(test.t1.a,test.t2.a)", - }, - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.b=t2.c and t1.b=1 and t2.d > t1.d-10 and t2.d < t1.d+10", - best: "IndexJoin{TableReader(Table(t)->Sel([eq(test.t1.b, 1)]))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}", - }, - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.b=t2.b and t1.c=1 and t2.c=1 and t2.d > t1.d-10 and t2.d < t1.d+10", - best: "LeftHashJoin{IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t))->IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t))}(test.t1.b,test.t2.b)", - }, - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t2.c > t1.d-10 and t2.c < t1.d+10", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}", - }, - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.b = t2.c and t2.c=1 and t2.d=2 and t2.e=4", - best: "LeftHashJoin{TableReader(Table(t)->Sel([eq(test.t1.b, 1)]))->IndexLookUp(Index(t.c_d_e)[[1 2 4,1 2 4]], Table(t))}", - }, - { - sql: "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t2.c=1 and t2.d=1 and t2.e > 10 and t2.e < 20", - best: "LeftHashJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[(1 1 10,1 1 20)], Table(t))}", - }, + var input []string + var output []struct { + SQL string + Best string } - for i, tt := range tests { - comment := Commentf("case:%v sql:%s", i, tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("case:%v sql:%s", i, tt) + stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, comment) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, comment) } } @@ -478,60 +136,28 @@ func (s *testPlanSuite) TestDAGPlanBuilderSubquery(c *C) { _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) se.Execute(context.Background(), "set sql_mode='STRICT_TRANS_TABLES'") // disable only full group by - tests := []struct { - sql string - best string - }{ - // Test join key with cast. - { - sql: "select * from t where exists (select s.a from t s having sum(s.a) = t.a )", - best: "LeftHashJoin{TableReader(Table(t))->Projection->TableReader(Table(t)->StreamAgg)->StreamAgg}(cast(test.t.a),sel_agg_1)->Projection", - }, - { - sql: "select * from t where exists (select s.a from t s having sum(s.a) = t.a ) order by t.a", - best: "LeftHashJoin{TableReader(Table(t))->Projection->TableReader(Table(t)->StreamAgg)->StreamAgg}(cast(test.t.a),sel_agg_1)->Projection->Sort", - }, - // FIXME: Report error by resolver. - //{ - // sql: "select * from t where exists (select s.a from t s having s.a = t.a ) order by t.a", - // best: "SemiJoin{TableReader(Table(t))->Projection->TableReader(Table(t)->HashAgg)->HashAgg}(cast(test.t.a),sel_agg_1)->Projection->Sort", - //}, - { - sql: "select * from t where a in (select s.a from t s) order by t.a", - best: "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t.a,test.s.a)->Projection", - }, - // Test Nested sub query. - { - sql: "select * from t where exists (select s.a from t s where s.c in (select c from t as k where k.d = s.d) having sum(s.a) = t.a )", - best: "LeftHashJoin{TableReader(Table(t))->Projection->MergeSemiJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.s.c,test.k.c)(test.s.d,test.k.d)->Projection->StreamAgg}(cast(test.t.a),sel_agg_1)->Projection", - }, - // Test Semi Join + Order by. - { - sql: "select * from t where a in (select a from t) order by b", - best: "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t.a,test.t.a)->Projection->Sort", - }, - // Test Apply. - { - sql: "select t.c in (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t", - best: "Apply{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->TableReader(Table(t)->Sel([eq(test.t1.a, test.t.a)]))}(test.s.a,test.t1.a)->StreamAgg}->Projection", - }, - { - sql: "select (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t", - best: "LeftHashJoin{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.s.a,test.t1.a)->StreamAgg}(test.t.a,test.s.a)->Projection->Projection", - }, - { - sql: "select (select count(*) from t s , t t1 where s.a = t.a and s.a = t1.a) from t order by t.a", - best: "LeftHashJoin{TableReader(Table(t))->MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.s.a,test.t1.a)->StreamAgg}(test.t.a,test.s.a)->Projection->Sort->Projection", - }, + ctx := se.(sessionctx.Context) + sessionVars := ctx.GetSessionVars() + sessionVars.HashAggFinalConcurrency = 1 + sessionVars.HashAggPartialConcurrency = 1 + var input []string + var output []struct { + SQL string + Best string } - for _, tt := range tests { - comment := Commentf("for %s", tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("for %s", tt) + stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, Commentf("for %s", tt)) } } @@ -548,47 +174,24 @@ func (s *testPlanSuite) TestDAGPlanTopN(c *C) { _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) - tests := []struct { - sql string - best string - }{ - { - sql: "select * from t t1 left join t t2 on t1.b = t2.b left join t t3 on t2.b = t3.b order by t1.a limit 1", - best: "LeftHashJoin{LeftHashJoin{TableReader(Table(t)->Limit)->Limit->TableReader(Table(t))}(test.t1.b,test.t2.b)->TopN([test.t1.a],0,1)->TableReader(Table(t))}(test.t2.b,test.t3.b)->TopN([test.t1.a],0,1)", - }, - { - sql: "select * from t t1 left join t t2 on t1.b = t2.b left join t t3 on t2.b = t3.b order by t1.b limit 1", - best: "LeftHashJoin{LeftHashJoin{TableReader(Table(t)->TopN([test.t1.b],0,1))->TopN([test.t1.b],0,1)->TableReader(Table(t))}(test.t1.b,test.t2.b)->TopN([test.t1.b],0,1)->TableReader(Table(t))}(test.t2.b,test.t3.b)->TopN([test.t1.b],0,1)", - }, - { - sql: "select * from t t1 left join t t2 on t1.b = t2.b left join t t3 on t2.b = t3.b limit 1", - best: "LeftHashJoin{LeftHashJoin{TableReader(Table(t)->Limit)->Limit->TableReader(Table(t))}(test.t1.b,test.t2.b)->Limit->TableReader(Table(t))}(test.t2.b,test.t3.b)->Limit", - }, - { - sql: "select * from t where b = 1 and c = 1 order by c limit 1", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t)->Sel([eq(test.t.b, 1)]))->Limit", - }, - { - sql: "select * from t where c = 1 order by c limit 1", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Limit, Table(t))->Limit", - }, - { - sql: "select * from t order by a limit 1", - best: "TableReader(Table(t)->Limit)->Limit", - }, - { - sql: "select c from t order by c limit 1", - best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Limit)->Limit", - }, + var input []string + var output []struct { + SQL string + Best string } - for i, tt := range tests { - comment := Commentf("case:%v sql:%s", i, tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("case:%v sql:%s", i, tt) + stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, comment) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, comment) } } @@ -606,93 +209,25 @@ func (s *testPlanSuite) TestDAGPlanBuilderBasePhysicalPlan(c *C) { _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) - tests := []struct { - sql string - best string - }{ - // Test for update. - { - sql: "select * from t order by b limit 1 for update", - // TODO: This is not reasonable. Mysql do like this because the limit of InnoDB, should TiDB keep consistency with MySQL? - best: "TableReader(Table(t))->Lock->TopN([test.t.b],0,1)", - }, - // Test complex update. - { - sql: "update t set a = 5 where b < 1 order by d limit 1", - best: "TableReader(Table(t)->Sel([lt(test.t.b, 1)])->TopN([test.t.d],0,1))->TopN([test.t.d],0,1)->Update", - }, - // Test simple update. - { - sql: "update t set a = 5", - best: "TableReader(Table(t))->Update", - }, - // TODO: Test delete/update with join. - // Test join hint for delete and update - { - sql: "delete /*+ TIDB_INLJ(t1, t2) */ t1 from t t1, t t2 where t1.c=t2.c", - best: "IndexJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)->Delete", - }, - { - sql: "delete /*+ TIDB_SMJ(t1, t2) */ from t1 using t t1, t t2 where t1.c=t2.c", - best: "MergeInnerJoin{IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)->Delete", - }, - { - sql: "update /*+ TIDB_SMJ(t1, t2) */ t t1, t t2 set t1.a=1, t2.a=1 where t1.a=t2.a", - best: "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->Update", - }, - { - sql: "update /*+ TIDB_HJ(t1, t2) */ t t1, t t2 set t1.a=1, t2.a=1 where t1.a=t2.a", - best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->Update", - }, - // Test complex delete. - { - sql: "delete from t where b < 1 order by d limit 1", - best: "TableReader(Table(t)->Sel([lt(test.t.b, 1)])->TopN([test.t.d],0,1))->TopN([test.t.d],0,1)->Delete", - }, - // Test simple delete. - { - sql: "delete from t", - best: "TableReader(Table(t))->Delete", - }, - // Test "USE INDEX" hint in delete statement from single table - { - sql: "delete from t use index(c_d_e) where b = 1", - best: "IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t)->Sel([eq(test.t.b, 1)]))->Delete", - }, - // Test complex insert. - { - sql: "insert into t select * from t where b < 1 order by d limit 1", - best: "TableReader(Table(t)->Sel([lt(test.t.b, 1)])->TopN([test.t.d],0,1))->TopN([test.t.d],0,1)->Insert", - }, - // Test simple insert. - { - sql: "insert into t (a, b, c, e, f, g) values(0,0,0,0,0,0)", - best: "Insert", - }, - // Test dual. - { - sql: "select 1", - best: "Dual->Projection", - }, - { - sql: "select * from t where false", - best: "Dual", - }, - // Test show. - { - sql: "show tables", - best: "Show", - }, + var input []string + var output []struct { + SQL string + Best string } - for _, tt := range tests { - comment := Commentf("for %s", tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("for %s", tt) + stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) core.Preprocess(se, stmt, s.is) - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, Commentf("for %s", tt)) } } @@ -702,276 +237,106 @@ func (s *testPlanSuite) TestDAGPlanBuilderUnion(c *C) { c.Assert(err, IsNil) defer func() { dom.Close() - store.Close() - }() - se, err := session.CreateSession4Test(store) - c.Assert(err, IsNil) - _, err = se.Execute(context.Background(), "use test") - c.Assert(err, IsNil) - - tests := []struct { - sql string - best string - }{ - // Test simple union. - { - sql: "select * from t union all select * from t", - best: "UnionAll{TableReader(Table(t))->TableReader(Table(t))}", - }, - // Test Order by + Union. - { - sql: "select * from t union all (select * from t) order by a ", - best: "UnionAll{TableReader(Table(t))->TableReader(Table(t))}->Sort", - }, - // Test Limit + Union. - { - sql: "select * from t union all (select * from t) limit 1", - best: "UnionAll{TableReader(Table(t)->Limit)->Limit->TableReader(Table(t)->Limit)->Limit}->Limit", - }, - // Test TopN + Union. - { - sql: "select a from t union all (select c from t) order by a limit 1", - best: "UnionAll{TableReader(Table(t)->Limit)->Limit->IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Limit)->Limit}->TopN([a],0,1)", - }, - } - for i, tt := range tests { - comment := Commentf("case:%v sql:%s", i, tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") - c.Assert(err, IsNil, comment) - - p, err := planner.Optimize(se, stmt, s.is) - c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, comment) - } -} - -func (s *testPlanSuite) TestDAGPlanBuilderUnionScan(c *C) { - defer testleak.AfterTest(c)() - store, dom, err := newStoreWithBootstrap() - c.Assert(err, IsNil) - defer func() { - dom.Close() - store.Close() - }() - se, err := session.CreateSession4Test(store) - c.Assert(err, IsNil) - _, err = se.Execute(context.Background(), "use test") - c.Assert(err, IsNil) - - tests := []struct { - sql string - best string - }{ - // Read table. - { - sql: "select * from t", - best: "TableReader(Table(t))->UnionScan([])", - }, - { - sql: "select * from t where b = 1", - best: "TableReader(Table(t)->Sel([eq(test.t.b, 1)]))->UnionScan([eq(test.t.b, 1)])", - }, - { - sql: "select * from t where a = 1", - best: "TableReader(Table(t))->UnionScan([eq(test.t.a, 1)])", - }, - { - sql: "select * from t where a = 1 order by a", - best: "TableReader(Table(t))->UnionScan([eq(test.t.a, 1)])", - }, - { - sql: "select * from t where a = 1 order by b", - best: "TableReader(Table(t))->UnionScan([eq(test.t.a, 1)])->Sort", - }, - { - sql: "select * from t where a = 1 limit 1", - best: "TableReader(Table(t))->UnionScan([eq(test.t.a, 1)])->Limit", - }, - { - sql: "select * from t where c = 1", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t))->UnionScan([eq(test.t.c, 1)])", - }, - { - sql: "select c from t where c = 1", - best: "IndexReader(Index(t.c_d_e)[[1,1]])->UnionScan([eq(test.t.c, 1)])->Projection", - }, - } - for _, tt := range tests { - comment := Commentf("for %s", tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") - c.Assert(err, IsNil, comment) - - err = se.NewTxn(context.Background()) - c.Assert(err, IsNil) - // Make txn not read only. - txn, err := se.Txn(true) - c.Assert(err, IsNil) - txn.Set(kv.Key("AAA"), []byte("BBB")) - c.Assert(se.StmtCommit(), IsNil) - p, err := planner.Optimize(se, stmt, s.is) - c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) - } -} - -func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) { - defer testleak.AfterTest(c)() - store, dom, err := newStoreWithBootstrap() - c.Assert(err, IsNil) - defer func() { - dom.Close() - store.Close() - }() - se, err := session.CreateSession4Test(store) - c.Assert(err, IsNil) - se.Execute(context.Background(), "use test") - se.Execute(context.Background(), "set sql_mode='STRICT_TRANS_TABLES'") // disable only full group by - c.Assert(err, IsNil) - - tests := []struct { - sql string - best string - }{ - // Test distinct. - { - sql: "select distinct b from t", - best: "TableReader(Table(t)->HashAgg)->HashAgg", - }, - { - sql: "select count(*) from (select * from t order by b) t group by b", - best: "TableReader(Table(t))->Sort->StreamAgg", - }, - { - sql: "select count(*), x from (select b as bbb, a + 1 as x from (select * from t order by b) t) t group by bbb", - best: "TableReader(Table(t))->Sort->Projection->StreamAgg", - }, - // Test agg + table. - { - sql: "select sum(a), avg(b + c) from t group by d", - best: "TableReader(Table(t)->HashAgg)->HashAgg", - }, - { - sql: "select sum(distinct a), avg(b + c) from t group by d", - best: "TableReader(Table(t))->Projection->HashAgg", - }, - // Test group by (c + d) - { - sql: "select sum(e), avg(e + c) from t where c = 1 group by (c + d)", - best: "IndexReader(Index(t.c_d_e)[[1,1]]->HashAgg)->HashAgg", - }, - // Test stream agg + index single. - { - sql: "select sum(e), avg(e + c) from t where c = 1 group by c", - best: "IndexReader(Index(t.c_d_e)[[1,1]]->StreamAgg)->StreamAgg", - }, - // Test hash agg + index single. - { - sql: "select sum(e), avg(e + c) from t where c = 1 group by d", - best: "IndexReader(Index(t.c_d_e)[[1,1]]->HashAgg)->HashAgg", - }, - // Test hash agg + index double. - { - sql: "select sum(e), avg(b + c) from t where c = 1 and e = 1 group by d", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t))->Projection->HashAgg", - }, - // Test stream agg + index double. - { - sql: "select sum(e), avg(b + c) from t where c = 1 and b = 1 group by c", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t)->Sel([eq(test.t.b, 1)]))->Projection->Projection->StreamAgg", - }, - // Test hash agg + order. - { - sql: "select sum(e) as k, avg(b + c) from t where c = 1 and b = 1 and e = 1 group by d order by k", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)]))->Projection->HashAgg->Sort", - }, - // Test stream agg + order. - { - sql: "select sum(e) as k, avg(b + c) from t where c = 1 and b = 1 and e = 1 group by c order by k", - best: "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)]))->Projection->Projection->StreamAgg->Sort", - }, - // Test agg can't push down. - { - sql: "select sum(to_base64(e)) from t where c = 1", - best: "IndexReader(Index(t.c_d_e)[[1,1]])->Projection->StreamAgg", - }, - { - sql: "select (select count(1) k from t s where s.a = t.a having k != 0) from t", - best: "MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t.a,test.s.a)->Projection->Projection", - }, - // Test stream agg with multi group by columns. - { - sql: "select sum(to_base64(e)) from t group by e,d,c order by c", - best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Projection->StreamAgg->Projection", - }, - { - sql: "select sum(e+1) from t group by e,d,c order by c", - best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->StreamAgg)->StreamAgg->Projection", - }, - { - sql: "select sum(to_base64(e)) from t group by e,d,c order by c,e", - best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Projection->StreamAgg->Sort->Projection", - }, - { - sql: "select sum(e+1) from t group by e,d,c order by c,e", - best: "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->StreamAgg)->StreamAgg->Sort->Projection", - }, - // Test stream agg + limit or sort - { - sql: "select count(*) from t group by g order by g limit 10", - best: "IndexReader(Index(t.g)[[NULL,+inf]]->StreamAgg)->StreamAgg->Limit->Projection", - }, - { - sql: "select count(*) from t group by g limit 10", - best: "IndexReader(Index(t.g)[[NULL,+inf]]->StreamAgg)->StreamAgg->Limit", - }, - { - sql: "select count(*) from t group by g order by g", - best: "IndexReader(Index(t.g)[[NULL,+inf]]->StreamAgg)->StreamAgg->Projection", - }, - { - sql: "select count(*) from t group by g order by g desc limit 1", - best: "IndexReader(Index(t.g)[[NULL,+inf]]->StreamAgg)->StreamAgg->Limit->Projection", - }, - // Test hash agg + limit or sort - { - sql: "select count(*) from t group by b order by b limit 10", - best: "TableReader(Table(t)->HashAgg)->HashAgg->TopN([test.t.b],0,10)->Projection", - }, - { - sql: "select count(*) from t group by b order by b", - best: "TableReader(Table(t)->HashAgg)->HashAgg->Sort->Projection", - }, - { - sql: "select count(*) from t group by b limit 10", - best: "TableReader(Table(t)->HashAgg)->HashAgg->Limit", - }, - // Test merge join + stream agg - { - sql: "select sum(a.g), sum(b.g) from t a join t b on a.g = b.g group by a.g", - best: "MergeInnerJoin{IndexReader(Index(t.g)[[NULL,+inf]])->IndexReader(Index(t.g)[[NULL,+inf]])}(test.a.g,test.b.g)->Projection->StreamAgg", - }, - // Test index join + stream agg - { - sql: "select /*+ tidb_inlj(a,b) */ sum(a.g), sum(b.g) from t a join t b on a.g = b.g and a.g > 60 group by a.g order by a.g limit 1", - best: "IndexJoin{IndexReader(Index(t.g)[(60,+inf]])->IndexReader(Index(t.g)[[NULL,+inf]]->Sel([gt(test.b.g, 60)]))}(test.a.g,test.b.g)->Projection->StreamAgg->Limit->Projection", - }, - { - sql: "select sum(a.g), sum(b.g) from t a join t b on a.g = b.g and a.a>5 group by a.g order by a.g limit 1", - best: "IndexJoin{IndexReader(Index(t.g)[[NULL,+inf]]->Sel([gt(test.a.a, 5)]))->IndexReader(Index(t.g)[[NULL,+inf]])}(test.a.g,test.b.g)->Projection->StreamAgg->Limit->Projection", - }, - { - sql: "select sum(d) from t", - best: "TableReader(Table(t)->StreamAgg)->StreamAgg", - }, + store.Close() + }() + se, err := session.CreateSession4Test(store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "use test") + c.Assert(err, IsNil) + + var input []string + var output []struct { + SQL string + Best string } - for _, tt := range tests { - comment := Commentf("for %s", tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("case:%v sql:%s", i, tt) + stmt, err := s.ParseOneStmt(tt, "", "") + c.Assert(err, IsNil, comment) + + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) + c.Assert(err, IsNil) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, comment) + } +} + +func (s *testPlanSuite) TestDAGPlanBuilderUnionScan(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + se, err := session.CreateSession4Test(store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "use test") + c.Assert(err, IsNil) + + var input []string + var output []struct { + SQL string + Best string + } + for i, tt := range input { + comment := Commentf("for %s", tt) + stmt, err := s.ParseOneStmt(tt, "", "") + c.Assert(err, IsNil, comment) + + err = se.NewTxn(context.Background()) + c.Assert(err, IsNil) + // Make txn not read only. + txn, err := se.Txn(true) + c.Assert(err, IsNil) + txn.Set(kv.Key("AAA"), []byte("BBB")) + c.Assert(se.StmtCommit(), IsNil) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) + c.Assert(err, IsNil) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, Commentf("for %s", tt)) + } +} + +func (s *testPlanSuite) TestDAGPlanBuilderAgg(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + se, err := session.CreateSession4Test(store) + c.Assert(err, IsNil) + se.Execute(context.Background(), "use test") + se.Execute(context.Background(), "set sql_mode='STRICT_TRANS_TABLES'") // disable only full group by + c.Assert(err, IsNil) + + var input []string + var output []struct { + SQL string + Best string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("for %s", tt) + stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, Commentf("for %s", tt)) } } @@ -988,228 +353,29 @@ func (s *testPlanSuite) TestRefine(c *C) { _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) - tests := []struct { - sql string - best string - }{ - { - sql: "select a from t where c is not null", - best: "IndexReader(Index(t.c_d_e)[[-inf,+inf]])->Projection", - }, - { - sql: "select a from t where c >= 4", - best: "IndexReader(Index(t.c_d_e)[[4,+inf]])->Projection", - }, - { - sql: "select a from t where c <= 4", - best: "IndexReader(Index(t.c_d_e)[[-inf,4]])->Projection", - }, - { - sql: "select a from t where c = 4 and d = 5 and e = 6", - best: "IndexReader(Index(t.c_d_e)[[4 5 6,4 5 6]])->Projection", - }, - { - sql: "select a from t where d = 4 and c = 5", - best: "IndexReader(Index(t.c_d_e)[[5 4,5 4]])->Projection", - }, - { - sql: "select a from t where c = 4 and e < 5", - best: "IndexReader(Index(t.c_d_e)[[4,4]]->Sel([lt(test.t.e, 5)]))->Projection", - }, - { - sql: "select a from t where c = 4 and d <= 5 and d > 3", - best: "IndexReader(Index(t.c_d_e)[(4 3,4 5]])->Projection", - }, - { - sql: "select a from t where d <= 5 and d > 3", - best: "TableReader(Table(t)->Sel([le(test.t.d, 5) gt(test.t.d, 3)]))->Projection", - }, - { - sql: "select a from t where c between 1 and 2", - best: "IndexReader(Index(t.c_d_e)[[1,2]])->Projection", - }, - { - sql: "select a from t where c not between 1 and 2", - best: "IndexReader(Index(t.c_d_e)[[-inf,1) (2,+inf]])->Projection", - }, - { - sql: "select a from t where c <= 5 and c >= 3 and d = 1", - best: "IndexReader(Index(t.c_d_e)[[3,5]]->Sel([eq(test.t.d, 1)]))->Projection", - }, - { - sql: "select a from t where c = 1 or c = 2 or c = 3", - best: "IndexReader(Index(t.c_d_e)[[1,3]])->Projection", - }, - { - sql: "select b from t where c = 1 or c = 2 or c = 3 or c = 4 or c = 5", - best: "IndexLookUp(Index(t.c_d_e)[[1,5]], Table(t))->Projection", - }, - { - sql: "select a from t where c = 5", - best: "IndexReader(Index(t.c_d_e)[[5,5]])->Projection", - }, - { - sql: "select a from t where c = 5 and b = 1", - best: "IndexLookUp(Index(t.c_d_e)[[5,5]], Table(t)->Sel([eq(test.t.b, 1)]))->Projection", - }, - { - sql: "select a from t where not a", - best: "TableReader(Table(t)->Sel([not(test.t.a)]))", - }, - { - sql: "select a from t where c in (1)", - best: "IndexReader(Index(t.c_d_e)[[1,1]])->Projection", - }, - { - sql: "select a from t where c in ('1')", - best: "IndexReader(Index(t.c_d_e)[[1,1]])->Projection", - }, - { - sql: "select a from t where c = 1.0", - best: "IndexReader(Index(t.c_d_e)[[1,1]])->Projection", - }, - { - sql: "select a from t where c in (1) and d > 3", - best: "IndexReader(Index(t.c_d_e)[(1 3,1 +inf]])->Projection", - }, - { - sql: "select a from t where c in (1, 2, 3) and (d > 3 and d < 4 or d > 5 and d < 6)", - best: "Dual->Projection", - }, - { - sql: "select a from t where c in (1, 2, 3) and (d > 2 and d < 4 or d > 5 and d < 7)", - best: "IndexReader(Index(t.c_d_e)[(1 2,1 4) (1 5,1 7) (2 2,2 4) (2 5,2 7) (3 2,3 4) (3 5,3 7)])->Projection", - }, - { - sql: "select a from t where c in (1, 2, 3)", - best: "IndexReader(Index(t.c_d_e)[[1,1] [2,2] [3,3]])->Projection", - }, - { - sql: "select a from t where c in (1, 2, 3) and d in (1,2) and e = 1", - best: "IndexReader(Index(t.c_d_e)[[1 1 1,1 1 1] [1 2 1,1 2 1] [2 1 1,2 1 1] [2 2 1,2 2 1] [3 1 1,3 1 1] [3 2 1,3 2 1]])->Projection", - }, - { - sql: "select a from t where d in (1, 2, 3)", - best: "TableReader(Table(t)->Sel([in(test.t.d, 1, 2, 3)]))->Projection", - }, - { - sql: "select a from t where c not in (1)", - best: "IndexReader(Index(t.c_d_e)[(NULL,1) (1,+inf]])->Projection", - }, - // test like - { - sql: "select a from t use index(c_d_e) where c != 1", - best: "IndexReader(Index(t.c_d_e)[[-inf,1) (1,+inf]])->Projection", - }, - { - sql: "select a from t where c_str like ''", - best: `IndexReader(Index(t.c_d_e_str)[["",""]])->Projection`, - }, - { - sql: "select a from t where c_str like 'abc'", - best: `IndexReader(Index(t.c_d_e_str)[["abc","abc"]])->Projection`, - }, - { - sql: "select a from t where c_str not like 'abc'", - best: `IndexReader(Index(t.c_d_e_str)[[-inf,"abc") ("abc",+inf]])->Projection`, - }, - { - sql: "select a from t where not (c_str like 'abc' or c_str like 'abd')", - best: `IndexReader(Index(t.c_d_e_str)[[-inf,"abc") ("abc","abd") ("abd",+inf]])->Projection`, - }, - { - sql: "select a from t where c_str like '_abc'", - best: "TableReader(Table(t)->Sel([like(test.t.c_str, _abc, 92)]))->Projection", - }, - { - sql: `select a from t where c_str like 'abc%'`, - best: `IndexReader(Index(t.c_d_e_str)[["abc","abd")])->Projection`, - }, - { - sql: "select a from t where c_str like 'abc_'", - best: `IndexReader(Index(t.c_d_e_str)[("abc","abd")]->Sel([like(test.t.c_str, abc_, 92)]))->Projection`, - }, - { - sql: "select a from t where c_str like 'abc%af'", - best: `IndexReader(Index(t.c_d_e_str)[["abc","abd")]->Sel([like(test.t.c_str, abc%af, 92)]))->Projection`, - }, - { - sql: `select a from t where c_str like 'abc\\_' escape ''`, - best: `IndexReader(Index(t.c_d_e_str)[["abc_","abc_"]])->Projection`, - }, - { - sql: `select a from t where c_str like 'abc\\_'`, - best: `IndexReader(Index(t.c_d_e_str)[["abc_","abc_"]])->Projection`, - }, - { - sql: `select a from t where c_str like 'abc\\\\_'`, - best: "IndexReader(Index(t.c_d_e_str)[(\"abc\\\",\"abc]\")]->Sel([like(test.t.c_str, abc\\\\_, 92)]))->Projection", - }, - { - sql: `select a from t where c_str like 'abc\\_%'`, - best: "IndexReader(Index(t.c_d_e_str)[[\"abc_\",\"abc`\")])->Projection", - }, - { - sql: `select a from t where c_str like 'abc=_%' escape '='`, - best: "IndexReader(Index(t.c_d_e_str)[[\"abc_\",\"abc`\")])->Projection", - }, - { - sql: `select a from t where c_str like 'abc\\__'`, - best: "IndexReader(Index(t.c_d_e_str)[(\"abc_\",\"abc`\")]->Sel([like(test.t.c_str, abc\\__, 92)]))->Projection", - }, - { - // Check that 123 is converted to string '123'. index can be used. - sql: `select a from t where c_str like 123`, - best: "IndexReader(Index(t.c_d_e_str)[[\"123\",\"123\"]])->Projection", - }, - // c is type int which will be added cast to specified type when building function signature, - // and rewrite predicate like to predicate '=' when exact match , index still can be used. - { - sql: `select a from t where c like '1'`, - best: "IndexReader(Index(t.c_d_e)[[1,1]])->Projection", - }, - { - sql: `select a from t where c = 1.9 and d > 3`, - best: "Dual", - }, - { - sql: `select a from t where c < 1.1`, - best: "IndexReader(Index(t.c_d_e)[[-inf,2)])->Projection", - }, - { - sql: `select a from t where c <= 1.9`, - best: "IndexReader(Index(t.c_d_e)[[-inf,1]])->Projection", - }, - { - sql: `select a from t where c >= 1.1`, - best: "IndexReader(Index(t.c_d_e)[[2,+inf]])->Projection", - }, - { - sql: `select a from t where c > 1.9`, - best: "IndexReader(Index(t.c_d_e)[(1,+inf]])->Projection", - }, - { - sql: `select a from t where c = 123456789098765432101234`, - best: "Dual", - }, - { - sql: `select a from t where c = 'hanfei'`, - best: "TableReader(Table(t))->Sel([eq(cast(test.t.c), cast(hanfei))])->Projection", - }, + var input []string + var output []struct { + SQL string + Best string } - for _, tt := range tests { - comment := Commentf("for %s", tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("for %s", tt) + stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) sc := se.(sessionctx.Context).GetSessionVars().StmtCtx sc.IgnoreTruncate = false - p, err := planner.Optimize(se, stmt, s.is) - c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) + c.Assert(err, IsNil, comment) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, comment) } } -func (s *testPlanSuite) TestAggEliminater(c *C) { +func (s *testPlanSuite) TestAggEliminator(c *C) { defer testleak.AfterTest(c)() store, dom, err := newStoreWithBootstrap() c.Assert(err, IsNil) @@ -1222,61 +388,25 @@ func (s *testPlanSuite) TestAggEliminater(c *C) { _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) se.Execute(context.Background(), "set sql_mode='STRICT_TRANS_TABLES'") // disable only full group by - tests := []struct { - sql string - best string - }{ - // Max to Limit + Sort-Desc. - { - sql: "select max(a) from t;", - best: "TableReader(Table(t)->Limit)->Limit->StreamAgg", - }, - // Min to Limit + Sort. - { - sql: "select min(a) from t;", - best: "TableReader(Table(t)->Limit)->Limit->StreamAgg", - }, - // Min to Limit + Sort, and isnull() should be added. - { - sql: "select min(c_str) from t;", - best: "IndexReader(Index(t.c_d_e_str)[[-inf,+inf]]->Limit)->Limit->StreamAgg", - }, - // Do nothing to max + firstrow. - { - sql: "select max(a), b from t;", - best: "TableReader(Table(t)->StreamAgg)->StreamAgg", - }, - // If max/min contains scalar function, we can still do transformation. - { - sql: "select max(a+1) from t;", - best: "TableReader(Table(t)->Sel([not(isnull(plus(test.t.a, 1)))])->TopN([plus(test.t.a, 1) true],0,1))->Projection->TopN([col_1 true],0,1)->Projection->Projection->StreamAgg", - }, - // Do nothing to max+min. - { - sql: "select max(a), min(a) from t;", - best: "TableReader(Table(t)->StreamAgg)->StreamAgg", - }, - // Do nothing to max with groupby. - { - sql: "select max(a) from t group by b;", - best: "TableReader(Table(t)->HashAgg)->HashAgg", - }, - // If inner is not a data source, we can still do transformation. - { - sql: "select max(a) from (select t1.a from t t1 join t t2 on t1.a=t2.a) t", - best: "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->Limit->StreamAgg", - }, + var input []string + var output []struct { + SQL string + Best string } - - for _, tt := range tests { - comment := Commentf("for %s", tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("for %s", tt) + stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) sc := se.(sessionctx.Context).GetSessionVars().StmtCtx sc.IgnoreTruncate = false - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, Commentf("for %s", tt)) } } @@ -1311,7 +441,7 @@ func (s *testPlanSuite) TestRequestTypeSupportedOff(c *C) { stmt, err := s.ParseOneStmt(sql, "", "") c.Assert(err, IsNil) - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) c.Assert(core.ToString(p), Equals, expect, Commentf("for %s", sql)) } @@ -1329,53 +459,15 @@ func (s *testPlanSuite) TestIndexJoinUnionScan(c *C) { _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) - definitions := []model.PartitionDefinition{ - { - ID: 41, - Name: model.NewCIStr("p1"), - LessThan: []string{"16"}, - }, - { - ID: 42, - Name: model.NewCIStr("p2"), - LessThan: []string{"32"}, - }, - } - pis := core.MockPartitionInfoSchema(definitions) - - tests := []struct { - sql string - best string - is infoschema.InfoSchema - }{ - // Test Index Join + UnionScan + TableScan. - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.a", - best: "IndexJoin{TableReader(Table(t))->UnionScan([])->TableReader(Table(t))->UnionScan([])}(test.t1.a,test.t2.a)", - is: s.is, - }, - // Test Index Join + UnionScan + DoubleRead. - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.c", - best: "IndexJoin{TableReader(Table(t))->UnionScan([])->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->UnionScan([])}(test.t1.a,test.t2.c)", - is: s.is, - }, - // Test Index Join + UnionScan + IndexScan. - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ t1.a , t2.c from t t1, t t2 where t1.a = t2.c", - best: "IndexJoin{TableReader(Table(t))->UnionScan([])->IndexReader(Index(t.c_d_e)[[NULL,+inf]])->UnionScan([])}(test.t1.a,test.t2.c)->Projection", - is: s.is, - }, - // Index Join + Union Scan + Union All is not supported now. - { - sql: "select /*+ TIDB_INLJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.a", - best: "LeftHashJoin{UnionAll{TableReader(Table(t))->UnionScan([])->TableReader(Table(t))->UnionScan([])}->UnionAll{TableReader(Table(t))->UnionScan([])->TableReader(Table(t))->UnionScan([])}}(test.t1.a,test.t2.a)", - is: pis, - }, + var input []string + var output []struct { + SQL string + Best string } - for i, tt := range tests { - comment := Commentf("case:%v sql:%s", i, tt.sql) - stmt, err := s.ParseOneStmt(tt.sql, "", "") + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + comment := Commentf("case:%v sql:%s", i, tt) + stmt, err := s.ParseOneStmt(tt, "", "") c.Assert(err, IsNil, comment) err = se.NewTxn(context.Background()) c.Assert(err, IsNil) @@ -1384,9 +476,13 @@ func (s *testPlanSuite) TestIndexJoinUnionScan(c *C) { c.Assert(err, IsNil) txn.Set(kv.Key("AAA"), []byte("BBB")) c.Assert(se.StmtCommit(), IsNil) - p, err := planner.Optimize(se, stmt, tt.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil, comment) - c.Assert(core.ToString(p), Equals, tt.best, comment) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best, Commentf("for %s", tt)) } } @@ -1415,7 +511,7 @@ func (s *testPlanSuite) TestDoSubquery(c *C) { comment := Commentf("for %s", tt.sql) stmt, err := s.ParseOneStmt(tt.sql, "", "") c.Assert(err, IsNil, comment) - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) c.Assert(core.ToString(p), Equals, tt.best, comment) } @@ -1436,7 +532,7 @@ func (s *testPlanSuite) TestIndexLookupCartesianJoin(c *C) { sql := "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 join t t2" stmt, err := s.ParseOneStmt(sql, "", "") c.Assert(err, IsNil) - p, err := planner.Optimize(se, stmt, s.is) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) c.Assert(core.ToString(p), Equals, "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}") warnings := se.GetSessionVars().StmtCtx.GetWarnings() @@ -1457,12 +553,23 @@ func (s *testPlanSuite) TestSemiJoinToInner(c *C) { c.Assert(err, IsNil) _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) - sql := "select t1.a, (select count(t2.a) from t t2 where t2.g in (select t3.d from t t3 where t3.c = t1.a)) as agg_col from t t1;" - stmt, err := s.ParseOneStmt(sql, "", "") - c.Assert(err, IsNil) - p, err := planner.Optimize(se, stmt, s.is) - c.Assert(err, IsNil) - c.Assert(core.ToString(p), Equals, "Apply{TableReader(Table(t))->IndexJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]]->HashAgg)->HashAgg->IndexReader(Index(t.g)[[NULL,+inf]])}(test.t3.d,test.t2.g)}->StreamAgg") + var input []string + var output []struct { + SQL string + Best string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + stmt, err := s.ParseOneStmt(tt, "", "") + c.Assert(err, IsNil) + p, err := planner.Optimize(context.TODO(), se, stmt, s.is) + c.Assert(err, IsNil) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Best = core.ToString(p) + }) + c.Assert(core.ToString(p), Equals, output[i].Best) + } } func (s *testPlanSuite) TestUnmatchedTableInHint(c *C) { @@ -1477,44 +584,80 @@ func (s *testPlanSuite) TestUnmatchedTableInHint(c *C) { c.Assert(err, IsNil) _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) - tests := []struct { - sql string - warning string - }{ - { - sql: "SELECT /*+ TIDB_SMJ(t3, t4) */ * from t t1, t t2 where t1.a = t2.a", - warning: "[planner:1815]There are no matching table names for (t3, t4) in optimizer hint /*+ TIDB_SMJ(t3, t4) */. Maybe you can use the table alias name", - }, - { - sql: "SELECT /*+ TIDB_HJ(t3, t4) */ * from t t1, t t2 where t1.a = t2.a", - warning: "[planner:1815]There are no matching table names for (t3, t4) in optimizer hint /*+ TIDB_HJ(t3, t4) */. Maybe you can use the table alias name", - }, - { - sql: "SELECT /*+ TIDB_INLJ(t3, t4) */ * from t t1, t t2 where t1.a = t2.a", - warning: "[planner:1815]There are no matching table names for (t3, t4) in optimizer hint /*+ TIDB_INLJ(t3, t4) */. Maybe you can use the table alias name", - }, - { - sql: "SELECT /*+ TIDB_SMJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.a", - warning: "", - }, - { - sql: "SELECT /*+ TIDB_SMJ(t3, t4) */ * from t t1, t t2, t t3 where t1.a = t2.a and t2.a = t3.a", - warning: "[planner:1815]There are no matching table names for (t4) in optimizer hint /*+ TIDB_SMJ(t3, t4) */. Maybe you can use the table alias name", - }, + var input []string + var output []struct { + SQL string + Warning string } - for _, test := range tests { + s.testData.GetTestCases(c, &input, &output) + for i, test := range input { se.GetSessionVars().StmtCtx.SetWarnings(nil) - stmt, err := s.ParseOneStmt(test.sql, "", "") + stmt, err := s.ParseOneStmt(test, "", "") c.Assert(err, IsNil) - _, err = planner.Optimize(se, stmt, s.is) + _, err = planner.Optimize(context.TODO(), se, stmt, s.is) c.Assert(err, IsNil) warnings := se.GetSessionVars().StmtCtx.GetWarnings() - if test.warning == "" { + s.testData.OnRecord(func() { + output[i].SQL = test + if len(warnings) > 0 { + output[i].Warning = warnings[0].Err.Error() + } + }) + if output[i].Warning == "" { c.Assert(len(warnings), Equals, 0) } else { c.Assert(len(warnings), Equals, 1) c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning) - c.Assert(warnings[0].Err.Error(), Equals, test.warning) + c.Assert(warnings[0].Err.Error(), Equals, output[i].Warning) + } + } +} + +func (s *testPlanSuite) TestIndexJoinHint(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + se, err := session.CreateSession4Test(store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "use test") + c.Assert(err, IsNil) + + var input []string + var output []struct { + SQL string + Best string + Warning string + } + s.testData.GetTestCases(c, &input, &output) + ctx := context.Background() + for i, test := range input { + comment := Commentf("case:%v sql:%s", i, test) + stmt, err := s.ParseOneStmt(test, "", "") + c.Assert(err, IsNil, comment) + + se.GetSessionVars().StmtCtx.SetWarnings(nil) + p, err := planner.Optimize(ctx, se, stmt, s.is) + c.Assert(err, IsNil) + warnings := se.GetSessionVars().StmtCtx.GetWarnings() + + s.testData.OnRecord(func() { + output[i].SQL = test + output[i].Best = core.ToString(p) + if len(warnings) > 0 { + output[i].Warning = warnings[0].Err.Error() + } + }) + c.Assert(core.ToString(p), Equals, output[i].Best) + if output[i].Warning == "" { + c.Assert(len(warnings), Equals, 0) + } else { + c.Assert(len(warnings), Equals, 1, Commentf("%v", warnings)) + c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning) + c.Assert(warnings[0].Err.Error(), Equals, output[i].Warning) } } } diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 32d257090270d..b49992f301197 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -72,6 +72,12 @@ type PhysicalIndexReader struct { OutputColumns []*expression.Column } +// PushedDownLimit is the limit operator pushed down into PhysicalIndexLookUpReader. +type PushedDownLimit struct { + Offset uint64 + Count uint64 +} + // PhysicalIndexLookUpReader is the index look up reader in tidb. It's used in case of double reading. type PhysicalIndexLookUpReader struct { physicalSchemaProducer @@ -82,6 +88,9 @@ type PhysicalIndexLookUpReader struct { TablePlans []PhysicalPlan indexPlan PhysicalPlan tablePlan PhysicalPlan + + // PushedLimit is used to avoid unnecessary table scan tasks of IndexLookUpReader. + PushedLimit *PushedDownLimit } // PhysicalIndexScan represents an index scan plan. @@ -119,6 +128,8 @@ type PhysicalIndexScan struct { // The index scan may be on a partition. isPartition bool physicalTableID int64 + + GenExprs map[model.TableColumnID]expression.Expression } // PhysicalMemTable reads memory table. @@ -160,6 +171,9 @@ type PhysicalTableScan struct { physicalTableID int64 rangeDecidedBy []*expression.Column + + // HandleIdx is the index of handle, which is only used for admin check table. + HandleIdx int } // IsPartition returns true and partition ID if it's actually a partition. @@ -232,6 +246,8 @@ type PhysicalIndexJoin struct { Ranges []*ranger.Range // KeyOff2IdxOff maps the offsets in join key to the offsets in the index. KeyOff2IdxOff []int + // IdxColLens stores the length of each index column. + IdxColLens []int // CompareFilters stores the filters for last column if those filters need to be evaluated during execution. // e.g. select * from t where t.a = t1.a and t.b > t1.b and t.b < t1.b+10 // If there's index(t.a, t.b). All the filters can be used to construct index range but t.b > t1.b and t.b < t1.b=10 @@ -378,6 +394,10 @@ type PhysicalTableDual struct { physicalSchemaProducer RowCount int + // placeHolder indicates if this dual plan is a place holder in query optimization + // for data sources like `Show`, if true, the dual plan would be substituted by + // `Show` in the final plan. + placeHolder bool } // PhysicalWindow is the physical operator of window function. diff --git a/planner/core/plan.go b/planner/core/plan.go index a141115e36140..35d8a96f7d306 100644 --- a/planner/core/plan.go +++ b/planner/core/plan.go @@ -35,6 +35,10 @@ type Plan interface { Schema() *expression.Schema // Get the ID. ID() int + + // TP get the plan type. + TP() string + // Get the ID in explain statement ExplainID() fmt.Stringer // replaceExprColumns replace all the column reference in the plan's expression node. @@ -113,6 +117,9 @@ type LogicalPlan interface { // SetChildren sets the children for the plan. SetChildren(...LogicalPlan) + + // SetChild sets the ith child for the plan. + SetChild(i int, child LogicalPlan) } // PhysicalPlan is a tree of the physical operators. @@ -141,6 +148,9 @@ type PhysicalPlan interface { // SetChildren sets the children for the plan. SetChildren(...PhysicalPlan) + // SetChild sets the ith child for the plan. + SetChild(i int, child PhysicalPlan) + // ResolveIndices resolves the indices for columns. After doing this, the columns can evaluate the rows by their indices. ResolveIndices() error } @@ -266,6 +276,11 @@ func (p *basePlan) ExplainID() fmt.Stringer { }) } +// TP implements Plan interface. +func (p *basePlan) TP() string { + return p.tp +} + // Schema implements Plan Schema interface. func (p *baseLogicalPlan) Schema() *expression.Schema { return p.children[0].Schema() @@ -296,6 +311,16 @@ func (p *basePhysicalPlan) SetChildren(children ...PhysicalPlan) { p.children = children } +// SetChild implements LogicalPlan SetChild interface. +func (p *baseLogicalPlan) SetChild(i int, child LogicalPlan) { + p.children[i] = child +} + +// SetChild implements PhysicalPlan SetChild interface. +func (p *basePhysicalPlan) SetChild(i int, child PhysicalPlan) { + p.children[i] = child +} + func (p *basePlan) context() sessionctx.Context { return p.ctx } diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index be8b3843f59ba..479c1c2c99a61 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -15,6 +15,7 @@ package core import ( "bytes" + "context" "fmt" "strings" @@ -26,6 +27,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" @@ -36,7 +38,10 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" + "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/ranger" + "go.uber.org/zap" ) type visitInfo struct { @@ -140,7 +145,6 @@ const ( onClause orderByClause whereClause - windowClause groupByClause showStatement globalOrderByClause @@ -156,9 +160,17 @@ var clauseMsg = map[clauseCode]string{ groupByClause: "group statement", showStatement: "show statement", globalOrderByClause: "global ORDER clause", - windowClause: "field list", // For window functions that in field list. } +type capFlagType = uint64 + +const ( + _ capFlagType = iota + // canExpandAST indicates whether the origin AST can be expanded during plan + // building. ONLY used for `CreateViewStmt` now. + canExpandAST +) + // PlanBuilder builds Plan from an ast.Node. // It just builds the ast node straightforwardly. type PlanBuilder struct { @@ -171,7 +183,10 @@ type PlanBuilder struct { // visitInfo is used for privilege check. visitInfo []visitInfo tableHintInfo []tableHintInfo - optFlag uint64 + // optFlag indicates the flags of the optimizer rules. + optFlag uint64 + // capFlag indicates the capability flags. + capFlag capFlagType curClause clauseCode @@ -227,43 +242,43 @@ func NewPlanBuilder(sctx sessionctx.Context, is infoschema.InfoSchema) *PlanBuil } // Build builds the ast node to a Plan. -func (b *PlanBuilder) Build(node ast.Node) (Plan, error) { +func (b *PlanBuilder) Build(ctx context.Context, node ast.Node) (Plan, error) { b.optFlag = flagPrunColumns switch x := node.(type) { case *ast.AdminStmt: - return b.buildAdmin(x) + return b.buildAdmin(ctx, x) case *ast.DeallocateStmt: return &Deallocate{Name: x.Name}, nil case *ast.DeleteStmt: - return b.buildDelete(x) + return b.buildDelete(ctx, x) case *ast.ExecuteStmt: - return b.buildExecute(x) + return b.buildExecute(ctx, x) case *ast.ExplainStmt: - return b.buildExplain(x) + return b.buildExplain(ctx, x) case *ast.ExplainForStmt: return b.buildExplainFor(x) case *ast.TraceStmt: return b.buildTrace(x) case *ast.InsertStmt: - return b.buildInsert(x) + return b.buildInsert(ctx, x) case *ast.LoadDataStmt: - return b.buildLoadData(x) + return b.buildLoadData(ctx, x) case *ast.LoadStatsStmt: return b.buildLoadStats(x), nil case *ast.PrepareStmt: return b.buildPrepare(x), nil case *ast.SelectStmt: - return b.buildSelect(x) + return b.buildSelect(ctx, x) case *ast.UnionStmt: - return b.buildUnion(x) + return b.buildUnion(ctx, x) case *ast.UpdateStmt: - return b.buildUpdate(x) + return b.buildUpdate(ctx, x) case *ast.ShowStmt: - return b.buildShow(x) + return b.buildShow(ctx, x) case *ast.DoStmt: - return b.buildDo(x) + return b.buildDo(ctx, x) case *ast.SetStmt: - return b.buildSet(x) + return b.buildSet(ctx, x) case *ast.AnalyzeTableStmt: return b.buildAnalyze(x) case *ast.BinlogStmt, *ast.FlushStmt, *ast.UseStmt, @@ -272,15 +287,15 @@ func (b *PlanBuilder) Build(node ast.Node) (Plan, error) { *ast.GrantRoleStmt, *ast.RevokeRoleStmt, *ast.SetRoleStmt, *ast.SetDefaultRoleStmt: return b.buildSimple(node.(ast.StmtNode)) case ast.DDLNode: - return b.buildDDL(x) + return b.buildDDL(ctx, x) case *ast.CreateBindingStmt: return b.buildCreateBindPlan(x) case *ast.DropBindingStmt: return b.buildDropBindPlan(x) case *ast.ChangeStmt: return b.buildChange(x) - case *ast.SplitIndexRegionStmt: - return b.buildSplitIndexRegion(x) + case *ast.SplitRegionStmt: + return b.buildSplitRegion(x) } return nil, ErrUnsupportedType.GenWithStack("Unsupported type %T", node) } @@ -292,10 +307,10 @@ func (b *PlanBuilder) buildChange(v *ast.ChangeStmt) (Plan, error) { return exe, nil } -func (b *PlanBuilder) buildExecute(v *ast.ExecuteStmt) (Plan, error) { +func (b *PlanBuilder) buildExecute(ctx context.Context, v *ast.ExecuteStmt) (Plan, error) { vars := make([]expression.Expression, 0, len(v.UsingVars)) for _, expr := range v.UsingVars { - newExpr, _, err := b.rewrite(expr, nil, nil, true) + newExpr, _, err := b.rewrite(ctx, expr, nil, nil, true) if err != nil { return nil, err } @@ -305,7 +320,7 @@ func (b *PlanBuilder) buildExecute(v *ast.ExecuteStmt) (Plan, error) { return exe, nil } -func (b *PlanBuilder) buildDo(v *ast.DoStmt) (Plan, error) { +func (b *PlanBuilder) buildDo(ctx context.Context, v *ast.DoStmt) (Plan, error) { var p LogicalPlan dual := LogicalTableDual{RowCount: 1}.Init(b.ctx) dual.SetSchema(expression.NewSchema()) @@ -313,7 +328,7 @@ func (b *PlanBuilder) buildDo(v *ast.DoStmt) (Plan, error) { proj := LogicalProjection{Exprs: make([]expression.Expression, 0, len(v.Exprs))}.Init(b.ctx) schema := expression.NewSchema(make([]*expression.Column, 0, len(v.Exprs))...) for _, astExpr := range v.Exprs { - expr, np, err := b.rewrite(astExpr, p, nil, true) + expr, np, err := b.rewrite(ctx, astExpr, p, nil, true) if err != nil { return nil, err } @@ -331,7 +346,7 @@ func (b *PlanBuilder) buildDo(v *ast.DoStmt) (Plan, error) { return proj, nil } -func (b *PlanBuilder) buildSet(v *ast.SetStmt) (Plan, error) { +func (b *PlanBuilder) buildSet(ctx context.Context, v *ast.SetStmt) (Plan, error) { p := &Set{} for _, vars := range v.Variables { if vars.IsGlobal { @@ -350,7 +365,7 @@ func (b *PlanBuilder) buildSet(v *ast.SetStmt) (Plan, error) { } mockTablePlan := LogicalTableDual{}.Init(b.ctx) var err error - assign.Expr, _, err = b.rewrite(vars.Value, mockTablePlan, nil, true) + assign.Expr, _, err = b.rewrite(ctx, vars.Value, mockTablePlan, nil, true) if err != nil { return nil, err } @@ -473,6 +488,17 @@ func getPossibleAccessPaths(indexHints []*ast.IndexHint, tblInfo *model.TableInf } hasScanHint = true + + // It is syntactically valid to omit index_list for USE INDEX, which means “use no indexes”. + // Omitting index_list for FORCE INDEX or IGNORE INDEX is a syntax error. + // See https://dev.mysql.com/doc/refman/8.0/en/index-hints.html. + if hint.IndexNames == nil && hint.HintType != ast.HintIgnore { + if path := getTablePath(publicPaths); path != nil { + hasUseOrForce = true + path.forced = true + available = append(available, path) + } + } for _, idxName := range hint.IndexNames { path := getPathByIndexName(publicPaths, idxName, tblInfo) if path == nil { @@ -540,7 +566,7 @@ func (b *PlanBuilder) buildPrepare(x *ast.PrepareStmt) Plan { return p } -func (b *PlanBuilder) buildCheckIndex(dbName model.CIStr, as *ast.AdminStmt) (Plan, error) { +func (b *PlanBuilder) buildCheckIndex(ctx context.Context, dbName model.CIStr, as *ast.AdminStmt) (Plan, error) { tblName := as.Tables[0] tbl, err := b.is.TableByName(dbName, tblName.Name) if err != nil { @@ -563,56 +589,21 @@ func (b *PlanBuilder) buildCheckIndex(dbName model.CIStr, as *ast.AdminStmt) (Pl return nil, errors.Errorf("index %s state %s isn't public", as.Index, idx.State) } - id := 1 - columns := make([]*model.ColumnInfo, 0, len(idx.Columns)) - schema := expression.NewSchema(make([]*expression.Column, 0, len(idx.Columns))...) - for _, idxCol := range idx.Columns { - for _, col := range tblInfo.Columns { - if idxCol.Name.L == col.Name.L { - columns = append(columns, col) - schema.Append(&expression.Column{ - ColName: col.Name, - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: &col.FieldType, - }) - } - } - } - is := PhysicalIndexScan{ - Table: tblInfo, - TableAsName: &tblName.Name, - DBName: dbName, - Columns: columns, - Index: idx, - dataSourceSchema: schema, - Ranges: ranger.FullRange(), - KeepOrder: false, - }.Init(b.ctx) - is.stats = &property.StatsInfo{} - cop := &copTask{indexPlan: is} - // It's double read case. - ts := PhysicalTableScan{Columns: columns, Table: is.Table}.Init(b.ctx) - ts.SetSchema(is.dataSourceSchema) - cop.tablePlan = ts - is.initSchema(id, idx, true) - t := finishCopTask(b.ctx, cop) - - rootT := t.(*rootTask) - return rootT.p, nil + return b.buildPhysicalIndexLookUpReader(ctx, dbName, tbl, idx) } -func (b *PlanBuilder) buildAdmin(as *ast.AdminStmt) (Plan, error) { +func (b *PlanBuilder) buildAdmin(ctx context.Context, as *ast.AdminStmt) (Plan, error) { var ret Plan var err error switch as.Tp { case ast.AdminCheckTable: - ret, err = b.buildAdminCheckTable(as) + ret, err = b.buildAdminCheckTable(ctx, as) if err != nil { return ret, err } case ast.AdminCheckIndex: dbName := as.Tables[0].Schema - readerPlan, err := b.buildCheckIndex(dbName, as) + readerPlan, err := b.buildCheckIndex(ctx, dbName, as) if err != nil { return ret, err } @@ -667,6 +658,14 @@ func (b *PlanBuilder) buildAdmin(as *ast.AdminStmt) (Plan, error) { p := &ShowSlow{ShowSlow: as.ShowSlow} p.SetSchema(buildShowSlowSchema()) ret = p + case ast.AdminReloadExprPushdownBlacklist: + return &ReloadExprPushdownBlacklist{}, nil + case ast.AdminReloadOptRuleBlacklist: + return &ReloadOptRuleBlacklist{}, nil + case ast.AdminPluginEnable: + return &AdminPlugins{Action: Enable, Plugins: as.Plugins}, nil + case ast.AdminPluginDisable: + return &AdminPlugins{Action: Disable, Plugins: as.Plugins}, nil default: return nil, ErrUnsupportedType.GenWithStack("Unsupported ast.AdminStmt(%T) for buildAdmin", as) } @@ -676,43 +675,208 @@ func (b *PlanBuilder) buildAdmin(as *ast.AdminStmt) (Plan, error) { return ret, nil } -func (b *PlanBuilder) buildAdminCheckTable(as *ast.AdminStmt) (*CheckTable, error) { - p := &CheckTable{Tables: as.Tables} - p.GenExprs = make(map[model.TableColumnID]expression.Expression, len(p.Tables)) - +// getGenExprs gets generated expressions map. +func (b *PlanBuilder) getGenExprs(ctx context.Context, dbName model.CIStr, tbl table.Table, idx *model.IndexInfo) ( + map[model.TableColumnID]expression.Expression, error) { + tblInfo := tbl.Meta() + genExprsMap := make(map[model.TableColumnID]expression.Expression) + exprs := make([]expression.Expression, 0, len(tbl.Cols())) + genExprIdxs := make([]model.TableColumnID, len(tbl.Cols())) mockTablePlan := LogicalTableDual{}.Init(b.ctx) - for _, tbl := range p.Tables { - tableInfo := tbl.TableInfo - schema := expression.TableInfo2SchemaWithDBName(b.ctx, tbl.Schema, tableInfo) - table, ok := b.is.TableByID(tableInfo.ID) - if !ok { - return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(tbl.DBInfo.Name.O, tableInfo.Name.O) + mockTablePlan.SetSchema(expression.TableInfo2SchemaWithDBName(b.ctx, dbName, tblInfo)) + for i, colExpr := range mockTablePlan.Schema().Columns { + col := tbl.Cols()[i] + var expr expression.Expression + expr = colExpr + if col.IsGenerated() && !col.GeneratedStored { + var err error + expr, _, err = b.rewrite(ctx, col.GeneratedExpr, mockTablePlan, nil, true) + if err != nil { + return nil, errors.Trace(err) + } + expr = expression.BuildCastFunction(b.ctx, expr, colExpr.GetType()) + found := false + for _, column := range idx.Columns { + if strings.EqualFold(col.Name.L, column.Name.L) { + found = true + break + } + } + if found { + genColumnID := model.TableColumnID{TableID: tblInfo.ID, ColumnID: col.ColumnInfo.ID} + genExprsMap[genColumnID] = expr + genExprIdxs[i] = genColumnID + } } + exprs = append(exprs, expr) + } + // Re-iterate expressions to handle those virtual generated columns that refers to the other generated columns. + for i, expr := range exprs { + exprs[i] = expression.ColumnSubstitute(expr, mockTablePlan.Schema(), exprs) + if _, ok := genExprsMap[genExprIdxs[i]]; ok { + genExprsMap[genExprIdxs[i]] = exprs[i] + } + } + return genExprsMap, nil +} - mockTablePlan.SetSchema(schema) - - // Calculate generated columns. - columns := table.Cols() - for _, column := range columns { - if !column.IsGenerated() { - continue +func (b *PlanBuilder) buildPhysicalIndexLookUpReader(ctx context.Context, dbName model.CIStr, tbl table.Table, idx *model.IndexInfo) (Plan, error) { + // Get generated columns. + var genCols []*expression.Column + pkOffset := -1 + tblInfo := tbl.Meta() + colsMap := make(map[int64]struct{}) + schema := expression.NewSchema(make([]*expression.Column, 0, len(idx.Columns))...) + idxReaderCols := make([]*model.ColumnInfo, 0, len(idx.Columns)) + tblReaderCols := make([]*model.ColumnInfo, 0, len(tbl.Cols())) + genExprsMap, err := b.getGenExprs(ctx, dbName, tbl, idx) + if err != nil { + return nil, errors.Trace(err) + } + for _, idxCol := range idx.Columns { + for _, col := range tblInfo.Columns { + if idxCol.Name.L == col.Name.L { + idxReaderCols = append(idxReaderCols, col) + tblReaderCols = append(tblReaderCols, col) + schema.Append(&expression.Column{ + ColName: col.Name, + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: &col.FieldType}) + colsMap[col.ID] = struct{}{} + if mysql.HasPriKeyFlag(col.Flag) { + pkOffset = len(tblReaderCols) - 1 + } } - columnName := &ast.ColumnName{Name: column.Name} - columnName.SetText(column.Name.O) - - colExpr, _, err := mockTablePlan.findColumn(columnName) - if err != nil { - return nil, err + genColumnID := model.TableColumnID{TableID: tblInfo.ID, ColumnID: col.ID} + if expr, ok := genExprsMap[genColumnID]; ok { + cols := expression.ExtractColumns(expr) + genCols = append(genCols, cols...) } + } + } + // Add generated columns to tblSchema and tblReaderCols. + tblSchema := schema.Clone() + for _, col := range genCols { + if _, ok := colsMap[col.ID]; !ok { + c := table.FindCol(tbl.Cols(), col.ColName.O) + if c != nil { + col.Index = len(tblReaderCols) + tblReaderCols = append(tblReaderCols, c.ColumnInfo) + tblSchema.Append(&expression.Column{ + ColName: c.Name, + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: &c.FieldType}) + colsMap[c.ID] = struct{}{} + if mysql.HasPriKeyFlag(c.Flag) { + pkOffset = len(tblReaderCols) - 1 + } + } + } + } + if !tbl.Meta().PKIsHandle || pkOffset == -1 { + tblReaderCols = append(tblReaderCols, model.NewExtraHandleColInfo()) + handleCol := &expression.Column{ + DBName: dbName, + TblName: tblInfo.Name, + ColName: model.ExtraHandleName, + RetType: types.NewFieldType(mysql.TypeLonglong), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + ID: model.ExtraHandleID, + } + tblSchema.Append(handleCol) + pkOffset = len(tblReaderCols) - 1 + } - expr, _, err := b.rewrite(column.GeneratedExpr, mockTablePlan, nil, true) - if err != nil { - return nil, err + is := PhysicalIndexScan{ + Table: tblInfo, + TableAsName: &tblInfo.Name, + DBName: dbName, + Columns: idxReaderCols, + Index: idx, + dataSourceSchema: schema, + Ranges: ranger.FullRange(), + GenExprs: genExprsMap, + }.Init(b.ctx) + is.stats = property.NewSimpleStats(0) + // It's double read case. + ts := PhysicalTableScan{Columns: tblReaderCols, Table: is.Table}.Init(b.ctx) + ts.SetSchema(tblSchema) + if tbl.Meta().GetPartitionInfo() != nil { + pid := tbl.(table.PhysicalTable).GetPhysicalID() + is.physicalTableID = pid + is.isPartition = true + ts.physicalTableID = pid + ts.isPartition = true + } + cop := &copTask{ + indexPlan: is, + tablePlan: ts, + } + ts.HandleIdx = pkOffset + is.initSchema(idx, true) + rootT := finishCopTask(b.ctx, cop).(*rootTask) + return rootT.p, nil +} + +func (b *PlanBuilder) buildPhysicalIndexLookUpReaders(ctx context.Context, dbName model.CIStr, tbl table.Table) ([]Plan, []*model.IndexInfo, error) { + tblInfo := tbl.Meta() + // get index information + indexInfos := make([]*model.IndexInfo, 0, len(tblInfo.Indices)) + indexLookUpReaders := make([]Plan, 0, len(tblInfo.Indices)) + for _, idx := range tbl.Indices() { + idxInfo := idx.Meta() + if idxInfo.State != model.StatePublic { + logutil.Logger(context.Background()).Info("build physical index lookup reader, the index isn't public", + zap.String("index", idxInfo.Name.O), zap.Stringer("state", idxInfo.State), zap.String("table", tblInfo.Name.O)) + continue + } + indexInfos = append(indexInfos, idxInfo) + // For partition tables. + if pi := tbl.Meta().GetPartitionInfo(); pi != nil { + for _, def := range pi.Definitions { + t := tbl.(table.PartitionedTable).GetPartition(def.ID) + reader, err := b.buildPhysicalIndexLookUpReader(ctx, dbName, t, idxInfo) + if err != nil { + return nil, nil, err + } + indexLookUpReaders = append(indexLookUpReaders, reader) } - expr = expression.BuildCastFunction(b.ctx, expr, colExpr.GetType()) - p.GenExprs[model.TableColumnID{TableID: tableInfo.ID, ColumnID: column.ColumnInfo.ID}] = expr + continue + } + // For non-partition tables. + reader, err := b.buildPhysicalIndexLookUpReader(ctx, dbName, tbl, idxInfo) + if err != nil { + return nil, nil, err } + indexLookUpReaders = append(indexLookUpReaders, reader) + } + if len(indexLookUpReaders) == 0 { + return nil, nil, nil + } + return indexLookUpReaders, indexInfos, nil +} + +func (b *PlanBuilder) buildAdminCheckTable(ctx context.Context, as *ast.AdminStmt) (*CheckTable, error) { + tbl := as.Tables[0] + tableInfo := as.Tables[0].TableInfo + table, ok := b.is.TableByID(tableInfo.ID) + if !ok { + return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(tbl.DBInfo.Name.O, tableInfo.Name.O) + } + p := &CheckTable{ + DBName: tbl.Schema.O, + Table: table, + } + readerPlans, indexInfos, err := b.buildPhysicalIndexLookUpReaders(ctx, tbl.Schema, table) + if err != nil { + return nil, errors.Trace(err) } + readers := make([]*PhysicalIndexLookUpReader, 0, len(readerPlans)) + for _, plan := range readerPlans { + readers = append(readers, plan.(*PhysicalIndexLookUpReader)) + } + p.IndexInfos = indexInfos + p.IndexLookUpReaders = readers return p, nil } @@ -754,6 +918,10 @@ func (b *PlanBuilder) buildCheckIndexSchema(tn *ast.TableName, indexName string) func getColsInfo(tn *ast.TableName) (indicesInfo []*model.IndexInfo, colsInfo []*model.ColumnInfo, pkCol *model.ColumnInfo) { tbl := tn.TableInfo for _, col := range tbl.Columns { + // The virtual column will not store any data in TiKV, so it should be ignored when collect statistics + if col.IsGenerated() && !col.GeneratedStored { + continue + } if tbl.PKIsHandle && mysql.HasPriKeyFlag(col.Flag) { pkCol = col } else { @@ -976,6 +1144,29 @@ func buildShowDDLJobsFields() *expression.Schema { return schema } +func buildTableRegionsSchema() *expression.Schema { + schema := expression.NewSchema(make([]*expression.Column, 0, 11)...) + schema.Append(buildColumn("", "REGION_ID", mysql.TypeLonglong, 4)) + schema.Append(buildColumn("", "START_KEY", mysql.TypeVarchar, 64)) + schema.Append(buildColumn("", "END_KEY", mysql.TypeVarchar, 64)) + schema.Append(buildColumn("", "LEADER_ID", mysql.TypeLonglong, 4)) + schema.Append(buildColumn("", "LEADER_STORE_ID", mysql.TypeLonglong, 4)) + schema.Append(buildColumn("", "PEERS", mysql.TypeVarchar, 64)) + schema.Append(buildColumn("", "SCATTERING", mysql.TypeTiny, 1)) + schema.Append(buildColumn("", "WRITTEN_BYTES", mysql.TypeLonglong, 4)) + schema.Append(buildColumn("", "READ_BYTES", mysql.TypeLonglong, 4)) + schema.Append(buildColumn("", "APPROXIMATE_SIZE(MB)", mysql.TypeLonglong, 4)) + schema.Append(buildColumn("", "APPROXIMATE_KEYS", mysql.TypeLonglong, 4)) + return schema +} + +func buildSplitRegionsSchema() *expression.Schema { + schema := expression.NewSchema(make([]*expression.Column, 0, 2)...) + schema.Append(buildColumn("", "TOTAL_SPLIT_REGION", mysql.TypeLonglong, 4)) + schema.Append(buildColumn("", "SCATTER_FINISH_RATIO", mysql.TypeDouble, 8)) + return schema +} + func buildShowDDLJobQueriesFields() *expression.Schema { schema := expression.NewSchema(make([]*expression.Column, 0, 1)...) schema.Append(buildColumn("", "QUERY", mysql.TypeVarchar, 256)) @@ -1057,12 +1248,13 @@ func splitWhere(where ast.ExprNode) []ast.ExprNode { return conditions } -func (b *PlanBuilder) buildShow(show *ast.ShowStmt) (Plan, error) { +func (b *PlanBuilder) buildShow(ctx context.Context, show *ast.ShowStmt) (Plan, error) { p := Show{ Tp: show.Tp, DBName: show.DBName, Table: show.Table, Column: show.Column, + IndexName: show.IndexName, Flag: show.Flag, Full: show.Full, User: show.User, @@ -1079,6 +1271,8 @@ func (b *PlanBuilder) buildShow(show *ast.ShowStmt) (Plan, error) { p.SetSchema(buildShowEventsSchema()) case ast.ShowWarnings, ast.ShowErrors: p.SetSchema(buildShowWarningsSchema()) + case ast.ShowRegions: + p.SetSchema(buildTableRegionsSchema()) default: isView := false switch showTp { @@ -1105,56 +1299,63 @@ func (b *PlanBuilder) buildShow(show *ast.ShowStmt) (Plan, error) { for _, col := range p.schema.Columns { col.UniqueID = b.ctx.GetSessionVars().AllocPlanColumnID() } - mockTablePlan := LogicalTableDual{}.Init(b.ctx) + mockTablePlan := LogicalTableDual{placeHolder: true}.Init(b.ctx) mockTablePlan.SetSchema(p.schema) + var err error + var np LogicalPlan + np = mockTablePlan if show.Pattern != nil { show.Pattern.Expr = &ast.ColumnNameExpr{ Name: &ast.ColumnName{Name: p.Schema().Columns[0].ColName}, } - expr, _, err := b.rewrite(show.Pattern, mockTablePlan, nil, false) + np, err = b.buildSelection(ctx, np, show.Pattern, nil) if err != nil { return nil, err } - p.Conditions = append(p.Conditions, expr) } if show.Where != nil { - conds := splitWhere(show.Where) - for _, cond := range conds { - expr, _, err := b.rewrite(cond, mockTablePlan, nil, false) - if err != nil { - return nil, err - } - p.Conditions = append(p.Conditions, expr) + np, err = b.buildSelection(ctx, np, show.Where, nil) + if err != nil { + return nil, err + } + } + if np != mockTablePlan { + fieldsLen := len(mockTablePlan.schema.Columns) + proj := LogicalProjection{Exprs: make([]expression.Expression, 0, fieldsLen)}.Init(b.ctx) + schema := expression.NewSchema(make([]*expression.Column, 0, fieldsLen)...) + for _, col := range mockTablePlan.schema.Columns { + proj.Exprs = append(proj.Exprs, col) + newCol := col.Clone().(*expression.Column) + newCol.UniqueID = b.ctx.GetSessionVars().AllocPlanColumnID() + schema.Append(newCol) } - err := p.ResolveIndices() + proj.SetSchema(schema) + proj.SetChildren(np) + physical, err := DoOptimize(ctx, b.optFlag|flagEliminateProjection, proj) if err != nil { return nil, err } + return substitutePlaceHolderDual(physical, p), nil } return p, nil } +func substitutePlaceHolderDual(src PhysicalPlan, dst PhysicalPlan) PhysicalPlan { + if dual, ok := src.(*PhysicalTableDual); ok && dual.placeHolder { + return dst + } + for i, child := range src.Children() { + newChild := substitutePlaceHolderDual(child, dst) + src.SetChild(i, newChild) + } + return src +} + func (b *PlanBuilder) buildSimple(node ast.StmtNode) (Plan, error) { p := &Simple{Statement: node} switch raw := node.(type) { - case *ast.CreateUserStmt: - if raw.IsCreateRole { - err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE ROLE") - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateRolePriv, "", "", "", err) - } else { - err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err) - } - case *ast.DropUserStmt: - if raw.IsDropRole { - err := ErrSpecificAccessDenied.GenWithStackByArgs("DROP ROLE") - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DropRolePriv, "", "", "", err) - } else { - err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err) - } - case *ast.AlterUserStmt, *ast.SetDefaultRoleStmt: + case *ast.AlterUserStmt: err := ErrSpecificAccessDenied.GenWithStackByArgs("CREATE USER") b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateUserPriv, "", "", "", err) case *ast.GrantStmt: @@ -1238,7 +1439,7 @@ func (b *PlanBuilder) findDefaultValue(cols []*table.Column, name *ast.ColumnNam // resolveGeneratedColumns resolves generated columns with their generation // expressions respectively. onDups indicates which columns are in on-duplicate list. -func (b *PlanBuilder) resolveGeneratedColumns(columns []*table.Column, onDups map[string]struct{}, mockPlan LogicalPlan) (igc InsertGeneratedColumns, err error) { +func (b *PlanBuilder) resolveGeneratedColumns(ctx context.Context, columns []*table.Column, onDups map[string]struct{}, mockPlan LogicalPlan) (igc InsertGeneratedColumns, err error) { for _, column := range columns { if !column.IsGenerated() { continue @@ -1251,7 +1452,7 @@ func (b *PlanBuilder) resolveGeneratedColumns(columns []*table.Column, onDups ma return igc, err } - expr, _, err := b.rewrite(column.GeneratedExpr, mockPlan, nil, true) + expr, _, err := b.rewrite(ctx, column.GeneratedExpr, mockPlan, nil, true) if err != nil { return igc, err } @@ -1273,7 +1474,7 @@ func (b *PlanBuilder) resolveGeneratedColumns(columns []*table.Column, onDups ma return igc, nil } -func (b *PlanBuilder) buildInsert(insert *ast.InsertStmt) (Plan, error) { +func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) (Plan, error) { ts, ok := insert.Table.TableRefs.Left.(*ast.TableSource) if !ok { return nil, infoschema.ErrTableNotExists.GenWithStackByArgs() @@ -1329,19 +1530,19 @@ func (b *PlanBuilder) buildInsert(insert *ast.InsertStmt) (Plan, error) { if len(insert.Setlist) > 0 { // Branch for `INSERT ... SET ...`. - err := b.buildSetValuesOfInsert(insert, insertPlan, mockTablePlan, checkRefColumn) + err := b.buildSetValuesOfInsert(ctx, insert, insertPlan, mockTablePlan, checkRefColumn) if err != nil { return nil, err } } else if len(insert.Lists) > 0 { // Branch for `INSERT ... VALUES ...`. - err := b.buildValuesListOfInsert(insert, insertPlan, mockTablePlan, checkRefColumn) + err := b.buildValuesListOfInsert(ctx, insert, insertPlan, mockTablePlan, checkRefColumn) if err != nil { return nil, err } } else { // Branch for `INSERT ... SELECT ...`. - err := b.buildSelectPlanOfInsert(insert, insertPlan) + err := b.buildSelectPlanOfInsert(ctx, insert, insertPlan) if err != nil { return nil, err } @@ -1358,7 +1559,7 @@ func (b *PlanBuilder) buildInsert(insert *ast.InsertStmt) (Plan, error) { } for i, assign := range insert.OnDuplicate { // Construct the function which calculates the assign value of the column. - expr, err1 := b.rewriteInsertOnDuplicateUpdate(assign.Expr, mockTablePlan, insertPlan) + expr, err1 := b.rewriteInsertOnDuplicateUpdate(ctx, assign.Expr, mockTablePlan, insertPlan) if err1 != nil { return nil, err1 } @@ -1371,7 +1572,7 @@ func (b *PlanBuilder) buildInsert(insert *ast.InsertStmt) (Plan, error) { // Calculate generated columns. mockTablePlan.schema = insertPlan.tableSchema - insertPlan.GenCols, err = b.resolveGeneratedColumns(insertPlan.Table.Cols(), onDupColSet, mockTablePlan) + insertPlan.GenCols, err = b.resolveGeneratedColumns(ctx, insertPlan.Table.Cols(), onDupColSet, mockTablePlan) if err != nil { return nil, err } @@ -1426,7 +1627,7 @@ func (b *PlanBuilder) getAffectCols(insertStmt *ast.InsertStmt, insertPlan *Inse return affectedValuesCols, nil } -func (b *PlanBuilder) buildSetValuesOfInsert(insert *ast.InsertStmt, insertPlan *Insert, mockTablePlan *LogicalTableDual, checkRefColumn func(n ast.Node) ast.Node) error { +func (b *PlanBuilder) buildSetValuesOfInsert(ctx context.Context, insert *ast.InsertStmt, insertPlan *Insert, mockTablePlan *LogicalTableDual, checkRefColumn func(n ast.Node) ast.Node) error { tableInfo := insertPlan.Table.Meta() colNames := make([]string, 0, len(insert.Setlist)) exprCols := make([]*expression.Column, 0, len(insert.Setlist)) @@ -1454,7 +1655,7 @@ func (b *PlanBuilder) buildSetValuesOfInsert(insert *ast.InsertStmt, insertPlan } for i, assign := range insert.Setlist { - expr, _, err := b.rewriteWithPreprocess(assign.Expr, mockTablePlan, nil, nil, true, checkRefColumn) + expr, _, err := b.rewriteWithPreprocess(ctx, assign.Expr, mockTablePlan, nil, nil, true, checkRefColumn) if err != nil { return err } @@ -1467,7 +1668,7 @@ func (b *PlanBuilder) buildSetValuesOfInsert(insert *ast.InsertStmt, insertPlan return nil } -func (b *PlanBuilder) buildValuesListOfInsert(insert *ast.InsertStmt, insertPlan *Insert, mockTablePlan *LogicalTableDual, checkRefColumn func(n ast.Node) ast.Node) error { +func (b *PlanBuilder) buildValuesListOfInsert(ctx context.Context, insert *ast.InsertStmt, insertPlan *Insert, mockTablePlan *LogicalTableDual, checkRefColumn func(n ast.Node) ast.Node) error { affectedValuesCols, err := b.getAffectCols(insert, insertPlan) if err != nil { return err @@ -1515,7 +1716,7 @@ func (b *PlanBuilder) buildValuesListOfInsert(insert *ast.InsertStmt, insertPlan RetType: &x.Type, } default: - expr, _, err = b.rewriteWithPreprocess(valueItem, mockTablePlan, nil, nil, true, checkRefColumn) + expr, _, err = b.rewriteWithPreprocess(ctx, valueItem, mockTablePlan, nil, nil, true, checkRefColumn) } if err != nil { return err @@ -1528,12 +1729,12 @@ func (b *PlanBuilder) buildValuesListOfInsert(insert *ast.InsertStmt, insertPlan return nil } -func (b *PlanBuilder) buildSelectPlanOfInsert(insert *ast.InsertStmt, insertPlan *Insert) error { +func (b *PlanBuilder) buildSelectPlanOfInsert(ctx context.Context, insert *ast.InsertStmt, insertPlan *Insert) error { affectedValuesCols, err := b.getAffectCols(insert, insertPlan) if err != nil { return err } - selectPlan, err := b.Build(insert.Select) + selectPlan, err := b.Build(ctx, insert.Select) if err != nil { return err } @@ -1556,7 +1757,7 @@ func (b *PlanBuilder) buildSelectPlanOfInsert(insert *ast.InsertStmt, insertPlan } } - insertPlan.SelectPlan, err = DoOptimize(b.optFlag, selectPlan.(LogicalPlan)) + insertPlan.SelectPlan, err = DoOptimize(ctx, b.optFlag, selectPlan.(LogicalPlan)) if err != nil { return err } @@ -1581,7 +1782,7 @@ func (b *PlanBuilder) buildSelectPlanOfInsert(insert *ast.InsertStmt, insertPlan return nil } -func (b *PlanBuilder) buildLoadData(ld *ast.LoadDataStmt) (Plan, error) { +func (b *PlanBuilder) buildLoadData(ctx context.Context, ld *ast.LoadDataStmt) (Plan, error) { p := &LoadData{ IsLocal: ld.IsLocal, OnDuplicate: ld.OnDuplicate, @@ -1603,7 +1804,7 @@ func (b *PlanBuilder) buildLoadData(ld *ast.LoadDataStmt) (Plan, error) { mockTablePlan.SetSchema(schema) var err error - p.GenCols, err = b.resolveGeneratedColumns(tableInPlan.Cols(), nil, mockTablePlan) + p.GenCols, err = b.resolveGeneratedColumns(ctx, tableInPlan.Cols(), nil, mockTablePlan) if err != nil { return nil, err } @@ -1615,47 +1816,188 @@ func (b *PlanBuilder) buildLoadStats(ld *ast.LoadStatsStmt) Plan { return p } -func (b *PlanBuilder) buildSplitIndexRegion(node *ast.SplitIndexRegionStmt) (Plan, error) { +func (b *PlanBuilder) buildSplitRegion(node *ast.SplitRegionStmt) (Plan, error) { + if len(node.IndexName.L) != 0 { + return b.buildSplitIndexRegion(node) + } + return b.buildSplitTableRegion(node) +} + +func (b *PlanBuilder) buildSplitIndexRegion(node *ast.SplitRegionStmt) (Plan, error) { tblInfo := node.Table.TableInfo - indexInfo := tblInfo.FindIndexByName(strings.ToLower(node.IndexName)) + indexInfo := tblInfo.FindIndexByName(node.IndexName.L) if indexInfo == nil { return nil, ErrKeyDoesNotExist.GenWithStackByArgs(node.IndexName, tblInfo.Name) } + mockTablePlan := LogicalTableDual{}.Init(b.ctx) + schema := expression.TableInfo2SchemaWithDBName(b.ctx, node.Table.Schema, tblInfo) + mockTablePlan.SetSchema(schema) - indexValues := make([][]types.Datum, 0, len(node.ValueLists)) - for i, valuesItem := range node.ValueLists { + p := &SplitRegion{ + TableInfo: tblInfo, + IndexInfo: indexInfo, + } + p.SetSchema(buildSplitRegionsSchema()) + // Split index regions by user specified value lists. + if len(node.SplitOpt.ValueLists) > 0 { + indexValues := make([][]types.Datum, 0, len(node.SplitOpt.ValueLists)) + for i, valuesItem := range node.SplitOpt.ValueLists { + if len(valuesItem) > len(indexInfo.Columns) { + return nil, ErrWrongValueCountOnRow.GenWithStackByArgs(i + 1) + } + values, err := b.convertValue2ColumnType(valuesItem, mockTablePlan, indexInfo, tblInfo) + if err != nil { + return nil, err + } + indexValues = append(indexValues, values) + } + p.ValueLists = indexValues + return p, nil + } + + // Split index regions by lower, upper value. + checkLowerUpperValue := func(valuesItem []ast.ExprNode, name string) ([]types.Datum, error) { + if len(valuesItem) == 0 { + return nil, errors.Errorf("Split index `%v` region %s value count should more than 0", indexInfo.Name, name) + } if len(valuesItem) > len(indexInfo.Columns) { - return nil, ErrWrongValueCountOnRow.GenWithStackByArgs(i + 1) + return nil, errors.Errorf("Split index `%v` region column count doesn't match value count at %v", indexInfo.Name, name) } - valueList := make([]types.Datum, 0, len(valuesItem)) - for j, valueItem := range valuesItem { - x, ok := valueItem.(*driver.ValueExpr) - if !ok { - return nil, errors.New("expect constant values") + return b.convertValue2ColumnType(valuesItem, mockTablePlan, indexInfo, tblInfo) + } + lowerValues, err := checkLowerUpperValue(node.SplitOpt.Lower, "lower") + if err != nil { + return nil, err + } + upperValues, err := checkLowerUpperValue(node.SplitOpt.Upper, "upper") + if err != nil { + return nil, err + } + p.Lower = lowerValues + p.Upper = upperValues + + maxSplitRegionNum := int64(config.GetGlobalConfig().SplitRegionMaxNum) + if node.SplitOpt.Num > maxSplitRegionNum { + return nil, errors.Errorf("Split index region num exceeded the limit %v", maxSplitRegionNum) + } else if node.SplitOpt.Num < 1 { + return nil, errors.Errorf("Split index region num should more than 0") + } + p.Num = int(node.SplitOpt.Num) + return p, nil +} + +func (b *PlanBuilder) convertValue2ColumnType(valuesItem []ast.ExprNode, mockTablePlan LogicalPlan, indexInfo *model.IndexInfo, tblInfo *model.TableInfo) ([]types.Datum, error) { + values := make([]types.Datum, 0, len(valuesItem)) + for j, valueItem := range valuesItem { + colOffset := indexInfo.Columns[j].Offset + value, err := b.convertValue(valueItem, mockTablePlan, tblInfo.Columns[colOffset]) + if err != nil { + return nil, err + } + values = append(values, value) + } + return values, nil +} + +func (b *PlanBuilder) convertValue(valueItem ast.ExprNode, mockTablePlan LogicalPlan, col *model.ColumnInfo) (d types.Datum, err error) { + var expr expression.Expression + switch x := valueItem.(type) { + case *driver.ValueExpr: + expr = &expression.Constant{ + Value: x.Datum, + RetType: &x.Type, + } + default: + expr, _, err = b.rewrite(context.TODO(), valueItem, mockTablePlan, nil, true) + if err != nil { + return d, err + } + } + constant, ok := expr.(*expression.Constant) + if !ok { + return d, errors.New("Expect constant values") + } + value, err := constant.Eval(chunk.Row{}) + if err != nil { + return d, err + } + d, err = value.ConvertTo(b.ctx.GetSessionVars().StmtCtx, &col.FieldType) + if err != nil { + if !types.ErrTruncated.Equal(err) { + return d, err + } + valStr, err1 := value.ToString() + if err1 != nil { + return d, err + } + return d, types.ErrTruncated.GenWithStack("Incorrect value: '%-.128s' for column '%.192s'", valStr, col.Name.O) + } + return d, nil +} + +func (b *PlanBuilder) buildSplitTableRegion(node *ast.SplitRegionStmt) (Plan, error) { + tblInfo := node.Table.TableInfo + var pkCol *model.ColumnInfo + if tblInfo.PKIsHandle { + if col := tblInfo.GetPkColInfo(); col != nil { + pkCol = col + } + } + if pkCol == nil { + pkCol = model.NewExtraHandleColInfo() + } + mockTablePlan := LogicalTableDual{}.Init(b.ctx) + schema := expression.TableInfo2SchemaWithDBName(b.ctx, node.Table.Schema, tblInfo) + mockTablePlan.SetSchema(schema) + + p := &SplitRegion{ + TableInfo: tblInfo, + } + p.SetSchema(buildSplitRegionsSchema()) + if len(node.SplitOpt.ValueLists) > 0 { + values := make([][]types.Datum, 0, len(node.SplitOpt.ValueLists)) + for i, valuesItem := range node.SplitOpt.ValueLists { + if len(valuesItem) > 1 { + return nil, ErrWrongValueCountOnRow.GenWithStackByArgs(i + 1) } - colOffset := indexInfo.Columns[j].Offset - value, err := x.Datum.ConvertTo(b.ctx.GetSessionVars().StmtCtx, &tblInfo.Columns[colOffset].FieldType) + value, err := b.convertValue(valuesItem[0], mockTablePlan, pkCol) if err != nil { return nil, err } + values = append(values, []types.Datum{value}) + } + p.ValueLists = values + return p, nil + } - valueList = append(valueList, value) + checkLowerUpperValue := func(valuesItem []ast.ExprNode, name string) (types.Datum, error) { + if len(valuesItem) != 1 { + return types.Datum{}, errors.Errorf("Split table region %s value count should be 1", name) } - indexValues = append(indexValues, valueList) + return b.convertValue(valuesItem[0], mockTablePlan, pkCol) } - tableInPlan, ok := b.is.TableByID(tblInfo.ID) - if !ok { - return nil, errors.Errorf("Can't get table %s.", tblInfo.Name.O) + lowerValues, err := checkLowerUpperValue(node.SplitOpt.Lower, "lower") + if err != nil { + return nil, err + } + upperValue, err := checkLowerUpperValue(node.SplitOpt.Upper, "upper") + if err != nil { + return nil, err } - return &SplitIndexRegion{ - Table: tableInPlan, - IndexInfo: indexInfo, - ValueLists: indexValues, - }, nil + p.Lower = []types.Datum{lowerValues} + p.Upper = []types.Datum{upperValue} + maxSplitRegionNum := int64(config.GetGlobalConfig().SplitRegionMaxNum) + if node.SplitOpt.Num > maxSplitRegionNum { + return nil, errors.Errorf("Split table region num exceeded the limit %v", maxSplitRegionNum) + } else if node.SplitOpt.Num < 1 { + return nil, errors.Errorf("Split table region num should more than 0") + } + p.Num = int(node.SplitOpt.Num) + return p, nil } -func (b *PlanBuilder) buildDDL(node ast.DDLNode) (Plan, error) { +func (b *PlanBuilder) buildDDL(ctx context.Context, node ast.DDLNode) (Plan, error) { var authErr error switch v := node.(type) { case *ast.AlterDatabaseStmt: @@ -1738,20 +2080,24 @@ func (b *PlanBuilder) buildDDL(node ast.DDLNode) (Plan, error) { v.ReferTable.Name.L, "", authErr) } case *ast.CreateViewStmt: - plan, err := b.Build(v.Select) + b.capFlag |= canExpandAST + defer func() { + b.capFlag &= ^canExpandAST + }() + plan, err := b.Build(ctx, v.Select) if err != nil { return nil, err } schema := plan.Schema() - if v.Cols != nil && len(v.Cols) != schema.Len() { - return nil, ddl.ErrViewWrongList + if v.Cols == nil { + v.Cols = make([]model.CIStr, len(schema.Columns)) + for i, col := range schema.Columns { + v.Cols[i] = col.ColName + } } - // we use fieldList to store schema.Columns temporary - var fieldList = make([]*ast.SelectField, schema.Len()) - for i, col := range schema.Columns { - fieldList[i] = &ast.SelectField{AsName: col.ColName} + if len(v.Cols) != schema.Len() { + return nil, ddl.ErrViewWrongList } - v.Select.(*ast.SelectStmt).Fields.Fields = fieldList if _, ok := plan.(LogicalPlan); ok { if b.ctx.GetSessionVars().User != nil { authErr = ErrTableaccessDenied.GenWithStackByArgs("CREATE VIEW", b.ctx.GetSessionVars().User.Hostname, @@ -1760,7 +2106,7 @@ func (b *PlanBuilder) buildDDL(node ast.DDLNode) (Plan, error) { b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreateViewPriv, v.ViewName.Schema.L, v.ViewName.Name.L, "", authErr) } - if v.Definer.CurrentUser { + if v.Definer.CurrentUser && b.ctx.GetSessionVars().User != nil { v.Definer = b.ctx.GetSessionVars().User } if b.ctx.GetSessionVars().User != nil && v.Definer.String() != b.ctx.GetSessionVars().User.String() { @@ -1913,11 +2259,11 @@ func (b *PlanBuilder) buildExplainFor(explainFor *ast.ExplainForStmt) (Plan, err return b.buildExplainPlan(targetPlan, explainFor.Format, false, nil) } -func (b *PlanBuilder) buildExplain(explain *ast.ExplainStmt) (Plan, error) { +func (b *PlanBuilder) buildExplain(ctx context.Context, explain *ast.ExplainStmt) (Plan, error) { if show, ok := explain.Stmt.(*ast.ShowStmt); ok { - return b.buildShow(show) + return b.buildShow(ctx, show) } - targetPlan, err := OptimizeAstNode(b.ctx, explain.Stmt, b.is) + targetPlan, err := OptimizeAstNode(ctx, b.ctx, explain.Stmt, b.is) if err != nil { return nil, err } diff --git a/planner/core/planbuilder_test.go b/planner/core/planbuilder_test.go index 3c1487d1a3637..eca40338aa44a 100644 --- a/planner/core/planbuilder_test.go +++ b/planner/core/planbuilder_test.go @@ -14,6 +14,8 @@ package core import ( + "context" + . "github.com/pingcap/check" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" @@ -101,7 +103,7 @@ func (s *testPlanBuilderSuite) TestRewriterPool(c *C) { // Make sure PlanBuilder.getExpressionRewriter() provides clean rewriter from pool. // First, pick one rewriter from the pool and make it dirty. builder.rewriterCounter++ - dirtyRewriter := builder.getExpressionRewriter(nil) + dirtyRewriter := builder.getExpressionRewriter(context.TODO(), nil) dirtyRewriter.asScalar = true dirtyRewriter.aggrMap = make(map[*ast.AggregateFuncExpr]int) dirtyRewriter.preprocess = func(ast.Node) ast.Node { return nil } @@ -111,7 +113,7 @@ func (s *testPlanBuilderSuite) TestRewriterPool(c *C) { builder.rewriterCounter-- // Then, pick again and check if it's cleaned up. builder.rewriterCounter++ - cleanRewriter := builder.getExpressionRewriter(nil) + cleanRewriter := builder.getExpressionRewriter(context.TODO(), nil) c.Assert(cleanRewriter, Equals, dirtyRewriter) // Rewriter should be reused. c.Assert(cleanRewriter.asScalar, Equals, false) c.Assert(cleanRewriter.aggrMap, IsNil) @@ -151,7 +153,7 @@ func (s *testPlanBuilderSuite) TestDisableFold(c *C) { builder := &PlanBuilder{ctx: ctx} builder.rewriterCounter++ - rewriter := builder.getExpressionRewriter(nil) + rewriter := builder.getExpressionRewriter(context.TODO(), nil) c.Assert(rewriter, NotNil) c.Assert(rewriter.disableFoldCounter, Equals, 0) rewritenExpression, _, err := builder.rewriteExprNode(rewriter, expr, true) diff --git a/planner/core/point_get_plan.go b/planner/core/point_get_plan.go index 0be899e4fb87e..1323171805f89 100644 --- a/planner/core/point_get_plan.go +++ b/planner/core/point_get_plan.go @@ -22,12 +22,15 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/opcode" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" + "github.com/pingcap/tidb/util/plancodec" "github.com/pingcap/tipb/go-tipb" ) @@ -36,6 +39,7 @@ import ( // This plan is much faster to build and to execute because it avoid the optimization and coprocessor cost. type PointGetPlan struct { basePlan + dbName string schema *expression.Schema TblInfo *model.TableInfo IndexInfo *model.IndexInfo @@ -46,6 +50,9 @@ type PointGetPlan struct { IndexValueParams []*driver.ParamMarkerExpr expr expression.Expression ctx sessionctx.Context + IsTableDual bool + Lock bool + IsForUpdate bool } type nameValuePair struct { @@ -84,7 +91,14 @@ func (p *PointGetPlan) ExplainInfo() string { } } } else { - fmt.Fprintf(buffer, ", handle:%d", p.Handle) + if p.UnsignedHandle { + fmt.Fprintf(buffer, ", handle:%d", uint64(p.Handle)) + } else { + fmt.Fprintf(buffer, ", handle:%d", p.Handle) + } + } + if p.Lock { + fmt.Fprintf(buffer, ", lock") } return buffer.String() } @@ -116,6 +130,9 @@ func (p *PointGetPlan) Children() []PhysicalPlan { // SetChildren sets the children for the plan. func (p *PointGetPlan) SetChildren(...PhysicalPlan) {} +// SetChild sets a specific child for the plan. +func (p *PointGetPlan) SetChild(i int, child PhysicalPlan) {} + // ResolveIndices resolves the indices for columns. After doing this, the columns can evaluate the rows by their indices. func (p *PointGetPlan) ResolveIndices() error { return nil @@ -130,6 +147,22 @@ func TryFastPlan(ctx sessionctx.Context, node ast.Node) Plan { if checkFastPlanPrivilege(ctx, fp, mysql.SelectPriv) != nil { return nil } + if fp.IsTableDual { + tableDual := PhysicalTableDual{} + tableDual.SetSchema(fp.Schema()) + return tableDual.Init(ctx, &property.StatsInfo{}) + } + if x.LockTp == ast.SelectLockForUpdate { + // Locking of rows for update using SELECT FOR UPDATE only applies when autocommit + // is disabled (either by beginning transaction with START TRANSACTION or by setting + // autocommit to 0. If autocommit is enabled, the rows matching the specification are not locked. + // See https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html + sessVars := ctx.GetSessionVars() + if !sessVars.IsAutocommit() || sessVars.InTxn() { + fp.Lock = true + fp.IsForUpdate = true + } + } return fp } case *ast.UpdateStmt: @@ -148,7 +181,7 @@ func TryFastPlan(ctx sessionctx.Context, node ast.Node) Plan { // 3. All the columns must be public and generated. // 4. The condition is an access path that the range is a unique key. func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetPlan { - if selStmt.Having != nil || selStmt.LockTp != ast.SelectLockNone { + if selStmt.Having != nil { return nil } else if selStmt.Limit != nil { count, offset, err := extractLimitCountOffset(ctx, selStmt.Limit) @@ -156,7 +189,7 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP return nil } } - tblName := getSingleTableName(selStmt.From) + tblName, tblAlias := getSingleTableNameAndAlias(selStmt.From) if tblName == nil { return nil } @@ -182,23 +215,41 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP } } pairs := make([]nameValuePair, 0, 4) - pairs = getNameValuePairs(pairs, selStmt.Where) + pairs = getNameValuePairs(pairs, tblAlias, selStmt.Where) if pairs == nil { return nil } - handlePair, unsigned := findPKHandle(tbl, pairs) + handlePair, fieldType := findPKHandle(tbl, pairs) if handlePair.value.Kind() != types.KindNull && len(pairs) == 1 { - schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, selStmt.Fields.Fields) + schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields) if schema == nil { return nil } - p := newPointGetPlan(ctx, schema, tbl) - var err error - p.Handle, err = handlePair.value.ToInt64(ctx.GetSessionVars().StmtCtx) + dbName := tblName.Schema.L + if dbName == "" { + dbName = ctx.GetSessionVars().CurrentDB + } + p := newPointGetPlan(ctx, dbName, schema, tbl) + intDatum, err := handlePair.value.ConvertTo(ctx.GetSessionVars().StmtCtx, fieldType) + if err != nil { + if terror.ErrorEqual(types.ErrOverflow, err) { + p.IsTableDual = true + return p + } + // some scenarios cast to int with error, but we may use this value in point get + if !terror.ErrorEqual(types.ErrTruncatedWrongVal, err) { + return nil + } + } + cmp, err := intDatum.CompareDatum(ctx.GetSessionVars().StmtCtx, &handlePair.value) if err != nil { return nil + } else if cmp != 0 { + p.IsTableDual = true + return p } - p.UnsignedHandle = unsigned + p.Handle = intDatum.GetInt64() + p.UnsignedHandle = mysql.HasUnsignedFlag(fieldType.Flag) p.HandleParam = handlePair.param return p } @@ -214,11 +265,15 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP if idxValues == nil { continue } - schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, selStmt.Fields.Fields) + schema := buildSchemaFromFields(ctx, tblName.Schema, tbl, tblAlias, selStmt.Fields.Fields) if schema == nil { return nil } - p := newPointGetPlan(ctx, schema, tbl) + dbName := tblName.Schema.L + if dbName == "" { + dbName = ctx.GetSessionVars().CurrentDB + } + p := newPointGetPlan(ctx, dbName, schema, tbl) p.IndexInfo = idxInfo p.IndexValues = idxValues p.IndexValueParams = idxValueParams @@ -227,12 +282,14 @@ func tryPointGetPlan(ctx sessionctx.Context, selStmt *ast.SelectStmt) *PointGetP return nil } -func newPointGetPlan(ctx sessionctx.Context, schema *expression.Schema, tbl *model.TableInfo) *PointGetPlan { +func newPointGetPlan(ctx sessionctx.Context, dbName string, schema *expression.Schema, tbl *model.TableInfo) *PointGetPlan { p := &PointGetPlan{ - basePlan: newBasePlan(ctx, "Point_Get"), + basePlan: newBasePlan(ctx, plancodec.TypePointGet), + dbName: dbName, schema: schema, TblInfo: tbl, } + ctx.GetSessionVars().StmtCtx.Tables = []stmtctx.TableEntry{{DB: ctx.GetSessionVars().CurrentDB, Table: tbl.Name.L}} return p } @@ -241,32 +298,37 @@ func checkFastPlanPrivilege(ctx sessionctx.Context, fastPlan *PointGetPlan, chec if pm == nil { return nil } - dbName := ctx.GetSessionVars().CurrentDB for _, checkType := range checkTypes { - if !pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, dbName, fastPlan.TblInfo.Name.L, "", checkType) { + if !pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, fastPlan.dbName, fastPlan.TblInfo.Name.L, "", checkType) { return errors.New("privilege check fail") } } return nil } -func buildSchemaFromFields(ctx sessionctx.Context, dbName model.CIStr, tbl *model.TableInfo, fields []*ast.SelectField) *expression.Schema { +func buildSchemaFromFields(ctx sessionctx.Context, dbName model.CIStr, tbl *model.TableInfo, tblName model.CIStr, fields []*ast.SelectField) *expression.Schema { if dbName.L == "" { dbName = model.NewCIStr(ctx.GetSessionVars().CurrentDB) } columns := make([]*expression.Column, 0, len(tbl.Columns)+1) - if len(fields) == 1 && fields[0].WildCard != nil { - for _, col := range tbl.Columns { - columns = append(columns, colInfoToColumn(dbName, tbl.Name, col.Name, col, len(columns))) - } - return expression.NewSchema(columns...) - } if len(fields) > 0 { for _, field := range fields { + if field.WildCard != nil { + if field.WildCard.Table.L != "" && field.WildCard.Table.L != tblName.L { + return nil + } + for _, col := range tbl.Columns { + columns = append(columns, colInfoToColumn(dbName, tbl.Name, tblName, col.Name, col, len(columns))) + } + continue + } colNameExpr, ok := field.Expr.(*ast.ColumnNameExpr) if !ok { return nil } + if colNameExpr.Name.Table.L != "" && colNameExpr.Name.Table.L != tblName.L { + return nil + } col := findCol(tbl, colNameExpr.Name) if col == nil { return nil @@ -275,21 +337,21 @@ func buildSchemaFromFields(ctx sessionctx.Context, dbName model.CIStr, tbl *mode if field.AsName.L != "" { asName = field.AsName } - columns = append(columns, colInfoToColumn(dbName, tbl.Name, asName, col, len(columns))) + columns = append(columns, colInfoToColumn(dbName, tbl.Name, tblName, asName, col, len(columns))) } return expression.NewSchema(columns...) } // fields len is 0 for update and delete. var handleCol *expression.Column for _, col := range tbl.Columns { - column := colInfoToColumn(dbName, tbl.Name, col.Name, col, len(columns)) + column := colInfoToColumn(dbName, tbl.Name, tblName, col.Name, col, len(columns)) if tbl.PKIsHandle && mysql.HasPriKeyFlag(col.Flag) { handleCol = column } columns = append(columns, column) } if handleCol == nil { - handleCol = colInfoToColumn(dbName, tbl.Name, model.ExtraHandleName, model.NewExtraHandleColInfo(), len(columns)) + handleCol = colInfoToColumn(dbName, tbl.Name, tblName, model.ExtraHandleName, model.NewExtraHandleColInfo(), len(columns)) columns = append(columns, handleCol) } schema := expression.NewSchema(columns...) @@ -298,36 +360,40 @@ func buildSchemaFromFields(ctx sessionctx.Context, dbName model.CIStr, tbl *mode return schema } -func getSingleTableName(tableRefs *ast.TableRefsClause) *ast.TableName { +// getSingleTableNameAndAlias return the ast node of queried table name and the alias string. +// `tblName` is `nil` if there are multiple tables in the query. +// `tblAlias` will be the real table name if there is no table alias in the query. +func getSingleTableNameAndAlias(tableRefs *ast.TableRefsClause) (tblName *ast.TableName, tblAlias model.CIStr) { if tableRefs == nil || tableRefs.TableRefs == nil || tableRefs.TableRefs.Right != nil { - return nil + return nil, tblAlias } tblSrc, ok := tableRefs.TableRefs.Left.(*ast.TableSource) if !ok { - return nil - } - if tblSrc.AsName.L != "" { - return nil + return nil, tblAlias } - tblName, ok := tblSrc.Source.(*ast.TableName) + tblName, ok = tblSrc.Source.(*ast.TableName) if !ok { - return nil + return nil, tblAlias } - return tblName + tblAlias = tblSrc.AsName + if tblSrc.AsName.L == "" { + tblAlias = tblName.Name + } + return tblName, tblAlias } // getNameValuePairs extracts `column = constant/paramMarker` conditions from expr as name value pairs. -func getNameValuePairs(nvPairs []nameValuePair, expr ast.ExprNode) []nameValuePair { +func getNameValuePairs(nvPairs []nameValuePair, tblName model.CIStr, expr ast.ExprNode) []nameValuePair { binOp, ok := expr.(*ast.BinaryOperationExpr) if !ok { return nil } if binOp.Op == opcode.LogicAnd { - nvPairs = getNameValuePairs(nvPairs, binOp.L) + nvPairs = getNameValuePairs(nvPairs, tblName, binOp.L) if nvPairs == nil { return nil } - nvPairs = getNameValuePairs(nvPairs, binOp.R) + nvPairs = getNameValuePairs(nvPairs, tblName, binOp.R) if nvPairs == nil { return nil } @@ -359,25 +425,28 @@ func getNameValuePairs(nvPairs []nameValuePair, expr ast.ExprNode) []nameValuePa if d.IsNull() { return nil } + if colName.Name.Table.L != "" && colName.Name.Table.L != tblName.L { + return nil + } return append(nvPairs, nameValuePair{colName: colName.Name.Name.L, value: d, param: param}) } return nil } -func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair, unsigned bool) { +func findPKHandle(tblInfo *model.TableInfo, pairs []nameValuePair) (handlePair nameValuePair, fieldType *types.FieldType) { if !tblInfo.PKIsHandle { - return handlePair, unsigned + return handlePair, nil } for _, col := range tblInfo.Columns { if mysql.HasPriKeyFlag(col.Flag) { i := findInPairs(col.Name.L, pairs) if i == -1 { - return handlePair, unsigned + return handlePair, nil } - return pairs[i], mysql.HasUnsignedFlag(col.Flag) + return pairs[i], &col.FieldType } } - return handlePair, unsigned + return handlePair, nil } func getIndexValues(idxInfo *model.IndexInfo, pairs []nameValuePair) ([]types.Datum, []*driver.ParamMarkerExpr) { @@ -427,6 +496,12 @@ func tryUpdatePointPlan(ctx sessionctx.Context, updateStmt *ast.UpdateStmt) Plan if checkFastPlanPrivilege(ctx, fastSelect, mysql.SelectPriv, mysql.UpdatePriv) != nil { return nil } + if fastSelect.IsTableDual { + return PhysicalTableDual{}.Init(ctx, &property.StatsInfo{}) + } + if ctx.GetSessionVars().TxnCtx.IsPessimistic { + fastSelect.Lock = true + } orderedList := buildOrderedList(ctx, fastSelect, updateStmt.List) if orderedList == nil { return nil @@ -484,6 +559,12 @@ func tryDeletePointPlan(ctx sessionctx.Context, delStmt *ast.DeleteStmt) Plan { if checkFastPlanPrivilege(ctx, fastSelect, mysql.SelectPriv, mysql.DeletePriv) != nil { return nil } + if fastSelect.IsTableDual { + return PhysicalTableDual{}.Init(ctx, &property.StatsInfo{}) + } + if ctx.GetSessionVars().TxnCtx.IsPessimistic { + fastSelect.Lock = true + } delPlan := Delete{ SelectPlan: fastSelect, }.Init(ctx) @@ -500,10 +581,10 @@ func findCol(tbl *model.TableInfo, colName *ast.ColumnName) *model.ColumnInfo { return nil } -func colInfoToColumn(db model.CIStr, tblName model.CIStr, asName model.CIStr, col *model.ColumnInfo, idx int) *expression.Column { +func colInfoToColumn(db model.CIStr, origTblName model.CIStr, tblName model.CIStr, asName model.CIStr, col *model.ColumnInfo, idx int) *expression.Column { return &expression.Column{ ColName: asName, - OrigTblName: tblName, + OrigTblName: origTblName, DBName: db, TblName: tblName, RetType: &col.FieldType, diff --git a/planner/core/point_get_plan_test.go b/planner/core/point_get_plan_test.go index 7d8eece9df1f8..eecffc3e6fa5b 100644 --- a/planner/core/point_get_plan_test.go +++ b/planner/core/point_get_plan_test.go @@ -14,9 +14,13 @@ package core_test import ( + "fmt" "math" + "strings" . "github.com/pingcap/check" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/util/testkit" @@ -27,20 +31,31 @@ import ( var _ = Suite(&testPointGetSuite{}) type testPointGetSuite struct { + store kv.Storage + dom *domain.Domain } -func (s *testPointGetSuite) TestPointGetPlanCache(c *C) { - defer testleak.AfterTest(c)() +func (s *testPointGetSuite) SetUpSuite(c *C) { + testleak.BeforeTest() store, dom, err := newStoreWithBootstrap() c.Assert(err, IsNil) - tk := testkit.NewTestKit(c, store) + s.store = store + s.dom = dom +} + +func (s *testPointGetSuite) TearDownSuite(c *C) { + s.dom.Close() + s.store.Close() + testleak.AfterTest(c)() +} + +func (s *testPointGetSuite) TestPointGetPlanCache(c *C) { + tk := testkit.NewTestKit(c, s.store) orgEnable := core.PreparedPlanCacheEnabled() orgCapacity := core.PreparedPlanCacheCapacity orgMemGuardRatio := core.PreparedPlanCacheMemoryGuardRatio orgMaxMemory := core.PreparedPlanCacheMaxMemory defer func() { - dom.Close() - store.Close() core.SetPreparedPlanCache(orgEnable) core.PreparedPlanCacheCapacity = orgCapacity core.PreparedPlanCacheMemoryGuardRatio = orgMemGuardRatio @@ -150,5 +165,33 @@ func (s *testPointGetSuite) TestPointGetPlanCache(c *C) { tk.MustQuery("execute stmt7 using @p2").Check(testkit.Rows("1")) counter.Write(pb) hit = pb.GetCounter().GetValue() - c.Check(hit, Equals, float64(3)) + c.Check(hit, Equals, float64(2)) +} + +func (s *testPointGetSuite) TestPointGetForUpdate(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table fu (id int primary key, val int)") + tk.MustExec("insert into fu values (6, 6)") + + // In autocommit mode, outside a transaction, "for update" doesn't take effect. + checkUseForUpdate(tk, c, false) + + tk.MustExec("begin") + checkUseForUpdate(tk, c, true) + tk.MustExec("rollback") + + tk.MustExec("set @@session.autocommit = 0") + checkUseForUpdate(tk, c, true) + tk.MustExec("rollback") +} + +func checkUseForUpdate(tk *testkit.TestKit, c *C, expectLock bool) { + res := tk.MustQuery("explain select * from fu where id = 6 for update") + // Point_Get_1 1.00 root table:fu, handle:6 + opInfo := res.Rows()[0][3] + selectLock := strings.Contains(fmt.Sprintf("%s", opInfo), "lock") + c.Assert(selectLock, Equals, expectLock) + + tk.MustQuery("select * from fu where id = 6 for update").Check(testkit.Rows("6 6")) } diff --git a/planner/core/prepare_test.go b/planner/core/prepare_test.go index b681b19f4a64a..829ec099fface 100644 --- a/planner/core/prepare_test.go +++ b/planner/core/prepare_test.go @@ -14,6 +14,7 @@ package core_test import ( + "context" "math" "strconv" "time" @@ -157,17 +158,18 @@ func (s *testPlanSuite) TestPrepareCacheDeferredFunction(c *C) { metrics.ResettablePlanCacheCounterFortTest = true metrics.PlanCacheCounter.Reset() counter := metrics.PlanCacheCounter.WithLabelValues("prepare") + ctx := context.TODO() for i := 0; i < 2; i++ { stmt, err := s.ParseOneStmt(sql1, "", "") c.Check(err, IsNil) is := tk.Se.GetSessionVars().TxnCtx.InfoSchema.(infoschema.InfoSchema) builder := core.NewPlanBuilder(tk.Se, is) - p, err := builder.Build(stmt) + p, err := builder.Build(ctx, stmt) c.Check(err, IsNil) execPlan, ok := p.(*core.Execute) c.Check(ok, IsTrue) executor.ResetContextOfStmt(tk.Se, stmt) - err = execPlan.OptimizePreparedPlan(tk.Se, is) + err = execPlan.OptimizePreparedPlan(ctx, tk.Se, is) c.Check(err, IsNil) planStr[i] = core.ToString(execPlan.Plan) c.Check(planStr[i], Matches, expectedPattern, Commentf("for %s", sql1)) @@ -343,3 +345,28 @@ func (s *testPrepareSuite) TestPrepareWithWindowFunction(c *C) { tk.MustExec("set @a=0, @b=1;") tk.MustQuery("execute stmt2 using @a, @b").Check(testkit.Rows("0", "0")) } + +func (s *testPrepareSuite) TestPrepareForGroupByItems(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(id int, v int)") + tk.MustExec("insert into t(id, v) values(1, 2),(1, 2),(2, 3);") + tk.MustExec("prepare s1 from 'select max(v) from t group by floor(id/?)';") + tk.MustExec("set @a=2;") + tk.MustQuery("execute s1 using @a;").Sort().Check(testkit.Rows("2", "3")) + + tk.MustExec("prepare s1 from 'select max(v) from t group by ?';") + tk.MustExec("set @a=2;") + err = tk.ExecToErr("execute s1 using @a;") + c.Assert(err.Error(), Equals, "Unknown column '2' in 'group statement'") + tk.MustExec("set @a=2.0;") + tk.MustQuery("execute s1 using @a;").Check(testkit.Rows("3")) +} diff --git a/planner/core/preprocess.go b/planner/core/preprocess.go index 26d93eabaf9ab..3ef613bb48995 100644 --- a/planner/core/preprocess.go +++ b/planner/core/preprocess.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/parser_driver" @@ -84,6 +85,7 @@ func (p *preprocessor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { switch node := in.(type) { case *ast.CreateTableStmt: p.flag |= inCreateOrDropTable + p.resolveCreateTableStmt(node) p.checkCreateTableGrammar(node) case *ast.CreateViewStmt: p.flag |= inCreateOrDropTable @@ -178,7 +180,12 @@ func (p *preprocessor) Leave(in ast.Node) (out ast.Node, ok bool) { p.err = expression.ErrIncorrectParameterCount.GenWithStackByArgs(x.FnName.L) } else { _, isValueExpr1 := x.Args[0].(*driver.ValueExpr) - _, isValueExpr2 := x.Args[1].(*driver.ValueExpr) + isValueExpr2 := false + switch x.Args[1].(type) { + case *driver.ValueExpr, *ast.UnaryOperationExpr: + isValueExpr2 = true + } + if !isValueExpr1 || !isValueExpr2 { p.err = ErrWrongArguments.GenWithStackByArgs("NAME_CONST") } @@ -290,7 +297,7 @@ func (p *preprocessor) checkAutoIncrement(stmt *ast.CreateTableStmt) { } } if (autoIncrementMustBeKey && !isKey) || count > 1 { - p.err = errors.New("Incorrect table definition; there can be only one auto column and it must be defined as a key") + p.err = autoid.ErrWrongAutoKey.GenWithStackByArgs() } switch autoIncrementCol.Tp.Tp { @@ -608,7 +615,7 @@ func checkColumn(colDef *ast.ColumnDef) error { if len(tp.Elems) > mysql.MaxTypeSetMembers { return types.ErrTooBigSet.GenWithStack("Too many strings for column %s and SET", colDef.Name.Name.O) } - // Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html . + // Check set elements. See https://dev.mysql.com/doc/refman/5.7/en/set.html. for _, str := range colDef.Tp.Elems { if strings.Contains(str, ",") { return types.ErrIllegalValueForType.GenWithStackByArgs(types.TypeStr(tp.Tp), str) @@ -732,6 +739,14 @@ func (p *preprocessor) resolveShowStmt(node *ast.ShowStmt) { } } +func (p *preprocessor) resolveCreateTableStmt(node *ast.CreateTableStmt) { + for _, val := range node.Constraints { + if val.Refer != nil && val.Refer.Table.Schema.String() == "" { + val.Refer.Table.Schema = node.Table.Schema + } + } +} + func (p *preprocessor) resolveAlterTableStmt(node *ast.AlterTableStmt) { for _, spec := range node.Specs { if spec.Tp == ast.AlterTableRenameTable { diff --git a/planner/core/preprocess_test.go b/planner/core/preprocess_test.go index 31fb17cf7f38f..01d04b9261208 100644 --- a/planner/core/preprocess_test.go +++ b/planner/core/preprocess_test.go @@ -53,11 +53,11 @@ func (s *testValidatorSuite) TestValidator(c *C) { {"create table t(id int auto_increment default null, primary key (id))", true, nil}, {"create table t(id int default null auto_increment, primary key (id))", true, nil}, {"create table t(id int not null auto_increment)", true, - errors.New("Incorrect table definition; there can be only one auto column and it must be defined as a key")}, + errors.New("[autoid:1075]Incorrect table definition; there can be only one auto column and it must be defined as a key")}, {"create table t(id int not null auto_increment, c int auto_increment, key (id, c))", true, - errors.New("Incorrect table definition; there can be only one auto column and it must be defined as a key")}, + errors.New("[autoid:1075]Incorrect table definition; there can be only one auto column and it must be defined as a key")}, {"create table t(id int not null auto_increment, c int, key (c, id))", true, - errors.New("Incorrect table definition; there can be only one auto column and it must be defined as a key")}, + errors.New("[autoid:1075]Incorrect table definition; there can be only one auto column and it must be defined as a key")}, {"create table t(id decimal auto_increment, key (id))", true, errors.New("Incorrect column specifier for column 'id'")}, {"create table t(id float auto_increment, key (id))", true, nil}, @@ -219,7 +219,7 @@ func (s *testValidatorSuite) TestValidator(c *C) { _, err = se.Execute(context.Background(), "use test") c.Assert(err, IsNil) ctx := se.(sessionctx.Context) - is := infoschema.MockInfoSchema([]*model.TableInfo{core.MockTable()}) + is := infoschema.MockInfoSchema([]*model.TableInfo{core.MockSignedTable()}) for _, tt := range tests { stmts, err1 := session.Parse(ctx, tt.sql) c.Assert(err1, IsNil) diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index def6cea1f32f5..2340e747bb731 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -480,17 +480,6 @@ func (p *Insert) ResolveIndices() (err error) { return } -// ResolveIndices implements Plan interface. -func (p *Show) ResolveIndices() (err error) { - for i, expr := range p.Conditions { - p.Conditions[i], err = expr.ResolveIndices(p.schema) - if err != nil { - return err - } - } - return err -} - func (p *physicalSchemaProducer) ResolveIndices() (err error) { err = p.basePhysicalPlan.ResolveIndices() if err != nil { diff --git a/planner/core/rule_aggregation_elimination.go b/planner/core/rule_aggregation_elimination.go index 95b2f7449e407..1b3a91ae1f937 100644 --- a/planner/core/rule_aggregation_elimination.go +++ b/planner/core/rule_aggregation_elimination.go @@ -14,6 +14,7 @@ package core import ( + "context" "math" "github.com/pingcap/parser/ast" @@ -136,10 +137,10 @@ func (a *aggregationEliminateChecker) wrapCastFunction(ctx sessionctx.Context, a return expression.BuildCastFunction(ctx, arg, targetTp) } -func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { +func (a *aggregationEliminator) optimize(ctx context.Context, p LogicalPlan) (LogicalPlan, error) { newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { - newChild, err := a.optimize(child) + newChild, err := a.optimize(ctx, child) if err != nil { return nil, err } @@ -155,3 +156,7 @@ func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { } return p, nil } + +func (*aggregationEliminator) name() string { + return "aggregation_eliminate" +} diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index bca18b6305b1e..6f7488667b0ac 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -13,6 +13,7 @@ package core import ( + "context" "fmt" "github.com/pingcap/parser/ast" @@ -188,22 +189,25 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a // tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't // process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator. // If the pushed aggregation is grouped by unique key, it's no need to push it down. -func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) LogicalPlan { +func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) (_ LogicalPlan, err error) { child := join.children[childIdx] if aggregation.IsAllFirstRow(aggFuncs) { - return child + return child, nil } // If the join is multiway-join, we forbid pushing down. if _, ok := join.children[childIdx].(*LogicalJoin); ok { - return child + return child, nil } tmpSchema := expression.NewSchema(gbyCols...) for _, key := range child.Schema().Keys { if tmpSchema.ColumnsIndices(key) != nil { - return child + return child, nil } } - agg := a.makeNewAgg(join.ctx, aggFuncs, gbyCols) + agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols) + if err != nil { + return nil, err + } agg.SetChildren(child) // If agg has no group-by item, it will return a default value, which may cause some bugs. // So here we add a group-by item forcely. @@ -216,10 +220,10 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg var existsDefaultValues bool join.DefaultValues, existsDefaultValues = a.getDefaultValues(agg) if !existsDefaultValues { - return child + return child, nil } } - return agg + return agg, nil } func (a *aggregationPushDownSolver) getDefaultValues(agg *LogicalAggregation) ([]types.Datum, bool) { @@ -243,7 +247,7 @@ func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation. return false } -func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) *LogicalAggregation { +func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) (*LogicalAggregation, error) { agg := LogicalAggregation{ GroupByItems: expression.Column2Exprs(gbyCols), groupByCols: gbyCols, @@ -257,7 +261,10 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs newAggFuncDescs = append(newAggFuncDescs, newFuncs...) } for _, gbyCol := range gbyCols { - firstRow := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{gbyCol}, false) + firstRow, err := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{gbyCol}, false) + if err != nil { + return nil, err + } newCol, _ := gbyCol.Clone().(*expression.Column) newCol.RetType = firstRow.RetTp newAggFuncDescs = append(newAggFuncDescs, firstRow) @@ -267,7 +274,7 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs agg.SetSchema(schema) // TODO: Add a Projection if any argument of aggregate funcs or group by items are scalar functions. // agg.buildProjectionIfNecessary() - return agg + return agg, nil } // pushAggCrossUnion will try to push the agg down to the union. If the new aggregation's group-by columns doesn't contain unique key. @@ -308,16 +315,15 @@ func (a *aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, u return newAgg } -func (a *aggregationPushDownSolver) optimize(p LogicalPlan) (LogicalPlan, error) { +func (a *aggregationPushDownSolver) optimize(ctx context.Context, p LogicalPlan) (LogicalPlan, error) { if !p.context().GetSessionVars().AllowAggPushDown { return p, nil } - a.aggPushDown(p) - return p, nil + return a.aggPushDown(p) } // aggPushDown tries to push down aggregate functions to join paths. -func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) LogicalPlan { +func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, err error) { if agg, ok := p.(*LogicalAggregation); ok { proj := a.tryToEliminateAggregation(agg) if proj != nil { @@ -334,12 +340,18 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) LogicalPlan { if rightInvalid { rChild = join.children[1] } else { - rChild = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) + rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) + if err != nil { + return nil, err + } } if leftInvalid { lChild = join.children[0] } else { - lChild = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) + lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) + if err != nil { + return nil, err + } } join.SetChildren(lChild, rChild) join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema())) @@ -368,7 +380,10 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) LogicalPlan { } else if union, ok1 := child.(*LogicalUnionAll); ok1 { var gbyCols []*expression.Column gbyCols = expression.ExtractColumnsFromExpressions(gbyCols, agg.GroupByItems, nil) - pushedAgg := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols) + pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols) + if err != nil { + return nil, err + } newChildren := make([]LogicalPlan, 0, len(union.children)) for _, child := range union.children { newChild := a.pushAggCrossUnion(pushedAgg, union.Schema(), child) @@ -381,9 +396,16 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) LogicalPlan { } newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { - newChild := a.aggPushDown(child) + newChild, err := a.aggPushDown(child) + if err != nil { + return nil, err + } newChildren = append(newChildren, newChild) } p.SetChildren(newChildren...) - return p + return p, nil +} + +func (*aggregationPushDownSolver) name() string { + return "aggregation_push_down" } diff --git a/planner/core/rule_build_key_info.go b/planner/core/rule_build_key_info.go index 92c2e67a99bf0..fefba82f10e9b 100644 --- a/planner/core/rule_build_key_info.go +++ b/planner/core/rule_build_key_info.go @@ -14,6 +14,8 @@ package core import ( + "context" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" @@ -21,7 +23,7 @@ import ( type buildKeySolver struct{} -func (s *buildKeySolver) optimize(lp LogicalPlan) (LogicalPlan, error) { +func (s *buildKeySolver) optimize(ctx context.Context, lp LogicalPlan) (LogicalPlan, error) { lp.buildKeyInfo() return lp, nil } @@ -218,3 +220,7 @@ func (ds *DataSource) buildKeyInfo() { } } } + +func (*buildKeySolver) name() string { + return "build_keys" +} diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 1a78a8ecb4d90..35bdc1931712e 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -14,19 +14,23 @@ package core import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/types" ) type columnPruner struct { } -func (s *columnPruner) optimize(lp LogicalPlan) (LogicalPlan, error) { +func (s *columnPruner) optimize(ctx context.Context, lp LogicalPlan) (LogicalPlan, error) { err := lp.PruneColumns(lp.Schema().Columns) return lp, err } @@ -49,17 +53,17 @@ func getUsedList(usedCols []*expression.Column, schema *expression.Schema) ([]bo return used, nil } -// exprHasSetVar checks if the expression has SetVar function. -func exprHasSetVar(expr expression.Expression) bool { +// exprHasSetVarOrSleep checks if the expression has SetVar function or Sleep function. +func exprHasSetVarOrSleep(expr expression.Expression) bool { scalaFunc, isScalaFunc := expr.(*expression.ScalarFunction) if !isScalaFunc { return false } - if scalaFunc.FuncName.L == ast.SetVar { + if scalaFunc.FuncName.L == ast.SetVar || scalaFunc.FuncName.L == ast.Sleep { return true } for _, arg := range scalaFunc.GetArgs() { - if exprHasSetVar(arg) { + if exprHasSetVarOrSleep(arg) { return true } } @@ -67,7 +71,7 @@ func exprHasSetVar(expr expression.Expression) bool { } // PruneColumns implements LogicalPlan interface. -// If any expression has SetVar functions, we do not prune it. +// If any expression has SetVar function or Sleep function, we do not prune it. func (p *LogicalProjection) PruneColumns(parentUsedCols []*expression.Column) error { child := p.children[0] used, err := getUsedList(parentUsedCols, p.schema) @@ -76,7 +80,7 @@ func (p *LogicalProjection) PruneColumns(parentUsedCols []*expression.Column) er } for i := len(used) - 1; i >= 0; i-- { - if !used[i] && !exprHasSetVar(p.Exprs[i]) { + if !used[i] && !exprHasSetVarOrSleep(p.Exprs[i]) { p.schema.Columns = append(p.schema.Columns[:i], p.schema.Columns[i+1:]...) p.Exprs = append(p.Exprs[:i], p.Exprs[i+1:]...) } @@ -117,6 +121,21 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) for _, aggrFunc := range la.AggFuncs { selfUsedCols = expression.ExtractColumnsFromExpressions(selfUsedCols, aggrFunc.Args, nil) } + if len(la.AggFuncs) == 0 { + // If all the aggregate functions are pruned, we should add an aggregate function to keep the correctness. + one, err := aggregation.NewAggFuncDesc(la.ctx, ast.AggFuncFirstRow, []expression.Expression{expression.One}, false) + if err != nil { + return err + } + la.AggFuncs = []*aggregation.AggFuncDesc{one} + col := &expression.Column{ + ColName: model.NewCIStr("dummy_agg"), + UniqueID: la.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: types.NewFieldType(mysql.TypeLonglong), + } + la.schema.Columns = []*expression.Column{col} + } + if len(la.GroupByItems) > 0 { for i := len(la.GroupByItems) - 1; i >= 0; i-- { cols := expression.ExtractColumns(la.GroupByItems[i]) @@ -164,6 +183,9 @@ func (p *LogicalUnionAll) PruneColumns(parentUsedCols []*expression.Column) erro hasBeenUsed := false for i := range used { hasBeenUsed = hasBeenUsed || used[i] + if hasBeenUsed { + break + } } if !hasBeenUsed { parentUsedCols = make([]*expression.Column, len(p.schema.Columns)) @@ -245,8 +267,15 @@ func (p *LogicalTableDual) PruneColumns(parentUsedCols []*expression.Column) err } } for k, cols := range p.schema.TblID2Handle { - if p.schema.ColumnIndex(cols[0]) == -1 { + for i := len(cols) - 1; i >= 0; i-- { + if p.schema.ColumnIndex(cols[i]) == -1 { + cols = append(cols[:i], cols[i+1:]...) + } + } + if len(cols) == 0 { delete(p.schema.TblID2Handle, k) + } else { + p.schema.TblID2Handle[k] = cols } } return nil @@ -388,3 +417,7 @@ func (p *LogicalWindow) extractUsedCols(parentUsedCols []*expression.Column) []* } return parentUsedCols } + +func (*columnPruner) name() string { + return "column_prune" +} diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index fb9e1c54fe43b..2d3eeb6650fb8 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -14,6 +14,7 @@ package core import ( + "context" "math" "github.com/pingcap/parser/ast" @@ -97,7 +98,7 @@ func (s *decorrelateSolver) aggDefaultValueMap(agg *LogicalAggregation) map[int] } // optimize implements logicalOptRule interface. -func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { +func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan) (LogicalPlan, error) { if apply, ok := p.(*LogicalApply); ok { outerPlan := apply.children[0] innerPlan := apply.children[1] @@ -117,12 +118,12 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { apply.attachOnConds(newConds) innerPlan = sel.children[0] apply.SetChildren(outerPlan, innerPlan) - return s.optimize(p) + return s.optimize(ctx, p) } else if m, ok := innerPlan.(*LogicalMaxOneRow); ok { if m.children[0].MaxOneRow() { innerPlan = m.children[0] apply.SetChildren(outerPlan, innerPlan) - return s.optimize(p) + return s.optimize(ctx, p) } } else if proj, ok := innerPlan.(*LogicalProjection); ok { for i, expr := range proj.Exprs { @@ -135,14 +136,14 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { proj.SetSchema(apply.Schema()) proj.Exprs = append(expression.Column2Exprs(outerPlan.Schema().Clone().Columns), proj.Exprs...) apply.SetSchema(expression.MergeSchema(outerPlan.Schema(), innerPlan.Schema())) - np, err := s.optimize(p) + np, err := s.optimize(ctx, p) if err != nil { return nil, err } proj.SetChildren(np) return proj, nil } - return s.optimize(p) + return s.optimize(ctx, p) } else if agg, ok := innerPlan.(*LogicalAggregation); ok { if apply.canPullUpAgg() && agg.canPullUp() { innerPlan = agg.children[0] @@ -154,7 +155,10 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { outerColsInSchema := make([]*expression.Column, 0, outerPlan.Schema().Len()) for i, col := range outerPlan.Schema().Columns { - first := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + first, err := aggregation.NewAggFuncDesc(agg.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, err + } newAggFuncs = append(newAggFuncs, first) outerCol, _ := outerPlan.Schema().Columns[i].Clone().(*expression.Column) @@ -164,7 +168,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { newAggFuncs = append(newAggFuncs, agg.AggFuncs...) agg.AggFuncs = newAggFuncs apply.SetSchema(expression.MergeSchema(expression.NewSchema(outerColsInSchema...), innerPlan.Schema())) - np, err := s.optimize(p) + np, err := s.optimize(ctx, p) if err != nil { return nil, err } @@ -201,7 +205,10 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { clonedCol := eqCond.GetArgs()[1] // If the join key is not in the aggregation's schema, add first row function. if agg.schema.ColumnIndex(eqCond.GetArgs()[1].(*expression.Column)) == -1 { - newFunc := aggregation.NewAggFuncDesc(apply.ctx, ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) + newFunc, err := aggregation.NewAggFuncDesc(apply.ctx, ast.AggFuncFirstRow, []expression.Expression{clonedCol}, false) + if err != nil { + return nil, err + } agg.AggFuncs = append(agg.AggFuncs, newFunc) agg.schema.Append(clonedCol.(*expression.Column)) agg.schema.Columns[agg.schema.Len()-1].RetType = newFunc.RetTp @@ -230,7 +237,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { proj.SetChildren(apply) p = proj } - return s.optimize(p) + return s.optimize(ctx, p) } sel.Conditions = originalExpr apply.corCols = extractCorColumnsBySchema(apply.children[1], apply.children[0].Schema()) @@ -240,7 +247,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { } newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { - np, err := s.optimize(child) + np, err := s.optimize(ctx, child) if err != nil { return nil, err } @@ -249,3 +256,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan) (LogicalPlan, error) { p.SetChildren(newChildren...) return p, nil } + +func (*decorrelateSolver) name() string { + return "decorrelate" +} diff --git a/planner/core/rule_eliminate_projection.go b/planner/core/rule_eliminate_projection.go index 5e72fdafe204b..5e0f07c758d10 100644 --- a/planner/core/rule_eliminate_projection.go +++ b/planner/core/rule_eliminate_projection.go @@ -14,6 +14,8 @@ package core import ( + "context" + "github.com/pingcap/tidb/expression" ) @@ -104,7 +106,7 @@ type projectionEliminater struct { } // optimize implements the logicalOptRule interface. -func (pe *projectionEliminater) optimize(lp LogicalPlan) (LogicalPlan, error) { +func (pe *projectionEliminater) optimize(ctx context.Context, lp LogicalPlan) (LogicalPlan, error) { root := pe.eliminate(lp, make(map[string]*expression.Column), false) return root, nil } @@ -220,3 +222,7 @@ func (p *LogicalWindow) replaceExprColumns(replace map[string]*expression.Column resolveColumnAndReplace(item.Col, replace) } } + +func (*projectionEliminater) name() string { + return "projection_eliminate" +} diff --git a/planner/core/rule_inject_extra_projection.go b/planner/core/rule_inject_extra_projection.go index 69d44294032c6..a2bcbcf28e3eb 100644 --- a/planner/core/rule_inject_extra_projection.go +++ b/planner/core/rule_inject_extra_projection.go @@ -99,9 +99,10 @@ func injectProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDes } projExprs = append(projExprs, arg) newArg := &expression.Column{ - RetType: arg.GetType(), - ColName: model.NewCIStr(fmt.Sprintf("col_%d", len(projSchemaCols))), - Index: cursor, + UniqueID: aggPlan.context().GetSessionVars().AllocPlanColumnID(), + RetType: arg.GetType(), + ColName: model.NewCIStr(fmt.Sprintf("col_%d", len(projSchemaCols))), + Index: cursor, } projSchemaCols = append(projSchemaCols, newArg) f.Args[i] = newArg @@ -157,7 +158,7 @@ func injectProjBelowSort(p PhysicalPlan, orderByItems []*ByItems) PhysicalPlan { topProjExprs := make([]expression.Expression, 0, p.Schema().Len()) for i := range p.Schema().Columns { - col := p.Schema().Columns[i] + col := p.Schema().Columns[i].Clone().(*expression.Column) col.Index = i topProjExprs = append(topProjExprs, col) } @@ -172,9 +173,10 @@ func injectProjBelowSort(p PhysicalPlan, orderByItems []*ByItems) PhysicalPlan { bottomProjSchemaCols := make([]*expression.Column, 0, len(childPlan.Schema().Columns)+numOrderByItems) bottomProjExprs := make([]expression.Expression, 0, len(childPlan.Schema().Columns)+numOrderByItems) for i, col := range childPlan.Schema().Columns { - col.Index = i - bottomProjSchemaCols = append(bottomProjSchemaCols, col) - bottomProjExprs = append(bottomProjExprs, col) + newCol := col.Clone().(*expression.Column) + newCol.Index = i + bottomProjSchemaCols = append(bottomProjSchemaCols, newCol) + bottomProjExprs = append(bottomProjExprs, newCol) } for _, item := range orderByItems { diff --git a/planner/core/rule_inject_extra_projection_test.go b/planner/core/rule_inject_extra_projection_test.go index 66e842837f86d..6b1f44e0e53e6 100644 --- a/planner/core/rule_inject_extra_projection_test.go +++ b/planner/core/rule_inject_extra_projection_test.go @@ -41,9 +41,10 @@ func (s *testInjectProjSuite) TestWrapCastForAggFuncs(c *C) { for _, mode := range modes { for _, retType := range retTypes { sctx := mock.NewContext() - aggFunc := aggregation.NewAggFuncDesc(sctx, name, + aggFunc, err := aggregation.NewAggFuncDesc(sctx, name, []expression.Expression{&expression.Constant{Value: types.Datum{}, RetType: types.NewFieldType(retType)}}, hasDistinct) + c.Assert(err, IsNil) aggFunc.Mode = mode aggFuncs = append(aggFuncs, aggFunc) } diff --git a/planner/core/rule_join_elimination.go b/planner/core/rule_join_elimination.go index 983da726cdd01..21807e9b101e6 100644 --- a/planner/core/rule_join_elimination.go +++ b/planner/core/rule_join_elimination.go @@ -14,8 +14,11 @@ package core import ( + "context" + "github.com/pingcap/parser/ast" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/util/set" ) type outerJoinEliminator struct { @@ -28,7 +31,7 @@ type outerJoinEliminator struct { // 2. outer join elimination with duplicate agnostic aggregate functions: For example left outer join. // If the parent only use the columns from left table with 'distinct' label. The left outer join can // be eliminated. -func (o *outerJoinEliminator) tryToEliminateOuterJoin(p *LogicalJoin, aggCols []*expression.Column, parentSchema *expression.Schema) (LogicalPlan, error) { +func (o *outerJoinEliminator) tryToEliminateOuterJoin(p *LogicalJoin, aggCols []*expression.Column, parentCols []*expression.Column) (LogicalPlan, bool, error) { var innerChildIdx int switch p.JoinType { case LeftOuterJoin: @@ -36,32 +39,42 @@ func (o *outerJoinEliminator) tryToEliminateOuterJoin(p *LogicalJoin, aggCols [] case RightOuterJoin: innerChildIdx = 0 default: - return p, nil + return p, false, nil } outerPlan := p.children[1^innerChildIdx] innerPlan := p.children[innerChildIdx] + outerUniqueIDs := set.NewInt64Set() + for _, outerCol := range outerPlan.Schema().Columns { + outerUniqueIDs.Insert(outerCol.UniqueID) + } + matched := o.isColsAllFromOuterTable(parentCols, outerUniqueIDs) + if !matched { + return p, false, nil + } // outer join elimination with duplicate agnostic aggregate functions - matched, err := o.isAggColsAllFromOuterTable(outerPlan, aggCols) - if err != nil || matched { - return outerPlan, err + matched = o.isColsAllFromOuterTable(aggCols, outerUniqueIDs) + if matched { + return outerPlan, true, nil } // outer join elimination without duplicate agnostic aggregate functions - matched, err = o.isParentColsAllFromOuterTable(outerPlan, parentSchema) - if err != nil || !matched { - return p, err - } innerJoinKeys := o.extractInnerJoinKeys(p, innerChildIdx) contain, err := o.isInnerJoinKeysContainUniqueKey(innerPlan, innerJoinKeys) - if err != nil || contain { - return outerPlan, err + if err != nil { + return p, false, err + } + if contain { + return outerPlan, true, nil } contain, err = o.isInnerJoinKeysContainIndex(innerPlan, innerJoinKeys) - if err != nil || contain { - return outerPlan, err + if err != nil { + return p, false, err + } + if contain { + return outerPlan, true, nil } - return p, nil + return p, false, nil } // extract join keys as a schema for inner child of a outer join @@ -73,33 +86,20 @@ func (o *outerJoinEliminator) extractInnerJoinKeys(join *LogicalJoin, innerChild return expression.NewSchema(joinKeys...) } -func (o *outerJoinEliminator) isAggColsAllFromOuterTable(outerPlan LogicalPlan, aggCols []*expression.Column) (bool, error) { - if len(aggCols) == 0 { - return false, nil - } - for _, col := range aggCols { - columnName := &ast.ColumnName{Schema: col.DBName, Table: col.TblName, Name: col.ColName} - c, err := outerPlan.Schema().FindColumn(columnName) - if err != nil || c == nil { - return false, err - } - } - return true, nil -} - -// check whether schema cols of join's parent plan are all from outer join table -func (o *outerJoinEliminator) isParentColsAllFromOuterTable(outerPlan LogicalPlan, parentSchema *expression.Schema) (bool, error) { - if parentSchema == nil { - return false, nil - } - for _, col := range parentSchema.Columns { - columnName := &ast.ColumnName{Schema: col.DBName, Table: col.TblName, Name: col.ColName} - c, err := outerPlan.Schema().FindColumn(columnName) - if err != nil || c == nil { - return false, err +// check whether the cols all from outer plan +func (o *outerJoinEliminator) isColsAllFromOuterTable(cols []*expression.Column, outerUniqueIDs set.Int64Set) bool { + // There are two cases "return false" here: + // 1. If cols represents aggCols, then "len(cols) == 0" means not all aggregate functions are duplicate agnostic before. + // 2. If cols represents parentCols, then "len(cols) == 0" means no parent logical plan of this join plan. + if len(cols) == 0 { + return false + } + for _, col := range cols { + if !outerUniqueIDs.Exist(col.UniqueID) { + return false } } - return true, nil + return true } // check whether one of unique keys sets is contained by inner join keys @@ -157,54 +157,87 @@ func (o *outerJoinEliminator) isInnerJoinKeysContainIndex(innerPlan LogicalPlan, return false, nil } -// Check whether a LogicalPlan is a LogicalAggregation and its all aggregate functions is duplicate agnostic. -// Also, check all the args are expression.Column. -func (o *outerJoinEliminator) isDuplicateAgnosticAgg(p LogicalPlan) (_ bool, cols []*expression.Column) { +// getDupAgnosticAggCols checks whether a LogicalPlan is LogicalAggregation. +// It extracts all the columns from the duplicate agnostic aggregate functions. +// The returned column set is nil if not all the aggregate functions are duplicate agnostic. +// Only the following functions are considered to be duplicate agnostic: +// 1. MAX(arg) +// 2. MIN(arg) +// 3. FIRST_ROW(arg) +// 4. Other agg functions with DISTINCT flag, like SUM(DISTINCT arg) +func (o *outerJoinEliminator) getDupAgnosticAggCols( + p LogicalPlan, + oldAggCols []*expression.Column, // Reuse the original buffer. +) (isAgg bool, newAggCols []*expression.Column) { agg, ok := p.(*LogicalAggregation) if !ok { return false, nil } - cols = agg.groupByCols + newAggCols = oldAggCols[:0] for _, aggDesc := range agg.AggFuncs { if !aggDesc.HasDistinct && aggDesc.Name != ast.AggFuncFirstRow && aggDesc.Name != ast.AggFuncMax && aggDesc.Name != ast.AggFuncMin { - return false, nil + // If not all aggregate functions are duplicate agnostic, + // we should clean the aggCols, so `return true, newAggCols[:0]`. + return true, newAggCols[:0] } for _, expr := range aggDesc.Args { - if col, ok := expr.(*expression.Column); ok { - cols = append(cols, col) - } else { - return false, nil - } + newAggCols = append(newAggCols, expression.ExtractColumns(expr)...) } } - return true, cols + return true, newAggCols } -func (o *outerJoinEliminator) doOptimize(p LogicalPlan, aggCols []*expression.Column, parentSchema *expression.Schema) (LogicalPlan, error) { - // check the duplicate agnostic aggregate functions - if ok, newCols := o.isDuplicateAgnosticAgg(p); ok { +func (o *outerJoinEliminator) doOptimize(p LogicalPlan, aggCols []*expression.Column, parentCols []*expression.Column) (LogicalPlan, error) { + var err error + var isEliminated bool + for join, isJoin := p.(*LogicalJoin); isJoin; join, isJoin = p.(*LogicalJoin) { + p, isEliminated, err = o.tryToEliminateOuterJoin(join, aggCols, parentCols) + if err != nil { + return p, err + } + if !isEliminated { + break + } + } + + switch x := p.(type) { + case *LogicalProjection: + parentCols = parentCols[:0] + for _, expr := range x.Exprs { + parentCols = append(parentCols, expression.ExtractColumns(expr)...) + } + case *LogicalAggregation: + parentCols = append(parentCols[:0], x.groupByCols...) + for _, aggDesc := range x.AggFuncs { + for _, expr := range aggDesc.Args { + parentCols = append(parentCols, expression.ExtractColumns(expr)...) + } + } + default: + parentCols = append(parentCols[:0], p.Schema().Columns...) + } + + if ok, newCols := o.getDupAgnosticAggCols(p, aggCols); ok { aggCols = newCols } - newChildren := make([]LogicalPlan, 0, len(p.Children())) - for _, child := range p.Children() { - newChild, err := o.doOptimize(child, aggCols, p.Schema()) + for i, child := range p.Children() { + newChild, err := o.doOptimize(child, aggCols, parentCols) if err != nil { return nil, err } - newChildren = append(newChildren, newChild) - } - p.SetChildren(newChildren...) - join, isJoin := p.(*LogicalJoin) - if !isJoin { - return p, nil + p.SetChild(i, newChild) } - return o.tryToEliminateOuterJoin(join, aggCols, parentSchema) + return p, nil } -func (o *outerJoinEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { +func (o *outerJoinEliminator) optimize(ctx context.Context, p LogicalPlan) (LogicalPlan, error) { return o.doOptimize(p, nil, nil) } + +func (*outerJoinEliminator) name() string { + return "outer_join_eliminate" +} diff --git a/planner/core/rule_join_reorder.go b/planner/core/rule_join_reorder.go index fac63d725cbb5..562ab9bf5a438 100644 --- a/planner/core/rule_join_reorder.go +++ b/planner/core/rule_join_reorder.go @@ -14,6 +14,8 @@ package core import ( + "context" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" ) @@ -52,7 +54,7 @@ type jrNode struct { cumCost float64 } -func (s *joinReOrderSolver) optimize(p LogicalPlan) (LogicalPlan, error) { +func (s *joinReOrderSolver) optimize(ctx context.Context, p LogicalPlan) (LogicalPlan, error) { return s.optimizeRecursive(p.context(), p) } @@ -166,3 +168,7 @@ func (s *baseSingleGroupJoinOrderSolver) newJoinWithEdges(lChild, rChild Logical func (s *baseSingleGroupJoinOrderSolver) calcJoinCumCost(join LogicalPlan, lNode, rNode *jrNode) float64 { return join.statsInfo().RowCount + lNode.cumCost + rNode.cumCost } + +func (*joinReOrderSolver) name() string { + return "join_reorder" +} diff --git a/planner/core/rule_max_min_eliminate.go b/planner/core/rule_max_min_eliminate.go index cd600872338e6..fddedf0440b70 100644 --- a/planner/core/rule_max_min_eliminate.go +++ b/planner/core/rule_max_min_eliminate.go @@ -13,6 +13,8 @@ package core import ( + "context" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" @@ -25,7 +27,7 @@ import ( type maxMinEliminator struct { } -func (a *maxMinEliminator) optimize(p LogicalPlan) (LogicalPlan, error) { +func (a *maxMinEliminator) optimize(ctx context.Context, p LogicalPlan) (LogicalPlan, error) { a.eliminateMaxMin(p) return p, nil } @@ -82,3 +84,7 @@ func (a *maxMinEliminator) eliminateMaxMin(p LogicalPlan) { a.eliminateMaxMin(child) } } + +func (*maxMinEliminator) name() string { + return "max_min_eliminate" +} diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index b3a6200151ce8..af1105c263cb7 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -13,6 +13,8 @@ package core import ( + "context" + "github.com/pingcap/errors" "github.com/pingcap/parser/model" "github.com/pingcap/tidb/expression" @@ -20,6 +22,7 @@ import ( "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/util/plancodec" "github.com/pingcap/tidb/util/ranger" ) @@ -39,7 +42,7 @@ import ( // partitionProcessor is here because it's easier to prune partition after predicate push down. type partitionProcessor struct{} -func (s *partitionProcessor) optimize(lp LogicalPlan) (LogicalPlan, error) { +func (s *partitionProcessor) optimize(ctx context.Context, lp LogicalPlan) (LogicalPlan, error) { return s.rewriteDataSource(lp) } @@ -60,9 +63,10 @@ func (s *partitionProcessor) rewriteDataSource(lp LogicalPlan) (LogicalPlan, err // Union->(UnionScan->DataSource1), (UnionScan->DataSource2) children := make([]LogicalPlan, 0, len(ua.Children())) for _, child := range ua.Children() { - us := LogicalUnionScan{}.Init(ua.ctx) - us.SetChildren(child) - children = append(children, us) + usChild := LogicalUnionScan{}.Init(ua.ctx) + usChild.conditions = us.conditions + usChild.SetChildren(child) + children = append(children, usChild) } ua.SetChildren(children...) return ua, nil @@ -127,7 +131,7 @@ func (s *partitionProcessor) prune(ds *DataSource) (LogicalPlan, error) { // Not a deep copy. newDataSource := *ds - newDataSource.baseLogicalPlan = newBaseLogicalPlan(ds.context(), TypeTableScan, &newDataSource) + newDataSource.baseLogicalPlan = newBaseLogicalPlan(ds.context(), plancodec.TypeTableScan, &newDataSource) newDataSource.isPartition = true newDataSource.physicalTableID = pi.Definitions[i].ID // There are many expression nodes in the plan tree use the original datasource @@ -203,3 +207,7 @@ func (s *partitionProcessor) findByName(partitionNames []model.CIStr, partitionN } return false } + +func (*partitionProcessor) name() string { + return "partition_processor" +} diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index ccf1a0855ccd1..fc4ab9510c5e9 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -13,6 +13,8 @@ package core import ( + "context" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" @@ -23,7 +25,7 @@ import ( type ppdSolver struct{} -func (s *ppdSolver) optimize(lp LogicalPlan) (LogicalPlan, error) { +func (s *ppdSolver) optimize(ctx context.Context, lp LogicalPlan) (LogicalPlan, error) { _, p := lp.PredicatePushDown(nil) return p, nil } @@ -147,7 +149,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret p.LeftConditions = nil ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, leftPushCond...) - case SemiJoin, AntiSemiJoin, InnerJoin: + case SemiJoin, InnerJoin: tempCond := make([]expression.Expression, 0, len(p.LeftConditions)+len(p.RightConditions)+len(p.EqualConditions)+len(p.OtherConditions)+len(predicates)) tempCond = append(tempCond, p.LeftConditions...) tempCond = append(tempCond, p.RightConditions...) @@ -156,13 +158,10 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret tempCond = append(tempCond, predicates...) tempCond = expression.ExtractFiltersFromDNFs(p.ctx, tempCond) tempCond = expression.PropagateConstant(p.ctx, tempCond) - // Return table dual when filter is constant false or null. Not applicable to AntiSemiJoin. - // TODO: For AntiSemiJoin, we can use outer plan to substitute LogicalJoin actually. - if p.JoinType != AntiSemiJoin { - dual := conds2TableDual(p, tempCond) - if dual != nil { - return ret, dual - } + // Return table dual when filter is constant false or null. + dual := conds2TableDual(p, tempCond) + if dual != nil { + return ret, dual } equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(tempCond, true, true) p.LeftConditions = nil @@ -171,6 +170,24 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret p.OtherConditions = otherCond leftCond = leftPushCond rightCond = rightPushCond + case AntiSemiJoin: + predicates = expression.PropagateConstant(p.ctx, predicates) + // Return table dual when filter is constant false or null. + dual := conds2TableDual(p, predicates) + if dual != nil { + return ret, dual + } + // `predicates` should only contain left conditions or constant filters. + _, leftPushCond, rightPushCond, _ = p.extractOnCondition(predicates, true, true) + // Do not derive `is not null` for anti join, since it may cause wrong results. + // For example: + // `select * from t t1 where t1.a not in (select b from t t2)` does not imply `t2.b is not null`, + // `select * from t t1 where t1.a not in (select a from t t2 where t1.b = t2.b` does not imply `t1.b is not null`, + // `select * from t t1 where not exists (select * from t t2 where t2.a = t1.a)` does not imply `t1.a is not null`, + leftCond = leftPushCond + rightCond = append(p.RightConditions, rightPushCond...) + p.RightConditions = nil + } leftCond = expression.RemoveDupExprs(p.ctx, leftCond) rightCond = expression.RemoveDupExprs(p.ctx, rightCond) @@ -518,3 +535,7 @@ func (p *LogicalWindow) PredicatePushDown(predicates []expression.Expression) ([ p.baseLogicalPlan.PredicatePushDown(nil) return predicates, p } + +func (*ppdSolver) name() string { + return "predicate_push_down" +} diff --git a/planner/core/rule_topn_push_down.go b/planner/core/rule_topn_push_down.go index ddebf02121632..cfa3cc73c6319 100644 --- a/planner/core/rule_topn_push_down.go +++ b/planner/core/rule_topn_push_down.go @@ -14,6 +14,8 @@ package core import ( + "context" + "github.com/cznic/mathutil" "github.com/pingcap/tidb/expression" ) @@ -22,7 +24,7 @@ import ( type pushDownTopNOptimizer struct { } -func (s *pushDownTopNOptimizer) optimize(p LogicalPlan) (LogicalPlan, error) { +func (s *pushDownTopNOptimizer) optimize(ctx context.Context, p LogicalPlan) (LogicalPlan, error) { return p.pushDownTopN(nil), nil } @@ -165,3 +167,7 @@ func (p *LogicalJoin) pushDownTopN(topN *LogicalTopN) LogicalPlan { } return p.self } + +func (*pushDownTopNOptimizer) name() string { + return "topn_push_down" +} diff --git a/planner/core/stats.go b/planner/core/stats.go index 38b73f3981d2d..c08e6d4262fe9 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -90,31 +90,29 @@ func (ds *DataSource) getColumnNDV(colID int64) (ndv float64) { return ndv } -func (ds *DataSource) getStatsByFilter(conds expression.CNFExprs) (*property.StatsInfo, *statistics.HistColl) { - profile := &property.StatsInfo{ +func (ds *DataSource) deriveStatsByFilter(conds expression.CNFExprs) { + tableStats := &property.StatsInfo{ RowCount: float64(ds.statisticTable.Count), Cardinality: make([]float64, len(ds.Columns)), HistColl: ds.statisticTable.GenerateHistCollFromColumnInfo(ds.Columns, ds.schema.Columns), StatsVersion: ds.statisticTable.Version, } if ds.statisticTable.Pseudo { - profile.StatsVersion = statistics.PseudoVersion + tableStats.StatsVersion = statistics.PseudoVersion } - for i, col := range ds.Columns { - profile.Cardinality[i] = ds.getColumnNDV(col.ID) + tableStats.Cardinality[i] = ds.getColumnNDV(col.ID) } - ds.stats = profile - selectivity, nodes, err := profile.HistColl.Selectivity(ds.ctx, conds) + ds.tableStats = tableStats + selectivity, nodes, err := tableStats.HistColl.Selectivity(ds.ctx, conds) if err != nil { - logutil.Logger(context.Background()).Warn("an error happened, use the default selectivity", zap.Error(err)) + logutil.Logger(context.Background()).Debug("an error happened, use the default selectivity", zap.Error(err)) selectivity = selectionFactor } - if ds.ctx.GetSessionVars().OptimizerSelectivityLevel >= 1 && ds.stats.HistColl != nil { - finalHist := ds.stats.HistColl.NewHistCollBySelectivity(ds.ctx.GetSessionVars().StmtCtx, nodes) - return profile, finalHist + ds.stats = tableStats.Scale(selectivity) + if ds.ctx.GetSessionVars().OptimizerSelectivityLevel >= 1 { + ds.stats.HistColl = ds.stats.HistColl.NewHistCollBySelectivity(ds.ctx.GetSessionVars().StmtCtx, nodes) } - return profile.Scale(selectivity), nil } // DeriveStats implement LogicalPlan DeriveStats interface. @@ -123,8 +121,7 @@ func (ds *DataSource) DeriveStats(childStats []*property.StatsInfo) (*property.S for i, expr := range ds.pushedDownConds { ds.pushedDownConds[i] = expression.PushDownNot(nil, expr, false) } - var finalHist *statistics.HistColl - ds.stats, finalHist = ds.getStatsByFilter(ds.pushedDownConds) + ds.deriveStatsByFilter(ds.pushedDownConds) for _, path := range ds.possibleAccessPaths { if path.isTablePath { noIntervalRanges, err := ds.deriveTablePathStats(path) @@ -150,9 +147,6 @@ func (ds *DataSource) DeriveStats(childStats []*property.StatsInfo) (*property.S break } } - if ds.ctx.GetSessionVars().OptimizerSelectivityLevel >= 1 { - ds.stats.HistColl = finalHist - } return ds.stats, nil } diff --git a/planner/core/stringer.go b/planner/core/stringer.go index beba86d8cd527..f4178509bc742 100644 --- a/planner/core/stringer.go +++ b/planner/core/stringer.go @@ -114,11 +114,7 @@ func toString(in Plan, strs []string, idxs []int) ([]string, []int) { case *ShowDDL: str = "ShowDDL" case *Show: - if len(x.Conditions) == 0 { - str = "Show" - } else { - str = fmt.Sprintf("Show(%s)", x.Conditions) - } + str = "Show" case *LogicalSort, *PhysicalSort: str = "Sort" case *LogicalJoin: diff --git a/planner/core/task.go b/planner/core/task.go index a1dabd6bc4ef3..36e23245b67ab 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/plancodec" ) // task is a new version of `PhysicalPlanInfo`. It stores cost information for a task. @@ -264,6 +265,7 @@ func (t *rootTask) plan() PhysicalPlan { func (p *PhysicalLimit) attach2Task(tasks ...task) task { t := tasks[0].copy() + sunk := false if cop, ok := t.(*copTask); ok { // If the table/index scans data by order and applies a double read, the limit cannot be pushed to the table side. if !cop.keepOrder || !cop.indexPlanFinished || cop.indexPlan == nil { @@ -272,9 +274,42 @@ func (p *PhysicalLimit) attach2Task(tasks ...task) task { cop = attachPlan2Task(pushedDownLimit, cop).(*copTask) } t = finishCopTask(p.ctx, cop) + sunk = p.sinkIntoIndexLookUp(t) } - t = attachPlan2Task(p, t) - return t + if sunk { + return t + } + return attachPlan2Task(p, t) +} + +func (p *PhysicalLimit) sinkIntoIndexLookUp(t task) bool { + root := t.(*rootTask) + reader, isDoubleRead := root.p.(*PhysicalIndexLookUpReader) + proj, isProj := root.p.(*PhysicalProjection) + if !isDoubleRead && !isProj { + return false + } + if isProj { + reader, isDoubleRead = proj.Children()[0].(*PhysicalIndexLookUpReader) + if !isDoubleRead { + return false + } + } + // We can sink Limit into IndexLookUpReader only if tablePlan contains no Selection. + ts, isTableScan := reader.tablePlan.(*PhysicalTableScan) + if !isTableScan { + return false + } + reader.PushedLimit = &PushedDownLimit{ + Offset: p.Offset, + Count: p.Count, + } + ts.stats = p.stats + reader.stats = p.stats + if isProj { + proj.stats = p.stats + } + return true } // GetCost computes the cost of in memory sort. @@ -454,7 +489,7 @@ func (p *basePhysicalAgg) newPartialAggregate() (partial, final PhysicalPlan) { } // Create physical "final" aggregation. - if p.tp == TypeStreamAgg { + if p.tp == plancodec.TypeStreamAgg { finalAgg := basePhysicalAgg{ AggFuncs: finalAggFuncs, GroupByItems: groupByItems, @@ -474,18 +509,27 @@ func (p *basePhysicalAgg) newPartialAggregate() (partial, final PhysicalPlan) { func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task { t := tasks[0].copy() if cop, ok := t.(*copTask); ok { - partialAgg, finalAgg := p.newPartialAggregate() - if partialAgg != nil { - if cop.tablePlan != nil { - partialAgg.SetChildren(cop.tablePlan) - cop.tablePlan = partialAgg - } else { - partialAgg.SetChildren(cop.indexPlan) - cop.indexPlan = partialAgg + // We should not push agg down across double read, since the data of second read is ordered by handle instead of index. + // The `doubleReadNeedProj` is always set if the double read needs to keep order. So we just use it to decided + // whether the following plan is double read with order reserved. + if !cop.doubleReadNeedProj { + partialAgg, finalAgg := p.newPartialAggregate() + if partialAgg != nil { + if cop.tablePlan != nil { + cop.finishIndexPlan() + partialAgg.SetChildren(cop.tablePlan) + cop.tablePlan = partialAgg + } else { + partialAgg.SetChildren(cop.indexPlan) + cop.indexPlan = partialAgg + } } + t = finishCopTask(p.ctx, cop) + attachPlan2Task(finalAgg, t) + } else { + t = finishCopTask(p.ctx, cop) + attachPlan2Task(p, t) } - t = finishCopTask(p.ctx, cop) - attachPlan2Task(finalAgg, t) } else { attachPlan2Task(p, t) } diff --git a/planner/core/testdata/analyze_suite_in.json b/planner/core/testdata/analyze_suite_in.json new file mode 100644 index 0000000000000..ccadb717d2b46 --- /dev/null +++ b/planner/core/testdata/analyze_suite_in.json @@ -0,0 +1,63 @@ +[ + { + "name": "TestLimitCrossEstimation", + "cases": [ + // Pseudo stats. + [ + "set session tidb_opt_correlation_exp_factor = 0", + "EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1;" + ], + // Positive correlation. + [ + "insert into t (a, b) values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 1),(8, 1),(9, 1),(10, 1),(11, 1),(12, 1),(13, 1),(14, 1),(15, 1),(16, 1),(17, 1),(18, 1),(19, 1),(20, 2),(21, 2),(22, 2),(23, 2),(24, 2),(25, 2)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1" + ], + // Negative correlation. + [ + "truncate table t", + "insert into t (a, b) values (1, 25),(2, 24),(3, 23),(4, 23),(5, 21),(6, 20),(7, 19),(8, 18),(9, 17),(10, 16),(11, 15),(12, 14),(13, 13),(14, 12),(15, 11),(16, 10),(17, 9),(18, 8),(19, 7),(20, 6),(21, 5),(22, 4),(23, 3),(24, 2),(25, 1)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b <= 6 ORDER BY a limit 1" + ], + // Outer plan of index join (to test that correct column ID is used). + [ + "EXPLAIN SELECT *, t1.a IN (SELECT t2.b FROM t t2) FROM t t1 WHERE t1.b <= 6 ORDER BY t1.a limit 1" + ], + // Desc TableScan. + [ + "truncate table t", + "insert into t (a, b) values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 2),(8, 2),(9, 2),(10, 2),(11, 2),(12, 2),(13, 2),(14, 2),(15, 2),(16, 2),(17, 2),(18, 2),(19, 2),(20, 2),(21, 2),(22, 2),(23, 2),(24, 2),(25, 2)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b = 1 ORDER BY a desc limit 1" + ], + // Correlation threshold not met. + [ + "truncate table t", + "insert into t (a, b) values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 1),(8, 1),(9, 2),(10, 1),(11, 1),(12, 1),(13, 1),(14, 2),(15, 2),(16, 1),(17, 2),(18, 1),(19, 2),(20, 1),(21, 2),(22, 1),(23, 1),(24, 1),(25, 1)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1" + ], + [ + "set session tidb_opt_correlation_exp_factor = 1", + "EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1" + ], + // TableScan has access conditions, but correlation is 1. + [ + "set session tidb_opt_correlation_exp_factor = 0", + "truncate table t", + "insert into t (a, b) values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 1),(8, 1),(9, 1),(10, 1),(11, 1),(12, 1),(13, 1),(14, 1),(15, 1),(16, 1),(17, 1),(18, 1),(19, 1),(20, 2),(21, 2),(22, 2),(23, 2),(24, 2),(25, 2)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b = 2 and a > 0 ORDER BY a limit 1" + ], + // Multi-column filter. + [ + "drop table t", + "create table t(a int primary key, b int, c int, d bigint default 2147483648, e bigint default 2147483648, f bigint default 2147483648, index idx(b,d,a,c))", + "insert into t(a, b, c) values (1, 1, 1),(2, 1, 2),(3, 1, 1),(4, 1, 2),(5, 1, 1),(6, 1, 2),(7, 1, 1),(8, 1, 2),(9, 1, 1),(10, 1, 2),(11, 1, 1),(12, 1, 2),(13, 1, 1),(14, 1, 2),(15, 1, 1),(16, 1, 2),(17, 1, 1),(18, 1, 2),(19, 1, 1),(20, 2, 2),(21, 2, 1),(22, 2, 2),(23, 2, 1),(24, 2, 2),(25, 2, 1)", + "analyze table t", + "EXPLAIN SELECT a FROM t WHERE b = 2 and c > 0 ORDER BY a limit 1" + ] + ] + } +] diff --git a/planner/core/testdata/analyze_suite_out.json b/planner/core/testdata/analyze_suite_out.json new file mode 100644 index 0000000000000..05e4b7800850c --- /dev/null +++ b/planner/core/testdata/analyze_suite_out.json @@ -0,0 +1,135 @@ +[ + { + "Name": "TestLimitCrossEstimation", + "Cases": [ + { + "SQL": [ + "set session tidb_opt_correlation_exp_factor = 0", + "EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1;" + ], + "Plan": [ + "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", + "└─IndexReader_16 1.00 root index:TopN_15", + " └─TopN_15 1.00 cop test.t.a:asc, offset:0, count:1", + " └─IndexScan_14 10.00 cop table:t, index:b, c, range:[2,2], keep order:false, stats:pseudo" + ] + }, + { + "SQL": [ + "insert into t (a, b) values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 1),(8, 1),(9, 1),(10, 1),(11, 1),(12, 1),(13, 1),(14, 1),(15, 1),(16, 1),(17, 1),(18, 1),(19, 1),(20, 2),(21, 2),(22, 2),(23, 2),(24, 2),(25, 2)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1" + ], + "Plan": [ + "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", + "└─IndexReader_16 1.00 root index:TopN_15", + " └─TopN_15 1.00 cop test.t.a:asc, offset:0, count:1", + " └─IndexScan_14 6.00 cop table:t, index:b, c, range:[2,2], keep order:false" + ] + }, + { + "SQL": [ + "truncate table t", + "insert into t (a, b) values (1, 25),(2, 24),(3, 23),(4, 23),(5, 21),(6, 20),(7, 19),(8, 18),(9, 17),(10, 16),(11, 15),(12, 14),(13, 13),(14, 12),(15, 11),(16, 10),(17, 9),(18, 8),(19, 7),(20, 6),(21, 5),(22, 4),(23, 3),(24, 2),(25, 1)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b <= 6 ORDER BY a limit 1" + ], + "Plan": [ + "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", + "└─IndexReader_16 1.00 root index:TopN_15", + " └─TopN_15 1.00 cop test.t.a:asc, offset:0, count:1", + " └─IndexScan_14 6.00 cop table:t, index:b, c, range:[-inf,6], keep order:false" + ] + }, + { + "SQL": [ + "EXPLAIN SELECT *, t1.a IN (SELECT t2.b FROM t t2) FROM t t1 WHERE t1.b <= 6 ORDER BY t1.a limit 1" + ], + "Plan": [ + "Limit_17 1.00 root offset:0, count:1", + "└─IndexJoin_58 1.00 root left outer semi join, inner:IndexReader_57, outer key:test.t1.a, inner key:test.t2.b", + " ├─TopN_23 1.00 root test.t1.a:asc, offset:0, count:1", + " │ └─IndexReader_31 1.00 root index:TopN_30", + " │ └─TopN_30 1.00 cop test.t1.a:asc, offset:0, count:1", + " │ └─IndexScan_29 6.00 cop table:t1, index:b, c, range:[-inf,6], keep order:false", + " └─IndexReader_57 1.04 root index:IndexScan_56", + " └─IndexScan_56 1.04 cop table:t2, index:b, c, range: decided by [eq(test.t2.b, test.t1.a)], keep order:false" + ] + }, + { + "SQL": [ + "truncate table t", + "insert into t (a, b) values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 2),(8, 2),(9, 2),(10, 2),(11, 2),(12, 2),(13, 2),(14, 2),(15, 2),(16, 2),(17, 2),(18, 2),(19, 2),(20, 2),(21, 2),(22, 2),(23, 2),(24, 2),(25, 2)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b = 1 ORDER BY a desc limit 1" + ], + "Plan": [ + "TopN_8 1.00 root test.t.a:desc, offset:0, count:1", + "└─IndexReader_16 1.00 root index:TopN_15", + " └─TopN_15 1.00 cop test.t.a:desc, offset:0, count:1", + " └─IndexScan_14 6.00 cop table:t, index:b, c, range:[1,1], keep order:false" + ] + }, + { + "SQL": [ + "truncate table t", + "insert into t (a, b) values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 1),(8, 1),(9, 2),(10, 1),(11, 1),(12, 1),(13, 1),(14, 2),(15, 2),(16, 1),(17, 2),(18, 1),(19, 2),(20, 1),(21, 2),(22, 1),(23, 1),(24, 1),(25, 1)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1" + ], + "Plan": [ + "Limit_11 1.00 root offset:0, count:1", + "└─TableReader_22 1.00 root data:Limit_21", + " └─Limit_21 1.00 cop offset:0, count:1", + " └─Selection_20 1.00 cop eq(test.t.b, 2)", + " └─TableScan_19 4.17 cop table:t, range:[-inf,+inf], keep order:true" + ] + }, + { + "SQL": [ + "set session tidb_opt_correlation_exp_factor = 1", + "EXPLAIN SELECT * FROM t WHERE b = 2 ORDER BY a limit 1" + ], + "Plan": [ + "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", + "└─IndexReader_16 1.00 root index:TopN_15", + " └─TopN_15 1.00 cop test.t.a:asc, offset:0, count:1", + " └─IndexScan_14 6.00 cop table:t, index:b, c, range:[2,2], keep order:false" + ] + }, + { + "SQL": [ + "set session tidb_opt_correlation_exp_factor = 0", + "truncate table t", + "insert into t (a, b) values (1, 1),(2, 1),(3, 1),(4, 1),(5, 1),(6, 1),(7, 1),(8, 1),(9, 1),(10, 1),(11, 1),(12, 1),(13, 1),(14, 1),(15, 1),(16, 1),(17, 1),(18, 1),(19, 1),(20, 2),(21, 2),(22, 2),(23, 2),(24, 2),(25, 2)", + "analyze table t", + "EXPLAIN SELECT * FROM t WHERE b = 2 and a > 0 ORDER BY a limit 1" + ], + "Plan": [ + "TopN_8 1.00 root test.t.a:asc, offset:0, count:1", + "└─IndexReader_19 1.00 root index:TopN_18", + " └─TopN_18 1.00 cop test.t.a:asc, offset:0, count:1", + " └─Selection_17 6.00 cop gt(test.t.a, 0)", + " └─IndexScan_16 6.00 cop table:t, index:b, c, range:[2,2], keep order:false" + ] + }, + { + "SQL": [ + "drop table t", + "create table t(a int primary key, b int, c int, d bigint default 2147483648, e bigint default 2147483648, f bigint default 2147483648, index idx(b,d,a,c))", + "insert into t(a, b, c) values (1, 1, 1),(2, 1, 2),(3, 1, 1),(4, 1, 2),(5, 1, 1),(6, 1, 2),(7, 1, 1),(8, 1, 2),(9, 1, 1),(10, 1, 2),(11, 1, 1),(12, 1, 2),(13, 1, 1),(14, 1, 2),(15, 1, 1),(16, 1, 2),(17, 1, 1),(18, 1, 2),(19, 1, 1),(20, 2, 2),(21, 2, 1),(22, 2, 2),(23, 2, 1),(24, 2, 2),(25, 2, 1)", + "analyze table t", + "EXPLAIN SELECT a FROM t WHERE b = 2 and c > 0 ORDER BY a limit 1" + ], + "Plan": [ + "Projection_7 1.00 root test.t.a", + "└─TopN_8 1.00 root test.t.a:asc, offset:0, count:1", + " └─IndexReader_17 1.00 root index:TopN_16", + " └─TopN_16 1.00 cop test.t.a:asc, offset:0, count:1", + " └─Selection_15 6.00 cop gt(test.t.c, 0)", + " └─IndexScan_14 6.00 cop table:t, index:b, d, a, c, range:[2,2], keep order:false" + ] + } + ] + } +] diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json new file mode 100644 index 0000000000000..7cac9b50c37d7 --- /dev/null +++ b/planner/core/testdata/integration_suite_in.json @@ -0,0 +1,29 @@ +[ + { + "name": "TestPushLimitDownIndexLookUpReader", + "cases": [ + // Limit should be pushed down into IndexLookUpReader, row count of IndexLookUpReader and TableScan should be 1.00. + "explain select * from tbl use index(idx_b_c) where b > 1 limit 2,1", + // Projection atop IndexLookUpReader, Limit should be pushed down into IndexLookUpReader, and Projection should have row count 1.00 as well. + "explain select * from tbl use index(idx_b_c) where b > 1 order by b desc limit 2,1", + // Limit should be pushed down into IndexLookUpReader when Selection on top of IndexScan. + "explain select * from tbl use index(idx_b_c) where b > 1 and c > 1 limit 2,1", + // Limit should NOT be pushed down into IndexLookUpReader when Selection on top of TableScan. + "explain select * from tbl use index(idx_b_c) where b > 1 and a > 1 limit 2,1" + ] + }, + { + "name": "TestIsFromUnixtimeNullRejective", + "cases": [ + // fix #12385 + "explain select * from t t1 left join t t2 on t1.a=t2.a where from_unixtime(t2.b);" + ] + }, + { + "name": "TestSimplifyOuterJoinWithCast", + "cases": [ + // LeftOuterJoin should no be simplified to InnerJoin. + "explain select * from t t1 left join t t2 on t1.a = t2.a where cast(t1.b as date) >= '2019-01-01'" + ] + } +] diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json new file mode 100644 index 0000000000000..13f71330350cd --- /dev/null +++ b/planner/core/testdata/integration_suite_out.json @@ -0,0 +1,40 @@ +[ + { + "Name": "TestPushLimitDownIndexLookUpReader", + "Cases": null + }, + { + "Name": "TestIsFromUnixtimeNullRejective", + "Cases": [ + { + "SQL": "explain select * from t t1 left join t t2 on t1.a=t2.a where from_unixtime(t2.b);", + "Plan": [ + "HashLeftJoin_8 9990.00 root inner join, inner:Selection_13, equal:[eq(test.t1.a, test.t2.a)]", + "├─TableReader_12 9990.00 root data:Selection_11", + "│ └─Selection_11 9990.00 cop not(isnull(test.t1.a))", + "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─Selection_13 7992.00 root from_unixtime(cast(test.t2.b))", + " └─TableReader_16 9990.00 root data:Selection_15", + " └─Selection_15 9990.00 cop not(isnull(test.t2.a))", + " └─TableScan_14 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo" + ] + } + ] + }, + { + "Name": "TestSimplifyOuterJoinWithCast", + "Cases": [ + { + "SQL": "explain select * from t t1 left join t t2 on t1.a = t2.a where cast(t1.b as date) >= '2019-01-01'", + "Plan": [ + "HashLeftJoin_8 10000.00 root left outer join, inner:TableReader_13, equal:[eq(test.t1.a, test.t2.a)]", + "├─Selection_9 8000.00 root ge(cast(test.t1.b), 2019-01-01 00:00:00.000000)", + "│ └─TableReader_11 10000.00 root data:TableScan_10", + "│ └─TableScan_10 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + "└─TableReader_13 10000.00 root data:TableScan_12", + " └─TableScan_12 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo" + ] + } + ] + } +] diff --git a/planner/core/testdata/plan_suite_in.json b/planner/core/testdata/plan_suite_in.json new file mode 100644 index 0000000000000..030fd3f91cc28 --- /dev/null +++ b/planner/core/testdata/plan_suite_in.json @@ -0,0 +1,408 @@ +[ + { + "name": "TestIndexHint", + "cases": [ + // simple case + "select /*+ USE_INDEX(t, c_d_e) */ * from t", + "select /*+ USE_INDEX(t, c_d_e) */ * from t t1", + "select /*+ USE_INDEX(t1, c_d_e) */ * from t t1", + "select /*+ USE_INDEX(t1, c_d_e), USE_INDEX(t2, f) */ * from t t1, t t2 where t1.a = t2.b", + // test multiple indexes + "select /*+ USE_INDEX(t, c_d_e, f, g) */ * from t order by f", + // use TablePath when the hint only contains table. + "select /*+ USE_INDEX(t) */ f from t where f > 10", + // there will be a warning instead of error when index not exist + "select /*+ USE_INDEX(t, no_such_index) */ * from t" + ] + }, + { + "name": "TestDAGPlanBuilderSimpleCase", + "cases":[ + // Test index hint. + "select * from t t1 use index(c_d_e)", + "select f from t use index() where f = 1", + // Test ts + Sort vs. DoubleRead + filter. + "select a from t where a between 1 and 2 order by c", + // Test DNF condition + Double Read. + "select * from t where (t.c > 0 and t.c < 2) or (t.c > 4 and t.c < 6) or (t.c > 8 and t.c < 10) or (t.c > 12 and t.c < 14) or (t.c > 16 and t.c < 18)", + "select * from t where (t.c > 0 and t.c < 1) or (t.c > 2 and t.c < 3) or (t.c > 4 and t.c < 5) or (t.c > 6 and t.c < 7) or (t.c > 9 and t.c < 10)", + // Test TopN to table branch in double read. + "select * from t where t.c = 1 and t.e = 1 order by t.b limit 1", + // Test Null Range + "select * from t where t.e_str is null", + // Test Null Range but the column has not null flag. + "select * from t where t.c is null", + // Test TopN to index branch in double read. + "select * from t where t.c = 1 and t.e = 1 order by t.e limit 1", + // Test TopN to Limit in double read. + "select * from t where t.c = 1 and t.e = 1 order by t.d limit 1", + // Test TopN to Limit in index single read. + "select c from t where t.c = 1 and t.e = 1 order by t.d limit 1", + // Test TopN to Limit in table single read. + "select c from t order by t.a limit 1", + // Test TopN push down in table single read. + "select c from t order by t.a + t.b limit 1", + // Test Limit push down in table single read. + "select c from t limit 1", + // Test Limit push down in index single read. + "select c from t where c = 1 limit 1", + // Test index single read and Selection. + "select c from t where c = 1", + // Test index single read and Sort. + "select c from t order by c", + // Test index single read and Sort. + "select c from t where c = 1 order by e", + // Test Limit push down in double single read. + "select c, b from t where c = 1 limit 1", + // Test Selection + Limit push down in double single read. + "select c, b from t where c = 1 and e = 1 and b = 1 limit 1", + // Test Order by multi columns. + "select c from t where c = 1 order by d, c", + // Test for index with length. + "select c_str from t where e_str = '1' order by d_str, c_str", + // Test PK in index single read. + "select c from t where t.c = 1 and t.a > 1 order by t.d limit 1", + // Test composed index. + // FIXME: The TopN didn't be pushed. + "select c from t where t.c = 1 and t.d = 1 order by t.a limit 1", + // Test PK in index double read. + "select * from t where t.c = 1 and t.a > 1 order by t.d limit 1", + // Test index filter condition push down. + "select * from t use index(e_d_c_str_prefix) where t.c_str = 'abcdefghijk' and t.d_str = 'd' and t.e_str = 'e'", + "select * from t use index(e_d_c_str_prefix) where t.e_str = b'1110000'", + "select * from (select * from t use index() order by b) t left join t t1 on t.a=t1.a limit 10", + // Test embedded ORDER BY which imposes on different number of columns than outer query. + "select * from ((SELECT 1 a,3 b) UNION (SELECT 2,1) ORDER BY (SELECT 2)) t order by a,b", + "select * from ((SELECT 1 a,6 b) UNION (SELECT 2,5) UNION (SELECT 2, 4) ORDER BY 1) t order by 1, 2", + "select * from (select *, NULL as xxx from t) t order by xxx", + "select lead(a, 1) over (partition by null) as c from t", + "select * from t use index(f) where f = 1 and a = 1", + "select * from t2 use index(b) where b = 1 and a = 1" + ] + }, + { + "name": "TestDAGPlanBuilderJoin", + "cases": [ + "select * from t t1 join t t2 on t1.a = t2.c_str", + "select * from t t1 join t t2 on t1.b = t2.a", + "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.a = t3.a", + "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.b = t3.a", + "select * from t t1 join t t2 on t1.b = t2.a order by t1.a", + "select * from t t1 join t t2 on t1.b = t2.a order by t1.a limit 1", + // Test hash join's hint. + "select /*+ TIDB_HJ(t1, t2) */ * from t t1 join t t2 on t1.b = t2.a order by t1.a limit 1", + "select * from t t1 left join t t2 on t1.b = t2.a where 1 = 1 limit 1", + "select * from t t1 join t t2 on t1.b = t2.a and t1.c = 1 and t1.d = 1 and t1.e = 1 order by t1.a limit 1", + "select * from t t1 join t t2 on t1.b = t2.b join t t3 on t1.b = t3.b", + "select * from t t1 join t t2 on t1.a = t2.a order by t1.a", + "select * from t t1 left outer join t t2 on t1.a = t2.a right outer join t t3 on t1.a = t3.a", + "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.a = t3.a and t1.b = 1 and t3.c = 1", + "select * from t where t.c in (select b from t s where s.a = t.a)", + "select t.c in (select b from t s where s.a = t.a) from t", + // Test Single Merge Join. + // Merge Join now enforce a sort. + "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.b", + "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.a", + // Test Single Merge Join + Sort. + "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.a order by t2.a", + "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.b = t2.b order by t2.a", + // Test Single Merge Join + Sort + desc. + "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.a order by t2.a desc", + "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.b = t2.b order by t2.b desc", + // Test Multi Merge Join. + "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.a = t2.a and t2.a = t3.a", + "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.a = t2.b and t2.a = t3.b", + // Test Multi Merge Join with multi keys. + // TODO: More tests should be added. + "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.c = t2.c and t1.d = t2.d and t3.c = t1.c and t3.d = t1.d", + "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.c = t2.c and t1.d = t2.d and t3.c = t1.c and t3.d = t1.d order by t1.c", + // Test Multi Merge Join + Outer Join. + "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1 left outer join t t2 on t1.a = t2.a left outer join t t3 on t2.a = t3.a", + "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1 left outer join t t2 on t1.a = t2.a left outer join t t3 on t1.a = t3.a", + // Test Index Join + TableScan. + "select /*+ TIDB_INLJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.a", + // Test Index Join + DoubleRead. + "select /*+ TIDB_INLJ(t2) */ * from t t1, t t2 where t1.a = t2.c", + // Test Index Join + SingleRead. + "select /*+ TIDB_INLJ(t2) */ t1.a , t2.a from t t1, t t2 where t1.a = t2.c", + // Test Index Join + Order by. + "select /*+ TIDB_INLJ(t1, t2) */ t1.a, t2.a from t t1, t t2 where t1.a = t2.a order by t1.c", + // Test Index Join + Order by. + "select /*+ TIDB_INLJ(t1, t2) */ t1.a, t2.a from t t1, t t2 where t1.a = t2.a order by t2.c", + // Test Index Join + TableScan + Rotate. + "select /*+ TIDB_INLJ(t1) */ t1.a , t2.a from t t1, t t2 where t1.a = t2.c", + // Test Index Join + OuterJoin + TableScan. + "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 left outer join t t2 on t1.a = t2.a and t2.b < 1", + "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 join t t2 on t1.d=t2.d and t2.c = 1", + // Test Index Join failed. + "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 left outer join t t2 on t1.a = t2.b", + // Test Index Join failed. + "select /*+ TIDB_INLJ(t2) */ * from t t1 right outer join t t2 on t1.a = t2.b", + // Test Semi Join hint success. + "select /*+ TIDB_INLJ(t2) */ * from t t1 where t1.a in (select a from t t2)", + // Test Semi Join hint fail. + "select /*+ TIDB_INLJ(t1) */ * from t t1 where t1.a in (select a from t t2)", + "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.c=t2.c and t1.f=t2.f", + "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.a = t2.a and t1.f=t2.f", + "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.f=t2.f and t1.a=t2.a", + "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.a=t2.a and t2.a in (1, 2)", + "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.b=t2.c and t1.b=1 and t2.d > t1.d-10 and t2.d < t1.d+10", + "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.b=t2.b and t1.c=1 and t2.c=1 and t2.d > t1.d-10 and t2.d < t1.d+10", + "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t2.c > t1.d-10 and t2.c < t1.d+10", + "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.b = t2.c and t2.c=1 and t2.d=2 and t2.e=4", + "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t2.c=1 and t2.d=1 and t2.e > 10 and t2.e < 20" + ] + }, + { + "name": "TestDAGPlanBuilderSubquery", + "cases": [ + // Test join key with cast. + "select * from t where exists (select s.a from t s having sum(s.a) = t.a )", + "select * from t where exists (select s.a from t s having sum(s.a) = t.a ) order by t.a", + // FIXME: Report error by resolver. + // "select * from t where exists (select s.a from t s having s.a = t.a ) order by t.a", + "select * from t where a in (select s.a from t s) order by t.a", + // Test Nested sub query. + "select * from t where exists (select s.a from t s where s.c in (select c from t as k where k.d = s.d) having sum(s.a) = t.a )", + // Test Semi Join + Order by. + "select * from t where a in (select a from t) order by b", + // Test Apply. + "select t.c in (select count(*) from t s, t t1 where s.a = t.a and s.a = t1.a) from t", + "select (select count(*) from t s, t t1 where s.a = t.a and s.a = t1.a) from t", + "select (select count(*) from t s, t t1 where s.a = t.a and s.a = t1.a) from t order by t.a" + ] + }, + { + "name": "TestDAGPlanTopN", + "cases": [ + "select * from t t1 left join t t2 on t1.b = t2.b left join t t3 on t2.b = t3.b order by t1.a limit 1", + "select * from t t1 left join t t2 on t1.b = t2.b left join t t3 on t2.b = t3.b order by t1.b limit 1", + "select * from t t1 left join t t2 on t1.b = t2.b left join t t3 on t2.b = t3.b limit 1", + "select * from t where b = 1 and c = 1 order by c limit 1", + "select * from t where c = 1 order by c limit 1", + "select * from t order by a limit 1", + "select c from t order by c limit 1" + ] + }, + { + "name": "TestDAGPlanBuilderBasePhysicalPlan", + "cases": [ + // Test for update. + // TODO: This is not reasonable. Mysql do like this because the limit of InnoDB, should TiDB keep consistency with MySQL? + "select * from t order by b limit 1 for update", + // Test complex update. + "update t set a = 5 where b < 1 order by d limit 1", + // Test simple update. + "update t set a = 5", + // TODO: Test delete/update with join. + // Test join hint for delete and update + "delete /*+ TIDB_INLJ(t1, t2) */ t1 from t t1, t t2 where t1.c=t2.c", + "delete /*+ TIDB_SMJ(t1, t2) */ from t1 using t t1, t t2 where t1.c=t2.c", + "update /*+ TIDB_SMJ(t1, t2) */ t t1, t t2 set t1.a=1, t2.a=1 where t1.a=t2.a", + "update /*+ TIDB_HJ(t1, t2) */ t t1, t t2 set t1.a=1, t2.a=1 where t1.a=t2.a", + // Test complex delete. + "delete from t where b < 1 order by d limit 1", + // Test simple delete. + "delete from t", + // Test "USE INDEX" hint in delete statement from single table + "delete from t use index(c_d_e) where b = 1", + // Test complex insert. + "insert into t select * from t where b < 1 order by d limit 1", + // Test simple insert. + "insert into t (a, b, c, e, f, g) values(0,0,0,0,0,0)", + // Test dual. + "select 1", + "select * from t where false", + // Test show. + "show tables" + ] + }, + { + "name": "TestDAGPlanBuilderUnion", + "cases": [ + // Test simple union. + "select * from t union all select * from t", + // Test Order by + Union. + "select * from t union all (select * from t) order by a ", + // Test Limit + Union. + "select * from t union all (select * from t) limit 1", + // Test TopN + Union. + "select a from t union all (select c from t) order by a limit 1" + ] + }, + { + "name": "TestDAGPlanBuilderUnionScan", + "cases": [ + // Read table. + "select * from t", + "select * from t where b = 1", + "select * from t where a = 1", + "select * from t where a = 1 order by a", + "select * from t where a = 1 order by b", + "select * from t where a = 1 limit 1", + "select * from t where c = 1", + "select c from t where c = 1" + ] + }, + { + "name": "TestDAGPlanBuilderAgg", + "cases": [ + // Test distinct. + "select distinct b from t", + "select count(*) from (select * from t order by b) t group by b", + "select count(*), x from (select b as bbb, a + 1 as x from (select * from t order by b) t) t group by bbb", + // Test agg + table. + "select sum(a), avg(b + c) from t group by d", + "select sum(distinct a), avg(b + c) from t group by d", + // Test group by (c + d) + "select sum(e), avg(e + c) from t where c = 1 group by (c + d)", + // Test stream agg + index single. + "select sum(e), avg(e + c) from t where c = 1 group by c", + // Test hash agg + index single. + "select sum(e), avg(e + c) from t where c = 1 group by e", + // Test hash agg + index double. + "select sum(e), avg(b + c) from t where c = 1 and e = 1 group by d", + // Test stream agg + index double. + "select sum(e), avg(b + c) from t where c = 1 and b = 1", + // Test hash agg + order. + "select sum(e) as k, avg(b + c) from t where c = 1 and b = 1 and e = 1 group by d order by k", + // Test stream agg + order. + "select sum(e) as k, avg(b + c) from t where c = 1 and b = 1 and e = 1 group by c order by k", + // Test agg can't push down. + "select sum(to_base64(e)) from t where c = 1", + "select (select count(1) k from t s where s.a = t.a having k != 0) from t", + // Test stream agg with multi group by columns. + "select sum(to_base64(e)) from t group by e,d,c order by c", + "select sum(e+1) from t group by e,d,c order by c", + "select sum(to_base64(e)) from t group by e,d,c order by c,e", + "select sum(e+1) from t group by e,d,c order by c,e", + // Test stream agg + limit or sort + "select count(*) from t group by g order by g limit 10", + "select count(*) from t group by g limit 10", + "select count(*) from t group by g order by g", + "select count(*) from t group by g order by g desc limit 1", + // Test hash agg + limit or sort + "select count(*) from t group by b order by b limit 10", + "select count(*) from t group by b order by b", + "select count(*) from t group by b limit 10", + // Test merge join + stream agg + "select sum(a.g), sum(b.g) from t a join t b on a.g = b.g group by a.g", + // Test index join + stream agg + "select /*+ tidb_inlj(a,b) */ sum(a.g), sum(b.g) from t a join t b on a.g = b.g and a.g > 60 group by a.g order by a.g limit 1", + "select sum(a.g), sum(b.g) from t a join t b on a.g = b.g and a.a>5 group by a.g order by a.g limit 1", + "select sum(d) from t" + ] + }, + { + "name": "TestRefine", + "cases": [ + "select a from t where c is not null", + "select a from t where c >= 4", + "select a from t where c <= 4", + "select a from t where c = 4 and d = 5 and e = 6", + "select a from t where d = 4 and c = 5", + "select a from t where c = 4 and e < 5", + "select a from t where c = 4 and d <= 5 and d > 3", + "select a from t where d <= 5 and d > 3", + "select a from t where c between 1 and 2", + "select a from t where c not between 1 and 2", + "select a from t where c <= 5 and c >= 3 and d = 1", + "select a from t where c = 1 or c = 2 or c = 3", + "select b from t where c = 1 or c = 2 or c = 3 or c = 4 or c = 5", + "select a from t where c = 5", + "select a from t where c = 5 and b = 1", + "select a from t where not a", + "select a from t where c in (1)", + "select a from t where c in ('1')", + "select a from t where c = 1.0", + "select a from t where c in (1) and d > 3", + "select a from t where c in (1, 2, 3) and (d > 3 and d < 4 or d > 5 and d < 6)", + "select a from t where c in (1, 2, 3) and (d > 2 and d < 4 or d > 5 and d < 7)", + "select a from t where c in (1, 2, 3)", + "select a from t where c in (1, 2, 3) and d in (1,2) and e = 1", + "select a from t where d in (1, 2, 3)", + "select a from t where c not in (1)", + "select a from t use index(c_d_e) where c != 1", + // test like + "select a from t where c_str like ''", + "select a from t where c_str like 'abc'", + "select a from t where c_str not like 'abc'", + "select a from t where not (c_str like 'abc' or c_str like 'abd')", + "select a from t where c_str like '_abc'", + "select a from t where c_str like 'abc%'", + "select a from t where c_str like 'abc_'", + "select a from t where c_str like 'abc%af'", + "select a from t where c_str like 'abc\\_' escape ''", + "select a from t where c_str like 'abc\\_'", + "select a from t where c_str like 'abc\\\\_'", + "select a from t where c_str like 'abc\\_%'", + "select a from t where c_str like 'abc=_%' escape '='", + "select a from t where c_str like 'abc\\__'", + // Check that 123 is converted to string '123'. index can be used. + "select a from t where c_str like 123", + "select a from t where c = 1.9 and d > 3", + "select a from t where c < 1.1", + "select a from t where c <= 1.9", + "select a from t where c >= 1.1", + "select a from t where c > 1.9", + "select a from t where c = 123456789098765432101234", + "select a from t where c = 'hanfei'" + ] + }, + { + "name": "TestAggEliminator", + "cases": [ + // Max to Limit + Sort-Desc. + "select max(a) from t;", + // Min to Limit + Sort. + "select min(a) from t;", + // Min to Limit + Sort, and isnull() should be added. + "select min(c_str) from t;", + // Do nothing to max + firstrow. + "select max(a), b from t;", + // If max/min contains scalar function, we can still do transformation. + "select max(a+1) from t;", + // Do nothing to max+min. + "select max(a), min(a) from t;", + // Do nothing to max with groupby. + "select max(a) from t group by b;", + // If inner is not a data source, we can still do transformation. + "select max(a) from (select t1.a from t t1 join t t2 on t1.a=t2.a) t" + ] + }, + { + "name": "TestUnmatchedTableInHint", + "cases": [ + "SELECT /*+ TIDB_SMJ(t3, t4) */ * from t t1, t t2 where t1.a = t2.a", + "SELECT /*+ TIDB_HJ(t3, t4) */ * from t t1, t t2 where t1.a = t2.a", + "SELECT /*+ TIDB_INLJ(t3, t4) */ * from t t1, t t2 where t1.a = t2.a", + "SELECT /*+ TIDB_SMJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.a", + "SELECT /*+ TIDB_SMJ(t3, t4) */ * from t t1, t t2, t t3 where t1.a = t2.a and t2.a = t3.a" + ] + }, + { + "name": "TestIndexJoinHint", + "cases": [ + "select /*+ TIDB_INLJ(t1) */ t1.a, t2.a, t3.a from t t1, t t2, t t3 where t1.a = t2.a and t2.a = t3.a;", + "select /*+ TIDB_INLJ(t1) */ t1.b, t2.a from t t1, t t2 where t1.b = t2.a;", + "select /*+ TIDB_INLJ(t2) */ t1.b, t2.a from t2 t1, t2 t2 where t1.b=t2.b and t2.c=-1;" + ] + }, + { + "name": "TestIndexJoinUnionScan", + "cases": [ + // Test Index Join + UnionScan + TableScan. + "select /*+ TIDB_INLJ(t2) */ * from t t1, t t2 where t1.a = t2.a", + // Test Index Join + UnionScan + DoubleRead. + "select /*+ TIDB_INLJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.c", + // Test Index Join + UnionScan + IndexScan. + "select /*+ TIDB_INLJ(t1, t2) */ t1.a , t2.c from t t1, t t2 where t1.a = t2.c" + ] + }, + { + "name": "TestSemiJoinToInner", + "cases": [ + "select t1.a, (select count(t2.a) from t t2 where t2.g in (select t3.d from t t3 where t3.c = t1.a)) as agg_col from t t1;" + ] + } +] diff --git a/planner/core/testdata/plan_suite_out.json b/planner/core/testdata/plan_suite_out.json new file mode 100644 index 0000000000000..37f089b700b23 --- /dev/null +++ b/planner/core/testdata/plan_suite_out.json @@ -0,0 +1,934 @@ +[ + { + "Name": "TestIndexHint", + "Cases": null + }, + { + "Name": "TestDAGPlanBuilderSimpleCase", + "Cases": [ + { + "SQL": "select * from t t1 use index(c_d_e)", + "Best": "IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))" + }, + { + "SQL": "select f from t use index() where f = 1", + "Best": "TableReader(Table(t)->Sel([eq(test.t.f, 1)]))" + }, + { + "SQL": "select a from t where a between 1 and 2 order by c", + "Best": "TableReader(Table(t))->Sort->Projection" + }, + { + "SQL": "select * from t where (t.c > 0 and t.c < 2) or (t.c > 4 and t.c < 6) or (t.c > 8 and t.c < 10) or (t.c > 12 and t.c < 14) or (t.c > 16 and t.c < 18)", + "Best": "IndexLookUp(Index(t.c_d_e)[(0,2) (4,6) (8,10) (12,14) (16,18)], Table(t))" + }, + { + "SQL": "select * from t where (t.c > 0 and t.c < 1) or (t.c > 2 and t.c < 3) or (t.c > 4 and t.c < 5) or (t.c > 6 and t.c < 7) or (t.c > 9 and t.c < 10)", + "Best": "Dual" + }, + { + "SQL": "select * from t where t.c = 1 and t.e = 1 order by t.b limit 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t))->TopN([test.t.b],0,1)" + }, + { + "SQL": "select * from t where t.e_str is null", + "Best": "IndexLookUp(Index(t.e_d_c_str_prefix)[[NULL,NULL]], Table(t))" + }, + { + "SQL": "select * from t where t.c is null", + "Best": "Dual" + }, + { + "SQL": "select * from t where t.c = 1 and t.e = 1 order by t.e limit 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t))->TopN([test.t.e],0,1)" + }, + { + "SQL": "select * from t where t.c = 1 and t.e = 1 order by t.d limit 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)])->Limit, Table(t))" + }, + { + "SQL": "select c from t where t.c = 1 and t.e = 1 order by t.d limit 1", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)])->Limit)->Limit->Projection" + }, + { + "SQL": "select c from t order by t.a limit 1", + "Best": "TableReader(Table(t)->Limit)->Limit->Projection" + }, + { + "SQL": "select c from t order by t.a + t.b limit 1", + "Best": "TableReader(Table(t)->TopN([plus(test.t.a, test.t.b)],0,1))->Projection->TopN([col_3],0,1)->Projection->Projection" + }, + { + "SQL": "select c from t limit 1", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Limit)->Limit" + }, + { + "SQL": "select c from t where c = 1 limit 1", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]]->Limit)->Limit" + }, + { + "SQL": "select c from t where c = 1", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]])" + }, + { + "SQL": "select c from t order by c", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]])" + }, + { + "SQL": "select c from t where c = 1 order by e", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]])->Sort->Projection" + }, + { + "SQL": "select c, b from t where c = 1 limit 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Limit, Table(t))->Projection" + }, + { + "SQL": "select c, b from t where c = 1 and e = 1 and b = 1 limit 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)])->Limit)->Limit->Projection" + }, + { + "SQL": "select c from t where c = 1 order by d, c", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]])->Sort->Projection" + }, + { + "SQL": "select c_str from t where e_str = '1' order by d_str, c_str", + "Best": "IndexLookUp(Index(t.e_d_c_str_prefix)[[\"1\",\"1\"]], Table(t))->Sort->Projection" + }, + { + "SQL": "select c from t where t.c = 1 and t.a > 1 order by t.d limit 1", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]]->Sel([gt(test.t.a, 1)])->Limit)->Limit->Projection" + }, + { + "SQL": "select c from t where t.c = 1 and t.d = 1 order by t.a limit 1", + "Best": "IndexReader(Index(t.c_d_e)[[1 1,1 1]])->TopN([test.t.a],0,1)->Projection" + }, + { + "SQL": "select * from t where t.c = 1 and t.a > 1 order by t.d limit 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([gt(test.t.a, 1)])->Limit, Table(t))" + }, + { + "SQL": "select * from t use index(e_d_c_str_prefix) where t.c_str = 'abcdefghijk' and t.d_str = 'd' and t.e_str = 'e'", + "Best": "IndexLookUp(Index(t.e_d_c_str_prefix)[[\"e\" \"d\" \"abcdefghij\",\"e\" \"d\" \"abcdefghij\"]], Table(t)->Sel([eq(test.t.c_str, abcdefghijk)]))" + }, + { + "SQL": "select * from t use index(e_d_c_str_prefix) where t.e_str = b'1110000'", + "Best": "IndexLookUp(Index(t.e_d_c_str_prefix)[[\"p\",\"p\"]], Table(t))" + }, + { + "SQL": "select * from (select * from t use index() order by b) t left join t t1 on t.a=t1.a limit 10", + "Best": "IndexJoin{TableReader(Table(t)->TopN([test.t.b],0,10))->TopN([test.t.b],0,10)->TableReader(Table(t))}(test.t.a,test.t1.a)->Limit" + }, + { + "SQL": "select * from ((SELECT 1 a,3 b) UNION (SELECT 2,1) ORDER BY (SELECT 2)) t order by a,b", + "Best": "UnionAll{Dual->Projection->Dual->Projection}->HashAgg->Sort" + }, + { + "SQL": "select * from ((SELECT 1 a,6 b) UNION (SELECT 2,5) UNION (SELECT 2, 4) ORDER BY 1) t order by 1, 2", + "Best": "UnionAll{Dual->Projection->Dual->Projection->Dual->Projection}->HashAgg->Sort->Sort" + }, + { + "SQL": "select * from (select *, NULL as xxx from t) t order by xxx", + "Best": "TableReader(Table(t))->Projection" + }, + { + "SQL": "select lead(a, 1) over (partition by null) as c from t", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Window(lead(test.t.a, 1) over())->Projection" + }, + { + "SQL": "select * from t use index(f) where f = 1 and a = 1", + "Best": "IndexLookUp(Index(t.f)[[1,1]]->Sel([eq(test.t.a, 1)]), Table(t))" + }, + { + "SQL": "select * from t2 use index(b) where b = 1 and a = 1", + "Best": "IndexLookUp(Index(t2.b)[[1,1]]->Sel([eq(test.t2.a, 1)]), Table(t2))" + } + ] + }, + { + "Name": "TestDAGPlanBuilderJoin", + "Cases": [ + { + "SQL": "select * from t t1 join t t2 on t1.a = t2.c_str", + "Best": "LeftHashJoin{TableReader(Table(t))->Projection->TableReader(Table(t))->Projection}(cast(test.t1.a),cast(test.t2.c_str))->Projection" + }, + { + "SQL": "select * from t t1 join t t2 on t1.b = t2.a", + "Best": "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)" + }, + { + "SQL": "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.a = t3.a", + "Best": "MergeInnerJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t1.a,test.t3.a)" + }, + { + "SQL": "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.b = t3.a", + "Best": "LeftHashJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t1.b,test.t3.a)" + }, + { + "SQL": "select * from t t1 join t t2 on t1.b = t2.a order by t1.a", + "Best": "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)->Sort" + }, + { + "SQL": "select * from t t1 join t t2 on t1.b = t2.a order by t1.a limit 1", + "Best": "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)->Limit" + }, + { + "SQL": "select /*+ TIDB_HJ(t1, t2) */ * from t t1 join t t2 on t1.b = t2.a order by t1.a limit 1", + "Best": "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)->TopN([test.t1.a],0,1)" + }, + { + "SQL": "select * from t t1 left join t t2 on t1.b = t2.a where 1 = 1 limit 1", + "Best": "IndexJoin{TableReader(Table(t)->Limit)->Limit->TableReader(Table(t))}(test.t1.b,test.t2.a)->Limit" + }, + { + "SQL": "select * from t t1 join t t2 on t1.b = t2.a and t1.c = 1 and t1.d = 1 and t1.e = 1 order by t1.a limit 1", + "Best": "IndexJoin{IndexLookUp(Index(t.c_d_e)[[1 1 1,1 1 1]], Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.a)->TopN([test.t1.a],0,1)" + }, + { + "SQL": "select * from t t1 join t t2 on t1.b = t2.b join t t3 on t1.b = t3.b", + "Best": "LeftHashJoin{LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.b,test.t2.b)->TableReader(Table(t))}(test.t1.b,test.t3.b)" + }, + { + "SQL": "select * from t t1 join t t2 on t1.a = t2.a order by t1.a", + "Best": "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select * from t t1 left outer join t t2 on t1.a = t2.a right outer join t t3 on t1.a = t3.a", + "Best": "MergeRightOuterJoin{MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t1.a,test.t3.a)" + }, + { + "SQL": "select * from t t1 join t t2 on t1.a = t2.a join t t3 on t1.a = t3.a and t1.b = 1 and t3.c = 1", + "Best": "IndexJoin{IndexJoin{TableReader(Table(t)->Sel([eq(test.t1.b, 1)]))->IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t))}(test.t3.a,test.t1.a)->TableReader(Table(t))}(test.t1.a,test.t2.a)->Projection" + }, + { + "SQL": "select * from t where t.c in (select b from t s where s.a = t.a)", + "Best": "MergeSemiJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t.a,test.s.a)" + }, + { + "SQL": "select t.c in (select b from t s where s.a = t.a) from t", + "Best": "MergeLeftOuterSemiJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t.a,test.s.a)->Projection" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.b", + "Best": "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))->Sort}(test.t1.a,test.t2.b)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.a", + "Best": "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.a order by t2.a", + "Best": "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.b = t2.b order by t2.a", + "Best": "MergeInnerJoin{TableReader(Table(t))->Sort->TableReader(Table(t))->Sort}(test.t1.b,test.t2.b)->Sort" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.a = t2.a order by t2.a desc", + "Best": "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2)*/ * from t t1, t t2 where t1.b = t2.b order by t2.b desc", + "Best": "MergeInnerJoin{TableReader(Table(t))->Sort->TableReader(Table(t))->Sort}(test.t1.b,test.t2.b)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.a = t2.a and t2.a = t3.a", + "Best": "MergeInnerJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t2.a,test.t3.a)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.a = t2.b and t2.a = t3.b", + "Best": "MergeInnerJoin{MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))->Sort}(test.t1.a,test.t2.b)->Sort->TableReader(Table(t))->Sort}(test.t2.a,test.t3.b)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.c = t2.c and t1.d = t2.d and t3.c = t1.c and t3.d = t1.d", + "Best": "MergeInnerJoin{MergeInnerJoin{IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)(test.t1.d,test.t2.d)->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t3.c)(test.t1.d,test.t3.d)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1, t t2, t t3 where t1.c = t2.c and t1.d = t2.d and t3.c = t1.c and t3.d = t1.d order by t1.c", + "Best": "MergeInnerJoin{MergeInnerJoin{IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)(test.t1.d,test.t2.d)->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t3.c)(test.t1.d,test.t3.d)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1 left outer join t t2 on t1.a = t2.a left outer join t t3 on t2.a = t3.a", + "Best": "MergeLeftOuterJoin{MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t2.a,test.t3.a)" + }, + { + "SQL": "select /*+ TIDB_SMJ(t1,t2,t3)*/ * from t t1 left outer join t t2 on t1.a = t2.a left outer join t t3 on t1.a = t3.a", + "Best": "MergeLeftOuterJoin{MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->TableReader(Table(t))}(test.t1.a,test.t3.a)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.a", + "Best": "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1, t t2 where t1.a = t2.c", + "Best": "IndexJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.a,test.t2.c)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ t1.a , t2.a from t t1, t t2 where t1.a = t2.c", + "Best": "IndexJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t1.a,test.t2.c)->Projection" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1, t2) */ t1.a, t2.a from t t1, t t2 where t1.a = t2.a order by t1.c", + "Best": "IndexJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->TableReader(Table(t))}(test.t1.a,test.t2.a)->Projection" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1, t2) */ t1.a, t2.a from t t1, t t2 where t1.a = t2.a order by t2.c", + "Best": "IndexJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.a,test.t1.a)->Projection" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1) */ t1.a , t2.a from t t1, t t2 where t1.a = t2.c", + "Best": "IndexJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.c,test.t1.a)->Projection" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 left outer join t t2 on t1.a = t2.a and t2.b < 1", + "Best": "IndexJoin{TableReader(Table(t))->TableReader(Table(t)->Sel([lt(test.t2.b, 1)]))}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 join t t2 on t1.d=t2.d and t2.c = 1", + "Best": "IndexJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.d,test.t2.d)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1, t2) */ * from t t1 left outer join t t2 on t1.a = t2.b", + "Best": "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 right outer join t t2 on t1.a = t2.b", + "Best": "RightHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 where t1.a in (select a from t t2)", + "Best": "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->Projection" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1) */ * from t t1 where t1.a in (select a from t t2)", + "Best": "IndexJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t2.a,test.t1.a)->Projection" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.c=t2.c and t1.f=t2.f", + "Best": "IndexJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.a = t2.a and t1.f=t2.f", + "Best": "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.f=t2.f and t1.a=t2.a", + "Best": "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.a=t2.a and t2.a in (1, 2)", + "Best": "IndexJoin{TableReader(Table(t))->TableReader(Table(t)->Sel([in(test.t2.a, 1, 2)]))}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.b=t2.c and t1.b=1 and t2.d > t1.d-10 and t2.d < t1.d+10", + "Best": "IndexJoin{TableReader(Table(t)->Sel([eq(test.t1.b, 1)]))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.b=t2.b and t1.c=1 and t2.c=1 and t2.d > t1.d-10 and t2.d < t1.d+10", + "Best": "LeftHashJoin{IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t))->IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t))}(test.t1.b,test.t2.b)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t2.c > t1.d-10 and t2.c < t1.d+10", + "Best": "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t1.b = t2.c and t2.c=1 and t2.d=2 and t2.e=4", + "Best": "LeftHashJoin{TableReader(Table(t)->Sel([eq(test.t1.b, 1)]))->IndexLookUp(Index(t.c_d_e)[[1 2 4,1 2 4]], Table(t))}" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1 join t t2 where t2.c=1 and t2.d=1 and t2.e > 10 and t2.e < 20", + "Best": "LeftHashJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[(1 1 10,1 1 20)], Table(t))}" + } + ] + }, + { + "Name": "TestDAGPlanBuilderSubquery", + "Cases": [ + { + "SQL": "select * from t where exists (select s.a from t s having sum(s.a) = t.a )", + "Best": "LeftHashJoin{TableReader(Table(t))->Projection->IndexReader(Index(t.c_d_e)[[NULL,+inf]]->StreamAgg)->StreamAgg}(cast(test.t.a),sel_agg_1)->Projection" + }, + { + "SQL": "select * from t where exists (select s.a from t s having sum(s.a) = t.a ) order by t.a", + "Best": "LeftHashJoin{TableReader(Table(t))->Projection->IndexReader(Index(t.c_d_e)[[NULL,+inf]]->StreamAgg)->StreamAgg}(cast(test.t.a),sel_agg_1)->Projection->Sort" + }, + { + "SQL": "select * from t where a in (select s.a from t s) order by t.a", + "Best": "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t.a,test.s.a)->Projection" + }, + { + "SQL": "select * from t where exists (select s.a from t s where s.c in (select c from t as k where k.d = s.d) having sum(s.a) = t.a )", + "Best": "LeftHashJoin{TableReader(Table(t))->Projection->MergeSemiJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.s.c,test.k.c)(test.s.d,test.k.d)->Projection->StreamAgg}(cast(test.t.a),sel_agg_1)->Projection" + }, + { + "SQL": "select * from t where a in (select a from t) order by b", + "Best": "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t.a,test.t.a)->Projection->Sort" + }, + { + "SQL": "select t.c in (select count(*) from t s, t t1 where s.a = t.a and s.a = t1.a) from t", + "Best": "Apply{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->IndexJoin{TableReader(Table(t))->TableReader(Table(t)->Sel([eq(test.t1.a, test.t.a)]))}(test.s.a,test.t1.a)->StreamAgg}->Projection" + }, + { + "SQL": "select (select count(*) from t s, t t1 where s.a = t.a and s.a = t1.a) from t", + "Best": "LeftHashJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.s.a,test.t1.a)->StreamAgg}(test.t.a,test.s.a)->Projection->Projection" + }, + { + "SQL": "select (select count(*) from t s, t t1 where s.a = t.a and s.a = t1.a) from t order by t.a", + "Best": "LeftHashJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.s.a,test.t1.a)->StreamAgg}(test.t.a,test.s.a)->Projection->Sort->Projection" + } + ] + }, + { + "Name": "TestDAGPlanTopN", + "Cases": [ + { + "SQL": "select * from t t1 left join t t2 on t1.b = t2.b left join t t3 on t2.b = t3.b order by t1.a limit 1", + "Best": "LeftHashJoin{LeftHashJoin{TableReader(Table(t)->Limit)->Limit->TableReader(Table(t))}(test.t1.b,test.t2.b)->TopN([test.t1.a],0,1)->TableReader(Table(t))}(test.t2.b,test.t3.b)->TopN([test.t1.a],0,1)" + }, + { + "SQL": "select * from t t1 left join t t2 on t1.b = t2.b left join t t3 on t2.b = t3.b order by t1.b limit 1", + "Best": "LeftHashJoin{LeftHashJoin{TableReader(Table(t)->TopN([test.t1.b],0,1))->TopN([test.t1.b],0,1)->TableReader(Table(t))}(test.t1.b,test.t2.b)->TopN([test.t1.b],0,1)->TableReader(Table(t))}(test.t2.b,test.t3.b)->TopN([test.t1.b],0,1)" + }, + { + "SQL": "select * from t t1 left join t t2 on t1.b = t2.b left join t t3 on t2.b = t3.b limit 1", + "Best": "LeftHashJoin{LeftHashJoin{TableReader(Table(t)->Limit)->Limit->TableReader(Table(t))}(test.t1.b,test.t2.b)->Limit->TableReader(Table(t))}(test.t2.b,test.t3.b)->Limit" + }, + { + "SQL": "select * from t where b = 1 and c = 1 order by c limit 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t)->Sel([eq(test.t.b, 1)]))->Limit" + }, + { + "SQL": "select * from t where c = 1 order by c limit 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Limit, Table(t))" + }, + { + "SQL": "select * from t order by a limit 1", + "Best": "TableReader(Table(t)->Limit)->Limit" + }, + { + "SQL": "select c from t order by c limit 1", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Limit)->Limit" + } + ] + }, + { + "Name": "TestDAGPlanBuilderBasePhysicalPlan", + "Cases": [ + { + "SQL": "select * from t order by b limit 1 for update", + "Best": "TableReader(Table(t))->Lock->TopN([test.t.b],0,1)" + }, + { + "SQL": "update t set a = 5 where b < 1 order by d limit 1", + "Best": "TableReader(Table(t)->Sel([lt(test.t.b, 1)])->TopN([test.t.d],0,1))->TopN([test.t.d],0,1)->Update" + }, + { + "SQL": "update t set a = 5", + "Best": "TableReader(Table(t))->Update" + }, + { + "SQL": "delete /*+ TIDB_INLJ(t1, t2) */ t1 from t t1, t t2 where t1.c=t2.c", + "Best": "IndexJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)->Delete" + }, + { + "SQL": "delete /*+ TIDB_SMJ(t1, t2) */ from t1 using t t1, t t2 where t1.c=t2.c", + "Best": "MergeInnerJoin{IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))}(test.t1.c,test.t2.c)->Delete" + }, + { + "SQL": "update /*+ TIDB_SMJ(t1, t2) */ t t1, t t2 set t1.a=1, t2.a=1 where t1.a=t2.a", + "Best": "MergeInnerJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->Update" + }, + { + "SQL": "update /*+ TIDB_HJ(t1, t2) */ t t1, t t2 set t1.a=1, t2.a=1 where t1.a=t2.a", + "Best": "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->Update" + }, + { + "SQL": "delete from t where b < 1 order by d limit 1", + "Best": "TableReader(Table(t)->Sel([lt(test.t.b, 1)])->TopN([test.t.d],0,1))->TopN([test.t.d],0,1)->Delete" + }, + { + "SQL": "delete from t", + "Best": "TableReader(Table(t))->Delete" + }, + { + "SQL": "delete from t use index(c_d_e) where b = 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t)->Sel([eq(test.t.b, 1)]))->Delete" + }, + { + "SQL": "insert into t select * from t where b < 1 order by d limit 1", + "Best": "TableReader(Table(t)->Sel([lt(test.t.b, 1)])->TopN([test.t.d],0,1))->TopN([test.t.d],0,1)->Insert" + }, + { + "SQL": "insert into t (a, b, c, e, f, g) values(0,0,0,0,0,0)", + "Best": "Insert" + }, + { + "SQL": "select 1", + "Best": "Dual->Projection" + }, + { + "SQL": "select * from t where false", + "Best": "Dual" + }, + { + "SQL": "show tables", + "Best": "Show" + } + ] + }, + { + "Name": "TestDAGPlanBuilderUnion", + "Cases": [ + { + "SQL": "select * from t union all select * from t", + "Best": "UnionAll{TableReader(Table(t))->TableReader(Table(t))}" + }, + { + "SQL": "select * from t union all (select * from t) order by a ", + "Best": "UnionAll{TableReader(Table(t))->TableReader(Table(t))}->Sort" + }, + { + "SQL": "select * from t union all (select * from t) limit 1", + "Best": "UnionAll{TableReader(Table(t)->Limit)->Limit->TableReader(Table(t)->Limit)->Limit}->Limit" + }, + { + "SQL": "select a from t union all (select c from t) order by a limit 1", + "Best": "UnionAll{TableReader(Table(t)->Limit)->Limit->IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Limit)->Limit}->TopN([a],0,1)" + } + ] + }, + { + "Name": "TestDAGPlanBuilderUnionScan", + "Cases": null + }, + { + "Name": "TestDAGPlanBuilderAgg", + "Cases": [ + { + "SQL": "select distinct b from t", + "Best": "TableReader(Table(t)->HashAgg)->HashAgg" + }, + { + "SQL": "select count(*) from (select * from t order by b) t group by b", + "Best": "TableReader(Table(t))->Sort->StreamAgg" + }, + { + "SQL": "select count(*), x from (select b as bbb, a + 1 as x from (select * from t order by b) t) t group by bbb", + "Best": "TableReader(Table(t))->Sort->Projection->StreamAgg" + }, + { + "SQL": "select sum(a), avg(b + c) from t group by d", + "Best": "TableReader(Table(t)->HashAgg)->HashAgg" + }, + { + "SQL": "select sum(distinct a), avg(b + c) from t group by d", + "Best": "TableReader(Table(t))->Projection->HashAgg" + }, + { + "SQL": "select sum(e), avg(e + c) from t where c = 1 group by (c + d)", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]]->HashAgg)->HashAgg" + }, + { + "SQL": "select sum(e), avg(e + c) from t where c = 1 group by c", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]]->StreamAgg)->StreamAgg" + }, + { + "SQL": "select sum(e), avg(e + c) from t where c = 1 group by e", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]]->HashAgg)->HashAgg" + }, + { + "SQL": "select sum(e), avg(b + c) from t where c = 1 and e = 1 group by d", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t))->Projection->HashAgg" + }, + { + "SQL": "select sum(e), avg(b + c) from t where c = 1 and b = 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]], Table(t)->Sel([eq(test.t.b, 1)]))->Projection->StreamAgg" + }, + { + "SQL": "select sum(e) as k, avg(b + c) from t where c = 1 and b = 1 and e = 1 group by d order by k", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)]))->Projection->HashAgg->Sort" + }, + { + "SQL": "select sum(e) as k, avg(b + c) from t where c = 1 and b = 1 and e = 1 group by c order by k", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,1]]->Sel([eq(test.t.e, 1)]), Table(t)->Sel([eq(test.t.b, 1)]))->Projection->Projection->StreamAgg->Sort" + }, + { + "SQL": "select sum(to_base64(e)) from t where c = 1", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]])->Projection->StreamAgg" + }, + { + "SQL": "select (select count(1) k from t s where s.a = t.a having k != 0) from t", + "Best": "MergeLeftOuterJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t.a,test.s.a)->Projection->Projection" + }, + { + "SQL": "select sum(to_base64(e)) from t group by e,d,c order by c", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Projection->StreamAgg->Projection" + }, + { + "SQL": "select sum(e+1) from t group by e,d,c order by c", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->StreamAgg)->StreamAgg->Projection" + }, + { + "SQL": "select sum(to_base64(e)) from t group by e,d,c order by c,e", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Projection->StreamAgg->Sort->Projection" + }, + { + "SQL": "select sum(e+1) from t group by e,d,c order by c,e", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->StreamAgg)->StreamAgg->Sort->Projection" + }, + { + "SQL": "select count(*) from t group by g order by g limit 10", + "Best": "IndexReader(Index(t.g)[[NULL,+inf]]->StreamAgg)->StreamAgg->Limit->Projection" + }, + { + "SQL": "select count(*) from t group by g limit 10", + "Best": "IndexReader(Index(t.g)[[NULL,+inf]]->StreamAgg)->StreamAgg->Limit" + }, + { + "SQL": "select count(*) from t group by g order by g", + "Best": "IndexReader(Index(t.g)[[NULL,+inf]]->StreamAgg)->StreamAgg->Projection" + }, + { + "SQL": "select count(*) from t group by g order by g desc limit 1", + "Best": "IndexReader(Index(t.g)[[NULL,+inf]]->StreamAgg)->StreamAgg->Limit->Projection" + }, + { + "SQL": "select count(*) from t group by b order by b limit 10", + "Best": "TableReader(Table(t)->HashAgg)->HashAgg->TopN([test.t.b],0,10)->Projection" + }, + { + "SQL": "select count(*) from t group by b order by b", + "Best": "TableReader(Table(t)->HashAgg)->HashAgg->Sort->Projection" + }, + { + "SQL": "select count(*) from t group by b limit 10", + "Best": "TableReader(Table(t)->HashAgg)->HashAgg->Limit" + }, + { + "SQL": "select sum(a.g), sum(b.g) from t a join t b on a.g = b.g group by a.g", + "Best": "MergeInnerJoin{IndexReader(Index(t.g)[[NULL,+inf]])->IndexReader(Index(t.g)[[NULL,+inf]])}(test.a.g,test.b.g)->Projection->StreamAgg" + }, + { + "SQL": "select /*+ tidb_inlj(a,b) */ sum(a.g), sum(b.g) from t a join t b on a.g = b.g and a.g > 60 group by a.g order by a.g limit 1", + "Best": "IndexJoin{IndexReader(Index(t.g)[(60,+inf]])->IndexReader(Index(t.g)[[NULL,+inf]]->Sel([gt(test.b.g, 60)]))}(test.a.g,test.b.g)->Projection->StreamAgg->Limit->Projection" + }, + { + "SQL": "select sum(a.g), sum(b.g) from t a join t b on a.g = b.g and a.a>5 group by a.g order by a.g limit 1", + "Best": "IndexJoin{IndexReader(Index(t.g)[[NULL,+inf]]->Sel([gt(test.a.a, 5)]))->IndexReader(Index(t.g)[[NULL,+inf]])}(test.a.g,test.b.g)->Projection->StreamAgg->Limit->Projection" + }, + { + "SQL": "select sum(d) from t", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->StreamAgg)->StreamAgg" + } + ] + }, + { + "Name": "TestRefine", + "Cases": [ + { + "SQL": "select a from t where c is not null", + "Best": "IndexReader(Index(t.c_d_e)[[-inf,+inf]])->Projection" + }, + { + "SQL": "select a from t where c >= 4", + "Best": "IndexReader(Index(t.c_d_e)[[4,+inf]])->Projection" + }, + { + "SQL": "select a from t where c <= 4", + "Best": "IndexReader(Index(t.c_d_e)[[-inf,4]])->Projection" + }, + { + "SQL": "select a from t where c = 4 and d = 5 and e = 6", + "Best": "IndexReader(Index(t.c_d_e)[[4 5 6,4 5 6]])->Projection" + }, + { + "SQL": "select a from t where d = 4 and c = 5", + "Best": "IndexReader(Index(t.c_d_e)[[5 4,5 4]])->Projection" + }, + { + "SQL": "select a from t where c = 4 and e < 5", + "Best": "IndexReader(Index(t.c_d_e)[[4,4]]->Sel([lt(test.t.e, 5)]))->Projection" + }, + { + "SQL": "select a from t where c = 4 and d <= 5 and d > 3", + "Best": "IndexReader(Index(t.c_d_e)[(4 3,4 5]])->Projection" + }, + { + "SQL": "select a from t where d <= 5 and d > 3", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Sel([le(test.t.d, 5) gt(test.t.d, 3)]))->Projection" + }, + { + "SQL": "select a from t where c between 1 and 2", + "Best": "IndexReader(Index(t.c_d_e)[[1,2]])->Projection" + }, + { + "SQL": "select a from t where c not between 1 and 2", + "Best": "IndexReader(Index(t.c_d_e)[[-inf,1) (2,+inf]])->Projection" + }, + { + "SQL": "select a from t where c <= 5 and c >= 3 and d = 1", + "Best": "IndexReader(Index(t.c_d_e)[[3,5]]->Sel([eq(test.t.d, 1)]))->Projection" + }, + { + "SQL": "select a from t where c = 1 or c = 2 or c = 3", + "Best": "IndexReader(Index(t.c_d_e)[[1,3]])->Projection" + }, + { + "SQL": "select b from t where c = 1 or c = 2 or c = 3 or c = 4 or c = 5", + "Best": "IndexLookUp(Index(t.c_d_e)[[1,5]], Table(t))->Projection" + }, + { + "SQL": "select a from t where c = 5", + "Best": "IndexReader(Index(t.c_d_e)[[5,5]])->Projection" + }, + { + "SQL": "select a from t where c = 5 and b = 1", + "Best": "IndexLookUp(Index(t.c_d_e)[[5,5]], Table(t)->Sel([eq(test.t.b, 1)]))->Projection" + }, + { + "SQL": "select a from t where not a", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Sel([not(test.t.a)]))" + }, + { + "SQL": "select a from t where c in (1)", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]])->Projection" + }, + { + "SQL": "select a from t where c in ('1')", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]])->Projection" + }, + { + "SQL": "select a from t where c = 1.0", + "Best": "IndexReader(Index(t.c_d_e)[[1,1]])->Projection" + }, + { + "SQL": "select a from t where c in (1) and d > 3", + "Best": "IndexReader(Index(t.c_d_e)[(1 3,1 +inf]])->Projection" + }, + { + "SQL": "select a from t where c in (1, 2, 3) and (d > 3 and d < 4 or d > 5 and d < 6)", + "Best": "Dual->Projection" + }, + { + "SQL": "select a from t where c in (1, 2, 3) and (d > 2 and d < 4 or d > 5 and d < 7)", + "Best": "IndexReader(Index(t.c_d_e)[(1 2,1 4) (1 5,1 7) (2 2,2 4) (2 5,2 7) (3 2,3 4) (3 5,3 7)])->Projection" + }, + { + "SQL": "select a from t where c in (1, 2, 3)", + "Best": "IndexReader(Index(t.c_d_e)[[1,1] [2,2] [3,3]])->Projection" + }, + { + "SQL": "select a from t where c in (1, 2, 3) and d in (1,2) and e = 1", + "Best": "IndexReader(Index(t.c_d_e)[[1 1 1,1 1 1] [1 2 1,1 2 1] [2 1 1,2 1 1] [2 2 1,2 2 1] [3 1 1,3 1 1] [3 2 1,3 2 1]])->Projection" + }, + { + "SQL": "select a from t where d in (1, 2, 3)", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Sel([in(test.t.d, 1, 2, 3)]))->Projection" + }, + { + "SQL": "select a from t where c not in (1)", + "Best": "IndexReader(Index(t.c_d_e)[(NULL,1) (1,+inf]])->Projection" + }, + { + "SQL": "select a from t use index(c_d_e) where c != 1", + "Best": "IndexReader(Index(t.c_d_e)[[-inf,1) (1,+inf]])->Projection" + }, + { + "SQL": "select a from t where c_str like ''", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"\",\"\"]])->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc'", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"abc\",\"abc\"]])->Projection" + }, + { + "SQL": "select a from t where c_str not like 'abc'", + "Best": "IndexReader(Index(t.c_d_e_str)[[-inf,\"abc\") (\"abc\",+inf]])->Projection" + }, + { + "SQL": "select a from t where not (c_str like 'abc' or c_str like 'abd')", + "Best": "IndexReader(Index(t.c_d_e_str)[[-inf,\"abc\") (\"abc\",\"abd\") (\"abd\",+inf]])->Projection" + }, + { + "SQL": "select a from t where c_str like '_abc'", + "Best": "IndexReader(Index(t.c_d_e_str)[[NULL,+inf]]->Sel([like(test.t.c_str, _abc, 92)]))->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc%'", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"abc\",\"abd\")])->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc_'", + "Best": "IndexReader(Index(t.c_d_e_str)[(\"abc\",\"abd\")]->Sel([like(test.t.c_str, abc_, 92)]))->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc%af'", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"abc\",\"abd\")]->Sel([like(test.t.c_str, abc%af, 92)]))->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc\\_' escape ''", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"abc_\",\"abc_\"]])->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc\\_'", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"abc_\",\"abc_\"]])->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc\\\\_'", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"abc_\",\"abc_\"]])->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc\\_%'", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"abc_\",\"abc`\")])->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc=_%' escape '='", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"abc_\",\"abc`\")])->Projection" + }, + { + "SQL": "select a from t where c_str like 'abc\\__'", + "Best": "IndexReader(Index(t.c_d_e_str)[(\"abc_\",\"abc`\")]->Sel([like(test.t.c_str, abc\\__, 92)]))->Projection" + }, + { + "SQL": "select a from t where c_str like 123", + "Best": "IndexReader(Index(t.c_d_e_str)[[\"123\",\"123\"]])->Projection" + }, + { + "SQL": "select a from t where c = 1.9 and d > 3", + "Best": "Dual" + }, + { + "SQL": "select a from t where c < 1.1", + "Best": "IndexReader(Index(t.c_d_e)[[-inf,2)])->Projection" + }, + { + "SQL": "select a from t where c <= 1.9", + "Best": "IndexReader(Index(t.c_d_e)[[-inf,1]])->Projection" + }, + { + "SQL": "select a from t where c >= 1.1", + "Best": "IndexReader(Index(t.c_d_e)[[2,+inf]])->Projection" + }, + { + "SQL": "select a from t where c > 1.9", + "Best": "IndexReader(Index(t.c_d_e)[(1,+inf]])->Projection" + }, + { + "SQL": "select a from t where c = 123456789098765432101234", + "Best": "Dual" + }, + { + "SQL": "select a from t where c = 'hanfei'", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]])->Sel([eq(cast(test.t.c), cast(hanfei))])->Projection" + } + ] + }, + { + "Name": "TestAggEliminator", + "Cases": [ + { + "SQL": "select max(a) from t;", + "Best": "TableReader(Table(t)->Limit)->Limit->StreamAgg" + }, + { + "SQL": "select min(a) from t;", + "Best": "TableReader(Table(t)->Limit)->Limit->StreamAgg" + }, + { + "SQL": "select min(c_str) from t;", + "Best": "IndexReader(Index(t.c_d_e_str)[[-inf,+inf]]->Limit)->Limit->StreamAgg" + }, + { + "SQL": "select max(a), b from t;", + "Best": "TableReader(Table(t)->StreamAgg)->StreamAgg" + }, + { + "SQL": "select max(a+1) from t;", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->Sel([not(isnull(plus(test.t.a, 1)))])->TopN([plus(test.t.a, 1) true],0,1))->Projection->TopN([col_1 true],0,1)->Projection->Projection->StreamAgg" + }, + { + "SQL": "select max(a), min(a) from t;", + "Best": "IndexReader(Index(t.c_d_e)[[NULL,+inf]]->StreamAgg)->StreamAgg" + }, + { + "SQL": "select max(a) from t group by b;", + "Best": "TableReader(Table(t)->HashAgg)->HashAgg" + }, + { + "SQL": "select max(a) from (select t1.a from t t1 join t t2 on t1.a=t2.a) t", + "Best": "IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.a)->Limit->StreamAgg" + } + ] + }, + { + "Name": "TestUnmatchedTableInHint", + "Cases": [ + { + "SQL": "SELECT /*+ TIDB_SMJ(t3, t4) */ * from t t1, t t2 where t1.a = t2.a", + "Warning": "[planner:1815]There are no matching table names for (t3, t4) in optimizer hint /*+ TIDB_SMJ(t3, t4) */. Maybe you can use the table alias name" + }, + { + "SQL": "SELECT /*+ TIDB_HJ(t3, t4) */ * from t t1, t t2 where t1.a = t2.a", + "Warning": "[planner:1815]There are no matching table names for (t3, t4) in optimizer hint /*+ TIDB_HJ(t3, t4) */. Maybe you can use the table alias name" + }, + { + "SQL": "SELECT /*+ TIDB_INLJ(t3, t4) */ * from t t1, t t2 where t1.a = t2.a", + "Warning": "[planner:1815]There are no matching table names for (t3, t4) in optimizer hint /*+ TIDB_INLJ(t3, t4) */. Maybe you can use the table alias name" + }, + { + "SQL": "SELECT /*+ TIDB_SMJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.a", + "Warning": "" + }, + { + "SQL": "SELECT /*+ TIDB_SMJ(t3, t4) */ * from t t1, t t2, t t3 where t1.a = t2.a and t2.a = t3.a", + "Warning": "[planner:1815]There are no matching table names for (t4) in optimizer hint /*+ TIDB_SMJ(t3, t4) */. Maybe you can use the table alias name" + } + ] + }, + { + "Name": "TestIndexJoinHint", + "Cases": [ + { + "SQL": "select /*+ TIDB_INLJ(t1) */ t1.a, t2.a, t3.a from t t1, t t2, t t3 where t1.a = t2.a and t2.a = t3.a;", + "Best": "MergeInnerJoin{TableReader(Table(t))->IndexJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t2.a,test.t1.a)}(test.t3.a,test.t2.a)->Projection", + "Warning": "" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1) */ t1.b, t2.a from t t1, t t2 where t1.b = t2.a;", + "Best": "LeftHashJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])}(test.t1.b,test.t2.a)", + "Warning": "[planner:1815]Optimizer Hint /*+ TIDB_INLJ(t1) */ is inapplicable" + }, + { + "SQL": "select /*+ TIDB_INLJ(t2) */ t1.b, t2.a from t2 t1, t2 t2 where t1.b=t2.b and t2.c=-1;", + "Best": "IndexJoin{IndexReader(Index(t2.b)[[NULL,+inf]])->IndexReader(Index(t2.b_c)[[NULL,+inf]]->Sel([eq(test.t2.c, -1)]))}(test.t2.b,test.t1.b)->Projection", + "Warning": "[planner:1815]Optimizer Hint /*+ TIDB_INLJ(t2) */ is inapplicable" + } + ] + }, + { + "Name": "TestIndexJoinUnionScan", + "Cases": [ + { + "SQL": "select /*+ TIDB_INLJ(t2) */ * from t t1, t t2 where t1.a = t2.a", + "Best": "IndexJoin{TableReader(Table(t))->UnionScan([])->TableReader(Table(t))->UnionScan([])}(test.t1.a,test.t2.a)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1, t2) */ * from t t1, t t2 where t1.a = t2.c", + "Best": "IndexJoin{TableReader(Table(t))->UnionScan([])->IndexLookUp(Index(t.c_d_e)[[NULL,+inf]], Table(t))->UnionScan([])}(test.t1.a,test.t2.c)" + }, + { + "SQL": "select /*+ TIDB_INLJ(t1, t2) */ t1.a , t2.c from t t1, t t2 where t1.a = t2.c", + "Best": "IndexJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]])->UnionScan([])->IndexReader(Index(t.c_d_e)[[NULL,+inf]])->UnionScan([])}(test.t1.a,test.t2.c)->Projection" + } + ] + }, + { + "Name": "TestSemiJoinToInner", + "Cases": [ + { + "SQL": "select t1.a, (select count(t2.a) from t t2 where t2.g in (select t3.d from t t3 where t3.c = t1.a)) as agg_col from t t1;", + "Best": "Apply{TableReader(Table(t))->IndexJoin{IndexReader(Index(t.c_d_e)[[NULL,+inf]]->HashAgg)->HashAgg->IndexReader(Index(t.g)[[NULL,+inf]])}(test.t3.d,test.t2.g)}->StreamAgg" + } + ] + } +] diff --git a/planner/core/util.go b/planner/core/util.go index 1dd8fcd762413..4d6fd39973712 100644 --- a/planner/core/util.go +++ b/planner/core/util.go @@ -129,16 +129,30 @@ func (s *baseSchemaProducer) SetSchema(schema *expression.Schema) { s.schema = schema } +// Schema implements the Plan.Schema interface. +func (p *LogicalMaxOneRow) Schema() *expression.Schema { + s := p.Children()[0].Schema().Clone() + resetNotNullFlag(s, 0, s.Len()) + return s +} + func buildLogicalJoinSchema(joinType JoinType, join LogicalPlan) *expression.Schema { + leftSchema := join.Children()[0].Schema() switch joinType { case SemiJoin, AntiSemiJoin: - return join.Children()[0].Schema().Clone() + return leftSchema.Clone() case LeftOuterSemiJoin, AntiLeftOuterSemiJoin: - newSchema := join.Children()[0].Schema().Clone() + newSchema := leftSchema.Clone() newSchema.Append(join.Schema().Columns[join.Schema().Len()-1]) return newSchema } - return expression.MergeSchema(join.Children()[0].Schema(), join.Children()[1].Schema()) + newSchema := expression.MergeSchema(leftSchema, join.Children()[1].Schema()) + if joinType == LeftOuterJoin { + resetNotNullFlag(newSchema, leftSchema.Len(), newSchema.Len()) + } else if joinType == RightOuterJoin { + resetNotNullFlag(newSchema, 0, leftSchema.Len()) + } + return newSchema } func buildPhysicalJoinSchema(joinType JoinType, join PhysicalPlan) *expression.Schema { @@ -152,3 +166,27 @@ func buildPhysicalJoinSchema(joinType JoinType, join PhysicalPlan) *expression.S } return expression.MergeSchema(join.Children()[0].Schema(), join.Children()[1].Schema()) } + +// GetStatsInfo gets the statistics info from a physical plan tree. +func GetStatsInfo(i interface{}) map[string]uint64 { + p := i.(Plan) + var physicalPlan PhysicalPlan + switch x := p.(type) { + case *Insert: + physicalPlan = x.SelectPlan + case *Update: + physicalPlan = x.SelectPlan + case *Delete: + physicalPlan = x.SelectPlan + case PhysicalPlan: + physicalPlan = x + } + + if physicalPlan == nil { + return nil + } + + statsInfos := make(map[string]uint64) + statsInfos = CollectPlanStatsVersion(physicalPlan, statsInfos) + return statsInfos +} diff --git a/planner/implementation/base_test.go b/planner/implementation/base_test.go index a27b92df96df8..589d8e65207a3 100644 --- a/planner/implementation/base_test.go +++ b/planner/implementation/base_test.go @@ -40,7 +40,7 @@ type testImplSuite struct { func (s *testImplSuite) SetUpSuite(c *C) { testleak.BeforeTest() - s.is = infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockTable()}) + s.is = infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockSignedTable()}) s.sctx = plannercore.MockContext() s.Parser = parser.New() } diff --git a/planner/memo/group_test.go b/planner/memo/group_test.go index 46fe680be1d48..51ccf1f029096 100644 --- a/planner/memo/group_test.go +++ b/planner/memo/group_test.go @@ -42,7 +42,7 @@ type testMemoSuite struct { func (s *testMemoSuite) SetUpSuite(c *C) { testleak.BeforeTest() - s.is = infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockTable()}) + s.is = infoschema.MockInfoSchema([]*model.TableInfo{plannercore.MockSignedTable()}) s.sctx = plannercore.MockContext() s.Parser = parser.New() } diff --git a/planner/optimize.go b/planner/optimize.go index 820983d505b97..a0f3008802dc0 100644 --- a/planner/optimize.go +++ b/planner/optimize.go @@ -14,6 +14,8 @@ package planner import ( + "context" + "github.com/pingcap/parser/ast" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/planner/cascades" @@ -24,27 +26,27 @@ import ( // Optimize does optimization and creates a Plan. // The node must be prepared first. -func Optimize(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (plannercore.Plan, error) { - fp := plannercore.TryFastPlan(ctx, node) +func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (plannercore.Plan, error) { + fp := plannercore.TryFastPlan(sctx, node) if fp != nil { return fp, nil } // build logical plan - ctx.GetSessionVars().PlanID = 0 - ctx.GetSessionVars().PlanColumnID = 0 - builder := plannercore.NewPlanBuilder(ctx, is) - p, err := builder.Build(node) + sctx.GetSessionVars().PlanID = 0 + sctx.GetSessionVars().PlanColumnID = 0 + builder := plannercore.NewPlanBuilder(sctx, is) + p, err := builder.Build(ctx, node) if err != nil { return nil, err } - ctx.GetSessionVars().StmtCtx.Tables = builder.GetDBTableInfo() - activeRoles := ctx.GetSessionVars().ActiveRoles + sctx.GetSessionVars().StmtCtx.Tables = builder.GetDBTableInfo() + activeRoles := sctx.GetSessionVars().ActiveRoles // Check privilege. Maybe it's better to move this to the Preprocess, but // we need the table information to check privilege, which is collected // into the visitInfo in the logical plan builder. - if pm := privilege.GetPrivilegeManager(ctx); pm != nil { + if pm := privilege.GetPrivilegeManager(sctx); pm != nil { if err := plannercore.CheckPrivilege(activeRoles, pm, builder.GetVisitInfo()); err != nil { return nil, err } @@ -52,7 +54,7 @@ func Optimize(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) ( // Handle the execute statement. if execPlan, ok := p.(*plannercore.Execute); ok { - err := execPlan.OptimizePreparedPlan(ctx, is) + err := execPlan.OptimizePreparedPlan(ctx, sctx, is) return p, err } @@ -63,10 +65,10 @@ func Optimize(ctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) ( } // Handle the logical plan statement, use cascades planner if enabled. - if ctx.GetSessionVars().EnableCascadesPlanner { - return cascades.FindBestPlan(ctx, logic) + if sctx.GetSessionVars().EnableCascadesPlanner { + return cascades.FindBestPlan(sctx, logic) } - return plannercore.DoOptimize(builder.GetOptFlag(), logic) + return plannercore.DoOptimize(ctx, builder.GetOptFlag(), logic) } func init() { diff --git a/plugin/audit.go b/plugin/audit.go index 8ad556495ac62..f1471562fc657 100644 --- a/plugin/audit.go +++ b/plugin/audit.go @@ -16,7 +16,6 @@ package plugin import ( "context" - "github.com/pingcap/parser/auth" "github.com/pingcap/tidb/sessionctx/variable" ) @@ -77,7 +76,7 @@ type AuditManifest struct { Manifest // OnConnectionEvent will be called when TiDB receive or disconnect from client. // return error will ignore and close current connection. - OnConnectionEvent func(ctx context.Context, identity *auth.UserIdentity, event ConnectionEvent, info *variable.ConnectionInfo) error + OnConnectionEvent func(ctx context.Context, event ConnectionEvent, info *variable.ConnectionInfo) error // OnGeneralEvent will be called during TiDB execution. OnGeneralEvent func(ctx context.Context, sctx *variable.SessionVars, event GeneralEvent, cmd string) // OnGlobalVariableEvent will be called when Change GlobalVariable. @@ -85,3 +84,8 @@ type AuditManifest struct { // OnParseEvent will be called around parse logic. OnParseEvent func(ctx context.Context, sctx *variable.SessionVars, event ParseEvent) error } + +const ( + // ExecStartTimeCtxKey indicates stmt start execution time. + ExecStartTimeCtxKey = "ExecStartTime" +) diff --git a/plugin/const_test.go b/plugin/const_test.go new file mode 100644 index 0000000000000..dd366b41d2c4e --- /dev/null +++ b/plugin/const_test.go @@ -0,0 +1,42 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "fmt" + "testing" +) + +func TestConstToString(t *testing.T) { + kinds := map[fmt.Stringer]string{ + Audit: "Audit", + Authentication: "Authentication", + Schema: "Schema", + Daemon: "Daemon", + Uninitialized: "Uninitialized", + Ready: "Ready", + Dying: "Dying", + Disable: "Disable", + Connected: "Connected", + Disconnect: "Disconnect", + ChangeUser: "ChangeUser", + PreAuth: "PreAuth", + ConnectionEvent(byte(15)): "", + } + for key, value := range kinds { + if key.String() != value { + t.Errorf("kind %s != %s", key.String(), kinds) + } + } +} diff --git a/plugin/helper_test.go b/plugin/helper_test.go new file mode 100644 index 0000000000000..1bb3fc71420ec --- /dev/null +++ b/plugin/helper_test.go @@ -0,0 +1,54 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import "testing" + +func TestPluginDeclare(t *testing.T) { + auditRaw := &AuditManifest{Manifest: Manifest{}} + auditExport := ExportManifest(auditRaw) + audit2 := DeclareAuditManifest(auditExport) + if audit2 != auditRaw { + t.Errorf("declare audit fail") + } + + authRaw := &AuthenticationManifest{Manifest: Manifest{}} + authExport := ExportManifest(authRaw) + auth2 := DeclareAuthenticationManifest(authExport) + if auth2 != authRaw { + t.Errorf("declare auth fail") + } + + schemaRaw := &SchemaManifest{Manifest: Manifest{}} + schemaExport := ExportManifest(schemaRaw) + schema2 := DeclareSchemaManifest(schemaExport) + if schema2 != schemaRaw { + t.Errorf("declare schema fail") + } + + daemonRaw := &DaemonManifest{Manifest: Manifest{}} + daemonExport := ExportManifest(daemonRaw) + daemon2 := DeclareDaemonManifest(daemonExport) + if daemon2 != daemonRaw { + t.Errorf("declare daemon fail") + } +} + +func TestDecode(t *testing.T) { + failID := ID("fail") + _, _, err := failID.Decode() + if err == nil { + t.Errorf("'fail' should not decode success") + } +} diff --git a/plugin/plugin.go b/plugin/plugin.go index ee5a7d68c6de6..3821ccb401d3e 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -49,8 +49,9 @@ type plugins struct { // clone deep copies plugins info. func (p *plugins) clone() *plugins { np := &plugins{ - plugins: make(map[Kind][]Plugin, len(p.plugins)), - versions: make(map[string]uint16, len(p.versions)), + plugins: make(map[Kind][]Plugin, len(p.plugins)), + versions: make(map[string]uint16, len(p.versions)), + dyingPlugins: make([]Plugin, len(p.dyingPlugins)), } for key, value := range p.plugins { np.plugins[key] = append([]Plugin(nil), value...) @@ -94,39 +95,31 @@ type Config struct { // Plugin presents a TiDB plugin. type Plugin struct { *Manifest - library *gplugin.Plugin - State State - Path string + library *gplugin.Plugin + State State + Path string + Disabled uint32 } -type validateMode int - -const ( - initMode validateMode = iota - reloadMode -) +// StateValue returns readable state string. +func (p *Plugin) StateValue() string { + flag := "enable" + if atomic.LoadUint32(&p.Disabled) == 1 { + flag = "disable" + } + return p.State.String() + "-" + flag +} -func (p *Plugin) validate(ctx context.Context, tiPlugins *plugins, mode validateMode) error { - if mode == reloadMode { - var oldPlugin *Plugin - for i, item := range tiPlugins.plugins[p.Kind] { - if item.Name == p.Name { - oldPlugin = &tiPlugins.plugins[p.Kind][i] - break - } - } - if oldPlugin == nil { - return errUnsupportedReloadPlugin.GenWithStackByArgs(p.Name) - } - if len(p.SysVars) != len(oldPlugin.SysVars) { - return errUnsupportedReloadPluginVar.GenWithStackByArgs("") - } - for varName, varVal := range p.SysVars { - if oldPlugin.SysVars[varName] == nil || *oldPlugin.SysVars[varName] != *varVal { - return errUnsupportedReloadPluginVar.GenWithStackByArgs(varVal) - } - } +// DisableFlag changes the disable flag of plugin. +func (p *Plugin) DisableFlag(disable bool) { + if disable { + atomic.StoreUint32(&p.Disabled, 1) + } else { + atomic.StoreUint32(&p.Disabled, 0) } +} + +func (p *Plugin) validate(ctx context.Context, tiPlugins *plugins) error { if p.RequireVersion != nil { for component, reqVer := range p.RequireVersion { if ver, ok := tiPlugins.versions[component]; !ok || ver < reqVer { @@ -197,7 +190,7 @@ func Load(ctx context.Context, cfg Config) (err error) { // Cross validate & Load plugins. for kind := range tiPlugins.plugins { for i := range tiPlugins.plugins[kind] { - if err = tiPlugins.plugins[kind][i].validate(ctx, tiPlugins, initMode); err != nil { + if err = tiPlugins.plugins[kind][i].validate(ctx, tiPlugins); err != nil { if cfg.SkipWhenFail { logutil.Logger(ctx).Warn("validate plugin fail and disable plugin", zap.String("plugin", tiPlugins.plugins[kind][i].Name), zap.Error(err)) @@ -251,6 +244,7 @@ func Init(ctx context.Context, cfg Config) (err error) { path: pluginWatchPrefix + tiPlugins.plugins[kind][i].Name, etcd: cfg.EtcdClient, manifest: tiPlugins.plugins[kind][i].Manifest, + plugin: &tiPlugins.plugins[kind][i], } tiPlugins.plugins[kind][i].flushWatcher = watcher go util.WithRecovery(watcher.watchLoop, nil) @@ -267,6 +261,7 @@ type flushWatcher struct { path string etcd *clientv3.Client manifest *Manifest + plugin *Plugin } func (w *flushWatcher) watchLoop() { @@ -276,7 +271,16 @@ func (w *flushWatcher) watchLoop() { case <-w.ctx.Done(): return case <-watchChan: - err := w.manifest.OnFlush(w.ctx, w.manifest) + disabled, err := w.getPluginDisabledFlag() + if err != nil { + logutil.Logger(context.Background()).Error("get plugin disabled flag failure", zap.String("plugin", w.manifest.Name), zap.Error(err)) + } + if disabled { + atomic.StoreUint32(&w.manifest.flushWatcher.plugin.Disabled, 1) + } else { + atomic.StoreUint32(&w.manifest.flushWatcher.plugin.Disabled, 0) + } + err = w.manifest.OnFlush(w.ctx, w.manifest) if err != nil { logutil.Logger(context.Background()).Error("notify plugin flush event failed", zap.String("plugin", w.manifest.Name), zap.Error(err)) } @@ -284,26 +288,39 @@ func (w *flushWatcher) watchLoop() { } } -func loadOne(dir string, pluginID ID) (plugin Plugin, err error) { - plugin.Path = filepath.Join(dir, string(pluginID)+LibrarySuffix) - plugin.library, err = gplugin.Open(plugin.Path) +func (w *flushWatcher) getPluginDisabledFlag() (bool, error) { + if w == nil || w.etcd == nil { + return true, errors.New("etcd is need to get plugin enable status") + } + resp, err := w.etcd.Get(context.Background(), w.manifest.flushWatcher.path) if err != nil { - err = errors.Trace(err) - return + return true, errors.Trace(err) } - manifestSym, err := plugin.library.Lookup(ManifestSymbol) + if len(resp.Kvs) == 0 { + return false, nil + } + return string(resp.Kvs[0].Value) == "1", nil +} + +type loadFn func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) + +var testHook *struct { + loadOne loadFn +} + +func loadOne(dir string, pluginID ID) (plugin Plugin, err error) { + pName, pVersion, err := pluginID.Decode() if err != nil { err = errors.Trace(err) return } - manifest, ok := manifestSym.(func() *Manifest) - if !ok { - err = errInvalidPluginManifest.GenWithStackByArgs(string(pluginID)) - return + var manifest func() *Manifest + if testHook == nil { + manifest, err = loadManifestByGoPlugin(&plugin, dir, pluginID) + } else { + manifest, err = testHook.loadOne(&plugin, dir, pluginID) } - pName, pVersion, err := pluginID.Decode() if err != nil { - err = errors.Trace(err) return } plugin.Manifest = manifest() @@ -318,6 +335,27 @@ func loadOne(dir string, pluginID ID) (plugin Plugin, err error) { return } +func loadManifestByGoPlugin(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) { + plugin.Path = filepath.Join(dir, string(pluginID)+LibrarySuffix) + plugin.library, err = gplugin.Open(plugin.Path) + if err != nil { + err = errors.Trace(err) + return + } + manifestSym, err := plugin.library.Lookup(ManifestSymbol) + if err != nil { + err = errors.Trace(err) + return + } + var ok bool + manifest, ok = manifestSym.(func() *Manifest) + if !ok { + err = errInvalidPluginManifest.GenWithStackByArgs(string(pluginID)) + return + } + return +} + // Shutdown cleanups all plugin resources. // Notice: it just cleanups the resource of plugin, but cannot unload plugins(limited by go plugin). func Shutdown(ctx context.Context) { @@ -332,6 +370,9 @@ func Shutdown(ctx context.Context) { if p.flushWatcher != nil { p.flushWatcher.cancel() } + if p.OnShutdown == nil { + continue + } if err := p.OnShutdown(ctx, p.Manifest); err != nil { logutil.Logger(ctx).Error("call OnShutdown for failure", zap.String("plugin", p.Name), zap.Error(err)) @@ -369,6 +410,9 @@ func ForeachPlugin(kind Kind, fn func(plugin *Plugin) error) error { if p.State != Ready { continue } + if atomic.LoadUint32(&p.Disabled) == 1 { + continue + } err := fn(p) if err != nil { return err @@ -377,6 +421,21 @@ func ForeachPlugin(kind Kind, fn func(plugin *Plugin) error) error { return nil } +// IsEnable checks plugin's enable state. +func IsEnable(kind Kind) bool { + plugins := pluginGlobal.plugins() + if plugins == nil { + return false + } + for i := range plugins.plugins[kind] { + p := &plugins.plugins[kind][i] + if p.State == Ready && atomic.LoadUint32(&p.Disabled) != 1 { + return true + } + } + return false +} + // GetAll finds and returns all plugins. func GetAll() map[Kind][]Plugin { plugins := pluginGlobal.plugins() @@ -392,7 +451,25 @@ func NotifyFlush(dom *domain.Domain, pluginName string) error { if p == nil || p.Manifest.flushWatcher == nil || p.State != Ready { return errors.Errorf("plugin %s doesn't exists or unsupported flush or doesn't start with PD", pluginName) } - _, err := dom.GetEtcdClient().KV.Put(context.Background(), p.Manifest.flushWatcher.path, "") + _, err := dom.GetEtcdClient().KV.Put(context.Background(), p.Manifest.flushWatcher.path, strconv.Itoa(int(p.Disabled))) + if err != nil { + return err + } + return nil +} + +// ChangeDisableFlagAndFlush changes plugin disable flag and notify other nodes to do same change. +func ChangeDisableFlagAndFlush(dom *domain.Domain, pluginName string, disable bool) error { + p := getByName(pluginName) + if p == nil || p.Manifest.flushWatcher == nil || p.State != Ready { + return errors.Errorf("plugin %s doesn't exists or unsupported flush or doesn't start with PD", pluginName) + } + disableInt := uint32(0) + if disable { + disableInt = 1 + } + atomic.StoreUint32(&p.Disabled, disableInt) + _, err := dom.GetEtcdClient().KV.Put(context.Background(), p.Manifest.flushWatcher.path, strconv.Itoa(int(disableInt))) if err != nil { return err } diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go new file mode 100644 index 0000000000000..0f5acb6b26616 --- /dev/null +++ b/plugin/plugin_test.go @@ -0,0 +1,281 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "context" + "io" + "strconv" + "testing" + + "github.com/pingcap/tidb/sessionctx/variable" +) + +func TestLoadPluginSuccess(t *testing.T) { + ctx := context.Background() + + pluginName := "tplugin" + pluginVersion := uint16(1) + pluginSign := pluginName + "-" + strconv.Itoa(int(pluginVersion)) + + cfg := Config{ + Plugins: []string{pluginSign}, + PluginDir: "", + GlobalSysVar: &variable.SysVars, + PluginVarNames: &variable.PluginVarNames, + EnvVersion: map[string]uint16{"go": 1112}, + } + + // setup load test hook. + testHook = &struct{ loadOne loadFn }{loadOne: func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) { + return func() *Manifest { + m := &AuditManifest{ + Manifest: Manifest{ + Kind: Authentication, + Name: pluginName, + Version: pluginVersion, + SysVars: map[string]*variable.SysVar{pluginName + "_key": {Scope: variable.ScopeGlobal, Name: pluginName + "_key", Value: "v1"}}, + OnInit: func(ctx context.Context, manifest *Manifest) error { + return nil + }, + OnShutdown: func(ctx context.Context, manifest *Manifest) error { + return nil + }, + Validate: func(ctx context.Context, manifest *Manifest) error { + return nil + }, + }, + OnGeneralEvent: func(ctx context.Context, sctx *variable.SessionVars, event GeneralEvent, cmd string) { + }, + } + return ExportManifest(m) + }, nil + }} + defer func() { + testHook = nil + }() + + // trigger load. + err := Load(ctx, cfg) + if err != nil { + t.Errorf("load plugin [%s] fail", pluginSign) + } + + err = Init(ctx, cfg) + if err != nil { + t.Errorf("init plugin [%s] fail", pluginSign) + } + + // load all. + ps := GetAll() + if len(ps) != 1 { + t.Errorf("loaded plugins is empty") + } + + // find plugin by type and name + p := Get(Authentication, "tplugin") + if p == nil { + t.Errorf("tplugin can not be load") + } + p = Get(Authentication, "tplugin2") + if p != nil { + t.Errorf("found miss plugin") + } + p = getByName("tplugin") + if p == nil { + t.Errorf("can not find miss plugin") + } + + // foreach plugin + err = ForeachPlugin(Authentication, func(plugin *Plugin) error { + return nil + }) + if err != nil { + t.Errorf("foreach error %v", err) + } + err = ForeachPlugin(Authentication, func(plugin *Plugin) error { + return io.EOF + }) + if err != io.EOF { + t.Errorf("foreach should return EOF error") + } + + Shutdown(ctx) +} + +func TestLoadPluginSkipError(t *testing.T) { + ctx := context.Background() + + pluginName := "tplugin" + pluginVersion := uint16(1) + pluginSign := pluginName + "-" + strconv.Itoa(int(pluginVersion)) + + cfg := Config{ + Plugins: []string{pluginSign, pluginSign, "notExists-2"}, + PluginDir: "", + PluginVarNames: &variable.PluginVarNames, + EnvVersion: map[string]uint16{"go": 1112}, + SkipWhenFail: true, + } + + // setup load test hook. + testHook = &struct{ loadOne loadFn }{loadOne: func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) { + return func() *Manifest { + m := &AuditManifest{ + Manifest: Manifest{ + Kind: Audit, + Name: pluginName, + Version: pluginVersion, + SysVars: map[string]*variable.SysVar{pluginName + "_key": {Scope: variable.ScopeGlobal, Name: pluginName + "_key", Value: "v1"}}, + OnInit: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + OnShutdown: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + Validate: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + }, + OnGeneralEvent: func(ctx context.Context, sctx *variable.SessionVars, event GeneralEvent, cmd string) { + }, + } + return ExportManifest(m) + }, nil + }} + defer func() { + testHook = nil + }() + + // trigger load. + err := Load(ctx, cfg) + if err != nil { + t.Errorf("load plugin [%s] fail %v", pluginSign, err) + } + + err = Init(ctx, cfg) + if err != nil { + t.Errorf("init plugin [%s] fail", pluginSign) + } + + // load all. + ps := GetAll() + if len(ps) != 1 { + t.Errorf("loaded plugins is empty") + } + + // find plugin by type and name + p := Get(Audit, "tplugin") + if p == nil { + t.Errorf("tplugin can not be load") + } + p = Get(Audit, "tplugin2") + if p != nil { + t.Errorf("found miss plugin") + } + p = getByName("tplugin") + if p == nil { + t.Errorf("can not find miss plugin") + } + p = getByName("not exists") + if p != nil { + t.Errorf("got not exists plugin") + } + + // foreach plugin + readyCount := 0 + err = ForeachPlugin(Authentication, func(plugin *Plugin) error { + readyCount++ + return nil + }) + if err != nil { + t.Errorf("foreach meet error %v", err) + } + if readyCount != 0 { + t.Errorf("validate fail can be load but no ready") + } + + Shutdown(ctx) +} + +func TestLoadFail(t *testing.T) { + ctx := context.Background() + + pluginName := "tplugin" + pluginVersion := uint16(1) + pluginSign := pluginName + "-" + strconv.Itoa(int(pluginVersion)) + + cfg := Config{ + Plugins: []string{pluginSign, pluginSign, "notExists-2"}, + PluginDir: "", + PluginVarNames: &variable.PluginVarNames, + EnvVersion: map[string]uint16{"go": 1112}, + SkipWhenFail: false, + } + + // setup load test hook. + testHook = &struct{ loadOne loadFn }{loadOne: func(plugin *Plugin, dir string, pluginID ID) (manifest func() *Manifest, err error) { + return func() *Manifest { + m := &AuditManifest{ + Manifest: Manifest{ + Kind: Audit, + Name: pluginName, + Version: pluginVersion, + SysVars: map[string]*variable.SysVar{pluginName + "_key": {Scope: variable.ScopeGlobal, Name: pluginName + "_key", Value: "v1"}}, + OnInit: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + OnShutdown: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + Validate: func(ctx context.Context, manifest *Manifest) error { + return io.EOF + }, + }, + OnGeneralEvent: func(ctx context.Context, sctx *variable.SessionVars, event GeneralEvent, cmd string) { + }, + } + return ExportManifest(m) + }, nil + }} + defer func() { + testHook = nil + }() + + err := Load(ctx, cfg) + if err == nil { + t.Errorf("load plugin should fail") + } +} + +func TestPluginsClone(t *testing.T) { + ps := &plugins{ + plugins: map[Kind][]Plugin{ + Audit: {{}}, + }, + versions: map[string]uint16{ + "whitelist": 1, + }, + dyingPlugins: []Plugin{{}}, + } + cps := ps.clone() + ps.dyingPlugins = append(ps.dyingPlugins, Plugin{}) + ps.versions["w"] = 2 + as := ps.plugins[Audit] + ps.plugins[Audit] = append(as, Plugin{}) + + if len(cps.plugins) != 1 || len(cps.plugins[Audit]) != 1 || len(cps.versions) != 1 || len(cps.dyingPlugins) != 1 { + t.Errorf("clone plugins failure") + } +} diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index 238b9e930545f..699bfe616dacc 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -380,7 +380,7 @@ func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string, defer terror.Call(rs.Close) fs := rs.Fields() - req := rs.NewRecordBatch() + req := rs.NewChunk() for { err = rs.Next(context.TODO(), req) if err != nil { @@ -389,7 +389,7 @@ func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string, if req.NumRows() == 0 { return nil } - it := chunk.NewIterator4Chunk(req.Chunk) + it := chunk.NewIterator4Chunk(req) for row := it.Begin(); row != it.End(); row = it.Next() { err = decodeTableRow(row, fs) if err != nil { @@ -399,7 +399,7 @@ func (p *MySQLPrivilege) loadTable(sctx sessionctx.Context, sql string, // NOTE: decodeTableRow decodes data from a chunk Row, that is a shallow copy. // The result will reference memory in the chunk, so the chunk must not be reused // here, otherwise some werid bug will happen! - req.Chunk = chunk.Renew(req.Chunk, sctx.GetSessionVars().MaxChunkSize) + req = chunk.Renew(req, sctx.GetSessionVars().MaxChunkSize) } } @@ -848,13 +848,19 @@ func (p *MySQLPrivilege) showGrants(user, host string, roles []*auth.RoleIdentit edgeTable, ok := p.RoleGraph[graphKey] g = "" if ok { + sortedRes := make([]string, 0, 10) for k := range edgeTable.roleList { role := strings.Split(k, "@") roleName, roleHost := role[0], role[1] - if g != "" { + tmp := fmt.Sprintf("'%s'@'%s'", roleName, roleHost) + sortedRes = append(sortedRes, tmp) + } + sort.Strings(sortedRes) + for i, r := range sortedRes { + g += r + if i != len(sortedRes)-1 { g += ", " } - g += fmt.Sprintf("'%s'@'%s'", roleName, roleHost) } s := fmt.Sprintf(`GRANT %s TO '%s'@'%s'`, g, user, host) gs = append(gs, s) diff --git a/privilege/privileges/cache_test.go b/privilege/privileges/cache_test.go index 6a3ffbdfa9fc1..b89d4c85f5e28 100644 --- a/privilege/privileges/cache_test.go +++ b/privilege/privileges/cache_test.go @@ -35,7 +35,7 @@ type testCacheSuite struct { func (s *testCacheSuite) SetUpSuite(c *C) { store, err := mockstore.NewMockTikvStore() session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() c.Assert(err, IsNil) s.domain, err = session.BootstrapSession(store) c.Assert(err, IsNil) @@ -276,7 +276,7 @@ func (s *testCacheSuite) TestAbnormalMySQLTable(c *C) { c.Assert(err, IsNil) defer store.Close() session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() dom, err := session.BootstrapSession(store) c.Assert(err, IsNil) diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index b9a74b09df24f..01956064882b6 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -52,6 +52,11 @@ func (p *UserPrivileges) RequestVerification(activeRoles []*auth.RoleIdentity, d // Skip check for INFORMATION_SCHEMA database. // See https://dev.mysql.com/doc/refman/5.7/en/information-schema.html if strings.EqualFold(db, "INFORMATION_SCHEMA") { + switch priv { + case mysql.CreatePriv, mysql.AlterPriv, mysql.DropPriv, mysql.IndexPriv, mysql.CreateViewPriv, + mysql.InsertPriv, mysql.UpdatePriv, mysql.DeletePriv: + return false + } return true } @@ -185,6 +190,9 @@ func (p *UserPrivileges) UserPrivilegesTable() [][]types.Datum { // ShowGrants implements privilege.Manager ShowGrants interface. func (p *UserPrivileges) ShowGrants(ctx sessionctx.Context, user *auth.UserIdentity, roles []*auth.RoleIdentity) (grants []string, err error) { + if SkipWithGrant { + return nil, errNonexistingGrant.GenWithStackByArgs("root", "%") + } mysqlPrivilege := p.Handle.Get() u := user.Username h := user.Hostname @@ -202,6 +210,9 @@ func (p *UserPrivileges) ShowGrants(ctx sessionctx.Context, user *auth.UserIdent // ActiveRoles implements privilege.Manager ActiveRoles interface. func (p *UserPrivileges) ActiveRoles(ctx sessionctx.Context, roleList []*auth.RoleIdentity) (bool, string) { + if SkipWithGrant { + return true, "" + } mysqlPrivilege := p.Handle.Get() u := p.user h := p.host @@ -218,6 +229,9 @@ func (p *UserPrivileges) ActiveRoles(ctx sessionctx.Context, roleList []*auth.Ro // FindEdge implements privilege.Manager FindRelationship interface. func (p *UserPrivileges) FindEdge(ctx sessionctx.Context, role *auth.RoleIdentity, user *auth.UserIdentity) bool { + if SkipWithGrant { + return false + } mysqlPrivilege := p.Handle.Get() ok := mysqlPrivilege.FindRole(user.Username, user.Hostname, role) if !ok { @@ -229,6 +243,9 @@ func (p *UserPrivileges) FindEdge(ctx sessionctx.Context, role *auth.RoleIdentit // GetDefaultRoles returns all default roles for certain user. func (p *UserPrivileges) GetDefaultRoles(user, host string) []*auth.RoleIdentity { + if SkipWithGrant { + return make([]*auth.RoleIdentity, 0, 10) + } mysqlPrivilege := p.Handle.Get() ret := mysqlPrivilege.getDefaultRoles(user, host) return ret @@ -236,6 +253,10 @@ func (p *UserPrivileges) GetDefaultRoles(user, host string) []*auth.RoleIdentity // GetAllRoles return all roles of user. func (p *UserPrivileges) GetAllRoles(user, host string) []*auth.RoleIdentity { + if SkipWithGrant { + return make([]*auth.RoleIdentity, 0, 10) + } + mysqlPrivilege := p.Handle.Get() return mysqlPrivilege.getAllRoles(user, host) } diff --git a/privilege/privileges/privileges_test.go b/privilege/privileges/privileges_test.go index 82b80cbc8c752..6ada150296ddb 100644 --- a/privilege/privileges/privileges_test.go +++ b/privilege/privileges/privileges_test.go @@ -129,6 +129,24 @@ func (s *testPrivilegeSuite) TestCheckDBPrivilege(c *C) { c.Assert(pc.RequestVerification(activeRoles, "test", "", "", mysql.UpdatePriv), IsTrue) } +func (s *testPrivilegeSuite) TestCheckPointGetDBPrivilege(c *C) { + rootSe := newSession(c, s.store, s.dbName) + mustExec(c, rootSe, `CREATE USER 'tester'@'localhost';`) + mustExec(c, rootSe, `GRANT SELECT,UPDATE ON test.* TO 'tester'@'localhost';`) + mustExec(c, rootSe, `flush privileges;`) + mustExec(c, rootSe, `create database test2`) + mustExec(c, rootSe, `create table test2.t(id int, v int, primary key(id))`) + mustExec(c, rootSe, `insert into test2.t(id, v) values(1, 1)`) + + se := newSession(c, s.store, s.dbName) + c.Assert(se.Auth(&auth.UserIdentity{Username: "tester", Hostname: "localhost"}, nil, nil), IsTrue) + mustExec(c, se, `use test;`) + _, err := se.Execute(context.Background(), `select * from test2.t where id = 1`) + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) + _, err = se.Execute(context.Background(), "update test2.t set v = 2 where id = 1") + c.Assert(terror.ErrorEqual(err, core.ErrTableaccessDenied), IsTrue) +} + func (s *testPrivilegeSuite) TestCheckTablePrivilege(c *C) { rootSe := newSession(c, s.store, s.dbName) mustExec(c, rootSe, `CREATE USER 'test1'@'localhost';`) @@ -162,6 +180,26 @@ func (s *testPrivilegeSuite) TestCheckTablePrivilege(c *C) { c.Assert(pc2.RequestVerification(activeRoles, "test", "test", "", mysql.IndexPriv), IsTrue) } +func (s *testPrivilegeSuite) TestCheckViewPrivilege(c *C) { + rootSe := newSession(c, s.store, s.dbName) + mustExec(c, rootSe, `CREATE USER 'vuser'@'localhost';`) + mustExec(c, rootSe, `CREATE VIEW v AS SELECT * FROM test;`) + + se := newSession(c, s.store, s.dbName) + activeRoles := make([]*auth.RoleIdentity, 0) + c.Assert(se.Auth(&auth.UserIdentity{Username: "vuser", Hostname: "localhost"}, nil, nil), IsTrue) + pc := privilege.GetPrivilegeManager(se) + c.Assert(pc.RequestVerification(activeRoles, "test", "v", "", mysql.SelectPriv), IsFalse) + + mustExec(c, rootSe, `GRANT SELECT ON test.v TO 'vuser'@'localhost';`) + c.Assert(pc.RequestVerification(activeRoles, "test", "v", "", mysql.SelectPriv), IsTrue) + c.Assert(pc.RequestVerification(activeRoles, "test", "v", "", mysql.ShowViewPriv), IsFalse) + + mustExec(c, rootSe, `GRANT SHOW VIEW ON test.v TO 'vuser'@'localhost';`) + c.Assert(pc.RequestVerification(activeRoles, "test", "v", "", mysql.SelectPriv), IsTrue) + c.Assert(pc.RequestVerification(activeRoles, "test", "v", "", mysql.ShowViewPriv), IsTrue) +} + func (s *testPrivilegeSuite) TestCheckPrivilegeWithRoles(c *C) { rootSe := newSession(c, s.store, s.dbName) mustExec(c, rootSe, `CREATE USER 'test_role'@'localhost';`) @@ -474,7 +512,6 @@ func (s *testPrivilegeSuite) TestUseDB(c *C) { mustExec(c, se, `CREATE USER 'dev'@'localhost'`) mustExec(c, se, `GRANT 'app_developer' TO 'dev'@'localhost'`) mustExec(c, se, `SET DEFAULT ROLE 'app_developer' TO 'dev'@'localhost'`) - mustExec(c, se, `FLUSH PRIVILEGES`) c.Assert(se.Auth(&auth.UserIdentity{Username: "dev", Hostname: "localhost", AuthUsername: "dev", AuthHostname: "localhost"}, nil, nil), IsTrue) _, err = se.Execute(context.Background(), "use app_db") c.Assert(err, IsNil) @@ -556,7 +593,7 @@ func (s *testPrivilegeSuite) TestAnalyzeTable(c *C) { c.Assert(err.Error(), Equals, "[planner:1142]INSERT command denied to user 'anobody'@'%' for table 't1'") _, err = se.Execute(context.Background(), "select * from t1") - c.Assert(err.Error(), Equals, "[planner:1142]SELECT command denied to user 'localhost'@'anobody' for table 't1'") + c.Assert(err.Error(), Equals, "[planner:1142]SELECT command denied to user 'anobody'@'localhost' for table 't1'") // try again after SELECT privilege granted c.Assert(se.Auth(&auth.UserIdentity{Username: "asuper", Hostname: "localhost", AuthUsername: "asuper", AuthHostname: "%"}, nil, nil), IsTrue) @@ -582,6 +619,12 @@ func (s *testPrivilegeSuite) TestInformationSchema(c *C) { c.Assert(se.Auth(&auth.UserIdentity{Username: "u1", Hostname: "localhost"}, nil, nil), IsTrue) mustExec(c, se, `select * from information_schema.tables`) mustExec(c, se, `select * from information_schema.key_column_usage`) + _, err := se.Execute(context.Background(), "create table information_schema.t(a int)") + c.Assert(strings.Contains(err.Error(), "denied to user"), IsTrue) + _, err = se.Execute(context.Background(), "drop table information_schema.tables") + c.Assert(strings.Contains(err.Error(), "denied to user"), IsTrue) + _, err = se.Execute(context.Background(), "update information_schema.tables set table_name = 'tst' where table_name = 'mysql'") + c.Assert(strings.Contains(err.Error(), "privilege check fail"), IsTrue) } func (s *testPrivilegeSuite) TestAdminCommand(c *C) { @@ -639,7 +682,7 @@ func mustExec(c *C, se session.Session, sql string) { func newStore(c *C, dbPath string) (*domain.Domain, kv.Storage) { store, err := mockstore.NewMockTikvStore() session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() c.Assert(err, IsNil) dom, err := session.BootstrapSession(store) c.Assert(err, IsNil) diff --git a/server/conn.go b/server/conn.go index 24d16c0af9b50..0c589786cb8d5 100644 --- a/server/conn.go +++ b/server/conn.go @@ -45,12 +45,12 @@ import ( "runtime" "strconv" "strings" - "sync" "sync/atomic" "time" "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -65,6 +65,7 @@ import ( "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/sqlexec" "go.uber.org/zap" ) @@ -150,18 +151,11 @@ type clientConn struct { peerHost string // peer host peerPort string // peer port lastCode uint16 // last error code - - // mu is used for cancelling the execution of current transaction. - mu struct { - sync.RWMutex - cancelFunc context.CancelFunc - resultSets []ResultSet - } } func (cc *clientConn) String() string { collationStr := mysql.Collations[cc.collation] - return fmt.Sprintf("id:%d, addr:%s status:%d, collation:%s, user:%s", + return fmt.Sprintf("id:%d, addr:%s status:%b, collation:%s, user:%s", cc.connectionID, cc.bufReadConn.RemoteAddr(), cc.ctx.Status(), collationStr, cc.user, ) } @@ -270,6 +264,11 @@ func (cc *clientConn) readPacket() ([]byte, error) { } func (cc *clientConn) writePacket(data []byte) error { + failpoint.Inject("FakeClientConn", func() { + if cc.pkt == nil { + failpoint.Return(nil) + } + }) return cc.pkt.writePacket(data) } @@ -534,6 +533,20 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con return err } +func (cc *clientConn) SessionStatusToString() string { + status := cc.ctx.Status() + inTxn, autoCommit := 0, 0 + if status&mysql.ServerStatusInTrans > 0 { + inTxn = 1 + } + if status&mysql.ServerStatusAutocommit > 0 { + autoCommit = 1 + } + return fmt.Sprintf("inTxn:%d, autocommit:%d", + inTxn, autoCommit, + ) +} + func (cc *clientConn) openSessionAndDoAuth(authData []byte) error { var tlsStatePtr *tls.ConnectionState if cc.tlsConn != nil { @@ -637,7 +650,8 @@ func (cc *clientConn) Run(ctx context.Context) { } else { errStack := errors.ErrorStack(err) if !strings.Contains(errStack, "use of closed network connection") { - logutil.Logger(ctx).Error("read packet failed, close this connection", zap.Error(err)) + logutil.Logger(ctx).Warn("read packet failed, close this connection", + zap.Error(errors.SuspendStack(err))) } } } @@ -667,6 +681,8 @@ func (cc *clientConn) Run(ctx context.Context) { } logutil.Logger(ctx).Warn("dispatch error", zap.String("connInfo", cc.String()), + zap.String("command", mysql.Command2Str[data[0]]), + zap.String("status", cc.SessionStatusToString()), zap.String("sql", queryStrForLog(string(data[1:]))), zap.String("err", errStrForLog(err)), ) @@ -847,22 +863,23 @@ func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { span := opentracing.StartSpan("server.dispatch") - ctx1, cancelFunc := context.WithCancel(ctx) - cc.mu.Lock() - cc.mu.cancelFunc = cancelFunc - cc.mu.Unlock() - t := time.Now() cmd := data[0] data = data[1:] cc.lastCmd = string(hack.String(data)) token := cc.server.getToken() defer func() { - cc.ctx.SetProcessInfo("", t, mysql.ComSleep) + // if handleChangeUser failed, cc.ctx may be nil + if cc.ctx != nil { + cc.ctx.SetProcessInfo("", t, mysql.ComSleep, 0) + } + cc.server.releaseToken(token) span.Finish() }() + vars := cc.ctx.GetSessionVars() + atomic.StoreUint32(&vars.Killed, 0) if cmd < mysql.ComEnd { cc.ctx.SetCommandValue(cmd) } @@ -871,9 +888,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { switch cmd { case mysql.ComPing, mysql.ComStmtClose, mysql.ComStmtSendLongData, mysql.ComStmtReset, mysql.ComSetOption, mysql.ComChangeUser: - cc.ctx.SetProcessInfo("", t, cmd) + cc.ctx.SetProcessInfo("", t, cmd, 0) case mysql.ComInitDB: - cc.ctx.SetProcessInfo("use "+dataStr, t, cmd) + cc.ctx.SetProcessInfo("use "+dataStr, t, cmd, 0) } switch cmd { @@ -893,11 +910,11 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { data = data[:len(data)-1] dataStr = string(hack.String(data)) } - return cc.handleQuery(ctx1, dataStr) + return cc.handleQuery(ctx, dataStr) case mysql.ComPing: return cc.writeOK() case mysql.ComInitDB: - if err := cc.useDB(ctx1, dataStr); err != nil { + if err := cc.useDB(ctx, dataStr); err != nil { return err } return cc.writeOK() @@ -906,9 +923,9 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { case mysql.ComStmtPrepare: return cc.handleStmtPrepare(dataStr) case mysql.ComStmtExecute: - return cc.handleStmtExecute(ctx1, data) + return cc.handleStmtExecute(ctx, data) case mysql.ComStmtFetch: - return cc.handleStmtFetch(ctx1, data) + return cc.handleStmtFetch(ctx, data) case mysql.ComStmtClose: return cc.handleStmtClose(data) case mysql.ComStmtSendLongData: @@ -918,7 +935,7 @@ func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { case mysql.ComSetOption: return cc.handleSetOption(data) case mysql.ComChangeUser: - return cc.handleChangeUser(ctx1, data) + return cc.handleChangeUser(ctx, data) default: return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd) } @@ -936,11 +953,20 @@ func (cc *clientConn) useDB(ctx context.Context, db string) (err error) { } func (cc *clientConn) flush() error { + failpoint.Inject("FakeClientConn", func() { + if cc.pkt == nil { + failpoint.Return(nil) + } + }) return cc.pkt.flush() } func (cc *clientConn) writeOK() error { msg := cc.ctx.LastMessage() + return cc.writeOkWith(msg, cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), cc.ctx.Status(), cc.ctx.WarningCount()) +} + +func (cc *clientConn) writeOkWith(msg string, affectedRows, lastInsertID uint64, status, warnCnt uint16) error { enclen := 0 if len(msg) > 0 { enclen = lengthEncodedIntSize(uint64(len(msg))) + len(msg) @@ -948,11 +974,11 @@ func (cc *clientConn) writeOK() error { data := cc.alloc.AllocWithLen(4, 32+enclen) data = append(data, mysql.OKHeader) - data = dumpLengthEncodedInt(data, cc.ctx.AffectedRows()) - data = dumpLengthEncodedInt(data, cc.ctx.LastInsertID()) + data = dumpLengthEncodedInt(data, affectedRows) + data = dumpLengthEncodedInt(data, lastInsertID) if cc.capability&mysql.ClientProtocol41 > 0 { - data = dumpUint16(data, cc.ctx.Status()) - data = dumpUint16(data, cc.ctx.WarningCount()) + data = dumpUint16(data, status) + data = dumpUint16(data, warnCnt) } if enclen > 0 { // although MySQL manual says the info message is string(https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html), @@ -1038,7 +1064,7 @@ func insertDataWithCommit(ctx context.Context, prevData, curData []byte, loadDat var err error var reachLimit bool for { - prevData, reachLimit, err = loadDataInfo.InsertData(prevData, curData) + prevData, reachLimit, err = loadDataInfo.InsertData(ctx, prevData, curData) if err != nil { return nil, err } @@ -1166,25 +1192,23 @@ func (cc *clientConn) handleLoadStats(ctx context.Context, loadStatsInfo *execut // There is a special query `load data` that does not return result, which is handled differently. // Query `load stats` does not return result either. func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) { - rs, err := cc.ctx.Execute(ctx, sql) + rss, err := cc.ctx.Execute(ctx, sql) if err != nil { metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc() return err } - cc.mu.Lock() - cc.mu.resultSets = rs status := atomic.LoadInt32(&cc.status) - if status == connStatusShutdown || status == connStatusWaitShutdown { - cc.mu.Unlock() - killConn(cc) - return errors.New("killed by another connection") - } - cc.mu.Unlock() - if rs != nil { - if len(rs) == 1 { - err = cc.writeResultset(ctx, rs[0], false, 0, 0) + if rss != nil && (status == connStatusShutdown || status == connStatusWaitShutdown) { + for _, rs := range rss { + terror.Call(rs.Close) + } + return executor.ErrQueryInterrupted + } + if rss != nil { + if len(rss) == 1 { + err = cc.writeResultset(ctx, rss[0], false, 0, 0) } else { - err = cc.writeMultiResultset(ctx, rs, false) + err = cc.writeMultiResultset(ctx, rss, false) } } else { loadDataInfo := cc.ctx.Value(executor.LoadDataVarKey) @@ -1270,6 +1294,7 @@ func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary b if err != nil { return err } + return cc.flush() } @@ -1294,7 +1319,7 @@ func (cc *clientConn) writeColumnInfo(columns []*ColumnInfo, serverStatus uint16 // serverStatus, a flag bit represents server information func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16) error { data := cc.alloc.AllocWithLen(4, 1024) - req := rs.NewRecordBatch() + req := rs.NewChunk() gotColumnInfo := false for { // Here server.tidbResultSet implements Next method. @@ -1342,7 +1367,7 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet fetchedRows := rs.GetFetchedRows() // if fetchedRows is not enough, getting data from recordSet. - req := rs.NewRecordBatch() + req := rs.NewChunk() for len(fetchedRows) < fetchSize { // Here server.tidbResultSet implements Next method. err := rs.Next(ctx, req) @@ -1357,7 +1382,7 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet for i := 0; i < rowCount; i++ { fetchedRows = append(fetchedRows, req.GetRow(i)) } - req.Chunk = chunk.Renew(req.Chunk, cc.ctx.GetSessionVars().MaxChunkSize) + req = chunk.Renew(req, cc.ctx.GetSessionVars().MaxChunkSize) } // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, @@ -1391,16 +1416,34 @@ func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet return err } } + if cl, ok := rs.(fetchNotifier); ok { + cl.OnFetchReturned() + } return cc.writeEOF(serverStatus) } func (cc *clientConn) writeMultiResultset(ctx context.Context, rss []ResultSet, binary bool) error { - for _, rs := range rss { - if err := cc.writeResultset(ctx, rs, binary, mysql.ServerMoreResultsExists, 0); err != nil { + for i, rs := range rss { + lastRs := i == len(rss)-1 + if r, ok := rs.(*tidbResultSet).recordSet.(sqlexec.MultiQueryNoDelayResult); ok { + status := r.Status() + if !lastRs { + status |= mysql.ServerMoreResultsExists + } + if err := cc.writeOkWith(r.LastMessage(), r.AffectedRows(), r.LastInsertID(), status, r.WarnCount()); err != nil { + return err + } + continue + } + status := uint16(0) + if !lastRs { + status |= mysql.ServerMoreResultsExists + } + if err := cc.writeResultset(ctx, rs, binary, status, 0); err != nil { return err } } - return cc.writeOK() + return nil } func (cc *clientConn) setConn(conn net.Conn) { @@ -1449,11 +1492,15 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { return err } + if plugin.IsEnable(plugin.Audit) { + cc.ctx.GetSessionVars().ConnectionInfo = cc.connectInfo() + } + err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { - connInfo := cc.connectInfo() - err = authPlugin.OnConnectionEvent(context.Background(), &auth.UserIdentity{Hostname: connInfo.Host}, plugin.ChangeUser, connInfo) + connInfo := cc.ctx.GetSessionVars().ConnectionInfo + err = authPlugin.OnConnectionEvent(context.Background(), plugin.ChangeUser, connInfo) if err != nil { return err } diff --git a/server/conn_stmt.go b/server/conn_stmt.go index d55ba0eea3a0f..b4b1ebcfd0ef6 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -176,12 +176,12 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e err = parseStmtArgs(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues) stmt.Reset() if err != nil { - return errors.Annotatef(err, "%s", cc.preparedStmt2String(stmtID)) + return errors.Annotate(err, cc.preparedStmt2String(stmtID)) } } rs, err := stmt.Execute(ctx, args...) if err != nil { - return errors.Annotatef(err, "%s", cc.preparedStmt2String(stmtID)) + return errors.Annotate(err, cc.preparedStmt2String(stmtID)) } if rs == nil { return cc.writeOK() @@ -196,10 +196,17 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e if err != nil { return err } + if cl, ok := rs.(fetchNotifier); ok { + cl.OnFetchReturned() + } // explicitly flush columnInfo to client. return cc.flush() } - return cc.writeResultset(ctx, rs, true, 0, 0) + err = cc.writeResultset(ctx, rs, true, 0, 0) + if err != nil { + return errors.Annotate(err, cc.preparedStmt2String(stmtID)) + } + return nil } // maxFetchSize constants @@ -208,6 +215,7 @@ const ( ) func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) { + cc.ctx.GetSessionVars().StartTime = time.Now() stmtID, fetchSize, err := parseStmtFetchCmd(data) if err != nil { @@ -216,21 +224,25 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err stmt := cc.ctx.GetStatement(int(stmtID)) if stmt == nil { - return mysql.NewErr(mysql.ErrUnknownStmtHandler, - strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch") + return errors.Annotate(mysql.NewErr(mysql.ErrUnknownStmtHandler, + strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch"), cc.preparedStmt2String(stmtID)) } sql := "" if prepared, ok := cc.ctx.GetStatement(int(stmtID)).(*TiDBStatement); ok { sql = prepared.sql } - cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute) + cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute, 0) rs := stmt.GetResultSet() if rs == nil { - return mysql.NewErr(mysql.ErrUnknownStmtHandler, - strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs") + return errors.Annotate(mysql.NewErr(mysql.ErrUnknownStmtHandler, + strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch_rs"), cc.preparedStmt2String(stmtID)) } - return cc.writeResultset(ctx, rs, true, mysql.ServerStatusCursorExists, int(fetchSize)) + err = cc.writeResultset(ctx, rs, true, mysql.ServerStatusCursorExists, int(fetchSize)) + if err != nil { + return errors.Annotate(err, cc.preparedStmt2String(stmtID)) + } + return nil } func parseStmtFetchCmd(data []byte) (uint32, uint32, error) { @@ -540,6 +552,7 @@ func (cc *clientConn) handleStmtReset(data []byte) (err error) { strconv.Itoa(stmtID), "stmt_reset") } stmt.Reset() + stmt.StoreResultSet(nil) return cc.writeOK() } @@ -569,7 +582,7 @@ func (cc *clientConn) handleSetOption(data []byte) (err error) { func (cc *clientConn) preparedStmt2String(stmtID uint32) string { sv := cc.ctx.GetSessionVars() if prepared, ok := sv.PreparedStmts[stmtID]; ok { - return prepared.Stmt.Text() + sv.GetExecuteArgumentsInfo() + return prepared.Stmt.Text() + sv.PreparedParams.String() } - return fmt.Sprintf("prepared statement not found, ID: %d", stmtID) + return "prepared statement not found, ID: " + strconv.FormatUint(uint64(stmtID), 10) } diff --git a/server/conn_test.go b/server/conn_test.go index d139fbd3aff32..ba61a43acd49f 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -18,13 +18,21 @@ import ( "bytes" "context" "encoding/binary" + "fmt" + "time" . "github.com/pingcap/check" + "github.com/pingcap/failpoint" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/util/arena" + "github.com/pingcap/tidb/util/chunk" ) type ConnTestSuite struct { @@ -207,3 +215,111 @@ func mapBelong(m1, m2 map[string]string) bool { } return true } + +func (ts ConnTestSuite) TestConnExecutionTimeout(c *C) { + //There is no underlying netCon, use failpoint to avoid panic + c.Assert(failpoint.Enable("github.com/pingcap/tidb/server/FakeClientConn", "return(1)"), IsNil) + + c.Parallel() + var err error + ts.store, err = mockstore.NewMockTikvStore() + c.Assert(err, IsNil) + ts.dom, err = session.BootstrapSession(ts.store) + c.Assert(err, IsNil) + se, err := session.CreateSession4Test(ts.store) + c.Assert(err, IsNil) + + connID := 1 + se.SetConnectionID(uint64(connID)) + tc := &TiDBContext{ + session: se, + stmts: make(map[int]*TiDBStatement), + } + cc := &clientConn{ + connectionID: uint32(connID), + server: &Server{ + capability: defaultCapability, + }, + ctx: tc, + alloc: arena.NewAllocator(32 * 1024), + } + srv := &Server{ + clients: map[uint32]*clientConn{ + uint32(connID): cc, + }, + } + handle := ts.dom.ExpensiveQueryHandle().SetSessionManager(srv) + go handle.Run() + defer handle.Close() + + _, err = se.Execute(context.Background(), "use test;") + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "CREATE TABLE testTable2 (id bigint PRIMARY KEY, age int)") + c.Assert(err, IsNil) + for i := 0; i < 10; i++ { + str := fmt.Sprintf("insert into testTable2 values(%d, %d)", i, i%80) + _, err = se.Execute(context.Background(), str) + c.Assert(err, IsNil) + } + + _, err = se.Execute(context.Background(), "select SLEEP(1);") + c.Assert(err, IsNil) + + _, err = se.Execute(context.Background(), "set @@max_execution_time = 500;") + c.Assert(err, IsNil) + + now := time.Now() + err = cc.handleQuery(context.Background(), "select * FROM testTable2 WHERE SLEEP(3);") + c.Assert(err, IsNil) + c.Assert(time.Since(now) < 3*time.Second, IsTrue) + + _, err = se.Execute(context.Background(), "set @@max_execution_time = 0;") + c.Assert(err, IsNil) + + now = time.Now() + err = cc.handleQuery(context.Background(), "select * FROM testTable2 WHERE SLEEP(1);") + c.Assert(err, IsNil) + c.Assert(time.Since(now) > 500*time.Millisecond, IsTrue) + + now = time.Now() + err = cc.handleQuery(context.Background(), "select /*+ MAX_EXECUTION_TIME(100)*/ * FROM testTable2 WHERE SLEEP(3);") + c.Assert(err, IsNil) + c.Assert(time.Since(now) < 3*time.Second, IsTrue) + + c.Assert(failpoint.Disable("github.com/pingcap/tidb/server/FakeClientConn"), IsNil) +} + +type mockTiDBCtx struct { + TiDBContext + rs []ResultSet + err error +} + +func (c *mockTiDBCtx) Execute(ctx context.Context, sql string) ([]ResultSet, error) { + return c.rs, c.err +} + +func (c *mockTiDBCtx) GetSessionVars() *variable.SessionVars { + return &variable.SessionVars{} +} + +type mockRecordSet struct{} + +func (m mockRecordSet) Fields() []*ast.ResultField { return nil } +func (m mockRecordSet) Next(ctx context.Context, req *chunk.Chunk) error { return nil } +func (m mockRecordSet) NewChunk() *chunk.Chunk { return nil } +func (m mockRecordSet) Close() error { return nil } + +func (ts *ConnTestSuite) TestShutDown(c *C) { + cc := &clientConn{} + + rs := &tidbResultSet{recordSet: mockRecordSet{}} + // mock delay response + cc.ctx = &mockTiDBCtx{rs: []ResultSet{rs}, err: nil} + // set killed flag + cc.status = connStatusShutdown + // assert ErrQueryInterrupted + err := cc.handleQuery(context.Background(), "dummy") + c.Assert(err, Equals, executor.ErrQueryInterrupted) + c.Assert(rs.closed, Equals, int32(1)) +} diff --git a/server/driver.go b/server/driver.go index d9212855201f6..362ce1d643a95 100644 --- a/server/driver.go +++ b/server/driver.go @@ -51,7 +51,7 @@ type QueryCtx interface { // SetValue saves a value associated with this context for key. SetValue(key fmt.Stringer, value interface{}) - SetProcessInfo(sql string, t time.Time, command byte) + SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) // CommitTxn commits the transaction operations. CommitTxn(ctx context.Context) error @@ -87,7 +87,7 @@ type QueryCtx interface { Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool // ShowProcess shows the information about the session. - ShowProcess() util.ProcessInfo + ShowProcess() *util.ProcessInfo // GetSessionVars return SessionVars. GetSessionVars() *variable.SessionVars @@ -136,9 +136,16 @@ type PreparedStatement interface { // ResultSet is the result set of an query. type ResultSet interface { Columns() []*ColumnInfo - NewRecordBatch() *chunk.RecordBatch - Next(context.Context, *chunk.RecordBatch) error + NewChunk() *chunk.Chunk + Next(context.Context, *chunk.Chunk) error StoreFetchedRows(rows []chunk.Row) GetFetchedRows() []chunk.Row Close() error } + +// fetchNotifier represents notifier will be called in COM_FETCH. +type fetchNotifier interface { + // OnFetchReturned be called when COM_FETCH returns. + // it will be used in server-side cursor. + OnFetchReturned() +} diff --git a/server/driver_tidb.go b/server/driver_tidb.go index abf08eae00939..e1689eafef544 100644 --- a/server/driver_tidb.go +++ b/server/driver_tidb.go @@ -212,8 +212,8 @@ func (tc *TiDBContext) CommitTxn(ctx context.Context) error { } // SetProcessInfo implements QueryCtx SetProcessInfo method. -func (tc *TiDBContext) SetProcessInfo(sql string, t time.Time, command byte) { - tc.session.SetProcessInfo(sql, t, command) +func (tc *TiDBContext) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) { + tc.session.SetProcessInfo(sql, t, command, maxExecutionTime) } // RollbackTxn implements QueryCtx RollbackTxn method. @@ -336,7 +336,7 @@ func (tc *TiDBContext) Prepare(sql string) (statement PreparedStatement, columns } // ShowProcess implements QueryCtx ShowProcess method. -func (tc *TiDBContext) ShowProcess() util.ProcessInfo { +func (tc *TiDBContext) ShowProcess() *util.ProcessInfo { return tc.session.ShowProcess() } @@ -357,11 +357,11 @@ type tidbResultSet struct { closed int32 } -func (trs *tidbResultSet) NewRecordBatch() *chunk.RecordBatch { - return trs.recordSet.NewRecordBatch() +func (trs *tidbResultSet) NewChunk() *chunk.Chunk { + return trs.recordSet.NewChunk() } -func (trs *tidbResultSet) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (trs *tidbResultSet) Next(ctx context.Context, req *chunk.Chunk) error { return trs.recordSet.Next(ctx, req) } @@ -383,6 +383,13 @@ func (trs *tidbResultSet) Close() error { return trs.recordSet.Close() } +// OnFetchReturned implements fetchNotifier#OnFetchReturned +func (trs *tidbResultSet) OnFetchReturned() { + if cl, ok := trs.recordSet.(fetchNotifier); ok { + cl.OnFetchReturned() + } +} + func (trs *tidbResultSet) Columns() []*ColumnInfo { if trs.columns == nil { fields := trs.recordSet.Fields() diff --git a/server/http_handler.go b/server/http_handler.go index 3066317006480..bc362104a5fb6 100644 --- a/server/http_handler.go +++ b/server/http_handler.go @@ -23,7 +23,6 @@ import ( "math" "net/http" "net/url" - "sort" "strconv" "strings" "sync/atomic" @@ -133,17 +132,33 @@ func (s *Server) newTikvHandlerTool() *tikvHandlerTool { } type mvccKV struct { - Key string `json:"key"` - Value *kvrpcpb.MvccGetByKeyResponse `json:"value"` + Key string `json:"key"` + RegionID uint64 `json:"region_id"` + Value *kvrpcpb.MvccGetByKeyResponse `json:"value"` +} + +func (t *tikvHandlerTool) getRegionIDByKey(encodedKey []byte) (uint64, error) { + keyLocation, err := t.RegionCache.LocateKey(tikv.NewBackoffer(context.Background(), 500), encodedKey) + if err != nil { + return 0, err + } + return keyLocation.Region.GetID(), nil } func (t *tikvHandlerTool) getMvccByHandle(tableID, handle int64) (*mvccKV, error) { encodedKey := tablecodec.EncodeRowKeyWithHandle(tableID, handle) data, err := t.GetMvccByEncodedKey(encodedKey) - return &mvccKV{Key: strings.ToUpper(hex.EncodeToString(encodedKey)), Value: data}, err + if err != nil { + return nil, err + } + regionID, err := t.getRegionIDByKey(encodedKey) + if err != nil { + return nil, err + } + return &mvccKV{Key: strings.ToUpper(hex.EncodeToString(encodedKey)), Value: data, RegionID: regionID}, err } -func (t *tikvHandlerTool) getMvccByStartTs(startTS uint64, startKey, endKey []byte) (*kvrpcpb.MvccGetByStartTsResponse, error) { +func (t *tikvHandlerTool) getMvccByStartTs(startTS uint64, startKey, endKey []byte) (*mvccKV, error) { bo := tikv.NewBackoffer(context.Background(), 5000) for { curRegion, err := t.RegionCache.LocateKey(bo, startKey) @@ -198,7 +213,8 @@ func (t *tikvHandlerTool) getMvccByStartTs(startTS uint64, startKey, endKey []by key := data.GetKey() if len(key) > 0 { - return data, nil + resp := &kvrpcpb.MvccGetByKeyResponse{Info: data.Info, RegionError: data.RegionError, Error: data.Error} + return &mvccKV{Key: strings.ToUpper(hex.EncodeToString(key)), Value: resp, RegionID: curRegion.Region.GetID()}, nil } if len(endKey) > 0 && curRegion.Contains(endKey) { @@ -229,7 +245,14 @@ func (t *tikvHandlerTool) getMvccByIdxValue(idx table.Index, values url.Values, return nil, errors.Trace(err) } data, err := t.GetMvccByEncodedKey(encodedKey) - return &mvccKV{strings.ToUpper(hex.EncodeToString(encodedKey)), data}, err + if err != nil { + return nil, err + } + regionID, err := t.getRegionIDByKey(encodedKey) + if err != nil { + return nil, err + } + return &mvccKV{strings.ToUpper(hex.EncodeToString(encodedKey)), regionID, data}, err } // formValue2DatumRow converts URL query string to a Datum Row. @@ -288,12 +311,20 @@ func (t *tikvHandlerTool) schema() (infoschema.InfoSchema, error) { return domain.GetDomain(session.(sessionctx.Context)).InfoSchema(), nil } -func (t *tikvHandlerTool) handleMvccGetByHex(params map[string]string) (interface{}, error) { +func (t *tikvHandlerTool) handleMvccGetByHex(params map[string]string) (*mvccKV, error) { encodedKey, err := hex.DecodeString(params[pHexKey]) if err != nil { return nil, errors.Trace(err) } - return t.GetMvccByEncodedKey(encodedKey) + data, err := t.GetMvccByEncodedKey(encodedKey) + if err != nil { + return nil, errors.Trace(err) + } + regionID, err := t.getRegionIDByKey(encodedKey) + if err != nil { + return nil, err + } + return &mvccKV{Key: strings.ToUpper(params[pHexKey]), Value: data, RegionID: regionID}, nil } // settingsHandler is the handler for list tidb server settings. @@ -483,14 +514,32 @@ type RegionDetail struct { func (rt *RegionDetail) addTableInRange(dbName string, curTable *model.TableInfo, r *helper.RegionFrameRange) { tName := curTable.Name.String() tID := curTable.ID - + pi := curTable.GetPartitionInfo() for _, index := range curTable.Indices { - if f := r.GetIndexFrame(tID, index.ID, dbName, tName, index.Name.String()); f != nil { - rt.Frames = append(rt.Frames, f) + if pi != nil { + for _, def := range pi.Definitions { + if f := r.GetIndexFrame(def.ID, index.ID, dbName, fmt.Sprintf("%s(%s)", tName, def.Name.O), index.Name.String()); f != nil { + rt.Frames = append(rt.Frames, f) + } + } + } else { + if f := r.GetIndexFrame(tID, index.ID, dbName, tName, index.Name.String()); f != nil { + rt.Frames = append(rt.Frames, f) + } } + } - if f := r.GetRecordFrame(tID, dbName, tName); f != nil { - rt.Frames = append(rt.Frames, f) + + if pi != nil { + for _, def := range pi.Definitions { + if f := r.GetRecordFrame(def.ID, dbName, fmt.Sprintf("%s(%s)", tName, def.Name.O)); f != nil { + rt.Frames = append(rt.Frames, f) + } + } + } else { + if f := r.GetRecordFrame(tID, dbName, tName); f != nil { + rt.Frames = append(rt.Frames, f) + } } } @@ -827,8 +876,8 @@ func (h tableHandler) addScatterSchedule(startKey, endKey []byte, name string) e } input := map[string]string{ "name": "scatter-range", - "start_key": string(startKey), - "end_key": string(endKey), + "start_key": url.QueryEscape(string(startKey)), + "end_key": url.QueryEscape(string(endKey)), "range_name": name, } v, err := json.Marshal(input) @@ -917,18 +966,43 @@ func (h tableHandler) handleStopScatterTableRequest(schema infoschema.InfoSchema } func (h tableHandler) handleRegionRequest(schema infoschema.InfoSchema, tbl table.Table, w http.ResponseWriter, req *http.Request) { - tableID := tbl.Meta().ID - // for record - startKey, endKey := tablecodec.GetTableHandleKeyRange(tableID) - recordRegionIDs, err := h.RegionCache.ListRegionIDsInKeyRange(tikv.NewBackoffer(context.Background(), 500), startKey, endKey) + pi := tbl.Meta().GetPartitionInfo() + if pi != nil { + // Partitioned table. + var data []*TableRegions + for _, def := range pi.Definitions { + tableRegions, err := h.getRegionsByID(tbl, def.ID, def.Name.O) + if err != nil { + writeError(w, err) + return + } + + data = append(data, tableRegions) + } + writeData(w, data) + return + } + + meta := tbl.Meta() + tableRegions, err := h.getRegionsByID(tbl, meta.ID, meta.Name.O) if err != nil { writeError(w, err) return } + + writeData(w, tableRegions) +} + +func (h tableHandler) getRegionsByID(tbl table.Table, id int64, name string) (*TableRegions, error) { + // for record + startKey, endKey := tablecodec.GetTableHandleKeyRange(id) + recordRegionIDs, err := h.RegionCache.ListRegionIDsInKeyRange(tikv.NewBackoffer(context.Background(), 500), startKey, endKey) + if err != nil { + return nil, err + } recordRegions, err := h.getRegionsMeta(recordRegionIDs) if err != nil { - writeError(w, err) - return + return nil, err } // for indices @@ -937,27 +1011,23 @@ func (h tableHandler) handleRegionRequest(schema infoschema.InfoSchema, tbl tabl indexID := index.Meta().ID indices[i].Name = index.Meta().Name.String() indices[i].ID = indexID - startKey, endKey := tablecodec.GetTableIndexKeyRange(tableID, indexID) + startKey, endKey := tablecodec.GetTableIndexKeyRange(id, indexID) rIDs, err := h.RegionCache.ListRegionIDsInKeyRange(tikv.NewBackoffer(context.Background(), 500), startKey, endKey) if err != nil { - writeError(w, err) - return + return nil, err } indices[i].Regions, err = h.getRegionsMeta(rIDs) if err != nil { - writeError(w, err) - return + return nil, err } } - tableRegions := &TableRegions{ - TableName: tbl.Meta().Name.O, - TableID: tableID, + return &TableRegions{ + TableName: name, + TableID: id, Indices: indices, RecordRegions: recordRegions, - } - - writeData(w, tableRegions) + }, nil } // pdRegionStats is the json response from PD. @@ -1068,17 +1138,9 @@ func (h regionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { writeError(w, err) return } - asSortedEntry := func(metric map[helper.TblIndex]helper.RegionMetric) hotRegions { - hs := make(hotRegions, 0, len(metric)) - for key, value := range metric { - hs = append(hs, hotRegion{key, value}) - } - sort.Sort(hs) - return hs - } writeData(w, map[string]interface{}{ - "write": asSortedEntry(hotWrite), - "read": asSortedEntry(hotRead), + "write": hotWrite, + "read": hotRead, }) return } @@ -1167,7 +1229,7 @@ func NewFrameItemFromRegionKey(key []byte) (frame *FrameItem, err error) { } // bigger than tablePrefix, means is bigger than all tables. frame.TableID = math.MaxInt64 - frame.TableID = math.MaxInt64 + frame.IndexID = math.MaxInt64 frame.IsRecord = true return } @@ -1472,8 +1534,8 @@ func (h dbTableHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { dbTblInfo.TableInfo = tbl.Meta() dbInfo, ok := schema.SchemaByTable(dbTblInfo.TableInfo) if !ok { - log.Warnf("can not find the database of table id: %v, table name: %v", dbTblInfo.TableInfo.ID, dbTblInfo.TableInfo.Name) - writeData(w, dbTblInfo) + logutil.Logger(context.Background()).Error("can not find the database of the table", zap.Int64("table id", dbTblInfo.TableInfo.ID), zap.String("table name", dbTblInfo.TableInfo.Name.L)) + writeError(w, infoschema.ErrTableNotExists.GenWithStack("Table which ID = %s does not exist.", tableID)) return } dbTblInfo.DBInfo = dbInfo diff --git a/server/http_handler_test.go b/server/http_handler_test.go index 4fb1b8db5cabf..060744990baa5 100644 --- a/server/http_handler_test.go +++ b/server/http_handler_test.go @@ -204,13 +204,28 @@ func regionContainsTable(c *C, regionID uint64, tableID int64) bool { return false } -func (ts *HTTPHandlerTestSuite) TestListTableRegionsWithError(c *C) { +func (ts *HTTPHandlerTestSuite) TestListTableRegions(c *C) { ts.startServer(c) defer ts.stopServer(c) + ts.prepareData(c) + // Test list table regions with error resp, err := http.Get("http://127.0.0.1:10090/tables/fdsfds/aaa/regions") c.Assert(err, IsNil) defer resp.Body.Close() c.Assert(resp.StatusCode, Equals, http.StatusBadRequest) + + resp, err = http.Get("http://127.0.0.1:10090/tables/tidb/pt/regions") + c.Assert(err, IsNil) + defer resp.Body.Close() + + var data []*TableRegions + dec := json.NewDecoder(resp.Body) + err = dec.Decode(&data) + c.Assert(err, IsNil) + + region := data[1] + resp, err = http.Get(fmt.Sprintf("http://127.0.0.1:10090/regions/%d", region.TableID)) + c.Assert(err, IsNil) } func (ts *HTTPHandlerTestSuite) TestGetRegionByIDWithError(c *C) { @@ -305,6 +320,11 @@ func (ts *HTTPHandlerTestSuite) prepareData(c *C) { c.Assert(err, IsNil) dbt.mustExec("alter table tidb.test add index idx1 (a, b);") dbt.mustExec("alter table tidb.test add unique index idx2 (a, b);") + + dbt.mustExec(`create table tidb.pt (a int) partition by range (a) +(partition p0 values less than (256), + partition p1 values less than (512), + partition p2 values less than (1024))`) } func decodeKeyMvcc(closer io.ReadCloser, c *C, valid bool) { @@ -566,7 +586,7 @@ func (ts *HTTPHandlerTestSuite) TestGetSchema(c *C) { decoder = json.NewDecoder(resp.Body) err = decoder.Decode(<) c.Assert(err, IsNil) - c.Assert(lt[0].Name.L, Equals, "test") + c.Assert(len(lt), Greater, 0) _, err = http.Get(fmt.Sprintf("http://127.0.0.1:10090/schema/abc")) c.Assert(err, IsNil) diff --git a/server/http_status.go b/server/http_status.go index c7ca853b252de..2dff937958b5e 100644 --- a/server/http_status.go +++ b/server/http_status.go @@ -49,6 +49,22 @@ func (s *Server) startStatusHTTP() { go s.startHTTPServer() } +func serveError(w http.ResponseWriter, status int, txt string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") + w.Header().Del("Content-Disposition") + w.WriteHeader(status) + _, err := fmt.Fprintln(w, txt) + terror.Log(err) +} + +func sleepWithCtx(ctx context.Context, d time.Duration) { + select { + case <-time.After(d): + case <-ctx.Done(): + } +} + func (s *Server) startHTTPServer() { router := mux.NewRouter() @@ -123,26 +139,6 @@ func (s *Server) startHTTPServer() { serverMux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) serverMux.HandleFunc("/debug/pprof/trace", pprof.Trace) - serveError := func(w http.ResponseWriter, status int, txt string) { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Go-Pprof", "1") - w.Header().Del("Content-Disposition") - w.WriteHeader(status) - _, err := fmt.Fprintln(w, txt) - terror.Log(err) - } - - sleep := func(w http.ResponseWriter, d time.Duration) { - var clientGone <-chan bool - if cn, ok := w.(http.CloseNotifier); ok { - clientGone = cn.CloseNotify() - } - select { - case <-time.After(d): - case <-clientGone: - } - } - serverMux.HandleFunc("/debug/zip", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="tidb_debug"`+time.Now().Format("20060102150405")+".zip")) @@ -191,7 +187,7 @@ func (s *Server) startHTTPServer() { if sec <= 0 || err != nil { sec = 10 } - sleep(w, time.Duration(sec)*time.Second) + sleepWithCtx(r.Context(), time.Duration(sec)*time.Second) rpprof.StopCPUProfile() // dump config @@ -220,11 +216,13 @@ func (s *Server) startHTTPServer() { err = zw.Close() terror.Log(err) }) + fetcher := sqlInfoFetcher{store: tikvHandlerTool.Store} + serverMux.HandleFunc("/debug/sub-optimal-plan", fetcher.zipInfoForSQL) var ( - err error httpRouterPage bytes.Buffer pathTemplate string + err error ) httpRouterPage.WriteString("TiDB Status and Metrics Report

TiDB Status and Metrics Report

") err = router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { diff --git a/server/server.go b/server/server.go index 58705781fef14..2a3bdc8019aa4 100644 --- a/server/server.go +++ b/server/server.go @@ -49,7 +49,6 @@ import ( "github.com/blacktear23/go-proxyprotocol" "github.com/pingcap/errors" - "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/config" @@ -84,12 +83,13 @@ func init() { } var ( - errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type") - errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length") - errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence") - errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type") - errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version") - errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied]) + errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type") + errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length") + errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence") + errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type") + errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version") + errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied]) + errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded]) ) // DefaultCapability is the capability of the server when it is created using the default configuration. @@ -107,7 +107,7 @@ type Server struct { driver IDriver listener net.Listener socket net.Listener - rwlock *sync.RWMutex + rwlock sync.RWMutex concurrentLimiter *TokenLimiter clients map[uint32]*clientConn capability uint32 @@ -199,7 +199,6 @@ func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { cfg: cfg, driver: driver, concurrentLimiter: NewTokenLimiter(cfg.TokenLimit), - rwlock: &sync.RWMutex{}, clients: make(map[uint32]*clientConn), stopListenerCh: make(chan struct{}, 1), } @@ -342,7 +341,7 @@ func (s *Server) Run() error { terror.Log(clientConn.Close()) return errors.Trace(err) } - err = authPlugin.OnConnectionEvent(context.Background(), &auth.UserIdentity{Hostname: host}, plugin.PreAuth, nil) + err = authPlugin.OnConnectionEvent(context.Background(), plugin.PreAuth, &variable.ConnectionInfo{Host: host}) if err != nil { logutil.Logger(context.Background()).Info("do connection event failed", zap.Error(err)) terror.Log(clientConn.Close()) @@ -422,11 +421,14 @@ func (s *Server) onConn(conn *clientConn) { s.rwlock.Unlock() metrics.ConnGauge.Set(float64(connections)) + if plugin.IsEnable(plugin.Audit) { + conn.ctx.GetSessionVars().ConnectionInfo = conn.connectInfo() + } err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { - connInfo := conn.connectInfo() - return authPlugin.OnConnectionEvent(context.Background(), conn.ctx.GetSessionVars().User, plugin.Connected, connInfo) + sessionVars := conn.ctx.GetSessionVars() + return authPlugin.OnConnectionEvent(context.Background(), plugin.Connected, sessionVars.ConnectionInfo) } return nil }) @@ -440,9 +442,9 @@ func (s *Server) onConn(conn *clientConn) { err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { - connInfo := conn.connectInfo() - connInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond) - err := authPlugin.OnConnectionEvent(context.Background(), conn.ctx.GetSessionVars().User, plugin.Disconnect, connInfo) + sessionVars := conn.ctx.GetSessionVars() + sessionVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond) + err := authPlugin.OnConnectionEvent(context.Background(), plugin.Disconnect, sessionVars.ConnectionInfo) if err != nil { logutil.Logger(context.Background()).Warn("do connection event failed", zap.String("plugin", authPlugin.Name), zap.Error(err)) } @@ -483,27 +485,28 @@ func (cc *clientConn) connectInfo() *variable.ConnectionInfo { } // ShowProcessList implements the SessionManager interface. -func (s *Server) ShowProcessList() map[uint64]util.ProcessInfo { +func (s *Server) ShowProcessList() map[uint64]*util.ProcessInfo { s.rwlock.RLock() - rs := make(map[uint64]util.ProcessInfo, len(s.clients)) + rs := make(map[uint64]*util.ProcessInfo, len(s.clients)) for _, client := range s.clients { if atomic.LoadInt32(&client.status) == connStatusWaitShutdown { continue } - pi := client.ctx.ShowProcess() - rs[pi.ID] = pi + if pi := client.ctx.ShowProcess(); pi != nil { + rs[pi.ID] = pi + } } s.rwlock.RUnlock() return rs } // GetProcessInfo implements the SessionManager interface. -func (s *Server) GetProcessInfo(id uint64) (util.ProcessInfo, bool) { +func (s *Server) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { s.rwlock.RLock() conn, ok := s.clients[uint32(id)] s.rwlock.RUnlock() if !ok || atomic.LoadInt32(&conn.status) == connStatusWaitShutdown { - return util.ProcessInfo{}, false + return &util.ProcessInfo{}, false } return conn.ctx.ShowProcess(), ok } @@ -529,19 +532,8 @@ func (s *Server) Kill(connectionID uint64, query bool) { } func killConn(conn *clientConn) { - conn.mu.RLock() - resultSets := conn.mu.resultSets - cancelFunc := conn.mu.cancelFunc - conn.mu.RUnlock() - for _, resultSet := range resultSets { - // resultSet.Close() is reentrant so it's safe to kill a same connID multiple times - if err := resultSet.Close(); err != nil { - logutil.Logger(context.Background()).Error("close result set error", zap.Uint32("connID", conn.connectionID), zap.Error(err)) - } - } - if cancelFunc != nil { - cancelFunc() - } + sessVars := conn.ctx.GetSessionVars() + atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1) } // KillAllConnections kills all connections when server is not gracefully shutdown. @@ -628,14 +620,16 @@ const ( codeInvalidSequence = 3 codeInvalidType = 4 - codeNotAllowedCommand = 1148 - codeAccessDenied = mysql.ErrAccessDenied + codeNotAllowedCommand = 1148 + codeAccessDenied = mysql.ErrAccessDenied + codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded ) func init() { serverMySQLErrCodes := map[terror.ErrCode]uint16{ - codeNotAllowedCommand: mysql.ErrNotAllowedCommand, - codeAccessDenied: mysql.ErrAccessDenied, + codeNotAllowedCommand: mysql.ErrNotAllowedCommand, + codeAccessDenied: mysql.ErrAccessDenied, + codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded, } terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes } diff --git a/server/server_test.go b/server/server_test.go index f23d2b0c793ed..c7ef21ae1a3b9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -427,7 +427,7 @@ func runTestLoadData(c *C, server *Server) { dbt.Assert(err, IsNil) lastID, err = rs.LastInsertId() dbt.Assert(err, IsNil) - dbt.Assert(lastID, Equals, int64(7)) + dbt.Assert(lastID, Equals, int64(6)) affectedRows, err = rs.RowsAffected() dbt.Assert(err, IsNil) dbt.Assert(affectedRows, Equals, int64(4)) @@ -466,7 +466,7 @@ func runTestLoadData(c *C, server *Server) { dbt.Assert(err, IsNil) lastID, err = rs.LastInsertId() dbt.Assert(err, IsNil) - dbt.Assert(lastID, Equals, int64(11)) + dbt.Assert(lastID, Equals, int64(10)) affectedRows, err = rs.RowsAffected() dbt.Assert(err, IsNil) dbt.Assert(affectedRows, Equals, int64(799)) diff --git a/server/sql_info_fetcher.go b/server/sql_info_fetcher.go new file mode 100644 index 0000000000000..8f07a340c13d6 --- /dev/null +++ b/server/sql_info_fetcher.go @@ -0,0 +1,327 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "archive/zip" + "context" + "encoding/json" + "fmt" + "net/http" + "runtime/pprof" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/parser" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/model" + "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/statistics/handle" + "github.com/pingcap/tidb/store/tikv" + "github.com/pingcap/tidb/util/sqlexec" + "github.com/pingcap/tidb/util/testkit" +) + +type sqlInfoFetcher struct { + store tikv.Storage + do *domain.Domain + s session.Session +} + +type tableNamePair struct { + DBName string + TableName string +} + +type tableNameExtractor struct { + curDB string + names map[tableNamePair]struct{} +} + +func (tne *tableNameExtractor) Enter(in ast.Node) (ast.Node, bool) { + if _, ok := in.(*ast.TableName); ok { + return in, true + } + return in, false +} + +func (tne *tableNameExtractor) Leave(in ast.Node) (ast.Node, bool) { + if t, ok := in.(*ast.TableName); ok { + tp := tableNamePair{DBName: t.Schema.L, TableName: t.Name.L} + if tp.DBName == "" { + tp.DBName = tne.curDB + } + if _, ok := tne.names[tp]; !ok { + tne.names[tp] = struct{}{} + } + } + return in, true +} + +func (sh *sqlInfoFetcher) zipInfoForSQL(w http.ResponseWriter, r *http.Request) { + var err error + sh.s, err = session.CreateSession(sh.store) + if err != nil { + serveError(w, http.StatusInternalServerError, fmt.Sprintf("create session failed, err: %v", err)) + return + } + defer sh.s.Close() + sh.do = domain.GetDomain(sh.s) + reqCtx := r.Context() + sql := r.FormValue("sql") + pprofTimeString := r.FormValue("pprof_time") + timeoutString := r.FormValue("timeout") + curDB := strings.ToLower(r.FormValue("current_db")) + if curDB != "" { + _, err = sh.s.Execute(reqCtx, "use %v"+curDB) + if err != nil { + serveError(w, http.StatusInternalServerError, fmt.Sprintf("use database %v failed, err: %v", curDB, err)) + return + } + } + var ( + pprofTime int + timeout int + ) + if pprofTimeString != "" { + pprofTime, err = strconv.Atoi(pprofTimeString) + if err != nil { + serveError(w, http.StatusBadRequest, "invalid value for pprof_time, please input a int value larger than 5") + return + } + } + if pprofTimeString != "" && pprofTime < 5 { + serveError(w, http.StatusBadRequest, "pprof time is too short, please input a int value larger than 5") + } + if timeoutString != "" { + timeout, err = strconv.Atoi(timeoutString) + if err != nil { + serveError(w, http.StatusBadRequest, "invalid value for timeout") + return + } + } + if timeout < pprofTime { + timeout = pprofTime + } + pairs, err := sh.extractTableNames(sql, curDB) + if err != nil { + serveError(w, http.StatusBadRequest, fmt.Sprintf("invalid SQL text, err: %v", err)) + return + } + zw := zip.NewWriter(w) + defer func() { + terror.Log(zw.Close()) + }() + for pair := range pairs { + jsonTbl, err := sh.getStatsForTable(pair) + if err != nil { + err = sh.writeErrFile(zw, fmt.Sprintf("%v.%v.stats.err.txt", pair.DBName, pair.TableName), err) + terror.Log(err) + continue + } + statsFw, err := zw.Create(fmt.Sprintf("%v.%v.json", pair.DBName, pair.TableName)) + if err != nil { + terror.Log(err) + continue + } + data, err := json.Marshal(jsonTbl) + if err != nil { + err = sh.writeErrFile(zw, fmt.Sprintf("%v.%v.stats.err.txt", pair.DBName, pair.TableName), err) + terror.Log(err) + continue + } + _, err = statsFw.Write(data) + if err != nil { + err = sh.writeErrFile(zw, fmt.Sprintf("%v.%v.stats.err.txt", pair.DBName, pair.TableName), err) + terror.Log(err) + continue + } + } + for pair := range pairs { + err = sh.getShowCreateTable(pair, zw) + if err != nil { + err = sh.writeErrFile(zw, fmt.Sprintf("%v.%v.schema.err.txt", pair.DBName, pair.TableName), err) + terror.Log(err) + return + } + } + // If we don't catch profile. We just get a explain result. + if pprofTime == 0 { + recordSets, err := sh.s.(sqlexec.SQLExecutor).Execute(reqCtx, fmt.Sprintf("explain %s", sql)) + if len(recordSets) > 0 { + defer terror.Call(recordSets[0].Close) + } + if err != nil { + err = sh.writeErrFile(zw, "explain.err.txt", err) + terror.Log(err) + return + } + sRows, err := testkit.ResultSetToStringSlice(reqCtx, sh.s, recordSets[0]) + if err != nil { + err = sh.writeErrFile(zw, "explain.err.txt", err) + terror.Log(err) + return + } + fw, err := zw.Create("explain.txt") + if err != nil { + terror.Log(err) + return + } + for _, row := range sRows { + fmt.Fprintf(fw, "%s\n", strings.Join(row, "\t")) + } + } else { + // Otherwise we catch a profile and run `EXPLAIN ANALYZE` result. + ctx, cancelFunc := context.WithCancel(reqCtx) + timer := time.NewTimer(time.Second * time.Duration(timeout)) + resultChan := make(chan *explainAnalyzeResult) + go sh.getExplainAnalyze(ctx, sql, resultChan) + errChan := make(chan error) + go sh.catchCPUProfile(reqCtx, pprofTime, zw, errChan) + select { + case result := <-resultChan: + timer.Stop() + cancelFunc() + if result.err != nil { + err = sh.writeErrFile(zw, "explain_analyze.err.txt", result.err) + terror.Log(err) + return + } + if len(result.rows) == 0 { + break + } + fw, err := zw.Create("explain_analyze.txt") + if err != nil { + terror.Log(err) + break + } + for _, row := range result.rows { + fmt.Fprintf(fw, "%s\n", strings.Join(row, "\t")) + } + case <-timer.C: + cancelFunc() + } + err = <-errChan + if err != nil { + err = sh.writeErrFile(zw, "profile.err.txt", err) + terror.Log(err) + return + } + } +} + +func (sh *sqlInfoFetcher) writeErrFile(zw *zip.Writer, name string, err error) error { + fw, err1 := zw.Create(name) + if err1 != nil { + return err1 + } + fmt.Fprintf(fw, "error: %v", err) + return nil +} + +type explainAnalyzeResult struct { + rows [][]string + err error +} + +func (sh *sqlInfoFetcher) getExplainAnalyze(ctx context.Context, sql string, resultChan chan<- *explainAnalyzeResult) { + recordSets, err := sh.s.(sqlexec.SQLExecutor).Execute(ctx, fmt.Sprintf("explain analyze %s", sql)) + if len(recordSets) > 0 { + defer terror.Call(recordSets[0].Close) + } + if err != nil { + resultChan <- &explainAnalyzeResult{err: err} + return + } + rows, err := testkit.ResultSetToStringSlice(ctx, sh.s, recordSets[0]) + if err != nil { + terror.Log(err) + rows = nil + return + } + resultChan <- &explainAnalyzeResult{rows: rows} +} + +func (sh *sqlInfoFetcher) catchCPUProfile(ctx context.Context, sec int, zw *zip.Writer, errChan chan<- error) { + // dump profile + fw, err := zw.Create("profile") + if err != nil { + errChan <- err + return + } + if err := pprof.StartCPUProfile(fw); err != nil { + errChan <- err + return + } + sleepWithCtx(ctx, time.Duration(sec)*time.Second) + pprof.StopCPUProfile() + errChan <- nil +} + +func (sh *sqlInfoFetcher) getStatsForTable(pair tableNamePair) (*handle.JSONTable, error) { + is := sh.do.InfoSchema() + h := sh.do.StatsHandle() + tbl, err := is.TableByName(model.NewCIStr(pair.DBName), model.NewCIStr(pair.TableName)) + if err != nil { + return nil, err + } + js, err := h.DumpStatsToJSON(pair.DBName, tbl.Meta(), nil) + return js, err +} + +func (sh *sqlInfoFetcher) getShowCreateTable(pair tableNamePair, zw *zip.Writer) error { + recordSets, err := sh.s.(sqlexec.SQLExecutor).Execute(context.TODO(), fmt.Sprintf("show create table `%v`.`%v`", pair.DBName, pair.TableName)) + if len(recordSets) > 0 { + defer terror.Call(recordSets[0].Close) + } + if err != nil { + return err + } + sRows, err := testkit.ResultSetToStringSlice(context.Background(), sh.s, recordSets[0]) + if err != nil { + terror.Log(err) + return nil + } + fw, err := zw.Create(fmt.Sprintf("%v.%v.schema.txt", pair.DBName, pair.TableName)) + if err != nil { + terror.Log(err) + return nil + } + for _, row := range sRows { + fmt.Fprintf(fw, "%s\n", strings.Join(row, "\t")) + } + return nil +} + +func (sh *sqlInfoFetcher) extractTableNames(sql, curDB string) (map[tableNamePair]struct{}, error) { + p := parser.New() + charset, collation := sh.s.GetSessionVars().GetCharsetInfo() + stmts, _, err := p.Parse(sql, charset, collation) + if err != nil { + return nil, err + } + if len(stmts) > 1 { + return nil, errors.Errorf("Only 1 statement is allowed") + } + extractor := &tableNameExtractor{ + curDB: curDB, + names: make(map[tableNamePair]struct{}), + } + stmts[0].Accept(extractor) + return extractor.names, nil +} diff --git a/server/statistics_handler_test.go b/server/statistics_handler_test.go index 00c58581958b5..41fa67f6461c0 100644 --- a/server/statistics_handler_test.go +++ b/server/statistics_handler_test.go @@ -47,7 +47,7 @@ func (ds *testDumpStatsSuite) startServer(c *C) { var err error ds.store, err = mockstore.NewMockTikvStore(mockstore.WithMVCCStore(mvccStore)) c.Assert(err, IsNil) - session.SetStatsLease(0) + session.DisableStats4Test() ds.domain, err = session.BootstrapSession(ds.store) c.Assert(err, IsNil) ds.domain.SetStatsUpdating(true) diff --git a/server/tidb_test.go b/server/tidb_test.go index faafef9c164ae..fb22dcb77b9b9 100644 --- a/server/tidb_test.go +++ b/server/tidb_test.go @@ -53,7 +53,7 @@ func (ts *TidbTestSuite) SetUpSuite(c *C) { metrics.RegisterMetrics() var err error ts.store, err = mockstore.NewMockTikvStore() - session.SetStatsLease(0) + session.DisableStats4Test() c.Assert(err, IsNil) ts.domain, err = session.BootstrapSession(ts.store) c.Assert(err, IsNil) @@ -437,7 +437,7 @@ func (ts *TidbTestSuite) TestCreateTableFlen(c *C) { c.Assert(err, IsNil) rs, err := qctx.Execute(ctx, "show create table t1") c.Assert(err, IsNil) - req := rs[0].NewRecordBatch() + req := rs[0].NewChunk() err = rs[0].Next(ctx, req) c.Assert(err, IsNil) cols := rs[0].Columns() @@ -467,7 +467,7 @@ func (ts *TidbTestSuite) TestShowTablesFlen(c *C) { c.Assert(err, IsNil) rs, err := qctx.Execute(ctx, "show tables") c.Assert(err, IsNil) - req := rs[0].NewRecordBatch() + req := rs[0].NewChunk() err = rs[0].Next(ctx, req) c.Assert(err, IsNil) cols := rs[0].Columns() diff --git a/session/bench_test.go b/session/bench_test.go index cd48f881dbcd8..8590a79ac36d3 100644 --- a/session/bench_test.go +++ b/session/bench_test.go @@ -88,7 +88,7 @@ func prepareJoinBenchData(se Session, colType string, valueFormat string, valueC } func readResult(ctx context.Context, rs sqlexec.RecordSet, count int) { - req := rs.NewRecordBatch() + req := rs.NewChunk() for count > 0 { err := rs.Next(ctx, req) if err != nil { diff --git a/session/bootstrap.go b/session/bootstrap.go index b6e8917994542..5f6c735e0b813 100644 --- a/session/bootstrap.go +++ b/session/bootstrap.go @@ -259,10 +259,21 @@ const ( count bigint(64) UNSIGNED NOT NULL, index tbl(table_id, is_index, hist_id) );` + + // CreateExprPushdownBlacklist stores the expressions which are not allowed to be pushed down. + CreateExprPushdownBlacklist = `CREATE TABLE IF NOT EXISTS mysql.expr_pushdown_blacklist ( + name char(100) NOT NULL + );` + + // CreateOptRuleBlacklist stores the list of disabled optimizing operations. + CreateOptRuleBlacklist = `CREATE TABLE IF NOT EXISTS mysql.opt_rule_blacklist ( + name char(100) NOT NULL + );` ) // bootstrap initiates system DB for a store. func bootstrap(s Session) { + startTime := time.Now() dom := domain.GetDomain(s) for { b, err := checkBootstrapped(s) @@ -273,6 +284,8 @@ func bootstrap(s Session) { // For rolling upgrade, we can't do upgrade only in the owner. if b { upgrade(s) + logutil.Logger(context.Background()).Info("upgrade successful in bootstrap", + zap.Duration("take time", time.Since(startTime))) return } // To reduce conflict when multiple TiDB-server start at the same time. @@ -280,6 +293,8 @@ func bootstrap(s Session) { if dom.DDL().OwnerManager().IsOwner() { doDDLWorks(s) doDMLWorks(s) + logutil.Logger(context.Background()).Info("bootstrap successful", + zap.Duration("take time", time.Since(startTime))) return } time.Sleep(200 * time.Millisecond) @@ -331,6 +346,9 @@ const ( version30 = 30 version31 = 31 version32 = 32 + version33 = 33 + version34 = 34 + version35 = 35 ) func checkBootstrapped(s Session) (bool, error) { @@ -373,7 +391,7 @@ func getTiDBVar(s Session, name string) (sVal string, isNull bool, e error) { } r := rs[0] defer terror.Call(r.Close) - req := r.NewRecordBatch() + req := r.NewChunk() err = r.Next(ctx, req) if err != nil || req.NumRows() == 0 { return "", true, errors.Trace(err) @@ -518,22 +536,36 @@ func upgrade(s Session) { upgradeToVer29(s) } + if ver < version33 { + upgradeToVer33(s) + } + + if ver < version34 { + upgradeToVer34(s) + } + + if ver < version35 { + upgradeToVer35(s) + } + updateBootstrapVer(s) _, err = s.Execute(context.Background(), "COMMIT") if err != nil { - time.Sleep(1 * time.Second) + sleepTime := 1 * time.Second + logutil.Logger(context.Background()).Info("update bootstrap ver failed", + zap.Error(err), zap.Duration("sleeping time", sleepTime)) + time.Sleep(sleepTime) // Check if TiDB is already upgraded. v, err1 := getBootstrapVersion(s) if err1 != nil { - logutil.Logger(context.Background()).Fatal("upgrade error", - zap.Error(err1)) + logutil.Logger(context.Background()).Fatal("upgrade failed", zap.Error(err1)) } if v >= currentBootstrapVersion { // It is already bootstrapped/upgraded by a higher version TiDB server. return } - logutil.Logger(context.Background()).Fatal("[Upgrade] upgrade error", + logutil.Logger(context.Background()).Fatal("[Upgrade] upgrade failed", zap.Int64("from", ver), zap.Int("to", currentBootstrapVersion), zap.Error(err)) @@ -641,8 +673,8 @@ func upgradeToVer12(s Session) { r := rs[0] sqls := make([]string, 0, 1) defer terror.Call(r.Close) - req := r.NewRecordBatch() - it := chunk.NewIterator4Chunk(req.Chunk) + req := r.NewChunk() + it := chunk.NewIterator4Chunk(req) err = r.Next(ctx, req) for err == nil && req.NumRows() != 0 { for row := it.Begin(); row != it.End(); row = it.Next() { @@ -818,6 +850,20 @@ func upgradeToVer32(s Session) { doReentrantDDL(s, "ALTER TABLE mysql.tables_priv MODIFY table_priv SET('Select','Insert','Update','Delete','Create','Drop','Grant', 'Index', 'Alter', 'Create View', 'Show View', 'Trigger', 'References')") } +func upgradeToVer33(s Session) { + doReentrantDDL(s, CreateExprPushdownBlacklist) +} + +func upgradeToVer34(s Session) { + doReentrantDDL(s, CreateOptRuleBlacklist) +} + +func upgradeToVer35(s Session) { + sql := fmt.Sprintf("UPDATE HIGH_PRIORITY %s.%s SET VARIABLE_NAME = '%s' WHERE VARIABLE_NAME = 'tidb_back_off_weight'", + mysql.SystemDB, mysql.GlobalVariablesTable, variable.TiDBBackOffWeight) + mustExecute(s, sql) +} + // updateBootstrapVer updates bootstrap version variable in mysql.TiDB table. func updateBootstrapVer(s Session) { // Update bootstrap version. @@ -876,6 +922,10 @@ func doDDLWorks(s Session) { mustExecute(s, CreateBindInfoTable) // Create stats_topn_store table. mustExecute(s, CreateStatsTopNTable) + // Create expr_pushdown_blacklist table. + mustExecute(s, CreateExprPushdownBlacklist) + // Create opt_rule_blacklist table. + mustExecute(s, CreateOptRuleBlacklist) } // doDMLWorks executes DML statements in bootstrap stage. @@ -912,16 +962,18 @@ func doDMLWorks(s Session) { writeSystemTZ(s) _, err := s.Execute(context.Background(), "COMMIT") if err != nil { - time.Sleep(1 * time.Second) + sleepTime := 1 * time.Second + logutil.Logger(context.Background()).Info("doDMLWorks failed", zap.Error(err), zap.Duration("sleeping time", sleepTime)) + time.Sleep(sleepTime) // Check if TiDB is already bootstrapped. b, err1 := checkBootstrapped(s) if err1 != nil { - logutil.Logger(context.Background()).Fatal("doDMLWorks error", zap.Error(err1)) + logutil.Logger(context.Background()).Fatal("doDMLWorks failed", zap.Error(err1)) } if b { return } - logutil.Logger(context.Background()).Fatal("doDMLWorks error", zap.Error(err)) + logutil.Logger(context.Background()).Fatal("doDMLWorks failed", zap.Error(err)) } } diff --git a/session/bootstrap_test.go b/session/bootstrap_test.go index 8004b7923ef3e..2c2cef919bcfe 100644 --- a/session/bootstrap_test.go +++ b/session/bootstrap_test.go @@ -51,7 +51,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { r := mustExecSQL(c, se, `select * from user;`) c.Assert(r, NotNil) ctx := context.Background() - req := r.NewRecordBatch() + req := r.NewChunk() err := r.Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -67,7 +67,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { // Check privilege tables. r = mustExecSQL(c, se, "SELECT COUNT(*) from mysql.global_variables;") c.Assert(r, NotNil) - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.GetRow(0).GetInt64(0), Equals, globalVarsCount()) @@ -88,7 +88,7 @@ func (s *testBootstrapSuite) TestBootstrap(c *C) { r = mustExecSQL(c, se, "select * from t") c.Assert(r, NotNil) - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) datums = statistics.RowToDatums(req.GetRow(0), r.Fields()) @@ -154,7 +154,7 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { se := newSession(c, store, s.dbNameBootstrap) mustExecSQL(c, se, "USE mysql;") r := mustExecSQL(c, se, `select * from user;`) - req := r.NewRecordBatch() + req := r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -173,7 +173,7 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { mustExecSQL(c, se, "SELECT * from mysql.default_roles;") // Check global variables. r = mustExecSQL(c, se, "SELECT COUNT(*) from mysql.global_variables;") - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) v := req.GetRow(0) @@ -181,7 +181,7 @@ func (s *testBootstrapSuite) TestBootstrapWithError(c *C) { c.Assert(r.Close(), IsNil) r = mustExecSQL(c, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="bootstrapped";`) - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -202,7 +202,7 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { // bootstrap with currentBootstrapVersion r := mustExecSQL(c, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version";`) - req := r.NewRecordBatch() + req := r.NewChunk() err := r.Next(ctx, req) row := req.GetRow(0) c.Assert(err, IsNil) @@ -232,7 +232,7 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { delete(storeBootstrapped, store.UUID()) // Make sure the version is downgraded. r = mustExecSQL(c, se1, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version";`) - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsTrue) @@ -248,7 +248,7 @@ func (s *testBootstrapSuite) TestUpgrade(c *C) { defer dom1.Close() se2 := newSession(c, store, s.dbName) r = mustExecSQL(c, se2, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="tidb_server_version";`) - req = r.NewRecordBatch() + req = r.NewChunk() err = r.Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -296,3 +296,16 @@ func (s *testBootstrapSuite) TestOldPasswordUpgrade(c *C) { c.Assert(err, IsNil) c.Assert(newpwd, Equals, "*0D3CED9BEC10A777AEC23CCC353A8C08A633045E") } + +func (s *testBootstrapSuite) TestBootstrapInitExpensiveQueryHandle(c *C) { + defer testleak.AfterTest(c)() + store := newStore(c, s.dbName) + defer store.Close() + se, err := createSession(store) + c.Assert(err, IsNil) + dom := domain.GetDomain(se) + c.Assert(dom, NotNil) + defer dom.Close() + dom.InitExpensiveQueryHandle() + c.Assert(dom.ExpensiveQueryHandle(), NotNil) +} diff --git a/session/isolation_test.go b/session/isolation_test.go index 630daa5dbc177..4e9f57e220a69 100644 --- a/session/isolation_test.go +++ b/session/isolation_test.go @@ -47,7 +47,7 @@ func (s *testIsolationSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() s.dom, err = session.BootstrapSession(s.store) c.Assert(err, IsNil) diff --git a/session/pessimistic_test.go b/session/pessimistic_test.go index 4eafa4a73b27f..7ec471c148f11 100644 --- a/session/pessimistic_test.go +++ b/session/pessimistic_test.go @@ -15,15 +15,22 @@ package session_test import ( "fmt" + "sync" + "sync/atomic" "time" . "github.com/pingcap/check" - "github.com/pingcap/tidb/config" + "github.com/pingcap/errors" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockstore" "github.com/pingcap/tidb/store/mockstore/mocktikv" + "github.com/pingcap/tidb/store/tikv" + "github.com/pingcap/tidb/tablecodec" + "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" ) @@ -39,7 +46,8 @@ type testPessimisticSuite struct { func (s *testPessimisticSuite) SetUpSuite(c *C) { testleak.BeforeTest() - config.GetGlobalConfig().PessimisticTxn.Enable = true + // Set it to 300ms for testing lock resolve. + tikv.PessimisticLockTTL = 300 s.cluster = mocktikv.NewCluster() mocktikv.BootstrapWithSingleStore(s.cluster) s.mvccStore = mocktikv.MustNewMVCCStore() @@ -50,15 +58,15 @@ func (s *testPessimisticSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() s.dom, err = session.BootstrapSession(s.store) + s.dom.GetGlobalVarsCache().Disable() c.Assert(err, IsNil) } func (s *testPessimisticSuite) TearDownSuite(c *C) { s.dom.Close() s.store.Close() - config.GetGlobalConfig().PessimisticTxn.Enable = false testleak.AfterTest(c)() } @@ -79,6 +87,7 @@ func (s *testPessimisticSuite) TestPessimisticTxn(c *C) { // Update can see the change, so this statement affects 0 roews. tk1.MustExec("update pessimistic set v = 3 where v = 1") c.Assert(tk1.Se.AffectedRows(), Equals, uint64(0)) + c.Assert(session.GetHistory(tk1.Se).Count(), Equals, 0) // select for update can see the change of another transaction. tk1.MustQuery("select * from pessimistic for update").Check(testkit.Rows("1 2")) // plain select can not see the change of another transaction. @@ -112,30 +121,19 @@ func (s *testPessimisticSuite) TestTxnMode(c *C) { tests := []struct { beginStmt string txnMode string - configDefault bool isPessimistic bool }{ - {"pessimistic", "pessimistic", false, true}, - {"pessimistic", "pessimistic", true, true}, - {"pessimistic", "optimistic", false, true}, - {"pessimistic", "optimistic", true, true}, - {"pessimistic", "", false, true}, - {"pessimistic", "", true, true}, - {"optimistic", "pessimistic", false, false}, - {"optimistic", "pessimistic", true, false}, - {"optimistic", "optimistic", false, false}, - {"optimistic", "optimistic", true, false}, - {"optimistic", "", false, false}, - {"optimistic", "", true, false}, - {"", "pessimistic", false, true}, - {"", "pessimistic", true, true}, - {"", "optimistic", false, false}, - {"", "optimistic", true, false}, - {"", "", false, false}, - {"", "", true, true}, + {"pessimistic", "pessimistic", true}, + {"pessimistic", "optimistic", true}, + {"pessimistic", "", true}, + {"optimistic", "pessimistic", false}, + {"optimistic", "optimistic", false}, + {"optimistic", "", false}, + {"", "pessimistic", true}, + {"", "optimistic", false}, + {"", "", false}, } for _, tt := range tests { - config.GetGlobalConfig().PessimisticTxn.Default = tt.configDefault tk.MustExec(fmt.Sprintf("set @@tidb_txn_mode = '%s'", tt.txnMode)) tk.MustExec("begin " + tt.beginStmt) c.Check(tk.Se.GetSessionVars().TxnCtx.IsPessimistic, Equals, tt.isPessimistic) @@ -146,22 +144,279 @@ func (s *testPessimisticSuite) TestTxnMode(c *C) { tk.MustExec("create table if not exists txn_mode (a int)") tests2 := []struct { txnMode string - configDefault bool isPessimistic bool }{ - {"pessimistic", false, true}, - {"pessimistic", true, true}, - {"optimistic", false, false}, - {"optimistic", true, false}, - {"", false, false}, - {"", true, true}, + {"pessimistic", true}, + {"optimistic", false}, + {"", false}, } for _, tt := range tests2 { - config.GetGlobalConfig().PessimisticTxn.Default = tt.configDefault tk.MustExec(fmt.Sprintf("set @@tidb_txn_mode = '%s'", tt.txnMode)) tk.MustExec("rollback") tk.MustExec("insert txn_mode values (1)") c.Check(tk.Se.GetSessionVars().TxnCtx.IsPessimistic, Equals, tt.isPessimistic) tk.MustExec("rollback") } + tk.MustExec("set @@global.tidb_txn_mode = 'pessimistic'") + tk1 := testkit.NewTestKitWithInit(c, s.store) + tk1.MustQuery("select @@tidb_txn_mode").Check(testkit.Rows("pessimistic")) + tk1.MustExec("set @@autocommit = 0") + tk1.MustExec("insert txn_mode values (2)") + c.Check(tk1.Se.GetSessionVars().TxnCtx.IsPessimistic, IsTrue) + tk1.MustExec("set @@tidb_txn_mode = ''") + tk1.MustExec("rollback") + tk1.MustExec("insert txn_mode values (2)") + c.Check(tk1.Se.GetSessionVars().TxnCtx.IsPessimistic, IsFalse) + tk1.MustExec("rollback") +} + +func (s *testPessimisticSuite) TestDeadlock(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists deadlock") + tk.MustExec("create table deadlock (k int primary key, v int)") + tk.MustExec("insert into deadlock values (1, 1), (2, 1)") + + syncCh := make(chan struct{}) + go func() { + tk1 := testkit.NewTestKitWithInit(c, s.store) + tk1.MustExec("begin pessimistic") + tk1.MustExec("update deadlock set v = v + 1 where k = 2") + <-syncCh + tk1.MustExec("update deadlock set v = v + 1 where k = 1") + <-syncCh + }() + tk.MustExec("begin pessimistic") + tk.MustExec("update deadlock set v = v + 1 where k = 1") + syncCh <- struct{}{} + time.Sleep(time.Millisecond * 10) + _, err := tk.Exec("update deadlock set v = v + 1 where k = 2") + e, ok := errors.Cause(err).(*terror.Error) + c.Assert(ok, IsTrue) + c.Assert(int(e.Code()), Equals, mysql.ErrLockDeadlock) + syncCh <- struct{}{} +} + +func (s *testPessimisticSuite) TestSingleStatementRollback(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + + tk.MustExec("drop table if exists pessimistic") + tk.MustExec("create table single_statement (id int primary key, v int)") + tk.MustExec("insert into single_statement values (1, 1), (2, 1), (3, 1), (4, 1)") + tblID := tk.GetTableID("single_statement") + s.cluster.SplitTable(s.mvccStore, tblID, 2) + region1Key := codec.EncodeBytes(nil, tablecodec.EncodeRowKeyWithHandle(tblID, 1)) + region1, _ := s.cluster.GetRegionByKey(region1Key) + region1ID := region1.Id + region2Key := codec.EncodeBytes(nil, tablecodec.EncodeRowKeyWithHandle(tblID, 3)) + region2, _ := s.cluster.GetRegionByKey(region2Key) + region2ID := region2.Id + + syncCh := make(chan bool) + go func() { + tk2.MustExec("begin pessimistic") + <-syncCh + s.cluster.ScheduleDelay(tk2.Se.GetSessionVars().TxnCtx.StartTS, region2ID, time.Millisecond*3) + tk2.MustExec("update single_statement set v = v + 1") + tk2.MustExec("commit") + <-syncCh + }() + tk.MustExec("begin pessimistic") + syncCh <- true + s.cluster.ScheduleDelay(tk.Se.GetSessionVars().TxnCtx.StartTS, region1ID, time.Millisecond*3) + tk.MustExec("update single_statement set v = v + 1") + tk.MustExec("commit") + syncCh <- true +} + +func (s *testPessimisticSuite) TestFirstStatementFail(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists first") + tk.MustExec("create table first (k int unique)") + tk.MustExec("insert first values (1)") + tk.MustExec("begin pessimistic") + _, err := tk.Exec("insert first values (1)") + c.Assert(err, NotNil) + tk.MustExec("insert first values (2)") + tk.MustExec("commit") +} + +func (s *testPessimisticSuite) TestKeyExistsCheck(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists chk") + tk.MustExec("create table chk (k int primary key)") + tk.MustExec("insert chk values (1)") + tk.MustExec("delete from chk where k = 1") + tk.MustExec("begin pessimistic") + tk.MustExec("insert chk values (1)") + tk.MustExec("commit") + + tk1 := testkit.NewTestKitWithInit(c, s.store) + tk1.MustExec("begin optimistic") + tk1.MustExec("insert chk values (1), (2), (3)") + _, err := tk1.Exec("commit") + c.Assert(err, NotNil) + + tk.MustExec("begin pessimistic") + tk.MustExec("insert chk values (2)") + tk.MustExec("commit") +} + +func (s *testPessimisticSuite) TestInsertOnDup(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists dup") + tk.MustExec("create table dup (id int primary key, c int)") + tk.MustExec("begin pessimistic") + + tk2.MustExec("insert dup values (1, 1)") + tk.MustExec("insert dup values (1, 1) on duplicate key update c = c + 1") + tk.MustExec("commit") + tk.MustQuery("select * from dup").Check(testkit.Rows("1 2")) +} + +func (s *testPessimisticSuite) TestPointGetKeyLock(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists point") + tk.MustExec("create table point (id int primary key, u int unique, c int)") + syncCh := make(chan struct{}) + + tk.MustExec("begin pessimistic") + tk.MustExec("update point set c = c + 1 where id = 1") + tk.MustExec("delete from point where u = 2") + go func() { + tk2.MustExec("begin pessimistic") + _, err1 := tk2.Exec("insert point values (1, 1, 1)") + c.Check(kv.ErrKeyExists.Equal(err1), IsTrue) + _, err1 = tk2.Exec("insert point values (2, 2, 2)") + c.Check(kv.ErrKeyExists.Equal(err1), IsTrue) + tk2.MustExec("rollback") + <-syncCh + }() + time.Sleep(time.Millisecond * 10) + tk.MustExec("insert point values (1, 1, 1)") + tk.MustExec("insert point values (2, 2, 2)") + tk.MustExec("commit") + syncCh <- struct{}{} + + tk.MustExec("begin pessimistic") + tk.MustExec("select * from point where id = 3 for update") + tk.MustExec("select * from point where u = 4 for update") + go func() { + tk2.MustExec("begin pessimistic") + _, err1 := tk2.Exec("insert point values (3, 3, 3)") + c.Check(kv.ErrKeyExists.Equal(err1), IsTrue) + _, err1 = tk2.Exec("insert point values (4, 4, 4)") + c.Check(kv.ErrKeyExists.Equal(err1), IsTrue) + tk2.MustExec("rollback") + <-syncCh + }() + time.Sleep(time.Millisecond * 10) + tk.MustExec("insert point values (3, 3, 3)") + tk.MustExec("insert point values (4, 4, 4)") + tk.MustExec("commit") + syncCh <- struct{}{} +} + +func (s *testPessimisticSuite) TestBankTransfer(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists accounts") + tk.MustExec("create table accounts (id int primary key, c int)") + tk.MustExec("insert accounts values (1, 100), (2, 100), (3, 100)") + syncCh := make(chan struct{}) + + tk.MustExec("begin pessimistic") + tk.MustQuery("select * from accounts where id = 1 for update").Check(testkit.Rows("1 100")) + go func() { + tk2.MustExec("begin pessimistic") + tk2.MustExec("select * from accounts where id = 2 for update") + <-syncCh + tk2.MustExec("select * from accounts where id = 3 for update") + tk2.MustExec("update accounts set c = 50 where id = 2") + tk2.MustExec("update accounts set c = 150 where id = 3") + tk2.MustExec("commit") + <-syncCh + }() + syncCh <- struct{}{} + tk.MustQuery("select * from accounts where id = 2 for update").Check(testkit.Rows("2 50")) + tk.MustExec("update accounts set c = 50 where id = 1") + tk.MustExec("update accounts set c = 100 where id = 2") + tk.MustExec("commit") + syncCh <- struct{}{} + tk.MustQuery("select sum(c) from accounts").Check(testkit.Rows("300")) +} + +func (s *testPessimisticSuite) TestOptimisticConflicts(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists conflict") + tk.MustExec("create table conflict (id int primary key, c int)") + tk.MustExec("insert conflict values (1, 1)") + tk.MustExec("begin pessimistic") + tk.MustQuery("select * from conflict where id = 1 for update") + syncCh := make(chan struct{}) + go func() { + tk2.MustExec("update conflict set c = 3 where id = 1") + <-syncCh + }() + time.Sleep(time.Millisecond * 10) + tk.MustExec("update conflict set c = 2 where id = 1") + tk.MustExec("commit") + syncCh <- struct{}{} + tk.MustQuery("select c from conflict where id = 1").Check(testkit.Rows("3")) + + // Check pessimistic lock is not resolved. + tk.MustExec("begin pessimistic") + tk.MustExec("update conflict set c = 4 where id = 1") + tk2.MustExec("begin optimistic") + tk2.MustExec("update conflict set c = 5 where id = 1") + // TODO: ResolveLock block until timeout, takes about 40s, makes CI slow! + _, err := tk2.Exec("commit") + c.Check(err, NotNil) + + // Update snapshotTS after a conflict, invalidate snapshot cache. + tk.MustExec("truncate table conflict") + tk.MustExec("insert into conflict values (1, 2)") + tk.MustExec("begin pessimistic") + // This SQL use BatchGet and cache data in the txn snapshot. + // It can be changed to other SQLs that use BatchGet. + tk.MustExec("insert ignore into conflict values (1, 2)") + + tk2.MustExec("update conflict set c = c - 1") + + // Make the txn update its forUpdateTS. + tk.MustQuery("select * from conflict where id = 1 for update").Check(testkit.Rows("1 1")) + // Cover a bug that the txn snapshot doesn't invalidate cache after ts change. + tk.MustExec("insert into conflict values (1, 999) on duplicate key update c = c + 2") + tk.MustExec("commit") + tk.MustQuery("select * from conflict").Check(testkit.Rows("1 3")) +} + +func (s *testPessimisticSuite) TestWaitLockKill(c *C) { + // Test kill command works on waiting pessimistic lock. + tk := testkit.NewTestKitWithInit(c, s.store) + tk2 := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists test_kill") + tk.MustExec("create table test_kill (id int primary key, c int)") + tk.MustExec("insert test_kill values (1, 1)") + tk.MustExec("begin pessimistic") + tk2.MustExec("begin pessimistic") + tk.MustQuery("select * from test_kill where id = 1 for update") + + var wg sync.WaitGroup + wg.Add(1) + go func() { + time.Sleep(500 * time.Millisecond) + sessVars := tk2.Se.GetSessionVars() + succ := atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1) + c.Assert(succ, IsTrue) + wg.Wait() + }() + _, err := tk2.Exec("update test_kill set c = c + 1 where id = 1") + wg.Done() + c.Assert(err, NotNil) + c.Assert(terror.ErrorEqual(err, tikv.ErrQueryInterrupted), IsTrue) + tk.MustExec("rollback") } diff --git a/session/session.go b/session/session.go index e5879c5b9d70e..a68aa8ed4c131 100644 --- a/session/session.go +++ b/session/session.go @@ -80,13 +80,14 @@ var ( transactionDurationGeneralOK = metrics.TransactionDuration.WithLabelValues(metrics.LblGeneral, "ok") transactionDurationGeneralError = metrics.TransactionDuration.WithLabelValues(metrics.LblGeneral, "error") - transactionCounterInternalOK = metrics.TransactionCounter.WithLabelValues(metrics.LblInternal, metrics.LblOK) - transactionCounterInternalErr = metrics.TransactionCounter.WithLabelValues(metrics.LblInternal, metrics.LblError) - transactionCounterGeneralOK = metrics.TransactionCounter.WithLabelValues(metrics.LblGeneral, metrics.LblOK) - transactionCounterGeneralErr = metrics.TransactionCounter.WithLabelValues(metrics.LblGeneral, metrics.LblError) - - transactionRollbackCounterInternal = metrics.TransactionCounter.WithLabelValues(metrics.LblInternal, metrics.LblRollback) - transactionRollbackCounterGeneral = metrics.TransactionCounter.WithLabelValues(metrics.LblGeneral, metrics.LblRollback) + transactionCounterInternalOK = metrics.TransactionCounter.WithLabelValues(metrics.LblInternal, metrics.LblOK) + transactionCounterInternalErr = metrics.TransactionCounter.WithLabelValues(metrics.LblInternal, metrics.LblError) + transactionCounterGeneralOK = metrics.TransactionCounter.WithLabelValues(metrics.LblGeneral, metrics.LblOK) + transactionCounterGeneralErr = metrics.TransactionCounter.WithLabelValues(metrics.LblGeneral, metrics.LblError) + transactionCounterInternalCommitRollback = metrics.TransactionCounter.WithLabelValues(metrics.LblInternal, metrics.LblComRol) + transactionCounterGeneralCommitRollback = metrics.TransactionCounter.WithLabelValues(metrics.LblGeneral, metrics.LblComRol) + transactionRollbackCounterInternal = metrics.TransactionCounter.WithLabelValues(metrics.LblInternal, metrics.LblRollback) + transactionRollbackCounterGeneral = metrics.TransactionCounter.WithLabelValues(metrics.LblGeneral, metrics.LblRollback) sessionExecuteRunDurationInternal = metrics.SessionExecuteRunDuration.WithLabelValues(metrics.LblInternal) sessionExecuteRunDurationGeneral = metrics.SessionExecuteRunDuration.WithLabelValues(metrics.LblGeneral) @@ -97,7 +98,7 @@ var ( sessionExecuteParseDurationGeneral = metrics.SessionExecuteParseDuration.WithLabelValues(metrics.LblGeneral) ) -// Session context +// Session context, it is consistent with the lifecycle of a client connection. type Session interface { sessionctx.Context Status() uint16 // Flag of current status, such as autocommit. @@ -116,13 +117,13 @@ type Session interface { SetClientCapability(uint32) // Set client capability flags. SetConnectionID(uint64) SetCommandValue(byte) - SetProcessInfo(string, time.Time, byte) + SetProcessInfo(string, time.Time, byte, uint64) SetTLSState(*tls.ConnectionState) SetCollation(coID int) error SetSessionManager(util.SessionManager) Close() Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool - ShowProcess() util.ProcessInfo + ShowProcess() *util.ProcessInfo // PrePareTxnCtx is exported for test. PrepareTxnCtx(context.Context) // FieldList returns fields list of a table. @@ -469,7 +470,7 @@ func (s *session) CommitTxn(ctx context.Context) error { s.sessionVars.StmtCtx.MergeExecDetails(nil, commitDetail) } s.sessionVars.TxnCtx.Cleanup() - s.recordTransactionCounter(err) + s.recordTransactionCounter(nil, err) return err } @@ -579,7 +580,7 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { connID := s.sessionVars.ConnectionID s.sessionVars.RetryInfo.Retrying = true if s.sessionVars.TxnCtx.ForUpdate { - err = errForUpdateCantRetry.GenWithStackByArgs(connID) + err = ErrForUpdateCantRetry.GenWithStackByArgs(connID) return err } @@ -594,9 +595,12 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { for i, sr := range nh.history { st := sr.st s.sessionVars.StmtCtx = sr.stmtCtx + s.sessionVars.StartTime = time.Now() + s.sessionVars.DurationCompile = time.Duration(0) + s.sessionVars.DurationParse = time.Duration(0) s.sessionVars.StmtCtx.ResetForRetry() s.sessionVars.PreparedParams = s.sessionVars.PreparedParams[:0] - schemaVersion, err = st.RebuildPlan() + schemaVersion, err = st.RebuildPlan(ctx) if err != nil { return err } @@ -608,7 +612,7 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { zap.Int64("schemaVersion", schemaVersion), zap.Uint("retryCnt", retryCnt), zap.Int("queryNum", i), - zap.String("sql", sqlForLog(st.OriginText())+sessVars.GetExecuteArgumentsInfo())) + zap.String("sql", sqlForLog(st.OriginText())+sessVars.PreparedParams.String())) } else { logutil.Logger(ctx).Warn("retrying", zap.Int64("schemaVersion", schemaVersion), @@ -780,6 +784,10 @@ func createSessionFunc(store kv.Storage) pools.Factory { if err != nil { return nil, err } + err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxExecutionTime, types.NewUintDatum(0)) + if err != nil { + return nil, errors.Trace(err) + } se.sessionVars.CommonGlobalLoaded = true se.sessionVars.InRestrictedSQL = true return se, nil @@ -796,6 +804,10 @@ func createSessionWithDomainFunc(store kv.Storage) func(*domain.Domain) (pools.R if err != nil { return nil, err } + err = variable.SetSessionSystemVar(se.sessionVars, variable.MaxExecutionTime, types.NewUintDatum(0)) + if err != nil { + return nil, errors.Trace(err) + } se.sessionVars.CommonGlobalLoaded = true se.sessionVars.InRestrictedSQL = true return se, nil @@ -804,17 +816,17 @@ func createSessionWithDomainFunc(store kv.Storage) func(*domain.Domain) (pools.R func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet) ([]chunk.Row, error) { var rows []chunk.Row - req := rs.NewRecordBatch() + req := rs.NewChunk() for { err := rs.Next(ctx, req) if err != nil || req.NumRows() == 0 { return rows, err } - iter := chunk.NewIterator4Chunk(req.Chunk) + iter := chunk.NewIterator4Chunk(req) for r := iter.Begin(); r != iter.End(); r = iter.Next() { rows = append(rows, r) } - req.Chunk = chunk.Renew(req.Chunk, se.sessionVars.MaxChunkSize) + req = chunk.Renew(req, se.sessionVars.MaxChunkSize) } } @@ -907,24 +919,37 @@ func (s *session) ParseSQL(ctx context.Context, sql, charset, collation string) return s.parser.Parse(sql, charset, collation) } -func (s *session) SetProcessInfo(sql string, t time.Time, command byte) { +func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) { + var db interface{} + if len(s.sessionVars.CurrentDB) > 0 { + db = s.sessionVars.CurrentDB + } + + var info interface{} + if len(sql) > 0 { + info = sql + } pi := util.ProcessInfo{ - ID: s.sessionVars.ConnectionID, - DB: s.sessionVars.CurrentDB, - Command: command, - Plan: s.currentPlan, - Time: t, - State: s.Status(), - Info: sql, + ID: s.sessionVars.ConnectionID, + DB: db, + Command: command, + Plan: s.currentPlan, + Time: t, + State: s.Status(), + Info: info, + CurTxnStartTS: s.sessionVars.TxnCtx.StartTS, + StmtCtx: s.sessionVars.StmtCtx, + StatsInfo: plannercore.GetStatsInfo, + MaxExecutionTime: maxExecutionTime, } if s.sessionVars.User != nil { pi.User = s.sessionVars.User.Username pi.Host = s.sessionVars.User.Hostname } - s.processInfo.Store(pi) + s.processInfo.Store(&pi) } -func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet) ([]sqlexec.RecordSet, error) { +func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode ast.StmtNode, stmt sqlexec.Statement, recordSets []sqlexec.RecordSet, inMulitQuery bool) ([]sqlexec.RecordSet, error) { s.SetValue(sessionctx.QueryString, stmt.OriginText()) if _, ok := stmtNode.(ast.DDLNode); ok { s.SetValue(sessionctx.LastExecuteDDL, true) @@ -943,12 +968,23 @@ func (s *session) executeStatement(ctx context.Context, connID uint64, stmtNode } return nil, err } + s.recordTransactionCounter(stmtNode, err) if s.isInternal() { sessionExecuteRunDurationInternal.Observe(time.Since(startTime).Seconds()) } else { sessionExecuteRunDurationGeneral.Observe(time.Since(startTime).Seconds()) } + if inMulitQuery && recordSet == nil { + recordSet = &multiQueryNoDelayRecordSet{ + affectedRows: s.AffectedRows(), + lastMessage: s.LastMessage(), + warnCount: s.sessionVars.StmtCtx.WarningCount(), + lastInsertID: s.sessionVars.StmtCtx.LastInsertID, + status: s.sessionVars.Status, + } + } + if recordSet != nil { recordSets = append(recordSets, recordSet) } @@ -978,6 +1014,7 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec // Step1: Compile query string to abstract syntax trees(ASTs). startTS := time.Now() + s.GetSessionVars().StartTime = startTS stmtNodes, warns, err := s.ParseSQL(ctx, sql, charsetInfo, collation) if err != nil { s.rollbackOnError(ctx) @@ -986,15 +1023,18 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec zap.String("sql", sql)) return nil, util.SyntaxError(err) } + durParse := time.Since(startTS) + s.GetSessionVars().DurationParse = durParse isInternal := s.isInternal() if isInternal { - sessionExecuteParseDurationInternal.Observe(time.Since(startTS).Seconds()) + sessionExecuteParseDurationInternal.Observe(durParse.Seconds()) } else { - sessionExecuteParseDurationGeneral.Observe(time.Since(startTS).Seconds()) + sessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) } var tempStmtNodes []ast.StmtNode compiler := executor.Compiler{Ctx: s} + multiQuery := len(stmtNodes) > 1 for idx, stmtNode := range stmtNodes { s.PrepareTxnCtx(ctx) @@ -1023,15 +1063,17 @@ func (s *session) execute(ctx context.Context, sql string) (recordSets []sqlexec } s.handleInvalidBindRecord(ctx, stmtNode) } + durCompile := time.Since(startTS) + s.GetSessionVars().DurationCompile = durCompile if isInternal { - sessionExecuteCompileDurationInternal.Observe(time.Since(startTS).Seconds()) + sessionExecuteCompileDurationInternal.Observe(durCompile.Seconds()) } else { - sessionExecuteCompileDurationGeneral.Observe(time.Since(startTS).Seconds()) + sessionExecuteCompileDurationGeneral.Observe(durCompile.Seconds()) } s.currentPlan = stmt.Plan // Step3: Execute the physical plan. - if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets); err != nil { + if recordSets, err = s.executeStatement(ctx, connID, stmtNode, stmt, recordSets, multiQuery); err != nil { return nil, err } } @@ -1190,7 +1232,8 @@ func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, args . } s.PrepareTxnCtx(ctx) - st, err := executor.CompileExecutePreparedStmt(s, stmtID, args...) + s.sessionVars.StartTime = time.Now() + st, err := executor.CompileExecutePreparedStmt(ctx, s, stmtID, args...) if err != nil { return nil, err } @@ -1286,6 +1329,10 @@ func (s *session) Close() { if s.statsCollector != nil { s.statsCollector.Delete() } + bindValue := s.Value(bindinfo.SessionBindInfoKeyType) + if bindValue != nil { + bindValue.(*bindinfo.SessionHandle).Close() + } ctx := context.TODO() s.RollbackTxn(ctx) if s.sessionVars != nil { @@ -1343,13 +1390,6 @@ func getHostByIP(ip string) []string { return addrs } -func chooseMinLease(n1 time.Duration, n2 time.Duration) time.Duration { - if n1 <= n2 { - return n1 - } - return n2 -} - // CreateSession4Test creates a new session environment for test. func CreateSession4Test(store kv.Storage) (Session, error) { s, err := CreateSession(store) @@ -1402,7 +1442,7 @@ func loadSystemTZ(se *session) (string, error) { logutil.Logger(context.Background()).Error("close result set error", zap.Error(err)) } }() - req := rss[0].NewRecordBatch() + req := rss[0].NewChunk() if err := rss[0].Next(context.Background(), req); err != nil { return "", err } @@ -1445,6 +1485,7 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { timeutil.SetSystemTZ(tz) dom := domain.GetDomain(se) + dom.InitExpensiveQueryHandle() if !config.GetGlobalConfig().Security.SkipGrantTable { err = dom.LoadPrivilegeLoop(se) @@ -1460,6 +1501,16 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { } } + err = executor.LoadExprPushdownBlacklist(se) + if err != nil { + return nil, err + } + + err = executor.LoadOptRuleBlacklist(se) + if err != nil { + return nil, err + } + se1, err := createSession(store) if err != nil { return nil, err @@ -1476,7 +1527,6 @@ func BootstrapSession(store kv.Storage) (*domain.Domain, error) { if err != nil { return nil, err } - if raw, ok := store.(tikv.EtcdBackend); ok { err = raw.StartGCWorker() if err != nil { @@ -1497,14 +1547,11 @@ func GetDomain(store kv.Storage) (*domain.Domain, error) { // bootstrap quickly, after bootstrapped, we will reset the lease time. // TODO: Using a bootstrap tool for doing this may be better later. func runInBootstrapSession(store kv.Storage, bootstrap func(Session)) { - saveLease := schemaLease - schemaLease = chooseMinLease(schemaLease, 100*time.Millisecond) s, err := createSession(store) if err != nil { // Bootstrap fail will cause program exit. logutil.Logger(context.Background()).Fatal("createSession error", zap.Error(err)) } - schemaLease = saveLease s.SetValue(sessionctx.Initing, true) bootstrap(s) @@ -1564,7 +1611,7 @@ func createSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, er const ( notBootstrapped = 0 - currentBootstrapVersion = 32 + currentBootstrapVersion = 35 ) func getStoreBootstrapVersion(store kv.Storage) int64 { @@ -1625,6 +1672,7 @@ var builtinGlobalVariable = []string{ variable.WaitTimeout, variable.InteractiveTimeout, variable.MaxPreparedStmtCount, + variable.MaxExecutionTime, /* TiDB specific global variables: */ variable.TiDBSkipUTF8Check, variable.TiDBIndexJoinBatchSize, @@ -1653,6 +1701,10 @@ var builtinGlobalVariable = []string{ variable.TiDBDisableTxnAutoRetry, variable.TiDBEnableWindowFunction, variable.TiDBEnableFastAnalyze, + variable.TiDBExpensiveQueryTimeThreshold, + variable.TiDBTxnMode, + variable.TiDBEnableStmtSummary, + variable.TiDBMaxDeltaSchemaCount, } var ( @@ -1741,11 +1793,7 @@ func (s *session) PrepareTxnCtx(ctx context.Context) { if !s.sessionVars.IsAutocommit() { pessTxnConf := config.GetGlobalConfig().PessimisticTxn if pessTxnConf.Enable { - txnMode := s.sessionVars.TxnMode - if txnMode == "" && pessTxnConf.Default { - txnMode = ast.Pessimistic - } - if txnMode == ast.Pessimistic { + if s.sessionVars.TxnMode == ast.Pessimistic { s.sessionVars.TxnCtx.IsPessimistic = true } } @@ -1786,11 +1834,11 @@ func (s *session) GetStore() kv.Storage { return s.store } -func (s *session) ShowProcess() util.ProcessInfo { - var pi util.ProcessInfo +func (s *session) ShowProcess() *util.ProcessInfo { + var pi *util.ProcessInfo tmp := s.processInfo.Load() if tmp != nil { - pi = tmp.(util.ProcessInfo) + pi = tmp.(*util.ProcessInfo) } return pi } @@ -1832,7 +1880,7 @@ func logQuery(query string, vars *variable.SessionVars) { zap.Int64("schemaVersion", vars.TxnCtx.SchemaVersion), zap.Uint64("txnStartTS", vars.TxnCtx.StartTS), zap.String("current_db", vars.CurrentDB), - zap.String("sql", query+vars.GetExecuteArgumentsInfo())) + zap.String("sql", query+vars.PreparedParams.String())) } } @@ -1856,18 +1904,81 @@ func (s *session) recordOnTransactionExecution(err error, counter int, duration } } -func (s *session) recordTransactionCounter(err error) { - if s.isInternal() { - if err != nil { - transactionCounterInternalErr.Inc() +func (s *session) recordTransactionCounter(stmtNode ast.StmtNode, err error) { + if stmtNode == nil { + if s.isInternal() { + if err != nil { + transactionCounterInternalErr.Inc() + } else { + transactionCounterInternalOK.Inc() + } } else { - transactionCounterInternalOK.Inc() + if err != nil { + transactionCounterGeneralErr.Inc() + } else { + transactionCounterGeneralOK.Inc() + } } + return + } + + var isTxn bool + switch stmtNode.(type) { + case *ast.CommitStmt: + isTxn = true + case *ast.RollbackStmt: + isTxn = true + } + if !isTxn { + return + } + if s.isInternal() { + transactionCounterInternalCommitRollback.Inc() } else { - if err != nil { - transactionCounterGeneralErr.Inc() - } else { - transactionCounterGeneralOK.Inc() - } + transactionCounterGeneralCommitRollback.Inc() } } + +type multiQueryNoDelayRecordSet struct { + affectedRows uint64 + lastMessage string + status uint16 + warnCount uint16 + lastInsertID uint64 +} + +func (c *multiQueryNoDelayRecordSet) Fields() []*ast.ResultField { + panic("unsupported method") +} + +func (c *multiQueryNoDelayRecordSet) Next(ctx context.Context, chk *chunk.Chunk) error { + panic("unsupported method") +} + +func (c *multiQueryNoDelayRecordSet) NewChunk() *chunk.Chunk { + panic("unsupported method") +} + +func (c *multiQueryNoDelayRecordSet) Close() error { + return nil +} + +func (c *multiQueryNoDelayRecordSet) AffectedRows() uint64 { + return c.affectedRows +} + +func (c *multiQueryNoDelayRecordSet) LastMessage() string { + return c.lastMessage +} + +func (c *multiQueryNoDelayRecordSet) WarnCount() uint16 { + return c.warnCount +} + +func (c *multiQueryNoDelayRecordSet) Status() uint16 { + return c.status +} + +func (c *multiQueryNoDelayRecordSet) LastInsertID() uint64 { + return c.lastInsertID +} diff --git a/session/session_test.go b/session/session_test.go index 7cc34eab65959..2bce04245a38d 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -72,7 +72,7 @@ func (s *testSessionSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() s.dom, err = session.BootstrapSession(s.store) c.Assert(err, IsNil) } @@ -451,6 +451,13 @@ func (s *testSessionSuite) TestGlobalVarAccessor(c *C) { c.Assert(err, IsNil) c.Assert(v, Equals, varValue2) + // For issue 10955, make sure the new session load `max_execution_time` into sessionVars. + s.dom.GetGlobalVarsCache().Disable() + tk1.MustExec("set @@global.max_execution_time = 100") + tk2 := testkit.NewTestKitWithInit(c, s.store) + c.Assert(tk2.Se.GetSessionVars().MaxExecutionTime, Equals, uint64(100)) + tk1.MustExec("set @@global.max_execution_time = 0") + result := tk.MustQuery("show global variables where variable_name='sql_select_limit';") result.Check(testkit.Rows("sql_select_limit 18446744073709551615")) result = tk.MustQuery("show session variables where variable_name='sql_select_limit';") @@ -702,7 +709,10 @@ func (s *testSessionSuite) TestSkipWithGrant(c *C) { c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "xxx", Hostname: `%`}, []byte("yyy"), []byte("zzz")), IsTrue) c.Assert(tk.Se.Auth(&auth.UserIdentity{Username: "root", Hostname: `%`}, []byte(""), []byte("")), IsTrue) tk.MustExec("create table t (id int)") - + tk.MustExec("create role r_1") + tk.MustExec("grant r_1 to root") + tk.MustExec("set role all") + tk.MustExec("show grants for root") privileges.SkipWithGrant = save2 } @@ -832,8 +842,11 @@ func (s *testSessionSuite) TestAutoIncrementID(c *C) { tk.MustExec("insert into autoid values();") tk.MustExec("insert into autoid values();") tk.MustQuery("select * from autoid").Check(testkit.Rows("9223372036854775808", "9223372036854775810", "9223372036854775812")) - tk.MustExec("insert into autoid values(18446744073709551614);") - _, err := tk.Exec("insert into autoid values()") + // In TiDB : _tidb_rowid will also consume the autoID when the auto_increment column is not the primary key. + // Using the MaxUint64 and MaxInt64 as the autoID upper limit like MySQL will cause _tidb_rowid allocation fail here. + _, err := tk.Exec("insert into autoid values(18446744073709551614)") + c.Assert(terror.ErrorEqual(err, autoid.ErrAutoincReadFailed), IsTrue) + _, err = tk.Exec("insert into autoid values()") c.Assert(terror.ErrorEqual(err, autoid.ErrAutoincReadFailed), IsTrue) // FixMe: MySQL works fine with the this sql. _, err = tk.Exec("insert into autoid values(18446744073709551615)") @@ -847,35 +860,38 @@ func (s *testSessionSuite) TestAutoIncrementID(c *C) { tk.MustQuery("select * from autoid").Check(testkit.Rows("1", "5000")) _, err = tk.Exec("update autoid set auto_inc_id = 8000") c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) - tk.MustQuery("select * from autoid").Check(testkit.Rows("1", "5000")) + tk.MustQuery("select * from autoid use index()").Check(testkit.Rows("1", "5000")) tk.MustExec("update autoid set auto_inc_id = 9000 where auto_inc_id=1") - tk.MustQuery("select * from autoid").Check(testkit.Rows("9000", "5000")) + tk.MustQuery("select * from autoid use index()").Check(testkit.Rows("9000", "5000")) tk.MustExec("insert into autoid values()") - tk.MustQuery("select * from autoid").Check(testkit.Rows("9000", "5000", "9001")) + tk.MustQuery("select * from autoid use index()").Check(testkit.Rows("9000", "5000", "9001")) // Corner cases for signed bigint auto_increment Columns. tk.MustExec("drop table if exists autoid") tk.MustExec("create table autoid(`auto_inc_id` bigint(20) NOT NULL AUTO_INCREMENT,UNIQUE KEY `auto_inc_id` (`auto_inc_id`))") - tk.MustExec("insert into autoid values(9223372036854775806);") - tk.MustQuery("select auto_inc_id, _tidb_rowid from autoid").Check(testkit.Rows("9223372036854775806 9223372036854775807")) + // In TiDB : _tidb_rowid will also consume the autoID when the auto_increment column is not the primary key. + // Using the MaxUint64 and MaxInt64 as autoID upper limit like MySQL will cause insert fail if the values is + // 9223372036854775806. Because _tidb_rowid will be allocated 9223372036854775807 at same time. + tk.MustExec("insert into autoid values(9223372036854775805);") + tk.MustQuery("select auto_inc_id, _tidb_rowid from autoid use index()").Check(testkit.Rows("9223372036854775805 9223372036854775806")) _, err = tk.Exec("insert into autoid values();") c.Assert(terror.ErrorEqual(err, autoid.ErrAutoincReadFailed), IsTrue) - tk.MustQuery("select auto_inc_id, _tidb_rowid from autoid").Check(testkit.Rows("9223372036854775806 9223372036854775807")) - tk.MustQuery("select auto_inc_id, _tidb_rowid from autoid use index(auto_inc_id)").Check(testkit.Rows("9223372036854775806 9223372036854775807")) + tk.MustQuery("select auto_inc_id, _tidb_rowid from autoid use index()").Check(testkit.Rows("9223372036854775805 9223372036854775806")) + tk.MustQuery("select auto_inc_id, _tidb_rowid from autoid use index(auto_inc_id)").Check(testkit.Rows("9223372036854775805 9223372036854775806")) tk.MustExec("drop table if exists autoid") tk.MustExec("create table autoid(`auto_inc_id` bigint(20) NOT NULL AUTO_INCREMENT,UNIQUE KEY `auto_inc_id` (`auto_inc_id`))") tk.MustExec("insert into autoid values()") - tk.MustQuery("select * from autoid").Check(testkit.Rows("1")) + tk.MustQuery("select * from autoid use index()").Check(testkit.Rows("1")) tk.MustExec("insert into autoid values(5000)") - tk.MustQuery("select * from autoid").Check(testkit.Rows("1", "5000")) + tk.MustQuery("select * from autoid use index()").Check(testkit.Rows("1", "5000")) _, err = tk.Exec("update autoid set auto_inc_id = 8000") c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue) - tk.MustQuery("select * from autoid").Check(testkit.Rows("1", "5000")) + tk.MustQuery("select * from autoid use index()").Check(testkit.Rows("1", "5000")) tk.MustExec("update autoid set auto_inc_id = 9000 where auto_inc_id=1") - tk.MustQuery("select * from autoid").Check(testkit.Rows("9000", "5000")) + tk.MustQuery("select * from autoid use index()").Check(testkit.Rows("9000", "5000")) tk.MustExec("insert into autoid values()") - tk.MustQuery("select * from autoid").Check(testkit.Rows("9000", "5000", "9001")) + tk.MustQuery("select * from autoid use index()").Check(testkit.Rows("9000", "5000", "9001")) } func (s *testSessionSuite) TestAutoIncrementWithRetry(c *C) { @@ -1152,7 +1168,7 @@ func (s *testSessionSuite) TestResultType(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) rs, err := tk.Exec(`select cast(null as char(30))`) c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(context.Background(), req) c.Assert(err, IsNil) c.Assert(req.GetRow(0).IsNull(0), IsTrue) @@ -1502,12 +1518,12 @@ func (s *testSessionSuite) TestUnique(c *C) { c.Assert(err, NotNil) // Check error type and error message c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue, Commentf("err %v", err)) - c.Assert(err.Error(), Equals, "[kv:1062]Duplicate entry '1' for key 'PRIMARY'") + c.Assert(err.Error(), Equals, "previous statement: insert into test(id, val) values(1, 1);: [kv:1062]Duplicate entry '1' for key 'PRIMARY'") _, err = tk1.Exec("commit") c.Assert(err, NotNil) c.Assert(terror.ErrorEqual(err, kv.ErrKeyExists), IsTrue, Commentf("err %v", err)) - c.Assert(err.Error(), Equals, "[kv:1062]Duplicate entry '2' for key 'val'") + c.Assert(err.Error(), Equals, "previous statement: insert into test(id, val) values(2, 2);: [kv:1062]Duplicate entry '2' for key 'val'") // Test for https://github.com/pingcap/tidb/issues/463 tk.MustExec("drop table test;") @@ -1700,7 +1716,7 @@ func (s *testSchemaSuite) SetUpSuite(c *C) { s.store = store s.lease = 20 * time.Millisecond session.SetSchemaLease(s.lease) - session.SetStatsLease(0) + session.DisableStats4Test() dom, err := session.BootstrapSession(s.store) c.Assert(err, IsNil) s.dom = dom @@ -1912,7 +1928,7 @@ func (s *testSchemaSuite) TestTableReaderChunk(c *C) { }() rs, err := tk.Exec("select * from chk") c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() var count int var numChunks int for { @@ -1949,7 +1965,7 @@ func (s *testSchemaSuite) TestInsertExecChunk(c *C) { c.Assert(err, IsNil) var idx int for { - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) if req.NumRows() == 0 { @@ -1983,7 +1999,7 @@ func (s *testSchemaSuite) TestUpdateExecChunk(c *C) { c.Assert(err, IsNil) var idx int for { - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) if req.NumRows() == 0 { @@ -2018,7 +2034,7 @@ func (s *testSchemaSuite) TestDeleteExecChunk(c *C) { rs, err := tk.Exec("select * from chk") c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) c.Assert(req.NumRows(), Equals, 1) @@ -2050,7 +2066,7 @@ func (s *testSchemaSuite) TestDeleteMultiTableExecChunk(c *C) { var idx int for { - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) @@ -2070,7 +2086,7 @@ func (s *testSchemaSuite) TestDeleteMultiTableExecChunk(c *C) { rs, err = tk.Exec("select * from chk2") c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() err = rs.Next(context.TODO(), req) c.Assert(err, IsNil) c.Assert(req.NumRows(), Equals, 0) @@ -2093,7 +2109,7 @@ func (s *testSchemaSuite) TestIndexLookUpReaderChunk(c *C) { tk.Se.GetSessionVars().IndexLookupSize = 10 rs, err := tk.Exec("select * from chk order by k") c.Assert(err, IsNil) - req := rs.NewRecordBatch() + req := rs.NewChunk() var count int for { err = rs.Next(context.TODO(), req) @@ -2113,7 +2129,7 @@ func (s *testSchemaSuite) TestIndexLookUpReaderChunk(c *C) { rs, err = tk.Exec("select k from chk where c < 90 order by k") c.Assert(err, IsNil) - req = rs.NewRecordBatch() + req = rs.NewChunk() count = 0 for { err = rs.Next(context.TODO(), req) @@ -2335,7 +2351,7 @@ func (s *testSessionSuite) TestKVVars(c *C) { tk.MustExec("insert kvvars values (1, 1)") tk2 := testkit.NewTestKitWithInit(c, s.store) tk2.MustExec("set @@tidb_backoff_lock_fast = 1") - tk2.MustExec("set @@tidb_back_off_weight = 100") + tk2.MustExec("set @@tidb_backoff_weight = 100") backoffVal := new(int64) backOffWeightVal := new(int32) tk2.Se.GetSessionVars().KVVars.Hook = func(name string, vars *kv.Variables) { @@ -2637,6 +2653,32 @@ func (s *testSessionSuite) TestTxnGoString(c *C) { c.Assert(fmt.Sprintf("%#v", txn), Equals, "Txn{state=invalid}") } +func (s *testSessionSuite) TestMaxExeucteTime(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + + tk.MustExec("create table MaxExecTime( id int,name varchar(128),age int);") + tk.MustExec("begin") + tk.MustExec("insert into MaxExecTime (id,name,age) values (1,'john',18),(2,'lary',19),(3,'lily',18);") + + tk.MustQuery("select @@MAX_EXECUTION_TIME;").Check(testkit.Rows("0")) + tk.MustQuery("select @@global.MAX_EXECUTION_TIME;").Check(testkit.Rows("0")) + tk.MustQuery("select /*+ MAX_EXECUTION_TIME(1000) */ * FROM MaxExecTime;") + + tk.MustExec("set @@global.MAX_EXECUTION_TIME = 300;") + tk.MustQuery("select * FROM MaxExecTime;") + + tk.MustExec("set @@MAX_EXECUTION_TIME = 150;") + tk.MustQuery("select * FROM MaxExecTime;") + + tk.MustQuery("select @@global.MAX_EXECUTION_TIME;").Check(testkit.Rows("300")) + tk.MustQuery("select @@MAX_EXECUTION_TIME;").Check(testkit.Rows("150")) + + tk.MustExec("set @@global.MAX_EXECUTION_TIME = 0;") + tk.MustExec("set @@MAX_EXECUTION_TIME = 0;") + tk.MustExec("commit") + tk.MustExec("drop table if exists MaxExecTime;") +} + func (s *testSessionSuite) TestGrantViewRelated(c *C) { tkRoot := testkit.NewTestKitWithInit(c, s.store) tkUser := testkit.NewTestKitWithInit(c, s.store) diff --git a/session/tidb.go b/session/tidb.go index 0863a1fc3f63c..5acf60bc896c5 100644 --- a/session/tidb.go +++ b/session/tidb.go @@ -35,6 +35,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/logutil" @@ -64,10 +65,8 @@ func (dm *domainMap) Get(store kv.Storage) (d *domain.Domain, err error) { return } - ddlLease := time.Duration(0) - statisticLease := time.Duration(0) - ddlLease = schemaLease - statisticLease = statsLease + ddlLease := schemaLease + statisticLease := statsLease err = util.RunWithRetry(util.DefaultMaxRetries, util.RetryInterval, func() (retry bool, err1 error) { logutil.Logger(context.Background()).Info("new domain", zap.String("store", store.UUID()), @@ -131,6 +130,11 @@ func SetStatsLease(lease time.Duration) { statsLease = lease } +// DisableStats4Test disables the stats for tests. +func DisableStats4Test() { + statsLease = -1 +} + // Parse parses a query string to raw ast.StmtNode. func Parse(ctx sessionctx.Context, src string) ([]ast.StmtNode, error) { logutil.Logger(context.Background()).Debug("compiling", zap.String("source", src)) @@ -158,21 +162,27 @@ func Compile(ctx context.Context, sctx sessionctx.Context, stmtNode ast.StmtNode return stmt, err } -func finishStmt(ctx context.Context, sctx sessionctx.Context, se *session, sessVars *variable.SessionVars, meetsErr error) error { +func finishStmt(ctx context.Context, sctx sessionctx.Context, se *session, sessVars *variable.SessionVars, + meetsErr error, sql sqlexec.Statement) error { if meetsErr != nil { if !sessVars.InTxn() { logutil.Logger(context.Background()).Info("rollbackTxn for ddl/autocommit error.") se.RollbackTxn(ctx) - } else if se.txn.Valid() && se.txn.IsPessimistic() && strings.Contains(meetsErr.Error(), "deadlock") { + } else if se.txn.Valid() && se.txn.IsPessimistic() && executor.ErrDeadlock.Equal(meetsErr) { logutil.Logger(context.Background()).Info("rollbackTxn for deadlock error", zap.Uint64("txn", se.txn.StartTS())) - meetsErr = errDeadlock se.RollbackTxn(ctx) } return meetsErr } if !sessVars.InTxn() { - return se.CommitTxn(ctx) + if err := se.CommitTxn(ctx); err != nil { + if _, ok := sql.(*executor.ExecStmt).StmtNode.(*ast.CommitStmt); ok { + err = errors.Annotatef(err, "previous statement: %s", se.GetSessionVars().PrevStmt) + } + return err + } + return nil } return checkStmtLimit(ctx, sctx, se, sessVars) @@ -207,11 +217,17 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) defer span1.Finish() } se := sctx.(*session) + sessVars := se.GetSessionVars() + // Save origTxnCtx here to avoid it reset in the transaction retry. + origTxnCtx := sessVars.TxnCtx defer func() { // If it is not a select statement, we record its slow log here, // then it could include the transaction commit time. if rs == nil { - s.(*executor.ExecStmt).LogSlowQuery(se.GetSessionVars().TxnCtx.StartTS, err != nil) + s.(*executor.ExecStmt).LogSlowQuery(origTxnCtx.StartTS, err == nil, false) + s.(*executor.ExecStmt).SummaryStmt() + pps := types.CloneRow(sessVars.PreparedParams) + sessVars.PrevStmt = executor.FormatSQL(s.OriginText(), pps) } }() @@ -220,11 +236,10 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) return nil, err } rs, err = s.Exec(ctx) - sessVars := se.GetSessionVars() // All the history should be added here. sessVars.TxnCtx.StatementCount++ if !s.IsReadOnly(sessVars) { - if err == nil { + if err == nil && !sessVars.TxnCtx.IsPessimistic { GetHistory(sctx).Add(0, s, se.sessionVars.StmtCtx) } if txn, err1 := sctx.Txn(false); err1 == nil { @@ -239,8 +254,7 @@ func runStmt(ctx context.Context, sctx sessionctx.Context, s sqlexec.Statement) logutil.Logger(context.Background()).Error("get txn error", zap.Error(err1)) } } - - err = finishStmt(ctx, sctx, se, sessVars, err) + err = finishStmt(ctx, sctx, se, sessVars, err, s) if se.txn.pending() { // After run statement finish, txn state is still pending means the // statement never need a Txn(), such as: @@ -272,11 +286,9 @@ func GetRows4Test(ctx context.Context, sctx sessionctx.Context, rs sqlexec.Recor return nil, nil } var rows []chunk.Row - req := rs.NewRecordBatch() + req := rs.NewChunk() + // Must reuse `req` for imitating server.(*clientConn).writeChunks for { - // Since we collect all the rows, we can not reuse the chunk. - iter := chunk.NewIterator4Chunk(req.Chunk) - err := rs.Next(ctx, req) if err != nil { return nil, err @@ -285,10 +297,10 @@ func GetRows4Test(ctx context.Context, sctx sessionctx.Context, rs sqlexec.Recor break } + iter := chunk.NewIterator4Chunk(req.CopyConstruct()) for row := iter.Begin(); row != iter.End(); row = iter.Next() { rows = append(rows, row) } - req.Chunk = chunk.Renew(req.Chunk, sctx.GetSessionVars().MaxChunkSize) } return rows, nil } @@ -325,21 +337,19 @@ func IsQuery(sql string) bool { return false } +// Session errors. var ( - errForUpdateCantRetry = terror.ClassSession.New(codeForUpdateCantRetry, + ErrForUpdateCantRetry = terror.ClassSession.New(codeForUpdateCantRetry, mysql.MySQLErrName[mysql.ErrForUpdateCantRetry]) - errDeadlock = terror.ClassSession.New(codeDeadlock, mysql.MySQLErrName[mysql.ErrLockDeadlock]) ) const ( codeForUpdateCantRetry terror.ErrCode = mysql.ErrForUpdateCantRetry - codeDeadlock terror.ErrCode = mysql.ErrLockDeadlock ) func init() { sessionMySQLErrCodes := map[terror.ErrCode]uint16{ codeForUpdateCantRetry: mysql.ErrForUpdateCantRetry, - codeDeadlock: mysql.ErrLockDeadlock, } terror.ErrClassToMySQLCodes[terror.ClassSession] = sessionMySQLErrCodes } diff --git a/session/txn.go b/session/txn.go index d226614da3e7f..254529b764e93 100644 --- a/session/txn.go +++ b/session/txn.go @@ -314,13 +314,7 @@ func (st *TxnState) KeysNeedToLock() ([]kv.Key, error) { if !keyNeedToLock(k, v) { return nil } - if mb := st.Transaction.GetMemBuffer(); mb != nil { - _, err1 := mb.Get(k) - if err1 == nil { - // Key is already in txn MemBuffer, must already been locked, we don't need to lock it again. - return nil - } - } + // If the key is already locked, it will be deduplicated in LockKeys method later. // The statement MemBuffer will be reused, so we must copy the key here. keys = append(keys, append([]byte{}, k...)) return nil diff --git a/sessionctx/binloginfo/binloginfo.go b/sessionctx/binloginfo/binloginfo.go index d4c7bb0be8ef1..67c1953fc9114 100644 --- a/sessionctx/binloginfo/binloginfo.go +++ b/sessionctx/binloginfo/binloginfo.go @@ -15,6 +15,7 @@ package binloginfo import ( "context" + "math" "regexp" "strings" "sync" @@ -29,7 +30,7 @@ import ( "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/logutil" - binlog "github.com/pingcap/tipb/go-binlog" + "github.com/pingcap/tipb/go-binlog" "go.uber.org/zap" "google.golang.org/grpc" ) @@ -42,6 +43,8 @@ func init() { // shared by all sessions. var pumpsClient *pumpcli.PumpsClient var pumpsClientLock sync.RWMutex +var shardPat = regexp.MustCompile(`SHARD_ROW_ID_BITS\s*=\s*\d+\s*`) +var preSplitPat = regexp.MustCompile(`PRE_SPLIT_REGIONS\s*=\s*\d+\s*`) // BinlogInfo contains binlog data and binlog client. type BinlogInfo struct { @@ -136,7 +139,7 @@ func SetDDLBinlog(client *pumpcli.PumpsClient, txn kv.Transaction, jobID int64, return } - ddlQuery = addSpecialComment(ddlQuery) + ddlQuery = AddSpecialComment(ddlQuery) info := &BinlogInfo{ Data: &binlog.Binlog{ Tp: binlog.BinlogType_Prewrite, @@ -150,18 +153,52 @@ func SetDDLBinlog(client *pumpcli.PumpsClient, txn kv.Transaction, jobID int64, const specialPrefix = `/*!90000 ` -func addSpecialComment(ddlQuery string) string { +// AddSpecialComment uses to add comment for table option in DDL query. +// Export for testing. +func AddSpecialComment(ddlQuery string) string { if strings.Contains(ddlQuery, specialPrefix) { return ddlQuery } + return addSpecialCommentByRegexps(ddlQuery, shardPat, preSplitPat) +} + +// addSpecialCommentByRegexps uses to add special comment for the worlds in the ddlQuery with match the regexps. +func addSpecialCommentByRegexps(ddlQuery string, regs ...*regexp.Regexp) string { upperQuery := strings.ToUpper(ddlQuery) - reg, err := regexp.Compile(`SHARD_ROW_ID_BITS\s*=\s*\d+`) - terror.Log(err) - loc := reg.FindStringIndex(upperQuery) - if len(loc) < 2 { - return ddlQuery + var specialComments []string + minIdx := math.MaxInt64 + for i := 0; i < len(regs); { + reg := regs[i] + loc := reg.FindStringIndex(upperQuery) + if len(loc) < 2 { + i++ + continue + } + specialComments = append(specialComments, ddlQuery[loc[0]:loc[1]]) + if loc[0] < minIdx { + minIdx = loc[0] + } + ddlQuery = ddlQuery[:loc[0]] + ddlQuery[loc[1]:] + upperQuery = upperQuery[:loc[0]] + upperQuery[loc[1]:] + } + if minIdx != math.MaxInt64 { + query := ddlQuery[:minIdx] + specialPrefix + for _, comment := range specialComments { + if query[len(query)-1] != ' ' { + query += " " + } + query += comment + } + if query[len(query)-1] != ' ' { + query += " " + } + query += "*/" + if len(ddlQuery[minIdx:]) > 0 { + return query + " " + ddlQuery[minIdx:] + } + return query } - return ddlQuery[:loc[0]] + specialPrefix + ddlQuery[loc[0]:loc[1]] + ` */` + ddlQuery[loc[1]:] + return ddlQuery } // MockPumpsClient creates a PumpsClient, used for test. diff --git a/sessionctx/binloginfo/binloginfo_test.go b/sessionctx/binloginfo/binloginfo_test.go index f960d14ef95c4..a10b064811d9a 100644 --- a/sessionctx/binloginfo/binloginfo_test.go +++ b/sessionctx/binloginfo/binloginfo_test.go @@ -129,8 +129,8 @@ func (s *testBinlogSuite) TestBinlog(c *C) { tk.Se.GetSessionVars().BinlogClient = s.client pump := s.pump tk.MustExec("drop table if exists local_binlog") - ddlQuery := "create table local_binlog (id int primary key, name varchar(10)) shard_row_id_bits=1" - binlogDDLQuery := "create table local_binlog (id int primary key, name varchar(10)) /*!90000 shard_row_id_bits=1 */" + ddlQuery := "create table local_binlog (id int unique key, name varchar(10)) shard_row_id_bits=1" + binlogDDLQuery := "create table local_binlog (id int unique key, name varchar(10)) /*!90000 shard_row_id_bits=1 */" tk.MustExec(ddlQuery) var matched bool // got matched pre DDL and commit DDL for i := 0; i < 10; i++ { @@ -155,7 +155,7 @@ func (s *testBinlogSuite) TestBinlog(c *C) { {types.NewIntDatum(1), types.NewStringDatum("abc")}, {types.NewIntDatum(2), types.NewStringDatum("cde")}, } - gotRows := mutationRowsToRows(c, prewriteVal.Mutations[0].InsertedRows, 0, 2) + gotRows := mutationRowsToRows(c, prewriteVal.Mutations[0].InsertedRows, 2, 4) c.Assert(gotRows, DeepEquals, expected) tk.MustExec("update local_binlog set name = 'xyz' where id = 2") @@ -169,7 +169,7 @@ func (s *testBinlogSuite) TestBinlog(c *C) { gotRows = mutationRowsToRows(c, prewriteVal.Mutations[0].UpdatedRows, 1, 3) c.Assert(gotRows, DeepEquals, oldRow) - gotRows = mutationRowsToRows(c, prewriteVal.Mutations[0].UpdatedRows, 5, 7) + gotRows = mutationRowsToRows(c, prewriteVal.Mutations[0].UpdatedRows, 7, 9) c.Assert(gotRows, DeepEquals, newRow) tk.MustExec("delete from local_binlog where id = 1") @@ -431,3 +431,44 @@ func (s *testBinlogSuite) TestDeleteSchema(c *C) { tk.MustExec("delete from b1 where job_id in (select job_id from b2 where batch_class = 'TEST') or split_job_id in (select job_id from b2 where batch_class = 'TEST');") tk.MustExec("delete b1 from b2 right join b1 on b1.job_id = b2.job_id and batch_class = 'TEST';") } + +func (s *testBinlogSuite) TestAddSpecialComment(c *C) { + testCase := []struct { + input string + result string + }{ + { + "create table t1 (id int ) shard_row_id_bits=2;", + "create table t1 (id int ) /*!90000 shard_row_id_bits=2 */ ;", + }, + { + "create table t1 (id int ) shard_row_id_bits=2 pre_split_regions=2;", + "create table t1 (id int ) /*!90000 shard_row_id_bits=2 pre_split_regions=2 */ ;", + }, + { + "create table t1 (id int ) shard_row_id_bits=2 pre_split_regions=2;", + "create table t1 (id int ) /*!90000 shard_row_id_bits=2 pre_split_regions=2 */ ;", + }, + + { + "create table t1 (id int ) shard_row_id_bits=2 engine=innodb pre_split_regions=2;", + "create table t1 (id int ) /*!90000 shard_row_id_bits=2 pre_split_regions=2 */ engine=innodb ;", + }, + { + "create table t1 (id int ) pre_split_regions=2 shard_row_id_bits=2;", + "create table t1 (id int ) /*!90000 shard_row_id_bits=2 pre_split_regions=2 */ ;", + }, + { + "create table t6 (id int ) shard_row_id_bits=2 shard_row_id_bits=3 pre_split_regions=2;", + "create table t6 (id int ) /*!90000 shard_row_id_bits=2 shard_row_id_bits=3 pre_split_regions=2 */ ;", + }, + { + "alter table t shard_row_id_bits=2 ", + "alter table t /*!90000 shard_row_id_bits=2 */", + }, + } + for _, ca := range testCase { + re := binloginfo.AddSpecialComment(ca.input) + c.Assert(re, Equals, ca.result) + } +} diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index d22d65f3ff4a3..2812ffea60bd7 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/memory" + "go.uber.org/zap" ) const ( @@ -54,6 +55,7 @@ type StatementContext struct { InDeleteStmt bool InSelectStmt bool InLoadDataStmt bool + InExplainStmt bool IgnoreTruncate bool IgnoreZeroInDate bool DupKeyAsWarning bool @@ -67,6 +69,11 @@ type StatementContext struct { BatchCheck bool InNullRejectCheck bool AllowInvalidDate bool + // CastStrToIntStrict is used to control the way we cast float format string to int. + // If ConvertStrToIntStrict is false, we convert it to a valid float string first, + // then cast the float string to int string. Otherwise, we cast string to integer + // prefix in a strict way, only extract 0-9 and (+ or - in first bit). + CastStrToIntStrict bool // mu struct holds variables that change during execution. mu struct { @@ -116,9 +123,9 @@ type StatementContext struct { MemTracker *memory.Tracker RuntimeStatsColl *execdetails.RuntimeStatsColl TableIDs []int64 - IndexIDs []int64 - NowTs time.Time - SysTs time.Time + IndexNames []string + nowTs time.Time // use this variable for now/current_timestamp calculation/cache for one stmt + stmtTimeCached bool StmtType string OriginalSQL string digestMemo struct { @@ -129,6 +136,21 @@ type StatementContext struct { Tables []TableEntry } +// GetNowTsCached getter for nowTs, if not set get now time and cache it +func (sc *StatementContext) GetNowTsCached() time.Time { + if !sc.stmtTimeCached { + now := time.Now() + sc.nowTs = now + sc.stmtTimeCached = true + } + return sc.nowTs +} + +// ResetNowTs resetter for nowTs, clear cached time flag +func (sc *StatementContext) ResetNowTs() { + sc.stmtTimeCached = false +} + // SQLDigest gets normalized and digest for provided sql. // it will cache result after first calling. func (sc *StatementContext) SQLDigest() (normalized, sqlDigest string) { @@ -397,7 +419,7 @@ func (sc *StatementContext) ResetForRetry() { sc.mu.allExecDetails = make([]*execdetails.ExecDetails, 0, 4) sc.mu.Unlock() sc.TableIDs = sc.TableIDs[:0] - sc.IndexIDs = sc.IndexIDs[:0] + sc.IndexNames = sc.IndexNames[:0] } // MergeExecDetails merges a single region execution details into self, used to print @@ -486,3 +508,21 @@ type CopTasksDetails struct { MaxWaitAddress string MaxWaitTime time.Duration } + +// ToZapFields wraps the CopTasksDetails as zap.Fileds. +func (d *CopTasksDetails) ToZapFields() (fields []zap.Field) { + if d.NumCopTasks == 0 { + return + } + fields = make([]zap.Field, 0, 10) + fields = append(fields, zap.Int("num_cop_tasks", d.NumCopTasks)) + fields = append(fields, zap.String("process_avg_time", strconv.FormatFloat(d.AvgProcessTime.Seconds(), 'f', -1, 64)+"s")) + fields = append(fields, zap.String("process_p90_time", strconv.FormatFloat(d.P90ProcessTime.Seconds(), 'f', -1, 64)+"s")) + fields = append(fields, zap.String("process_max_time", strconv.FormatFloat(d.MaxProcessTime.Seconds(), 'f', -1, 64)+"s")) + fields = append(fields, zap.String("process_max_addr", d.MaxProcessAddress)) + fields = append(fields, zap.String("wait_avg_time", strconv.FormatFloat(d.AvgWaitTime.Seconds(), 'f', -1, 64)+"s")) + fields = append(fields, zap.String("wait_p90_time", strconv.FormatFloat(d.P90WaitTime.Seconds(), 'f', -1, 64)+"s")) + fields = append(fields, zap.String("wait_max_time", strconv.FormatFloat(d.MaxWaitTime.Seconds(), 'f', -1, 64)+"s")) + fields = append(fields, zap.String("wait_max_addr", d.MaxWaitAddress)) + return fields +} diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 1d27574a76875..349d0762c073d 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -111,7 +111,7 @@ type TransactionContext struct { TableDeltaMap map[int64]TableDelta IsPessimistic bool - // For metrics. + // CreateTime For metrics. CreateTime time.Time StatementCount int } @@ -205,15 +205,14 @@ type SessionVars struct { PreparedStmtNameToID map[string]uint32 // preparedStmtID is id of prepared statement. preparedStmtID uint32 - // params for prepared statements - PreparedParams []types.Datum + // PreparedParams params for prepared statements + PreparedParams PreparedParams // ActiveRoles stores active roles for current user ActiveRoles []*auth.RoleIdentity - // retry information RetryInfo *RetryInfo - // Should be reset on transaction finished. + // TxnCtx Should be reset on transaction finished. TxnCtx *TransactionContext // KVVars is the variables for KV storage. @@ -221,9 +220,9 @@ type SessionVars struct { // TxnIsolationLevelOneShot is used to implements "set transaction isolation level ..." TxnIsolationLevelOneShot struct { - // state 0 means default - // state 1 means it's set in current transaction. - // state 2 means it should be used in current transaction. + // State 0 means default + // State 1 means it's set in current transaction. + // State 2 means it should be used in current transaction. State int Value string } @@ -342,8 +341,11 @@ type SessionVars struct { // DDLReorgPriority is the operation priority of adding indices. DDLReorgPriority int - // WaitTableSplitFinish defines the create table pre-split behaviour is sync or async. - WaitTableSplitFinish bool + // WaitSplitRegionFinish defines the split region behaviour is sync or async. + WaitSplitRegionFinish bool + + // WaitSplitRegionTimeout defines the split region timeout. + WaitSplitRegionTimeout uint64 // EnableStreaming indicates whether the coprocessor request can use streaming API. // TODO: remove this after tidb-server configuration "enable-streaming' removed. @@ -364,7 +366,7 @@ type SessionVars struct { // CommandValue indicates which command current session is doing. CommandValue uint32 - // TIDBOptJoinOrderAlgoThreshold defines the minimal number of join nodes + // TiDBOptJoinReorderThreshold defines the minimal number of join nodes // to use the greedy join reorder algorithm. TiDBOptJoinReorderThreshold int @@ -379,6 +381,42 @@ type SessionVars struct { // LowResolutionTSO is used for reading data with low resolution TSO which is updated once every two seconds. LowResolutionTSO bool + + // MaxExecutionTime is the timeout for select statement, in milliseconds. + // If the value is 0, timeouts are not enabled. + // See https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_max_execution_time + MaxExecutionTime uint64 + + // Killed is a flag to indicate that this query is killed. + Killed uint32 + + // ConnectionInfo indicates current connection info used by current session, only be lazy assigned by plugin. + ConnectionInfo *ConnectionInfo + + // StartTime is the start time of the last query. + StartTime time.Time + + // DurationParse is the duration of parsing SQL string to AST of the last query. + DurationParse time.Duration + + // DurationCompile is the duration of compiling AST to execution plan of the last query. + DurationCompile time.Duration + + // PrevStmt is used to store the previous executed statement in the current session. + PrevStmt fmt.Stringer + + // AllowRemoveAutoInc indicates whether a user can drop the auto_increment column attribute or not. + AllowRemoveAutoInc bool +} + +// PreparedParams contains the parameters of the current prepared statement when executing it. +type PreparedParams []types.Datum + +func (pps PreparedParams) String() string { + if len(pps) == 0 { + return "" + } + return " [arguments: " + types.DatumsToStrNoErr(pps) + "]" } // ConnectionInfo present connection used by audit. @@ -429,6 +467,9 @@ func NewSessionVars() *SessionVars { CommandValue: uint32(mysql.ComSleep), TiDBOptJoinReorderThreshold: DefTiDBOptJoinReorderThreshold, SlowQueryFile: config.GetGlobalConfig().Log.SlowQueryFile, + WaitSplitRegionFinish: DefTiDBWaitSplitRegionFinish, + WaitSplitRegionTimeout: DefWaitSplitRegionTimeout, + AllowRemoveAutoInc: DefTiDBAllowRemoveAutoInc, } vars.Concurrency = Concurrency{ IndexLookupConcurrency: DefIndexLookupConcurrency, @@ -473,6 +514,11 @@ func (s *SessionVars) GetWriteStmtBufs() *WriteStmtBufs { return &s.writeStmtBufs } +// GetSplitRegionTimeout gets split region timeout. +func (s *SessionVars) GetSplitRegionTimeout() time.Duration { + return time.Duration(s.WaitSplitRegionTimeout) * time.Second +} + // CleanBuffers cleans the temporary bufs func (s *SessionVars) CleanBuffers() { if !s.LightningMode { @@ -548,26 +594,6 @@ func (s *SessionVars) Location() *time.Location { return loc } -// GetExecuteArgumentsInfo gets the argument list as a string of execute statement. -func (s *SessionVars) GetExecuteArgumentsInfo() string { - if len(s.PreparedParams) == 0 { - return "" - } - args := make([]string, 0, len(s.PreparedParams)) - for _, v := range s.PreparedParams { - if v.IsNull() { - args = append(args, "") - } else { - str, err := v.ToString() - if err != nil { - terror.Log(err) - } - args = append(args, str) - } - } - return fmt.Sprintf(" [arguments: %s]", strings.Join(args, ", ")) -} - // GetSystemVar gets the string value of a system variable. func (s *SessionVars) GetSystemVar(name string) (string, bool) { val, ok := s.systems[name] @@ -688,6 +714,9 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { if isAutocommit { s.SetStatusFlag(mysql.ServerStatusInTrans, false) } + case MaxExecutionTime: + timeoutMS := tidbOptPositiveInt32(val, 0) + s.MaxExecutionTime = uint64(timeoutMS) case TiDBSkipUTF8Check: s.SkipUTF8Check = TiDBOptOn(val) case TiDBOptAggPushDown: @@ -699,7 +728,7 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { case TiDBOptCorrelationThreshold: s.CorrelationThreshold = tidbOptFloat64(val, DefOptCorrelationThreshold) case TiDBOptCorrelationExpFactor: - s.CorrelationExpFactor = tidbOptPositiveInt32(val, DefOptCorrelationExpFactor) + s.CorrelationExpFactor = int(tidbOptInt64(val, DefOptCorrelationExpFactor)) case TiDBIndexLookupConcurrency: s.IndexLookupConcurrency = tidbOptPositiveInt32(val, DefIndexLookupConcurrency) case TiDBIndexLookupJoinConcurrency: @@ -760,6 +789,8 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { atomic.StoreUint32(&ProcessGeneralLog, uint32(tidbOptPositiveInt32(val, DefTiDBGeneralLog))) case TiDBSlowLogThreshold: atomic.StoreUint64(&config.GetGlobalConfig().Log.SlowThreshold, uint64(tidbOptInt64(val, logutil.DefaultSlowThreshold))) + case TiDBRecordPlanInSlowLog: + atomic.StoreUint32(&config.GetGlobalConfig().Log.RecordPlanInSlowLog, uint32(tidbOptInt64(val, logutil.DefaultRecordPlanInSlowLog))) case TiDBDDLSlowOprThreshold: atomic.StoreUint32(&DDLSlowOprThreshold, uint32(tidbOptPositiveInt32(val, DefTiDBDDLSlowOprThreshold))) case TiDBQueryLogMaxLen: @@ -792,14 +823,21 @@ func (s *SessionVars) SetSystemVar(name string, val string) error { s.SlowQueryFile = val case TiDBEnableFastAnalyze: s.EnableFastAnalyze = TiDBOptOn(val) - case TiDBWaitTableSplitFinish: - s.WaitTableSplitFinish = TiDBOptOn(val) + case TiDBWaitSplitRegionFinish: + s.WaitSplitRegionFinish = TiDBOptOn(val) + case TiDBWaitSplitRegionTimeout: + s.WaitSplitRegionTimeout = uint64(tidbOptPositiveInt32(val, DefWaitSplitRegionTimeout)) + case TiDBExpensiveQueryTimeThreshold: + atomic.StoreUint64(&ExpensiveQueryTimeThreshold, uint64(tidbOptPositiveInt32(val, DefTiDBExpensiveQueryTimeThreshold))) case TiDBTxnMode: - if err := s.setTxnMode(val); err != nil { - return err - } + s.TxnMode = strings.ToUpper(val) case TiDBLowResolutionTSO: s.LowResolutionTSO = TiDBOptOn(val) + case TiDBAllowRemoveAutoInc: + s.AllowRemoveAutoInc = TiDBOptOn(val) + // It's a global variable, but it also wants to be cached in server. + case TiDBMaxDeltaSchemaCount: + SetMaxDeltaSchemaCount(tidbOptInt64(val, DefTiDBMaxDeltaSchemaCount)) } s.systems[name] = val return nil @@ -840,6 +878,7 @@ const ( TxnIsolation = "tx_isolation" TransactionIsolation = "transaction_isolation" TxnIsolationOneShot = "tx_isolation_one_shot" + MaxExecutionTime = "max_execution_time" ) // these variables are useless for TiDB, but still need to validate their values for some compatible issues. @@ -886,7 +925,7 @@ type Concurrency struct { // HashAggPartialConcurrency is the number of concurrent hash aggregation partial worker. HashAggPartialConcurrency int - // HashAggPartialConcurrency is the number of concurrent hash aggregation final worker. + // HashAggFinalConcurrency is the number of concurrent hash aggregation final worker. HashAggFinalConcurrency int // IndexSerialScanConcurrency is the number of concurrent index serial scan worker. @@ -949,16 +988,22 @@ const ( SlowLogTxnStartTSStr = "Txn_start_ts" // SlowLogUserStr is slow log field name. SlowLogUserStr = "User" + // SlowLogHostStr only for slow_query table usage. + SlowLogHostStr = "Host" // SlowLogConnIDStr is slow log field name. SlowLogConnIDStr = "Conn_ID" // SlowLogQueryTimeStr is slow log field name. SlowLogQueryTimeStr = "Query_time" + // SlowLogParseTimeStr is the parse sql time. + SlowLogParseTimeStr = "Parse_time" + // SlowLogCompileTimeStr is the compile plan time. + SlowLogCompileTimeStr = "Compile_time" // SlowLogDBStr is slow log field name. SlowLogDBStr = "DB" // SlowLogIsInternalStr is slow log field name. SlowLogIsInternalStr = "Is_internal" - // SlowLogIndexIDsStr is slow log field name. - SlowLogIndexIDsStr = "Index_ids" + // SlowLogIndexNamesStr is slow log field name. + SlowLogIndexNamesStr = "Index_names" // SlowLogDigestStr is slow log field name. SlowLogDigestStr = "Digest" // SlowLogQuerySQLStr is slow log field name. @@ -985,8 +1030,45 @@ const ( SlowLogCopWaitAddr = "Cop_wait_addr" // SlowLogMemMax is the max number bytes of memory used in this statement. SlowLogMemMax = "Mem_max" + // SlowLogPrepared is used to indicate whether this sql execute in prepare. + SlowLogPrepared = "Prepared" + // SlowLogHasMoreResults is used to indicate whether this sql has more following results. + SlowLogHasMoreResults = "Has_more_results" + // SlowLogSucc is used to indicate whether this sql execute successfully. + SlowLogSucc = "Succ" + // SlowLogPrevStmt is used to show the previous executed statement. + SlowLogPrevStmt = "Prev_stmt" + // SlowLogPlan is used to record the query plan. + SlowLogPlan = "Plan" + // SlowLogPlanPrefix is the prefix of the plan value. + SlowLogPlanPrefix = ast.TiDBDecodePlan + "('" + // SlowLogPlanSuffix is the suffix of the plan value. + SlowLogPlanSuffix = "')" + // SlowLogPrevStmtPrefix is the prefix of Prev_stmt in slow log file. + SlowLogPrevStmtPrefix = SlowLogPrevStmt + SlowLogSpaceMarkStr ) +// SlowQueryLogItems is a collection of items that should be included in the +// slow query log. +type SlowQueryLogItems struct { + TxnTS uint64 + SQL string + Digest string + TimeTotal time.Duration + TimeParse time.Duration + TimeCompile time.Duration + IndexNames string + StatsInfos map[string]uint64 + CopTasks *stmtctx.CopTasksDetails + ExecDetail execdetails.ExecDetails + MemMax int64 + Succ bool + Prepared bool + HasMoreResults bool + PrevStmt string + Plan string +} + // SlowLogFormat uses for formatting slow log. // The slow log output is like below: // # Time: 2019-04-28T15:24:04.309074+08:00 @@ -996,7 +1078,7 @@ const ( // # Query_time: 4.895492 // # Process_time: 0.161 Request_count: 1 Total_keys: 100001 Processed_keys: 100000 // # DB: test -// # Index_ids: [1,2] +// # Index_names: [t1.idx1,t2.idx2] // # Is_internal: false // # Digest: 42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772 // # Stats: t1:1,t2:2 @@ -1004,37 +1086,43 @@ const ( // # Cop_process: Avg_time: 1s P90_time: 2s Max_time: 3s Max_addr: 10.6.131.78 // # Cop_wait: Avg_time: 10ms P90_time: 20ms Max_time: 30ms Max_Addr: 10.6.131.79 // # Memory_max: 4096 +// # Succ: true +// # Prev_stmt: begin; // select * from t_slim; -func (s *SessionVars) SlowLogFormat(txnTS uint64, costTime time.Duration, execDetail execdetails.ExecDetails, indexIDs string, digest string, - statsInfos map[string]uint64, copTasks *stmtctx.CopTasksDetails, memMax int64, sql string) string { +func (s *SessionVars) SlowLogFormat(logItems *SlowQueryLogItems) string { var buf bytes.Buffer - execDetailStr := execDetail.String() - buf.WriteString(SlowLogRowPrefixStr + SlowLogTxnStartTSStr + SlowLogSpaceMarkStr + strconv.FormatUint(txnTS, 10) + "\n") + + writeSlowLogItem(&buf, SlowLogTxnStartTSStr, strconv.FormatUint(logItems.TxnTS, 10)) if s.User != nil { - buf.WriteString(SlowLogRowPrefixStr + SlowLogUserStr + SlowLogSpaceMarkStr + s.User.String() + "\n") + writeSlowLogItem(&buf, SlowLogUserStr, s.User.String()) } if s.ConnectionID != 0 { - buf.WriteString(SlowLogRowPrefixStr + SlowLogConnIDStr + SlowLogSpaceMarkStr + strconv.FormatUint(s.ConnectionID, 10) + "\n") + writeSlowLogItem(&buf, SlowLogConnIDStr, strconv.FormatUint(s.ConnectionID, 10)) } - buf.WriteString(SlowLogRowPrefixStr + SlowLogQueryTimeStr + SlowLogSpaceMarkStr + strconv.FormatFloat(costTime.Seconds(), 'f', -1, 64) + "\n") - if len(execDetailStr) > 0 { + writeSlowLogItem(&buf, SlowLogQueryTimeStr, strconv.FormatFloat(logItems.TimeTotal.Seconds(), 'f', -1, 64)) + writeSlowLogItem(&buf, SlowLogParseTimeStr, strconv.FormatFloat(logItems.TimeParse.Seconds(), 'f', -1, 64)) + writeSlowLogItem(&buf, SlowLogCompileTimeStr, strconv.FormatFloat(logItems.TimeCompile.Seconds(), 'f', -1, 64)) + + if execDetailStr := logItems.ExecDetail.String(); len(execDetailStr) > 0 { buf.WriteString(SlowLogRowPrefixStr + execDetailStr + "\n") } + if len(s.CurrentDB) > 0 { - buf.WriteString(SlowLogRowPrefixStr + SlowLogDBStr + SlowLogSpaceMarkStr + s.CurrentDB + "\n") + writeSlowLogItem(&buf, SlowLogDBStr, s.CurrentDB) } - if len(indexIDs) > 0 { - buf.WriteString(SlowLogRowPrefixStr + SlowLogIndexIDsStr + SlowLogSpaceMarkStr + indexIDs + "\n") + if len(logItems.IndexNames) > 0 { + writeSlowLogItem(&buf, SlowLogIndexNamesStr, logItems.IndexNames) } - buf.WriteString(SlowLogRowPrefixStr + SlowLogIsInternalStr + SlowLogSpaceMarkStr + strconv.FormatBool(s.InRestrictedSQL) + "\n") - if len(digest) > 0 { - buf.WriteString(SlowLogRowPrefixStr + SlowLogDigestStr + SlowLogSpaceMarkStr + digest + "\n") + + writeSlowLogItem(&buf, SlowLogIsInternalStr, strconv.FormatBool(s.InRestrictedSQL)) + if len(logItems.Digest) > 0 { + writeSlowLogItem(&buf, SlowLogDigestStr, logItems.Digest) } - if len(statsInfos) > 0 { + if len(logItems.StatsInfos) > 0 { buf.WriteString(SlowLogRowPrefixStr + SlowLogStatsInfoStr + SlowLogSpaceMarkStr) firstComma := false vStr := "" - for k, v := range statsInfos { + for k, v := range logItems.StatsInfos { if v == 0 { vStr = "pseudo" } else { @@ -1050,28 +1138,54 @@ func (s *SessionVars) SlowLogFormat(txnTS uint64, costTime time.Duration, execDe } buf.WriteString("\n") } - if copTasks != nil { - buf.WriteString(SlowLogRowPrefixStr + SlowLogNumCopTasksStr + SlowLogSpaceMarkStr + strconv.FormatInt(int64(copTasks.NumCopTasks), 10) + "\n") - buf.WriteString(SlowLogRowPrefixStr + fmt.Sprintf("%v%v%v %v%v%v %v%v%v %v%v%v", - SlowLogCopProcAvg, SlowLogSpaceMarkStr, copTasks.AvgProcessTime.Seconds(), - SlowLogCopProcP90, SlowLogSpaceMarkStr, copTasks.P90ProcessTime.Seconds(), - SlowLogCopProcMax, SlowLogSpaceMarkStr, copTasks.MaxProcessTime.Seconds(), - SlowLogCopProcAddr, SlowLogSpaceMarkStr, copTasks.MaxProcessAddress) + "\n") - buf.WriteString(SlowLogRowPrefixStr + fmt.Sprintf("%v%v%v %v%v%v %v%v%v %v%v%v", - SlowLogCopWaitAvg, SlowLogSpaceMarkStr, copTasks.AvgWaitTime.Seconds(), - SlowLogCopWaitP90, SlowLogSpaceMarkStr, copTasks.P90WaitTime.Seconds(), - SlowLogCopWaitMax, SlowLogSpaceMarkStr, copTasks.MaxWaitTime.Seconds(), - SlowLogCopWaitAddr, SlowLogSpaceMarkStr, copTasks.MaxWaitAddress) + "\n") + if logItems.CopTasks != nil { + writeSlowLogItem(&buf, SlowLogNumCopTasksStr, strconv.FormatInt(int64(logItems.CopTasks.NumCopTasks), 10)) + if logItems.CopTasks.NumCopTasks > 0 { + if logItems.CopTasks.NumCopTasks == 1 { + buf.WriteString(SlowLogRowPrefixStr + fmt.Sprintf("%v%v%v %v%v%v", + SlowLogCopProcAvg, SlowLogSpaceMarkStr, logItems.CopTasks.AvgProcessTime.Seconds(), + SlowLogCopProcAddr, SlowLogSpaceMarkStr, logItems.CopTasks.MaxProcessAddress) + "\n") + buf.WriteString(SlowLogRowPrefixStr + fmt.Sprintf("%v%v%v %v%v%v", + SlowLogCopWaitAvg, SlowLogSpaceMarkStr, logItems.CopTasks.AvgWaitTime.Seconds(), + SlowLogCopWaitAddr, SlowLogSpaceMarkStr, logItems.CopTasks.MaxWaitAddress) + "\n") + + } else { + buf.WriteString(SlowLogRowPrefixStr + fmt.Sprintf("%v%v%v %v%v%v %v%v%v %v%v%v", + SlowLogCopProcAvg, SlowLogSpaceMarkStr, logItems.CopTasks.AvgProcessTime.Seconds(), + SlowLogCopProcP90, SlowLogSpaceMarkStr, logItems.CopTasks.P90ProcessTime.Seconds(), + SlowLogCopProcMax, SlowLogSpaceMarkStr, logItems.CopTasks.MaxProcessTime.Seconds(), + SlowLogCopProcAddr, SlowLogSpaceMarkStr, logItems.CopTasks.MaxProcessAddress) + "\n") + buf.WriteString(SlowLogRowPrefixStr + fmt.Sprintf("%v%v%v %v%v%v %v%v%v %v%v%v", + SlowLogCopWaitAvg, SlowLogSpaceMarkStr, logItems.CopTasks.AvgWaitTime.Seconds(), + SlowLogCopWaitP90, SlowLogSpaceMarkStr, logItems.CopTasks.P90WaitTime.Seconds(), + SlowLogCopWaitMax, SlowLogSpaceMarkStr, logItems.CopTasks.MaxWaitTime.Seconds(), + SlowLogCopWaitAddr, SlowLogSpaceMarkStr, logItems.CopTasks.MaxWaitAddress) + "\n") + } + } + } + if logItems.MemMax > 0 { + writeSlowLogItem(&buf, SlowLogMemMax, strconv.FormatInt(logItems.MemMax, 10)) } - if memMax > 0 { - buf.WriteString(SlowLogRowPrefixStr + SlowLogMemMax + SlowLogSpaceMarkStr + strconv.FormatInt(memMax, 10) + "\n") + + writeSlowLogItem(&buf, SlowLogPrepared, strconv.FormatBool(logItems.Prepared)) + writeSlowLogItem(&buf, SlowLogHasMoreResults, strconv.FormatBool(logItems.HasMoreResults)) + writeSlowLogItem(&buf, SlowLogSucc, strconv.FormatBool(logItems.Succ)) + if len(logItems.Plan) != 0 { + writeSlowLogItem(&buf, SlowLogPlan, logItems.Plan) } - if len(sql) == 0 { - sql = ";" + + if logItems.PrevStmt != "" { + writeSlowLogItem(&buf, SlowLogPrevStmt, logItems.PrevStmt) } - buf.WriteString(sql) - if sql[len(sql)-1] != ';' { + + buf.WriteString(logItems.SQL) + if len(logItems.SQL) == 0 || logItems.SQL[len(logItems.SQL)-1] != ';' { buf.WriteString(";") } return buf.String() } + +// writeSlowLogItem writes a slow log item in the form of: "# ${key}:${value}" +func writeSlowLogItem(buf *bytes.Buffer, key, value string) { + buf.WriteString(SlowLogRowPrefixStr + key + SlowLogSpaceMarkStr + value + "\n") +} diff --git a/sessionctx/variable/session_test.go b/sessionctx/variable/session_test.go index c8d78344e165c..c10fd18a85d24 100644 --- a/sessionctx/variable/session_test.go +++ b/sessionctx/variable/session_test.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/parser" "github.com/pingcap/parser/auth" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/execdetails" "github.com/pingcap/tidb/util/mock" ) @@ -123,9 +124,11 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { # User: root@192.168.0.1 # Conn_ID: 1 # Query_time: 1 +# Parse_time: 0.00000001 +# Compile_time: 0.00000001 # Process_time: 2 Wait_time: 60 Backoff_time: 0.001 Request_count: 2 Total_keys: 10000 Process_keys: 20001 # DB: test -# Index_ids: [1,2] +# Index_names: [t1:a,t2:b] # Is_internal: true # Digest: 42a1c8aae6f133e934d4bf0147491709a8812ea05ff8819ec522780fe657b772 # Stats: t1:pseudo @@ -133,9 +136,27 @@ func (*testSessionSuite) TestSlowLogFormat(c *C) { # Cop_proc_avg: 1 Cop_proc_p90: 2 Cop_proc_max: 3 Cop_proc_addr: 10.6.131.78 # Cop_wait_avg: 0.01 Cop_wait_p90: 0.02 Cop_wait_max: 0.03 Cop_wait_addr: 10.6.131.79 # Mem_max: 2333 +# Prepared: true +# Has_more_results: true +# Succ: true select * from t;` sql := "select * from t" digest := parser.DigestHash(sql) - logString := seVar.SlowLogFormat(txnTS, costTime, execDetail, "[1,2]", digest, statsInfos, copTasks, memMax, sql) + logString := seVar.SlowLogFormat(&variable.SlowQueryLogItems{ + TxnTS: txnTS, + SQL: sql, + Digest: digest, + TimeTotal: costTime, + TimeParse: time.Duration(10), + TimeCompile: time.Duration(10), + IndexNames: "[t1:a,t2:b]", + StatsInfos: statsInfos, + CopTasks: copTasks, + ExecDetail: execDetail, + MemMax: memMax, + Prepared: true, + HasMoreResults: true, + Succ: true, + }) c.Assert(logString, Equals, resultString) } diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 4d30c6e383df1..a6996b154349b 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -117,6 +117,14 @@ func BoolToIntStr(b bool) string { return "0" } +// BoolToInt32 converts bool to int32 +func BoolToInt32(b bool) int32 { + if b { + return 1 + } + return 0 +} + // we only support MySQL now var defaultSysVars = []*SysVar{ {ScopeGlobal, "gtid_mode", "OFF"}, @@ -174,6 +182,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "innodb_max_undo_log_size", ""}, {ScopeGlobal | ScopeSession, "range_alloc_block_size", "4096"}, {ScopeGlobal, ConnectTimeout, "10"}, + {ScopeGlobal | ScopeSession, MaxExecutionTime, "0"}, {ScopeGlobal | ScopeSession, "collation_server", mysql.DefaultCollationName}, {ScopeNone, "have_rtree_keys", "YES"}, {ScopeGlobal, "innodb_old_blocks_pct", "37"}, @@ -628,6 +637,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, "innodb_online_alter_log_max_size", "134217728"}, {ScopeSession, WarningCount, "0"}, {ScopeSession, ErrorCount, "0"}, + {ScopeGlobal, "thread_pool_size", "16"}, /* TiDB specific variables */ {ScopeSession, TiDBSnapshot, ""}, {ScopeSession, TiDBOptAggPushDown, BoolToIntStr(DefOptAggPushDown)}, @@ -675,7 +685,7 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal | ScopeSession, TiDBRetryLimit, strconv.Itoa(DefTiDBRetryLimit)}, {ScopeGlobal | ScopeSession, TiDBDisableTxnAutoRetry, BoolToIntStr(DefTiDBDisableTxnAutoRetry)}, {ScopeGlobal | ScopeSession, TiDBConstraintCheckInPlace, BoolToIntStr(DefTiDBConstraintCheckInPlace)}, - {ScopeSession, TiDBTxnMode, DefTiDBTxnMode}, + {ScopeGlobal | ScopeSession, TiDBTxnMode, DefTiDBTxnMode}, {ScopeSession, TiDBOptimizerSelectivityLevel, strconv.Itoa(DefTiDBOptimizerSelectivityLevel)}, {ScopeGlobal | ScopeSession, TiDBEnableWindowFunction, BoolToIntStr(DefEnableWindowFunction)}, {ScopeGlobal | ScopeSession, TiDBEnableFastAnalyze, BoolToIntStr(DefTiDBUseFastAnalyze)}, @@ -683,6 +693,7 @@ var defaultSysVars = []*SysVar{ /* The following variable is defined as session scope but is actually server scope. */ {ScopeSession, TiDBGeneralLog, strconv.Itoa(DefTiDBGeneralLog)}, {ScopeSession, TiDBSlowLogThreshold, strconv.Itoa(logutil.DefaultSlowThreshold)}, + {ScopeSession, TiDBRecordPlanInSlowLog, strconv.Itoa(logutil.DefaultRecordPlanInSlowLog)}, {ScopeSession, TiDBDDLSlowOprThreshold, strconv.Itoa(DefTiDBDDLSlowOprThreshold)}, {ScopeSession, TiDBQueryLogMaxLen, strconv.Itoa(logutil.DefaultQueryLogMaxLen)}, {ScopeSession, TiDBConfig, ""}, @@ -690,13 +701,19 @@ var defaultSysVars = []*SysVar{ {ScopeGlobal, TiDBDDLReorgBatchSize, strconv.Itoa(DefTiDBDDLReorgBatchSize)}, {ScopeGlobal, TiDBDDLErrorCountLimit, strconv.Itoa(DefTiDBDDLErrorCountLimit)}, {ScopeSession, TiDBDDLReorgPriority, "PRIORITY_LOW"}, + {ScopeGlobal, TiDBMaxDeltaSchemaCount, strconv.Itoa(DefTiDBMaxDeltaSchemaCount)}, {ScopeSession, TiDBForcePriority, mysql.Priority2Str[DefTiDBForcePriority]}, {ScopeSession, TiDBEnableRadixJoin, BoolToIntStr(DefTiDBUseRadixJoin)}, {ScopeGlobal | ScopeSession, TiDBOptJoinReorderThreshold, strconv.Itoa(DefTiDBOptJoinReorderThreshold)}, {ScopeSession, TiDBCheckMb4ValueInUTF8, BoolToIntStr(config.GetGlobalConfig().CheckMb4ValueInUTF8)}, {ScopeSession, TiDBSlowQueryFile, ""}, - {ScopeSession, TiDBWaitTableSplitFinish, BoolToIntStr(DefTiDBWaitTableSplitFinish)}, + {ScopeGlobal, TiDBScatterRegion, BoolToIntStr(DefTiDBScatterRegion)}, + {ScopeSession, TiDBWaitSplitRegionFinish, BoolToIntStr(DefTiDBWaitSplitRegionFinish)}, + {ScopeSession, TiDBWaitSplitRegionTimeout, strconv.Itoa(DefWaitSplitRegionTimeout)}, {ScopeSession, TiDBLowResolutionTSO, "0"}, + {ScopeSession, TiDBExpensiveQueryTimeThreshold, strconv.Itoa(DefTiDBExpensiveQueryTimeThreshold)}, + {ScopeSession, TiDBAllowRemoveAutoInc, BoolToIntStr(DefTiDBAllowRemoveAutoInc)}, + {ScopeGlobal | ScopeSession, TiDBEnableStmtSummary, "0"}, } // SynonymsSysVariables is synonyms of system variables. @@ -933,6 +950,8 @@ const ( InnodbTableLocks = "innodb_table_locks" // InnodbStatusOutput is the name for 'innodb_status_output' system variable. InnodbStatusOutput = "innodb_status_output" + // ThreadPoolSize is the name of 'thread_pool_size' variable. + ThreadPoolSize = "thread_pool_size" ) // GlobalVarAccessor is the interface for accessing global scope system and status variables. diff --git a/sessionctx/variable/sysvar_test.go b/sessionctx/variable/sysvar_test.go index f4b2e50db44d2..a084c39414014 100644 --- a/sessionctx/variable/sysvar_test.go +++ b/sessionctx/variable/sysvar_test.go @@ -61,3 +61,8 @@ func (*testSysVarSuite) TestTxnMode(c *C) { err = seVar.setTxnMode("something else") c.Assert(err, NotNil) } + +func (*testSysVarSuite) TestBoolToInt32(c *C) { + c.Assert(BoolToInt32(true), Equals, int32(1)) + c.Assert(BoolToInt32(false), Equals, int32(0)) +} diff --git a/sessionctx/variable/tidb_vars.go b/sessionctx/variable/tidb_vars.go index 9f4941f6a7fc3..3d22e952775d3 100644 --- a/sessionctx/variable/tidb_vars.go +++ b/sessionctx/variable/tidb_vars.go @@ -29,7 +29,8 @@ import ( 5. Update the `NewSessionVars` function to set the field to its default value. 6. Update the `variable.SetSessionSystemVar` function to use the new value when SET statement is executed. 7. If it is a global variable, add it in `session.loadCommonGlobalVarsSQL`. - 8. Use this variable to control the behavior in code. + 8. Update ValidateSetSystemVar if the variable's value need to be validated. + 9. Use this variable to control the behavior in code. */ // TiDB system variable names that only in session scope. @@ -107,6 +108,9 @@ const ( // tidb_slow_log_threshold is used to set the slow log threshold in the server. TiDBSlowLogThreshold = "tidb_slow_log_threshold" + // tidb_record_plan_in_slow_log is used to log the plan of the slow query. + TiDBRecordPlanInSlowLog = "tidb_record_plan_in_slow_log" + // tidb_query_log_max_len is used to set the max length of the query in the log. TiDBQueryLogMaxLen = "tidb_query_log_max_len" @@ -141,6 +145,9 @@ const ( // TiDBLowResolutionTSO is used for reading data with low resolution TSO which is updated once every two seconds TiDBLowResolutionTSO = "tidb_low_resolution_tso" + + // TiDBAllowRemoveAutoInc indicates whether a user can drop the auto_increment column attribute or not. + TiDBAllowRemoveAutoInc = "tidb_allow_remove_auto_inc" ) // TiDB system variable names that both in session and global scope. @@ -225,11 +232,11 @@ const ( // tidb_backoff_lock_fast is used for tikv backoff base time in milliseconds. TiDBBackoffLockFast = "tidb_backoff_lock_fast" - // tidb_back_off_weight is used to control the max back off time in TiDB. + // tidb_backoff_weight is used to control the max back off time in TiDB. // The default maximum back off time is a small value. // BackOffWeight could multiply it to let the user adjust the maximum time for retrying. // Only positive integers can be accepted, which means that the maximum back off time can only grow. - TiDBBackOffWeight = "tidb_back_off_weight" + TiDBBackOffWeight = "tidb_backoff_weight" // tidb_ddl_reorg_worker_cnt defines the count of ddl reorg workers. TiDBDDLReorgWorkerCount = "tidb_ddl_reorg_worker_cnt" @@ -244,8 +251,18 @@ const ( // It can be: PRIORITY_LOW, PRIORITY_NORMAL, PRIORITY_HIGH TiDBDDLReorgPriority = "tidb_ddl_reorg_priority" - // TiDBWaitTableSplitFinish defines the create table pre-split behaviour is sync or async. - TiDBWaitTableSplitFinish = "tidb_wait_table_split_finish" + // tidb_max_delta_schema_count defines the max length of deltaSchemaInfos. + // deltaSchemaInfos is a queue that maintains the history of schema changes. + TiDBMaxDeltaSchemaCount = "tidb_max_delta_schema_count" + + // tidb_scatter_region will scatter the regions for DDLs when it is ON. + TiDBScatterRegion = "tidb_scatter_region" + + // TiDBWaitSplitRegionFinish defines the split region behaviour is sync or async. + TiDBWaitSplitRegionFinish = "tidb_wait_split_region_finish" + + // TiDBWaitSplitRegionTimeout uses to set the split and scatter region back off time. + TiDBWaitSplitRegionTimeout = "tidb_wait_split_region_timeout" // tidb_force_priority defines the operations priority of all statements. // It can be "NO_PRIORITY", "LOW_PRIORITY", "HIGH_PRIORITY", "DELAYED" @@ -271,66 +288,77 @@ const ( // TiDBEnableFastAnalyze indicates to use fast analyze. TiDBEnableFastAnalyze = "tidb_enable_fast_analyze" + + // TiDBExpensiveQueryTimeThreshold indicates the time threshold of expensive query. + TiDBExpensiveQueryTimeThreshold = "tidb_expensive_query_time_threshold" + + // TiDBEnableStmtSummary indicates whether the statement summary is enabled. + TiDBEnableStmtSummary = "tidb_enable_stmt_summary" ) // Default TiDB system variable values. const ( - DefHostname = "localhost" - DefIndexLookupConcurrency = 4 - DefIndexLookupJoinConcurrency = 4 - DefIndexSerialScanConcurrency = 1 - DefIndexJoinBatchSize = 25000 - DefIndexLookupSize = 20000 - DefDistSQLScanConcurrency = 15 - DefBuildStatsConcurrency = 4 - DefAutoAnalyzeRatio = 0.5 - DefAutoAnalyzeStartTime = "00:00 +0000" - DefAutoAnalyzeEndTime = "23:59 +0000" - DefChecksumTableConcurrency = 4 - DefSkipUTF8Check = false - DefOptAggPushDown = false - DefOptWriteRowID = false - DefOptCorrelationThreshold = 0.9 - DefOptCorrelationExpFactor = 0 - DefOptInSubqToJoinAndAgg = true - DefBatchInsert = false - DefBatchDelete = false - DefBatchCommit = false - DefCurretTS = 0 - DefInitChunkSize = 32 - DefMaxChunkSize = 1024 - DefDMLBatchSize = 20000 - DefMaxPreparedStmtCount = -1 - DefWaitTimeout = 0 - DefTiDBMemQuotaHashJoin = 32 << 30 // 32GB. - DefTiDBMemQuotaMergeJoin = 32 << 30 // 32GB. - DefTiDBMemQuotaSort = 32 << 30 // 32GB. - DefTiDBMemQuotaTopn = 32 << 30 // 32GB. - DefTiDBMemQuotaIndexLookupReader = 32 << 30 // 32GB. - DefTiDBMemQuotaIndexLookupJoin = 32 << 30 // 32GB. - DefTiDBMemQuotaNestedLoopApply = 32 << 30 // 32GB. - DefTiDBMemQuotaDistSQL = 32 << 30 // 32GB. - DefTiDBGeneralLog = 0 - DefTiDBRetryLimit = 10 - DefTiDBDisableTxnAutoRetry = true - DefTiDBConstraintCheckInPlace = false - DefTiDBHashJoinConcurrency = 5 - DefTiDBProjectionConcurrency = 4 - DefTiDBOptimizerSelectivityLevel = 0 - DefTiDBTxnMode = "" - DefTiDBDDLReorgWorkerCount = 16 - DefTiDBDDLReorgBatchSize = 1024 - DefTiDBDDLErrorCountLimit = 512 - DefTiDBHashAggPartialConcurrency = 4 - DefTiDBHashAggFinalConcurrency = 4 - DefTiDBForcePriority = mysql.NoPriority - DefTiDBUseRadixJoin = false - DefEnableWindowFunction = false - DefTiDBOptJoinReorderThreshold = 0 - DefTiDBDDLSlowOprThreshold = 300 - DefTiDBUseFastAnalyze = false - DefTiDBSkipIsolationLevelCheck = false - DefTiDBWaitTableSplitFinish = false + DefHostname = "localhost" + DefIndexLookupConcurrency = 4 + DefIndexLookupJoinConcurrency = 4 + DefIndexSerialScanConcurrency = 1 + DefIndexJoinBatchSize = 25000 + DefIndexLookupSize = 20000 + DefDistSQLScanConcurrency = 15 + DefBuildStatsConcurrency = 4 + DefAutoAnalyzeRatio = 0.5 + DefAutoAnalyzeStartTime = "00:00 +0000" + DefAutoAnalyzeEndTime = "23:59 +0000" + DefChecksumTableConcurrency = 4 + DefSkipUTF8Check = false + DefOptAggPushDown = false + DefOptWriteRowID = false + DefOptCorrelationThreshold = 0.9 + DefOptCorrelationExpFactor = 1 + DefOptInSubqToJoinAndAgg = true + DefBatchInsert = false + DefBatchDelete = false + DefBatchCommit = false + DefCurretTS = 0 + DefInitChunkSize = 32 + DefMaxChunkSize = 1024 + DefDMLBatchSize = 20000 + DefMaxPreparedStmtCount = -1 + DefWaitTimeout = 0 + DefTiDBMemQuotaHashJoin = 32 << 30 // 32GB. + DefTiDBMemQuotaMergeJoin = 32 << 30 // 32GB. + DefTiDBMemQuotaSort = 32 << 30 // 32GB. + DefTiDBMemQuotaTopn = 32 << 30 // 32GB. + DefTiDBMemQuotaIndexLookupReader = 32 << 30 // 32GB. + DefTiDBMemQuotaIndexLookupJoin = 32 << 30 // 32GB. + DefTiDBMemQuotaNestedLoopApply = 32 << 30 // 32GB. + DefTiDBMemQuotaDistSQL = 32 << 30 // 32GB. + DefTiDBGeneralLog = 0 + DefTiDBRetryLimit = 10 + DefTiDBDisableTxnAutoRetry = true + DefTiDBConstraintCheckInPlace = false + DefTiDBHashJoinConcurrency = 5 + DefTiDBProjectionConcurrency = 4 + DefTiDBOptimizerSelectivityLevel = 0 + DefTiDBTxnMode = "" + DefTiDBDDLReorgWorkerCount = 4 + DefTiDBDDLReorgBatchSize = 256 + DefTiDBDDLErrorCountLimit = 512 + DefTiDBMaxDeltaSchemaCount = 1024 + DefTiDBHashAggPartialConcurrency = 4 + DefTiDBHashAggFinalConcurrency = 4 + DefTiDBForcePriority = mysql.NoPriority + DefTiDBUseRadixJoin = false + DefEnableWindowFunction = true + DefTiDBOptJoinReorderThreshold = 0 + DefTiDBDDLSlowOprThreshold = 300 + DefTiDBUseFastAnalyze = false + DefTiDBSkipIsolationLevelCheck = false + DefTiDBScatterRegion = false + DefTiDBWaitSplitRegionFinish = true + DefTiDBExpensiveQueryTimeThreshold = 60 // 60s + DefWaitSplitRegionTimeout = 300 // 300s + DefTiDBAllowRemoveAutoInc = false ) // Process global variables. @@ -340,12 +368,15 @@ var ( maxDDLReorgWorkerCount int32 = 128 ddlReorgBatchSize int32 = DefTiDBDDLReorgBatchSize ddlErrorCountlimit int64 = DefTiDBDDLErrorCountLimit + maxDeltaSchemaCount int64 = DefTiDBMaxDeltaSchemaCount // Export for testing. MaxDDLReorgBatchSize int32 = 10240 MinDDLReorgBatchSize int32 = 32 // DDLSlowOprThreshold is the threshold for ddl slow operations, uint is millisecond. - DDLSlowOprThreshold uint32 = DefTiDBDDLSlowOprThreshold - ForcePriority = int32(DefTiDBForcePriority) - ServerHostname, _ = os.Hostname() - MaxOfMaxAllowedPacket uint64 = 1073741824 + DDLSlowOprThreshold uint32 = DefTiDBDDLSlowOprThreshold + ForcePriority = int32(DefTiDBForcePriority) + ServerHostname, _ = os.Hostname() + MaxOfMaxAllowedPacket uint64 = 1073741824 + ExpensiveQueryTimeThreshold uint64 = DefTiDBExpensiveQueryTimeThreshold + MinExpensiveQueryTimeThreshold uint64 = 10 //10s ) diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 2180b61391032..778c115cc017b 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -23,6 +23,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/types" @@ -73,6 +74,16 @@ func GetDDLErrorCountLimit() int64 { return atomic.LoadInt64(&ddlErrorCountlimit) } +// SetMaxDeltaSchemaCount sets maxDeltaSchemaCount size. +func SetMaxDeltaSchemaCount(cnt int64) { + atomic.StoreInt64(&maxDeltaSchemaCount, cnt) +} + +// GetMaxDeltaSchemaCount gets maxDeltaSchemaCount size. +func GetMaxDeltaSchemaCount() int64 { + return atomic.LoadInt64(&maxDeltaSchemaCount) +} + // GetSessionSystemVar gets a system variable. // If it is a session only variable, use the default value defined in code. // Returns error if there is no such variable. @@ -103,6 +114,8 @@ func GetSessionOnlySysVars(s *SessionVars, key string) (string, bool, error) { return fmt.Sprintf("%d", s.TxnCtx.StartTS), true, nil case TiDBGeneralLog: return fmt.Sprintf("%d", atomic.LoadUint32(&ProcessGeneralLog)), true, nil + case TiDBExpensiveQueryTimeThreshold: + return fmt.Sprintf("%d", atomic.LoadUint64(&ExpensiveQueryTimeThreshold)), true, nil case TiDBConfig: conf := config.GetGlobalConfig() j, err := json.MarshalIndent(conf, "", "\t") @@ -114,6 +127,8 @@ func GetSessionOnlySysVars(s *SessionVars, key string) (string, bool, error) { return mysql.Priority2Str[mysql.PriorityEnum(atomic.LoadInt32(&ForcePriority))], true, nil case TiDBSlowLogThreshold: return strconv.FormatUint(atomic.LoadUint64(&config.GetGlobalConfig().Log.SlowThreshold), 10), true, nil + case TiDBRecordPlanInSlowLog: + return strconv.FormatUint(uint64(atomic.LoadUint32(&config.GetGlobalConfig().Log.RecordPlanInSlowLog)), 10), true, nil case TiDBDDLSlowOprThreshold: return strconv.FormatUint(uint64(atomic.LoadUint32(&DDLSlowOprThreshold)), 10), true, nil case TiDBQueryLogMaxLen: @@ -318,6 +333,8 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, return checkUInt64SystemVar(name, value, 0, 4294967295, vars) case OldPasswords: return checkUInt64SystemVar(name, value, 0, 2, vars) + case TiDBMaxDeltaSchemaCount: + return checkInt64SystemVar(name, value, 100, 16384, vars) case SessionTrackGtids: if strings.EqualFold(value, "OFF") || value == "0" { return "OFF", nil @@ -379,7 +396,7 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, SQLWarnings, UniqueChecks, OldAlterTable, LogBinTrustFunctionCreators, SQLBigSelects, BinlogDirectNonTransactionalUpdates, SQLQuoteShowCreate, AutomaticSpPrivileges, RelayLogPurge, SQLAutoIsNull, QueryCacheWlockInvalidate, ValidatePasswordCheckUserName, - SuperReadOnly, BinlogOrderCommits, MasterVerifyChecksum, BinlogRowQueryLogEvents, LogSlowSlaveStatements, + SuperReadOnly, BinlogOrderCommits, MasterVerifyChecksum, BinlogRowQueryLogEvents, LogSlowSlaveStatements, TiDBRecordPlanInSlowLog, LogSlowAdminStatements, LogQueriesNotUsingIndexes, Profiling: if strings.EqualFold(value, "ON") { return "1", nil @@ -417,11 +434,15 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, TiDBOptInSubqToJoinAndAgg, TiDBEnableFastAnalyze, TiDBBatchInsert, TiDBDisableTxnAutoRetry, TiDBEnableStreaming, TiDBBatchDelete, TiDBBatchCommit, TiDBEnableCascadesPlanner, TiDBEnableWindowFunction, - TiDBCheckMb4ValueInUTF8, TiDBLowResolutionTSO: + TiDBCheckMb4ValueInUTF8, TiDBLowResolutionTSO, TiDBScatterRegion: if strings.EqualFold(value, "ON") || value == "1" || strings.EqualFold(value, "OFF") || value == "0" { return value, nil } return value, ErrWrongValueForVar.GenWithStackByArgs(name, value) + case MaxExecutionTime: + return checkUInt64SystemVar(name, value, 0, math.MaxUint64, vars) + case ThreadPoolSize: + return checkUInt64SystemVar(name, value, 1, 64, vars) case TiDBEnableTablePartition: switch { case strings.EqualFold(value, "ON") || value == "1": @@ -436,6 +457,8 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, return checkUInt64SystemVar(name, value, uint64(MinDDLReorgBatchSize), uint64(MaxDDLReorgBatchSize), vars) case TiDBDDLErrorCountLimit: return checkUInt64SystemVar(name, value, uint64(0), math.MaxInt64, vars) + case TiDBExpensiveQueryTimeThreshold: + return checkUInt64SystemVar(name, value, MinExpensiveQueryTimeThreshold, math.MaxInt64, vars) case TiDBIndexLookupConcurrency, TiDBIndexLookupJoinConcurrency, TiDBIndexJoinBatchSize, TiDBIndexLookupSize, TiDBHashJoinConcurrency, @@ -548,6 +571,38 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, if v < 0 || v >= 64 { return value, errors.Errorf("tidb_join_order_algo_threshold(%d) cannot be smaller than 0 or larger than 63", v) } + case TiDBWaitSplitRegionTimeout: + v, err := strconv.Atoi(value) + if err != nil { + return value, ErrWrongTypeForVar.GenWithStackByArgs(name) + } + if v <= 0 { + return value, errors.Errorf("tidb_wait_split_region_timeout(%d) cannot be smaller than 1", v) + } + case TiDBTxnMode: + switch strings.ToUpper(value) { + case ast.Pessimistic, ast.Optimistic, "": + default: + return value, ErrWrongValueForVar.GenWithStackByArgs(TiDBTxnMode, value) + } + case TiDBAllowRemoveAutoInc: + switch { + case strings.EqualFold(value, "ON") || value == "1": + return "on", nil + case strings.EqualFold(value, "OFF") || value == "0": + return "off", nil + } + return value, ErrWrongValueForVar.GenWithStackByArgs(name, value) + case TiDBEnableStmtSummary: + switch { + case strings.EqualFold(value, "ON") || value == "1": + return "1", nil + case strings.EqualFold(value, "OFF") || value == "0": + return "0", nil + case value == "": + return "", nil + } + return value, ErrWrongValueForVar.GenWithStackByArgs(name, value) } return value, nil } diff --git a/sessionctx/variable/varsutil_test.go b/sessionctx/variable/varsutil_test.go index f4af0357a682b..ec97078c4d849 100644 --- a/sessionctx/variable/varsutil_test.go +++ b/sessionctx/variable/varsutil_test.go @@ -295,3 +295,67 @@ func (s *testVarsutilSuite) TestVarsutil(c *C) { c.Assert(val, Equals, "0") c.Assert(v.CorrelationThreshold, Equals, float64(0)) } + +func (s *testVarsutilSuite) TestValidate(c *C) { + v := NewSessionVars() + v.GlobalVarsAccessor = NewMockGlobalAccessor() + v.TimeZone = time.UTC + + tests := []struct { + key string + value string + error bool + }{ + {TiDBAutoAnalyzeStartTime, "15:04", false}, + {TiDBAutoAnalyzeStartTime, "15:04 -0700", false}, + {DelayKeyWrite, "ON", false}, + {DelayKeyWrite, "OFF", false}, + {DelayKeyWrite, "ALL", false}, + {DelayKeyWrite, "3", true}, + {ForeignKeyChecks, "3", true}, + {MaxSpRecursionDepth, "256", false}, + {SessionTrackGtids, "OFF", false}, + {SessionTrackGtids, "OWN_GTID", false}, + {SessionTrackGtids, "ALL_GTIDS", false}, + {SessionTrackGtids, "ON", true}, + {EnforceGtidConsistency, "OFF", false}, + {EnforceGtidConsistency, "ON", false}, + {EnforceGtidConsistency, "WARN", false}, + {QueryCacheType, "OFF", false}, + {QueryCacheType, "ON", false}, + {QueryCacheType, "DEMAND", false}, + {QueryCacheType, "3", true}, + {SecureAuth, "1", false}, + {SecureAuth, "3", true}, + {MyISAMUseMmap, "ON", false}, + {MyISAMUseMmap, "OFF", false}, + {TiDBEnableTablePartition, "ON", false}, + {TiDBEnableTablePartition, "OFF", false}, + {TiDBEnableTablePartition, "AUTO", false}, + {TiDBEnableTablePartition, "UN", true}, + {TiDBOptCorrelationExpFactor, "a", true}, + {TiDBOptCorrelationExpFactor, "-10", true}, + {TiDBOptCorrelationThreshold, "a", true}, + {TiDBOptCorrelationThreshold, "-2", true}, + {TxnIsolation, "READ-UNCOMMITTED", true}, + {TiDBInitChunkSize, "a", true}, + {TiDBInitChunkSize, "-1", true}, + {TiDBMaxChunkSize, "a", true}, + {TiDBMaxChunkSize, "-1", true}, + {TiDBOptJoinReorderThreshold, "a", true}, + {TiDBOptJoinReorderThreshold, "-1", true}, + {TiDBTxnMode, "invalid", true}, + {TiDBTxnMode, "pessimistic", false}, + {TiDBTxnMode, "optimistic", false}, + {TiDBTxnMode, "", false}, + } + + for _, t := range tests { + _, err := ValidateSetSystemVar(v, t.key, t.value) + if t.error { + c.Assert(err, NotNil, Commentf("%v got err=%v", t, err)) + } else { + c.Assert(err, IsNil, Commentf("%v got err=%v", t, err)) + } + } +} diff --git a/statistics/cmsketch.go b/statistics/cmsketch.go index 747337c1d50ee..9473ae616d427 100644 --- a/statistics/cmsketch.go +++ b/statistics/cmsketch.go @@ -23,8 +23,8 @@ import ( "github.com/cznic/sortutil" "github.com/pingcap/errors" "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/sqlexec" "github.com/pingcap/tipb/go-tipb" @@ -61,14 +61,18 @@ func NewCMSketch(d, w int32) *CMSketch { return &CMSketch{depth: d, width: w, table: tbl} } +type dataCnt struct { + data []byte + cnt uint64 +} + // topNHelper wraps some variables used when building cmsketch with top n. type topNHelper struct { sampleSize uint64 - counter map[hack.MutableString]uint64 - sorted []uint64 + sorted []dataCnt onlyOnceItems uint64 sumTopN uint64 - lastVal uint64 + actualNumTop uint32 } func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper { @@ -76,20 +80,16 @@ func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper { for i := range sample { counter[hack.String(sample[i])]++ } - sorted, onlyOnceItems := make([]uint64, 0, len(counter)), uint64(0) - for _, cnt := range counter { - sorted = append(sorted, cnt) + sorted, onlyOnceItems := make([]dataCnt, 0, len(counter)), uint64(0) + for key, cnt := range counter { + sorted = append(sorted, dataCnt{hack.Slice(string(key)), cnt}) if cnt == 1 { onlyOnceItems++ } } - sort.Slice(sorted, func(i, j int) bool { - return sorted[i] > sorted[j] - }) + sort.SliceStable(sorted, func(i, j int) bool { return sorted[i].cnt > sorted[j].cnt }) var ( - // last is the last element in top N index should occurres atleast `last` times. - last uint64 sumTopN uint64 sampleNDV = uint32(len(sorted)) ) @@ -98,15 +98,15 @@ func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper { // frequency of the n-th element are added to the TopN statistics. We chose // 2/3 as an empirical value because the average cardinality estimation // error is relatively small compared with 1/2. - for i := uint32(0); i < sampleNDV && i < numTop*2; i++ { - if i >= numTop && sorted[i]*3 < sorted[numTop-1]*2 && last != sorted[i] { + var actualNumTop uint32 + for ; actualNumTop < sampleNDV && actualNumTop < numTop*2; actualNumTop++ { + if actualNumTop >= numTop && sorted[actualNumTop].cnt*3 < sorted[numTop-1].cnt*2 { break } - last = sorted[i] - sumTopN += sorted[i] + sumTopN += sorted[actualNumTop].cnt } - return &topNHelper{uint64(len(sample)), counter, sorted, onlyOnceItems, sumTopN, last} + return &topNHelper{uint64(len(sample)), sorted, onlyOnceItems, sumTopN, actualNumTop} } // NewCMSketchWithTopN returns a new CM sketch with TopN elements, the estimate NDV and the scale ratio. @@ -116,37 +116,44 @@ func NewCMSketchWithTopN(d, w int32, sample [][]byte, numTop uint32, rowCount ui // In some cases, if user triggers fast analyze when rowCount is close to sampleSize, unexpected bahavior might happen. rowCount = mathutil.MaxUint64(rowCount, uint64(len(sample))) estimateNDV, scaleRatio := calculateEstimateNDV(helper, rowCount) - c := buildCMSWithTopN(helper, d, w, scaleRatio) - c.calculateDefaultVal(helper, estimateNDV, scaleRatio, rowCount) + defaultVal := calculateDefaultVal(helper, estimateNDV, scaleRatio, rowCount) + c := buildCMSWithTopN(helper, d, w, scaleRatio, defaultVal) return c, estimateNDV, scaleRatio } -func buildCMSWithTopN(helper *topNHelper, d, w int32, scaleRatio uint64) (c *CMSketch) { +func buildCMSWithTopN(helper *topNHelper, d, w int32, scaleRatio uint64, defaultVal uint64) (c *CMSketch) { c = NewCMSketch(d, w) enableTopN := helper.sampleSize/topNThreshold <= helper.sumTopN if enableTopN { c.topN = make(map[uint64][]*TopNMeta) - } - for counterKey, cnt := range helper.counter { - data, scaledCount := hack.Slice(string(counterKey)), cnt*scaleRatio - if enableTopN && cnt >= helper.lastVal { + for i := uint32(0); i < helper.actualNumTop; i++ { + data, cnt := helper.sorted[i].data, helper.sorted[i].cnt h1, h2 := murmur3.Sum128(data) - c.topN[h1] = append(c.topN[h1], &TopNMeta{h2, data, scaledCount}) - } else { - c.insertBytesByCount(data, scaledCount) + c.topN[h1] = append(c.topN[h1], &TopNMeta{h2, data, cnt * scaleRatio}) + } + helper.sorted = helper.sorted[helper.actualNumTop:] + } + c.defaultValue = defaultVal + for i := range helper.sorted { + data, cnt := helper.sorted[i].data, helper.sorted[i].cnt + // If the value only occurred once in the sample, we assumes that there is no difference with + // value that does not occurred in the sample. + rowCount := defaultVal + if cnt > 1 { + rowCount = cnt * scaleRatio } + c.insertBytesByCount(data, rowCount) } return } -func (c *CMSketch) calculateDefaultVal(helper *topNHelper, estimateNDV, scaleRatio, rowCount uint64) { +func calculateDefaultVal(helper *topNHelper, estimateNDV, scaleRatio, rowCount uint64) uint64 { sampleNDV := uint64(len(helper.sorted)) if rowCount <= (helper.sampleSize-uint64(helper.onlyOnceItems))*scaleRatio { - c.defaultValue = 1 - } else { - estimateRemainingCount := rowCount - (helper.sampleSize-uint64(helper.onlyOnceItems))*scaleRatio - c.defaultValue = estimateRemainingCount / (estimateNDV - uint64(sampleNDV) + helper.onlyOnceItems) + return 1 } + estimateRemainingCount := rowCount - (helper.sampleSize-uint64(helper.onlyOnceItems))*scaleRatio + return estimateRemainingCount / mathutil.MaxUint64(1, estimateNDV-uint64(sampleNDV)+helper.onlyOnceItems) } func (c *CMSketch) findTopNMeta(h1, h2 uint64, d []byte) *TopNMeta { @@ -238,7 +245,7 @@ func (c *CMSketch) setValue(h1, h2 uint64, count uint64) { } func (c *CMSketch) queryValue(sc *stmtctx.StatementContext, val types.Datum) (uint64, error) { - bytes, err := codec.EncodeValue(sc, nil, val) + bytes, err := tablecodec.EncodeValue(sc, val) if err != nil { return 0, errors.Trace(err) } @@ -280,14 +287,46 @@ func (c *CMSketch) queryHashValue(h1, h2 uint64) uint64 { return uint64(res) } +func (c *CMSketch) mergeTopN(lTopN map[uint64][]*TopNMeta, rTopN map[uint64][]*TopNMeta, numTop uint32) { + counter := make(map[hack.MutableString]uint64) + for _, metas := range lTopN { + for _, meta := range metas { + counter[hack.String(meta.Data)] += meta.Count + } + } + for _, metas := range rTopN { + for _, meta := range metas { + counter[hack.String(meta.Data)] += meta.Count + } + } + sorted := make([]uint64, len(counter)) + for _, cnt := range counter { + sorted = append(sorted, cnt) + } + sort.Slice(sorted, func(i, j int) bool { + return sorted[i] > sorted[j] + }) + numTop = mathutil.MinUint32(uint32(len(counter)), numTop) + lastTopCnt := sorted[numTop-1] + c.topN = make(map[uint64][]*TopNMeta) + for value, cnt := range counter { + data := hack.Slice(string(value)) + if cnt >= lastTopCnt { + h1, h2 := murmur3.Sum128(data) + c.topN[h1] = append(c.topN[h1], &TopNMeta{h2, data, cnt}) + } else { + c.insertBytesByCount(data, cnt) + } + } +} + // MergeCMSketch merges two CM Sketch. -// Call with CMSketch with Top-N initialized may downgrade the result -func (c *CMSketch) MergeCMSketch(rc *CMSketch) error { +func (c *CMSketch) MergeCMSketch(rc *CMSketch, numTopN uint32) error { if c.depth != rc.depth || c.width != rc.width { return errors.New("Dimensions of Count-Min Sketch should be the same") } if c.topN != nil || rc.topN != nil { - return errors.New("CMSketch with Top-N does not support merge") + c.mergeTopN(c.topN, rc.topN, numTopN) } c.count += rc.count for i := range c.table { @@ -406,7 +445,9 @@ func LoadCMSketchWithTopN(exec sqlexec.RestrictedSQLExecutor, tableID, isIndex, } topN := make([]*TopNMeta, 0, len(topNRows)) for _, row := range topNRows { - topN = append(topN, &TopNMeta{Data: row.GetBytes(0), Count: row.GetUint64(1)}) + data := make([]byte, len(row.GetBytes(0))) + copy(data, row.GetBytes(0)) + topN = append(topN, &TopNMeta{Data: data, Count: row.GetUint64(1)}) } return decodeCMSketch(cms, topN) } diff --git a/statistics/cmsketch_test.go b/statistics/cmsketch_test.go index 44ee8a57dfed6..ab5b1e3b0b858 100644 --- a/statistics/cmsketch_test.go +++ b/statistics/cmsketch_test.go @@ -131,7 +131,7 @@ func (s *testStatisticsSuite) TestCMSketch(c *C) { c.Assert(err, IsNil) c.Check(avg, LessEqual, t.avgError) - err = lSketch.MergeCMSketch(rSketch) + err = lSketch.MergeCMSketch(rSketch, 0) c.Assert(err, IsNil) for val, count := range rMap { lMap[val] += count @@ -167,21 +167,21 @@ func (s *testStatisticsSuite) TestCMSketchTopN(c *C) { // The first two tests produces almost same avg. { zipfFactor: 1.0000001, - avgError: 48, + avgError: 30, }, { zipfFactor: 1.1, - avgError: 48, + avgError: 30, }, { zipfFactor: 2, - avgError: 128, + avgError: 89, }, // If the most data lies in a narrow range, our guess may have better result. // The error mainly comes from huge numbers. { zipfFactor: 5, - avgError: 256, + avgError: 208, }, } d, w := int32(5), int32(2048) @@ -189,6 +189,7 @@ func (s *testStatisticsSuite) TestCMSketchTopN(c *C) { for _, t := range tests { lSketch, lMap, err := buildCMSketchTopNAndMap(d, w, 20, 1000, 0, total, imax, t.zipfFactor) c.Check(err, IsNil) + c.Assert(len(lSketch.TopN()), LessEqual, 40) avg, err := averageAbsoluteError(lSketch, lMap) c.Assert(err, IsNil) c.Check(avg, LessEqual, t.avgError) diff --git a/statistics/feedback.go b/statistics/feedback.go index b21793b298318..89e5b3c540924 100644 --- a/statistics/feedback.go +++ b/statistics/feedback.go @@ -35,7 +35,6 @@ import ( "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/ranger" - "github.com/spaolacci/murmur3" "go.uber.org/atomic" "go.uber.org/zap" ) @@ -122,11 +121,11 @@ func (q *QueryFeedback) DecodeToRanges(isIndex bool) ([]*ranger.Range, error) { if isIndex { var err error // As we do not know the origin length, just use a custom value here. - lowVal, err = codec.DecodeRange(low.GetBytes(), 4) + lowVal, _, err = codec.DecodeRange(low.GetBytes(), 4) if err != nil { return nil, errors.Trace(err) } - highVal, err = codec.DecodeRange(high.GetBytes(), 4) + highVal, _, err = codec.DecodeRange(high.GetBytes(), 4) if err != nil { return nil, errors.Trace(err) } @@ -315,15 +314,21 @@ func buildBucketFeedback(h *Histogram, feedback *QueryFeedback) (map[int]*Bucket if skip { continue } - idx, _ := h.Bounds.LowerBound(0, fb.Lower) + idx := h.Bounds.UpperBound(0, fb.Lower) bktIdx := 0 // The last bucket also stores the feedback that falls outside the upper bound. - if idx >= h.Bounds.NumRows()-2 { + if idx >= h.Bounds.NumRows()-1 { bktIdx = h.Len() - 1 + } else if h.Len() == 1 { + bktIdx = 0 } else { - bktIdx = idx / 2 + if idx == 0 { + bktIdx = 0 + } else { + bktIdx = (idx - 1) / 2 + } // Make sure that this feedback lies within the bucket. - if chunk.Compare(h.Bounds.GetRow(2*bktIdx+1), 0, fb.Upper) < 0 { + if chunk.Compare(h.Bounds.GetRow(2*(bktIdx+1)), 0, fb.Upper) < 0 { continue } } @@ -690,8 +695,11 @@ func buildNewHistogram(h *Histogram, buckets []bucket) *Histogram { type queryFeedback struct { IntRanges []int64 // HashValues is the murmur hash values for each index point. + // Note that index points will be stored in `IndexPoints`, we keep it here only for compatibility. HashValues []uint64 IndexRanges [][]byte + // IndexPoints stores the value of each equal condition. + IndexPoints [][]byte // Counts is the number of scan keys in each range. It first stores the count for `IntRanges`, `IndexRanges` or `ColumnRanges`. // After that, it stores the Ranges for `HashValues`. Counts []int64 @@ -724,8 +732,7 @@ func encodeIndexFeedback(q *QueryFeedback) *queryFeedback { var pointCounts []int64 for _, fb := range q.Feedback { if bytes.Compare(kv.Key(fb.Lower.GetBytes()).PrefixNext(), fb.Upper.GetBytes()) >= 0 { - h1, h2 := murmur3.Sum128(fb.Lower.GetBytes()) - pb.HashValues = append(pb.HashValues, h1, h2) + pb.IndexPoints = append(pb.IndexPoints, fb.Lower.GetBytes()) pointCounts = append(pointCounts, fb.Count) } else { pb.IndexRanges = append(pb.IndexRanges, fb.Lower.GetBytes(), fb.Upper.GetBytes()) @@ -788,9 +795,18 @@ func decodeFeedbackForIndex(q *QueryFeedback, pb *queryFeedback, c *CMSketch) { if c != nil { // decode the index point feedback, just set value count in CM Sketch start := len(pb.IndexRanges) / 2 - for i := 0; i < len(pb.HashValues); i += 2 { - // TODO: update using raw bytes instead of hash values. - c.setValue(pb.HashValues[i], pb.HashValues[i+1], uint64(pb.Counts[start+i/2])) + if len(pb.HashValues) > 0 { + // It needs raw values to update the top n, so just skip it here. + if len(c.topN) > 0 { + return + } + for i := 0; i < len(pb.HashValues); i += 2 { + c.setValue(pb.HashValues[i], pb.HashValues[i+1], uint64(pb.Counts[start+i/2])) + } + return + } + for i := 0; i < len(pb.IndexPoints); i++ { + c.updateValueBytes(pb.IndexPoints[i], uint64(pb.Counts[start+i])) } } } @@ -811,16 +827,40 @@ func decodeFeedbackForPK(q *QueryFeedback, pb *queryFeedback, isUnsigned bool) { } } -func decodeFeedbackForColumn(q *QueryFeedback, pb *queryFeedback) error { +// ConvertDatumsType converts the datums type to `ft`. +func ConvertDatumsType(vals []types.Datum, ft *types.FieldType, loc *time.Location) error { + for i, val := range vals { + if val.Kind() == types.KindMinNotNull || val.Kind() == types.KindMaxValue { + continue + } + newVal, err := tablecodec.UnflattenDatums([]types.Datum{val}, []*types.FieldType{ft}, loc) + if err != nil { + return err + } + vals[i] = newVal[0] + } + return nil +} + +func decodeColumnBounds(data []byte, ft *types.FieldType) ([]types.Datum, error) { + vals, _, err := codec.DecodeRange(data, 1) + if err != nil { + return nil, err + } + err = ConvertDatumsType(vals, ft, time.UTC) + return vals, err +} + +func decodeFeedbackForColumn(q *QueryFeedback, pb *queryFeedback, ft *types.FieldType) error { q.Tp = ColType for i := 0; i < len(pb.ColumnRanges); i += 2 { - low, err := codec.DecodeRange(pb.ColumnRanges[i], 1) + low, err := decodeColumnBounds(pb.ColumnRanges[i], ft) if err != nil { - return errors.Trace(err) + return err } - high, err := codec.DecodeRange(pb.ColumnRanges[i+1], 1) + high, err := decodeColumnBounds(pb.ColumnRanges[i+1], ft) if err != nil { - return errors.Trace(err) + return err } q.Feedback = append(q.Feedback, Feedback{&low[0], &high[0], pb.Counts[i/2], 0}) } @@ -828,7 +868,7 @@ func decodeFeedbackForColumn(q *QueryFeedback, pb *queryFeedback) error { } // DecodeFeedback decodes a byte slice to feedback. -func DecodeFeedback(val []byte, q *QueryFeedback, c *CMSketch, isUnsigned bool) error { +func DecodeFeedback(val []byte, q *QueryFeedback, c *CMSketch, ft *types.FieldType) error { buf := bytes.NewBuffer(val) dec := gob.NewDecoder(buf) pb := &queryFeedback{} @@ -836,12 +876,12 @@ func DecodeFeedback(val []byte, q *QueryFeedback, c *CMSketch, isUnsigned bool) if err != nil { return errors.Trace(err) } - if len(pb.IndexRanges) > 0 || len(pb.HashValues) > 0 { + if len(pb.IndexRanges) > 0 || len(pb.HashValues) > 0 || len(pb.IndexPoints) > 0 { decodeFeedbackForIndex(q, pb, c) } else if len(pb.IntRanges) > 0 { - decodeFeedbackForPK(q, pb, isUnsigned) + decodeFeedbackForPK(q, pb, mysql.HasUnsignedFlag(ft.Flag)) } else { - err := decodeFeedbackForColumn(q, pb) + err := decodeFeedbackForColumn(q, pb, ft) if err != nil { return errors.Trace(err) } @@ -952,7 +992,7 @@ func GetMaxValue(ft *types.FieldType) (max types.Datum) { case mysql.TypeNewDecimal: max.SetMysqlDecimal(types.NewMaxOrMinDec(false, ft.Flen, ft.Decimal)) case mysql.TypeDuration: - max.SetMysqlDuration(types.Duration{Duration: math.MaxInt64}) + max.SetMysqlDuration(types.Duration{Duration: types.MaxTime}) case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { max.SetMysqlTime(types.Time{Time: types.MaxDatetime, Type: ft.Tp}) @@ -987,7 +1027,7 @@ func GetMinValue(ft *types.FieldType) (min types.Datum) { case mysql.TypeNewDecimal: min.SetMysqlDecimal(types.NewMaxOrMinDec(true, ft.Flen, ft.Decimal)) case mysql.TypeDuration: - min.SetMysqlDuration(types.Duration{Duration: math.MinInt64}) + min.SetMysqlDuration(types.Duration{Duration: types.MinTime}) case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { min.SetMysqlTime(types.Time{Time: types.MinDatetime, Type: ft.Tp}) diff --git a/statistics/feedback_test.go b/statistics/feedback_test.go index 2cc91c96c932e..c2e5664b01f58 100644 --- a/statistics/feedback_test.go +++ b/statistics/feedback_test.go @@ -70,14 +70,13 @@ func (s *testFeedbackSuite) TestUpdateHistogram(c *C) { defaultBucketCount = 7 defer func() { defaultBucketCount = originBucketCount }() c.Assert(UpdateHistogram(q.Hist, q).ToString(0), Equals, - "column:0 ndv:10058 totColSize:0\n"+ - "num: 10000 lower_bound: 0 upper_bound: 1 repeats: 0\n"+ - "num: 9 lower_bound: 2 upper_bound: 7 repeats: 0\n"+ - "num: 11 lower_bound: 8 upper_bound: 19 repeats: 0\n"+ - "num: 0 lower_bound: 20 upper_bound: 20 repeats: 0\n"+ - "num: 18 lower_bound: 21 upper_bound: 39 repeats: 0\n"+ - "num: 18 lower_bound: 40 upper_bound: 58 repeats: 0\n"+ - "num: 2 lower_bound: 59 upper_bound: 60 repeats: 0") + "column:0 ndv:10053 totColSize:0\n"+ + "num: 10001 lower_bound: 0 upper_bound: 2 repeats: 0\n"+ + "num: 7 lower_bound: 2 upper_bound: 5 repeats: 0\n"+ + "num: 4 lower_bound: 5 upper_bound: 7 repeats: 0\n"+ + "num: 11 lower_bound: 10 upper_bound: 20 repeats: 0\n"+ + "num: 19 lower_bound: 30 upper_bound: 49 repeats: 0\n"+ + "num: 11 lower_bound: 50 upper_bound: 60 repeats: 0") } func (s *testFeedbackSuite) TestSplitBuckets(c *C) { @@ -236,7 +235,7 @@ func (s *testFeedbackSuite) TestFeedbackEncoding(c *C) { val, err := EncodeFeedback(q) c.Assert(err, IsNil) rq := &QueryFeedback{} - c.Assert(DecodeFeedback(val, rq, nil, false), IsNil) + c.Assert(DecodeFeedback(val, rq, nil, hist.Tp), IsNil) for _, fb := range rq.Feedback { fb.Lower.SetBytes(codec.EncodeInt(nil, fb.Lower.GetInt64())) fb.Upper.SetBytes(codec.EncodeInt(nil, fb.Upper.GetInt64())) @@ -251,7 +250,7 @@ func (s *testFeedbackSuite) TestFeedbackEncoding(c *C) { c.Assert(err, IsNil) rq = &QueryFeedback{} cms := NewCMSketch(4, 4) - c.Assert(DecodeFeedback(val, rq, cms, false), IsNil) + c.Assert(DecodeFeedback(val, rq, cms, hist.Tp), IsNil) c.Assert(cms.QueryBytes(codec.EncodeInt(nil, 0)), Equals, uint64(1)) q.Feedback = q.Feedback[:1] c.Assert(q.Equal(rq), IsTrue) diff --git a/statistics/handle/bootstrap.go b/statistics/handle/bootstrap.go index 3a9109454495b..e4e83963017ac 100644 --- a/statistics/handle/bootstrap.go +++ b/statistics/handle/bootstrap.go @@ -69,8 +69,8 @@ func (h *Handle) initStatsMeta(is infoschema.InfoSchema) (StatsCache, error) { return nil, errors.Trace(err) } tables := StatsCache{} - req := rc[0].NewRecordBatch() - iter := chunk.NewIterator4Chunk(req.Chunk) + req := rc[0].NewChunk() + iter := chunk.NewIterator4Chunk(req) for { err := rc[0].Next(context.TODO(), req) if err != nil { @@ -91,6 +91,7 @@ func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, tables Stat continue } id, ndv, nullCount, version, totColSize := row.GetInt64(2), row.GetInt64(3), row.GetInt64(5), row.GetUint64(4), row.GetInt64(7) + lastAnalyzePos := row.GetDatum(11, types.NewFieldType(mysql.TypeBlob)) tbl, _ := h.getTableByPhysicalID(is, table.PhysicalID) if row.GetInt64(1) > 0 { var idxInfo *model.IndexInfo @@ -109,7 +110,7 @@ func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, tables Stat terror.Log(errors.Trace(err)) } hist := statistics.NewHistogram(id, ndv, nullCount, version, types.NewFieldType(mysql.TypeBlob), chunk.InitialCapacity, 0) - table.Indices[hist.ID] = &statistics.Index{Histogram: *hist, CMSketch: cms, Info: idxInfo, StatsVer: row.GetInt64(8), Flag: row.GetInt64(10), LastAnalyzePos: row.GetDatum(11, types.NewFieldType(mysql.TypeBlob))} + table.Indices[hist.ID] = &statistics.Index{Histogram: *hist, CMSketch: cms, Info: idxInfo, StatsVer: row.GetInt64(8), Flag: row.GetInt64(10), LastAnalyzePos: *lastAnalyzePos.Copy()} } else { var colInfo *model.ColumnInfo for _, col := range tbl.Meta().Columns { @@ -130,7 +131,7 @@ func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, tables Stat Count: nullCount, IsHandle: tbl.Meta().PKIsHandle && mysql.HasPriKeyFlag(colInfo.Flag), Flag: row.GetInt64(10), - LastAnalyzePos: row.GetDatum(11, types.NewFieldType(mysql.TypeBlob)), + LastAnalyzePos: *lastAnalyzePos.Copy(), } } } @@ -147,8 +148,8 @@ func (h *Handle) initStatsHistograms(is infoschema.InfoSchema, tables StatsCache if err != nil { return errors.Trace(err) } - req := rc[0].NewRecordBatch() - iter := chunk.NewIterator4Chunk(req.Chunk) + req := rc[0].NewChunk() + iter := chunk.NewIterator4Chunk(req) for { err := rc[0].Next(context.TODO(), req) if err != nil { @@ -219,8 +220,8 @@ func (h *Handle) initStatsBuckets(tables StatsCache) error { if err != nil { return errors.Trace(err) } - req := rc[0].NewRecordBatch() - iter := chunk.NewIterator4Chunk(req.Chunk) + req := rc[0].NewChunk() + iter := chunk.NewIterator4Chunk(req) for { err := rc[0].Next(context.TODO(), req) if err != nil { diff --git a/statistics/handle/ddl.go b/statistics/handle/ddl.go index 14f6baec5f412..8a84f67b72ec4 100644 --- a/statistics/handle/ddl.go +++ b/statistics/handle/ddl.go @@ -137,7 +137,7 @@ func (h *Handle) insertColStats2KV(physicalID int64, colInfo *model.ColumnInfo) if err != nil { return } - req := rs[0].NewRecordBatch() + req := rs[0].NewChunk() err = rs[0].Next(ctx, req) if err != nil { return diff --git a/statistics/handle/ddl_test.go b/statistics/handle/ddl_test.go index 769a7093bb59d..a85d40e154bef 100644 --- a/statistics/handle/ddl_test.go +++ b/statistics/handle/ddl_test.go @@ -68,7 +68,7 @@ func (s *testStatsSuite) TestDDLTable(c *C) { h := do.StatsHandle() err = h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(err, IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl := h.GetTableStats(tableInfo) c.Assert(statsTbl.Pseudo, IsFalse) @@ -79,7 +79,7 @@ func (s *testStatsSuite) TestDDLTable(c *C) { tableInfo = tbl.Meta() err = h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(err, IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl = h.GetTableStats(tableInfo) c.Assert(statsTbl.Pseudo, IsFalse) @@ -90,7 +90,7 @@ func (s *testStatsSuite) TestDDLTable(c *C) { tableInfo = tbl.Meta() err = h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(err, IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl = h.GetTableStats(tableInfo) c.Assert(statsTbl.Pseudo, IsFalse) } @@ -111,7 +111,7 @@ func (s *testStatsSuite) TestDDLHistogram(c *C) { err := h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(err, IsNil) is := do.InfoSchema() - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) tableInfo := tbl.Meta() @@ -124,7 +124,7 @@ func (s *testStatsSuite) TestDDLHistogram(c *C) { err = h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(err, IsNil) is = do.InfoSchema() - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl, err = is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) tableInfo = tbl.Meta() @@ -142,7 +142,7 @@ func (s *testStatsSuite) TestDDLHistogram(c *C) { err = h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(err, IsNil) is = do.InfoSchema() - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl, err = is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) tableInfo = tbl.Meta() @@ -154,7 +154,7 @@ func (s *testStatsSuite) TestDDLHistogram(c *C) { err = h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(err, IsNil) is = do.InfoSchema() - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl, err = is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) tableInfo = tbl.Meta() @@ -191,7 +191,7 @@ PARTITION BY RANGE ( a ) ( h := do.StatsHandle() err = h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(err, IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) pi := tableInfo.GetPartitionInfo() for _, def := range pi.Definitions { statsTbl := h.GetPartitionStats(tableInfo, def.ID) @@ -204,7 +204,7 @@ PARTITION BY RANGE ( a ) ( err = h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(err, IsNil) is = do.InfoSchema() - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl, err = is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) tableInfo = tbl.Meta() diff --git a/statistics/handle/dump_test.go b/statistics/handle/dump_test.go index 7f080b92ad32c..934518aa44a1c 100644 --- a/statistics/handle/dump_test.go +++ b/statistics/handle/dump_test.go @@ -35,8 +35,8 @@ func (s *testStatsSuite) TestConversion(c *C) { tk.MustExec("insert into t(a,b) values (1, 1),(3, 1),(5, 10)") is := s.do.InfoSchema() h := s.do.StatsHandle() - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) tableInfo, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) @@ -73,7 +73,7 @@ PARTITION BY RANGE ( a ) ( tk.MustExec("analyze table t") is := s.do.InfoSchema() h := s.do.StatsHandle() - h.Update(is) + c.Assert(h.Update(is), IsNil) table, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) @@ -105,9 +105,9 @@ func (s *testStatsSuite) TestDumpAlteredTable(c *C) { tk.MustExec("use test") tk.MustExec("drop table if exists t") h := s.do.StatsHandle() - oriLease := h.Lease - h.Lease = 1 - defer func() { h.Lease = oriLease }() + oriLease := h.Lease() + h.SetLease(1) + defer func() { h.SetLease(oriLease) }() tk.MustExec("create table t(a int, b int)") tk.MustExec("analyze table t") tk.MustExec("alter table t drop column a") @@ -131,7 +131,7 @@ func (s *testStatsSuite) TestDumpCMSketchWithTopN(c *C) { c.Assert(err, IsNil) tableInfo := tbl.Meta() h := s.do.StatsHandle() - h.Update(is) + c.Assert(h.Update(is), IsNil) // Insert 30 fake data fakeData := make([][]byte, 0, 30) @@ -158,3 +158,20 @@ func (s *testStatsSuite) TestDumpCMSketchWithTopN(c *C) { cmsFromJSON := stat.Columns[tableInfo.Columns[0].ID].CMSketch.Copy() c.Check(cms.Equal(cmsFromJSON), IsTrue) } + +func (s *testStatsSuite) TestDumpPseudoColumns(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + testKit.MustExec("use test") + testKit.MustExec("create table t(a int, b int, index idx(a))") + // Force adding an pseudo tables in stats cache. + testKit.MustQuery("select * from t") + testKit.MustExec("analyze table t index idx") + + is := s.do.InfoSchema() + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + h := s.do.StatsHandle() + _, err = h.DumpStatsToJSON("test", tbl.Meta(), nil) + c.Assert(err, IsNil) +} diff --git a/statistics/handle/gc.go b/statistics/handle/gc.go index 0cc79e89667e7..5f1d9b907c3a9 100644 --- a/statistics/handle/gc.go +++ b/statistics/handle/gc.go @@ -29,7 +29,7 @@ import ( func (h *Handle) GCStats(is infoschema.InfoSchema, ddlLease time.Duration) error { // To make sure that all the deleted tables' schema and stats info have been acknowledged to all tidb, // we only garbage collect version before 10 lease. - lease := mathutil.MaxInt64(int64(h.Lease), int64(ddlLease)) + lease := mathutil.MaxInt64(int64(h.Lease()), int64(ddlLease)) offset := DurationToTS(10 * time.Duration(lease)) if h.LastUpdateVersion() < offset { return nil diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 912dbed1e6865..bdf16a79405e3 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -70,7 +70,7 @@ type Handle struct { // feedback is used to store query feedback info. feedback []*statistics.QueryFeedback - Lease time.Duration + lease atomic2.Duration } // Clear the StatsCache, only for test. @@ -83,7 +83,8 @@ func (h *Handle) Clear() { } h.feedback = h.feedback[:0] h.mu.ctx.GetSessionVars().InitChunkSize = 1 - h.mu.ctx.GetSessionVars().MaxChunkSize = 32 + h.mu.ctx.GetSessionVars().MaxChunkSize = 1 + h.mu.ctx.GetSessionVars().ProjectionConcurrency = 0 h.listHead = &SessionStatsCollector{mapper: make(tableDeltaMap), rateMap: make(errorRateDeltaMap)} h.globalMap = make(tableDeltaMap) h.mu.rateMap = make(errorRateDeltaMap) @@ -99,9 +100,9 @@ func NewHandle(ctx sessionctx.Context, lease time.Duration) *Handle { ddlEventCh: make(chan *util.Event, 100), listHead: &SessionStatsCollector{mapper: make(tableDeltaMap), rateMap: make(errorRateDeltaMap)}, globalMap: make(tableDeltaMap), - Lease: lease, feedback: make([]*statistics.QueryFeedback, 0, MaxQueryFeedbackCount.Load()), } + handle.lease.Store(lease) // It is safe to use it concurrently because the exec won't touch the ctx. if exec, ok := ctx.(sqlexec.RestrictedSQLExecutor); ok { handle.restrictedExec = exec @@ -112,6 +113,16 @@ func NewHandle(ctx sessionctx.Context, lease time.Duration) *Handle { return handle } +// Lease returns the stats lease. +func (h *Handle) Lease() time.Duration { + return h.lease.Load() +} + +// SetLease sets the stats lease. +func (h *Handle) SetLease(lease time.Duration) { + h.lease.Store(lease) +} + // GetQueryFeedback gets the query feedback. It is only use in test. func (h *Handle) GetQueryFeedback() []*statistics.QueryFeedback { defer func() { @@ -133,7 +144,7 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { // and A0 < B0 < B1 < A1. We will first read the stats of B, and update the lastVersion to B0, but we cannot read // the table stats of A0 if we read stats that greater than lastVersion which is B0. // We can read the stats if the diff between commit time and version is less than three lease. - offset := DurationToTS(3 * h.Lease) + offset := DurationToTS(3 * h.Lease()) if lastVersion >= offset { lastVersion = lastVersion - offset } else { @@ -343,6 +354,7 @@ func (h *Handle) indexStatsFromStorage(row chunk.Row, table *statistics.Table, t idx := table.Indices[histID] errorRate := statistics.ErrorRate{} flag := row.GetInt64(8) + lastAnalyzePos := row.GetDatum(10, types.NewFieldType(mysql.TypeBlob)) if statistics.IsAnalyzed(flag) { h.mu.Lock() h.mu.rateMap.clear(table.PhysicalID, histID, true) @@ -363,7 +375,7 @@ func (h *Handle) indexStatsFromStorage(row chunk.Row, table *statistics.Table, t if err != nil { return errors.Trace(err) } - idx = &statistics.Index{Histogram: *hg, CMSketch: cms, Info: idxInfo, ErrorRate: errorRate, StatsVer: row.GetInt64(7), Flag: flag, LastAnalyzePos: row.GetDatum(10, types.NewFieldType(mysql.TypeBlob))} + idx = &statistics.Index{Histogram: *hg, CMSketch: cms, Info: idxInfo, ErrorRate: errorRate, StatsVer: row.GetInt64(7), Flag: flag, LastAnalyzePos: *lastAnalyzePos.Copy()} } break } @@ -382,6 +394,7 @@ func (h *Handle) columnStatsFromStorage(row chunk.Row, table *statistics.Table, nullCount := row.GetInt64(5) totColSize := row.GetInt64(6) correlation := row.GetFloat64(9) + lastAnalyzePos := row.GetDatum(10, types.NewFieldType(mysql.TypeBlob)) col := table.Columns[histID] errorRate := statistics.ErrorRate{} flag := row.GetInt64(8) @@ -402,7 +415,7 @@ func (h *Handle) columnStatsFromStorage(row chunk.Row, table *statistics.Table, // 2. this column is not handle, and: // 3. the column doesn't has buckets before, and: // 4. loadAll is false. - notNeedLoad := h.Lease > 0 && + notNeedLoad := h.Lease() > 0 && !isHandle && (col == nil || col.Len() == 0 && col.LastUpdateVersion < histVer) && !loadAll @@ -419,7 +432,7 @@ func (h *Handle) columnStatsFromStorage(row chunk.Row, table *statistics.Table, ErrorRate: errorRate, IsHandle: tableInfo.PKIsHandle && mysql.HasPriKeyFlag(colInfo.Flag), Flag: flag, - LastAnalyzePos: row.GetDatum(10, types.NewFieldType(mysql.TypeBlob)), + LastAnalyzePos: *lastAnalyzePos.Copy(), } col.Histogram.Correlation = correlation break @@ -442,7 +455,7 @@ func (h *Handle) columnStatsFromStorage(row chunk.Row, table *statistics.Table, ErrorRate: errorRate, IsHandle: tableInfo.PKIsHandle && mysql.HasPriKeyFlag(colInfo.Flag), Flag: flag, - LastAnalyzePos: row.GetDatum(10, types.NewFieldType(mysql.TypeBlob)), + LastAnalyzePos: *lastAnalyzePos.Copy(), } break } diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 2be9b8e798194..bb3b08cb5490e 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -246,7 +246,7 @@ func (s *testStatsSuite) TestVersion(c *C) { unit := oracle.ComposeTS(1, 0) testKit.MustExec("update mysql.stats_meta set version = ? where table_id = ?", 2*unit, tableInfo1.ID) - h.Update(is) + c.Assert(h.Update(is), IsNil) c.Assert(h.LastUpdateVersion(), Equals, 2*unit) statsTbl1 := h.GetTableStats(tableInfo1) c.Assert(statsTbl1.Pseudo, IsFalse) @@ -259,7 +259,7 @@ func (s *testStatsSuite) TestVersion(c *C) { tableInfo2 := tbl2.Meta() // A smaller version write, and we can still read it. testKit.MustExec("update mysql.stats_meta set version = ? where table_id = ?", unit, tableInfo2.ID) - h.Update(is) + c.Assert(h.Update(is), IsNil) c.Assert(h.LastUpdateVersion(), Equals, 2*unit) statsTbl2 := h.GetTableStats(tableInfo2) c.Assert(statsTbl2.Pseudo, IsFalse) @@ -268,7 +268,7 @@ func (s *testStatsSuite) TestVersion(c *C) { testKit.MustExec("analyze table t1") offset := 3 * unit testKit.MustExec("update mysql.stats_meta set version = ? where table_id = ?", offset+4, tableInfo1.ID) - h.Update(is) + c.Assert(h.Update(is), IsNil) c.Assert(h.LastUpdateVersion(), Equals, offset+uint64(4)) statsTbl1 = h.GetTableStats(tableInfo1) c.Assert(statsTbl1.Count, Equals, int64(1)) @@ -277,7 +277,7 @@ func (s *testStatsSuite) TestVersion(c *C) { testKit.MustExec("analyze table t2") // A smaller version write, and we can still read it. testKit.MustExec("update mysql.stats_meta set version = ? where table_id = ?", offset+3, tableInfo2.ID) - h.Update(is) + c.Assert(h.Update(is), IsNil) c.Assert(h.LastUpdateVersion(), Equals, offset+uint64(4)) statsTbl2 = h.GetTableStats(tableInfo2) c.Assert(statsTbl2.Count, Equals, int64(1)) @@ -286,7 +286,7 @@ func (s *testStatsSuite) TestVersion(c *C) { testKit.MustExec("analyze table t2") // A smaller version write, and we cannot read it. Because at this time, lastThree Version is 4. testKit.MustExec("update mysql.stats_meta set version = 1 where table_id = ?", tableInfo2.ID) - h.Update(is) + c.Assert(h.Update(is), IsNil) c.Assert(h.LastUpdateVersion(), Equals, offset+uint64(4)) statsTbl2 = h.GetTableStats(tableInfo2) c.Assert(statsTbl2.Count, Equals, int64(1)) @@ -295,13 +295,13 @@ func (s *testStatsSuite) TestVersion(c *C) { testKit.MustExec("alter table t2 add column c3 int") testKit.MustExec("analyze table t2") // load it with old schema. - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl2 = h.GetTableStats(tableInfo2) c.Assert(statsTbl2.Pseudo, IsFalse) c.Assert(statsTbl2.Columns[int64(3)], IsNil) // Next time DDL updated. is = do.InfoSchema() - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl2 = h.GetTableStats(tableInfo2) c.Assert(statsTbl2.Pseudo, IsFalse) // We can read it without analyze again! Thanks for PrevLastVersion. @@ -330,7 +330,7 @@ func (s *testStatsSuite) TestLoadHist(c *C) { for i := 0; i < rowCount; i++ { testKit.MustExec("insert into t values('bb','sdfga')") } - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) h.Update(do.InfoSchema()) newStatsTbl := h.GetTableStats(tableInfo) // The stats table is updated. @@ -356,7 +356,7 @@ func (s *testStatsSuite) TestLoadHist(c *C) { tbl, err = is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) tableInfo = tbl.Meta() - h.Update(is) + c.Assert(h.Update(is), IsNil) newStatsTbl2 := h.GetTableStats(tableInfo) c.Assert(newStatsTbl2 == newStatsTbl, IsFalse) // The histograms is not updated. @@ -371,7 +371,7 @@ func (s *testStatsSuite) TestInitStats(c *C) { testKit := testkit.NewTestKit(c, s.store) testKit.MustExec("use test") testKit.MustExec("create table t(a int, b int, c int, primary key(a), key idx(b))") - testKit.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3),(4,4,4),(5,5,5),(6,6,6)") + testKit.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3),(4,4,4),(5,5,5),(6,7,8)") testKit.MustExec("analyze table t") h := s.do.StatsHandle() is := s.do.InfoSchema() @@ -379,16 +379,20 @@ func (s *testStatsSuite) TestInitStats(c *C) { c.Assert(err, IsNil) // `Update` will not use load by need strategy when `Lease` is 0, and `InitStats` is only called when // `Lease` is not 0, so here we just change it. - h.Lease = time.Millisecond + h.SetLease(time.Millisecond) h.Clear() c.Assert(h.InitStats(is), IsNil) table0 := h.GetTableStats(tbl.Meta()) + cols := table0.Columns + c.Assert(cols[1].LastAnalyzePos.GetBytes()[0], Equals, uint8(0x36)) + c.Assert(cols[2].LastAnalyzePos.GetBytes()[0], Equals, uint8(0x37)) + c.Assert(cols[3].LastAnalyzePos.GetBytes()[0], Equals, uint8(0x38)) h.Clear() c.Assert(h.Update(is), IsNil) table1 := h.GetTableStats(tbl.Meta()) assertTableEqual(c, table0, table1) - h.Lease = 0 + h.SetLease(0) } func (s *testStatsSuite) TestLoadStats(c *C) { @@ -398,10 +402,10 @@ func (s *testStatsSuite) TestLoadStats(c *C) { testKit.MustExec("create table t(a int, b int, c int, primary key(a), key idx(b))") testKit.MustExec("insert into t values (1,1,1),(2,2,2),(3,3,3)") - oriLease := s.do.StatsHandle().Lease - s.do.StatsHandle().Lease = 1 + oriLease := s.do.StatsHandle().Lease() + s.do.StatsHandle().SetLease(1) defer func() { - s.do.StatsHandle().Lease = oriLease + s.do.StatsHandle().SetLease(oriLease) }() testKit.MustExec("analyze table t") @@ -431,13 +435,13 @@ func (s *testStatsSuite) TestLoadStats(c *C) { c.Assert(hg.Len(), Greater, 0) } -func newStoreWithBootstrap(statsLease time.Duration) (kv.Storage, *domain.Domain, error) { +func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) { store, err := mockstore.NewMockTikvStore() if err != nil { return nil, nil, errors.Trace(err) } session.SetSchemaLease(0) - session.SetStatsLease(statsLease) + session.DisableStats4Test() domain.RunAutoAnalyze = false do, err := session.BootstrapSession(store) do.SetStatsUpdating(true) diff --git a/statistics/handle/update.go b/statistics/handle/update.go index 1084fc5dc1528..38a82f12c5292 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -559,7 +559,7 @@ func (h *Handle) handleSingleHistogramUpdate(is infoschema.InfoSchema, rows []ch } q := &statistics.QueryFeedback{} for _, row := range rows { - err1 := statistics.DecodeFeedback(row.GetBytes(3), q, cms, mysql.HasUnsignedFlag(hist.Tp.Flag)) + err1 := statistics.DecodeFeedback(row.GetBytes(3), q, cms, hist.Tp) if err1 != nil { logutil.Logger(context.Background()).Debug("decode feedback failed", zap.Error(err)) } @@ -740,7 +740,7 @@ func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics if statsTbl.Pseudo || statsTbl.Count < AutoAnalyzeMinCnt { return false } - if needAnalyze, reason := NeedAnalyzeTable(statsTbl, 20*h.Lease, ratio, start, end, time.Now()); needAnalyze { + if needAnalyze, reason := NeedAnalyzeTable(statsTbl, 20*h.Lease(), ratio, start, end, time.Now()); needAnalyze { logutil.Logger(context.Background()).Info("[stats] auto analyze triggered", zap.String("sql", sql), zap.String("reason", reason)) h.execAutoAnalyze(sql) return true @@ -778,11 +778,11 @@ func formatBuckets(hg *statistics.Histogram, lowBkt, highBkt, idxCols int) strin return hg.BucketToString(lowBkt, idxCols) } if lowBkt+1 == highBkt { - return fmt.Sprintf("%s, %s", hg.BucketToString(lowBkt, 0), hg.BucketToString(highBkt, 0)) + return fmt.Sprintf("%s, %s", hg.BucketToString(lowBkt, idxCols), hg.BucketToString(highBkt, idxCols)) } // do not care the middle buckets - return fmt.Sprintf("%s, (%d buckets, total count %d), %s", hg.BucketToString(lowBkt, 0), - highBkt-lowBkt-1, hg.Buckets[highBkt-1].Count-hg.Buckets[lowBkt].Count, hg.BucketToString(highBkt, 0)) + return fmt.Sprintf("%s, (%d buckets, total count %d), %s", hg.BucketToString(lowBkt, idxCols), + highBkt-lowBkt-1, hg.Buckets[highBkt-1].Count-hg.Buckets[lowBkt].Count, hg.BucketToString(highBkt, idxCols)) } func colRangeToStr(c *statistics.Column, ran *ranger.Range, actual int64, factor float64) string { @@ -852,10 +852,13 @@ func logForIndex(prefix string, t *statistics.Table, idx *statistics.Index, rang zap.String("equality", equalityString), zap.Uint64("expected equality", equalityCount), zap.String("range", rangeString)) } else if colHist := t.ColumnByName(colName); colHist != nil && colHist.Histogram.Len() > 0 { - rangeString := colRangeToStr(colHist, &rang, -1, factor) - logutil.Logger(context.Background()).Debug(prefix, zap.String("index", idx.Info.Name.O), zap.Int64("actual", actual[i]), - zap.String("equality", equalityString), zap.Uint64("expected equality", equalityCount), - zap.String("range", rangeString)) + err = convertRangeType(&rang, colHist.Tp, time.UTC) + if err == nil { + rangeString := colRangeToStr(colHist, &rang, -1, factor) + logutil.Logger(context.Background()).Debug(prefix, zap.String("index", idx.Info.Name.O), zap.Int64("actual", actual[i]), + zap.String("equality", equalityString), zap.Uint64("expected equality", equalityCount), + zap.String("range", rangeString)) + } } else { count, err := statistics.GetPseudoRowCountByColumnRanges(sc, float64(t.Count), []*ranger.Range{&rang}, 0) if err == nil { @@ -1004,6 +1007,14 @@ func (h *Handle) dumpRangeFeedback(sc *stmtctx.StatementContext, ran *ranger.Ran return errors.Trace(h.DumpFeedbackToKV(q)) } +func convertRangeType(ran *ranger.Range, ft *types.FieldType, loc *time.Location) error { + err := statistics.ConvertDatumsType(ran.LowVal, ft, loc) + if err != nil { + return err + } + return statistics.ConvertDatumsType(ran.HighVal, ft, loc) +} + // DumpFeedbackForIndex dumps the feedback for index. // For queries that contains both equality and range query, we will split them and Update accordingly. func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics.Table) error { @@ -1033,7 +1044,7 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics continue } equalityCount := float64(idx.CMSketch.QueryBytes(bytes)) * idx.GetIncreaseFactor(t.Count) - rang := ranger.Range{ + rang := &ranger.Range{ LowVal: []types.Datum{ran.LowVal[rangePosition]}, HighVal: []types.Datum{ran.HighVal[rangePosition]}, } @@ -1042,11 +1053,14 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics rangeFB := &statistics.QueryFeedback{PhysicalID: q.PhysicalID} // prefer index stats over column stats if idx := t.IndexStartWithColumn(colName); idx != nil && idx.Histogram.Len() != 0 { - rangeCount, err = t.GetRowCountByIndexRanges(sc, idx.ID, []*ranger.Range{&rang}) + rangeCount, err = t.GetRowCountByIndexRanges(sc, idx.ID, []*ranger.Range{rang}) rangeFB.Tp, rangeFB.Hist = statistics.IndexType, &idx.Histogram } else if col := t.ColumnByName(colName); col != nil && col.Histogram.Len() != 0 { - rangeCount, err = t.GetRowCountByColumnRanges(sc, col.ID, []*ranger.Range{&rang}) - rangeFB.Tp, rangeFB.Hist = statistics.ColType, &col.Histogram + err = convertRangeType(rang, col.Tp, time.UTC) + if err == nil { + rangeCount, err = t.GetRowCountByColumnRanges(sc, col.ID, []*ranger.Range{rang}) + rangeFB.Tp, rangeFB.Hist = statistics.ColType, &col.Histogram + } } else { continue } @@ -1058,7 +1072,7 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics equalityCount, rangeCount = getNewCountForIndex(equalityCount, rangeCount, float64(t.Count), float64(q.Feedback[i].Count)) value := types.NewBytesDatum(bytes) q.Feedback[i] = statistics.Feedback{Lower: &value, Upper: &value, Count: int64(equalityCount)} - err = h.dumpRangeFeedback(sc, &rang, rangeCount, rangeFB) + err = h.dumpRangeFeedback(sc, rang, rangeCount, rangeFB) if err != nil { logutil.Logger(context.Background()).Debug("dump range feedback fail", zap.Error(err)) continue diff --git a/statistics/handle/update_test.go b/statistics/handle/update_test.go index c5adb405a113e..0843cb5efc872 100644 --- a/statistics/handle/update_test.go +++ b/statistics/handle/update_test.go @@ -52,7 +52,7 @@ func (s *testStatsSuite) SetUpSuite(c *C) { // Add the hook here to avoid data race. s.registerHook() var err error - s.store, s.do, err = newStoreWithBootstrap(0) + s.store, s.do, err = newStoreWithBootstrap() c.Assert(err, IsNil) } @@ -95,8 +95,8 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) { h.HandleDDLEvent(<-h.DDLEventCh()) h.HandleDDLEvent(<-h.DDLEventCh()) - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 := h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1)) @@ -111,8 +111,8 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) { for i := 0; i < rowCount1; i++ { testKit.MustExec("insert into t1 values(1, 2)") } - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1*2)) @@ -126,8 +126,8 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) { testKit.MustExec("insert into t1 values(1, 2)") } testKit.MustExec("commit") - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1*3)) @@ -142,8 +142,8 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) { testKit.MustExec("update t2 set c2 = c1") } testKit.MustExec("commit") - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1*3)) stats2 = h.GetTableStats(tableInfo2) @@ -152,8 +152,8 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) { testKit.MustExec("begin") testKit.MustExec("delete from t1") testKit.MustExec("commit") - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(0)) @@ -174,19 +174,19 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) { testKit.MustExec("insert into t1 values (1,2)") } h.DumpStatsDeltaToKV(handle.DumpDelta) - h.Update(is) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1)) // not dumped testKit.MustExec("insert into t1 values (1,2)") h.DumpStatsDeltaToKV(handle.DumpDelta) - h.Update(is) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1)) h.FlushStats() - h.Update(is) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1+1)) } @@ -206,8 +206,8 @@ func (s *testStatsSuite) TestRollback(c *C) { tableInfo := tbl.Meta() h := s.do.StatsHandle() h.HandleDDLEvent(<-h.DDLEventCh()) - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats := h.GetTableStats(tableInfo) c.Assert(stats.Count, Equals, int64(0)) @@ -241,8 +241,8 @@ func (s *testStatsSuite) TestMultiSession(c *C) { h.HandleDDLEvent(<-h.DDLEventCh()) - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 := h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1)) @@ -261,8 +261,8 @@ func (s *testStatsSuite) TestMultiSession(c *C) { testKit.Se.Close() testKit2.Se.Close() - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1*2)) // The session in testKit is already Closed, set it to nil will create a new session. @@ -290,29 +290,29 @@ func (s *testStatsSuite) TestTxnWithFailure(c *C) { for i := 0; i < rowCount1; i++ { testKit.MustExec("insert into t1 values(?, 2)", i) } - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 := h.GetTableStats(tableInfo1) // have not commit c.Assert(stats1.Count, Equals, int64(0)) testKit.MustExec("commit") - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1)) _, err = testKit.Exec("insert into t1 values(0, 2)") c.Assert(err, NotNil) - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1)) testKit.MustExec("insert into t1 values(-1, 2)") - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) stats1 = h.GetTableStats(tableInfo1) c.Assert(stats1.Count, Equals, int64(rowCount1+1)) } @@ -388,16 +388,16 @@ func (s *testStatsSuite) TestAutoUpdate(c *C) { h := do.StatsHandle() h.HandleDDLEvent(<-h.DDLEventCh()) - h.Update(is) + c.Assert(h.Update(is), IsNil) stats := h.GetTableStats(tableInfo) c.Assert(stats.Count, Equals, int64(0)) _, err = testKit.Exec("insert into t values ('ss')") c.Assert(err, IsNil) - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) h.HandleAutoAnalyze(is) - h.Update(is) + c.Assert(h.Update(is), IsNil) stats = h.GetTableStats(tableInfo) c.Assert(stats.Count, Equals, int64(1)) c.Assert(stats.ModifyCount, Equals, int64(0)) @@ -408,14 +408,14 @@ func (s *testStatsSuite) TestAutoUpdate(c *C) { } // Test that even if the table is recently modified, we can still analyze the table. - h.Lease = time.Second - defer func() { h.Lease = 0 }() + h.SetLease(time.Second) + defer func() { h.SetLease(0) }() _, err = testKit.Exec("insert into t values ('fff')") c.Assert(err, IsNil) c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(is), IsNil) h.HandleAutoAnalyze(is) - h.Update(is) + c.Assert(h.Update(is), IsNil) stats = h.GetTableStats(tableInfo) c.Assert(stats.Count, Equals, int64(2)) c.Assert(stats.ModifyCount, Equals, int64(1)) @@ -425,17 +425,17 @@ func (s *testStatsSuite) TestAutoUpdate(c *C) { c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(is), IsNil) h.HandleAutoAnalyze(is) - h.Update(is) + c.Assert(h.Update(is), IsNil) stats = h.GetTableStats(tableInfo) c.Assert(stats.Count, Equals, int64(3)) c.Assert(stats.ModifyCount, Equals, int64(0)) _, err = testKit.Exec("insert into t values ('eee')") c.Assert(err, IsNil) - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) h.HandleAutoAnalyze(is) - h.Update(is) + c.Assert(h.Update(is), IsNil) stats = h.GetTableStats(tableInfo) c.Assert(stats.Count, Equals, int64(4)) // Modify count is non-zero means that we do not analyze the table. @@ -454,7 +454,7 @@ func (s *testStatsSuite) TestAutoUpdate(c *C) { c.Assert(err, IsNil) tableInfo = tbl.Meta() h.HandleAutoAnalyze(is) - h.Update(is) + c.Assert(h.Update(is), IsNil) stats = h.GetTableStats(tableInfo) c.Assert(stats.Count, Equals, int64(4)) c.Assert(stats.ModifyCount, Equals, int64(0)) @@ -487,13 +487,13 @@ func (s *testStatsSuite) TestAutoUpdatePartition(c *C) { pi := tableInfo.GetPartitionInfo() h := do.StatsHandle() - h.Update(is) + c.Assert(h.Update(is), IsNil) stats := h.GetPartitionStats(tableInfo, pi.Definitions[0].ID) c.Assert(stats.Count, Equals, int64(0)) testKit.MustExec("insert into t values (1)") - h.DumpStatsDeltaToKV(handle.DumpAll) - h.Update(is) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.Update(is), IsNil) h.HandleAutoAnalyze(is) stats = h.GetPartitionStats(tableInfo, pi.Definitions[0].ID) c.Assert(stats.Count, Equals, int64(1)) @@ -513,23 +513,23 @@ func (s *testStatsSuite) TestTableAnalyzed(c *C) { tableInfo := tbl.Meta() h := s.do.StatsHandle() - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl := h.GetTableStats(tableInfo) c.Assert(handle.TableAnalyzed(statsTbl), IsFalse) testKit.MustExec("analyze table t") - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl = h.GetTableStats(tableInfo) c.Assert(handle.TableAnalyzed(statsTbl), IsTrue) h.Clear() - oriLease := h.Lease + oriLease := h.Lease() // set it to non-zero so we will use load by need strategy - h.Lease = 1 + h.SetLease(1) defer func() { - h.Lease = oriLease + h.SetLease(oriLease) }() - h.Update(is) + c.Assert(h.Update(is), IsNil) statsTbl = h.GetTableStats(tableInfo) c.Assert(handle.TableAnalyzed(statsTbl), IsTrue) } @@ -538,8 +538,8 @@ func (s *testStatsSuite) TestUpdateErrorRate(c *C) { defer cleanEnv(c, s.store, s.do) h := s.do.StatsHandle() is := s.do.InfoSchema() - h.Lease = 0 - h.Update(is) + h.SetLease(0) + c.Assert(h.Update(is), IsNil) oriProbability := statistics.FeedbackProbability defer func() { @@ -563,7 +563,7 @@ func (s *testStatsSuite) TestUpdateErrorRate(c *C) { testKit.MustExec("insert into t values (12, 3)") c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) is = s.do.InfoSchema() - h.Update(is) + c.Assert(h.Update(is), IsNil) table, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) @@ -580,7 +580,7 @@ func (s *testStatsSuite) TestUpdateErrorRate(c *C) { c.Assert(h.DumpStatsFeedbackToKV(), IsNil) c.Assert(h.HandleUpdateStats(is), IsNil) h.UpdateErrorRate(is) - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl = h.GetTableStats(tblInfo) // The error rate of this column is not larger than MaxErrorRate now. @@ -592,14 +592,14 @@ func (s *testStatsSuite) TestUpdateErrorRate(c *C) { c.Assert(h.DumpStatsFeedbackToKV(), IsNil) c.Assert(h.HandleUpdateStats(is), IsNil) h.UpdateErrorRate(is) - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl = h.GetTableStats(tblInfo) c.Assert(tbl.Indices[bID].NotAccurate(), IsFalse) c.Assert(tbl.Indices[bID].QueryTotal, Equals, int64(1)) testKit.MustExec("analyze table t") c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl = h.GetTableStats(tblInfo) c.Assert(tbl.Indices[bID].QueryTotal, Equals, int64(0)) } @@ -608,8 +608,8 @@ func (s *testStatsSuite) TestUpdatePartitionErrorRate(c *C) { defer cleanEnv(c, s.store, s.do) h := s.do.StatsHandle() is := s.do.InfoSchema() - h.Lease = 0 - h.Update(is) + h.SetLease(0) + c.Assert(h.Update(is), IsNil) oriProbability := statistics.FeedbackProbability defer func() { @@ -634,7 +634,7 @@ func (s *testStatsSuite) TestUpdatePartitionErrorRate(c *C) { testKit.MustExec("insert into t values (12)") c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) is = s.do.InfoSchema() - h.Update(is) + c.Assert(h.Update(is), IsNil) table, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) @@ -651,7 +651,7 @@ func (s *testStatsSuite) TestUpdatePartitionErrorRate(c *C) { c.Assert(h.DumpStatsFeedbackToKV(), IsNil) c.Assert(h.HandleUpdateStats(is), IsNil) h.UpdateErrorRate(is) - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl = h.GetPartitionStats(tblInfo, pid) // The error rate of this column is not larger than MaxErrorRate now. @@ -684,18 +684,18 @@ func (s *testStatsSuite) TestSplitRange(c *C) { { points: []int64{0, 1, 3, 8, 8, 20}, exclude: []bool{true, false, true, false, true, false}, - result: "(0,1],(3,5],(5,7],(7,8],(8,20]", + result: "(0,1],(3,7),[7,8),[8,8],(8,10),[10,20]", }, { points: []int64{8, 10, 20, 30}, exclude: []bool{false, false, true, true}, - result: "[8,8],(8,10],(20,30)", + result: "[8,10),[10,10],(20,30)", }, { // test remove invalid range points: []int64{8, 9}, exclude: []bool{false, true}, - result: "[8,8]", + result: "[8,9)", }, } for _, t := range tests { @@ -743,25 +743,25 @@ func (s *testStatsSuite) TestQueryFeedback(c *C) { // test primary key feedback sql: "select * from t where t.a <= 5", hist: "column:1 ndv:4 totColSize:0\n" + - "num: 1 lower_bound: -9223372036854775808 upper_bound: 1 repeats: 0\n" + - "num: 1 lower_bound: 2 upper_bound: 2 repeats: 1\n" + - "num: 2 lower_bound: 3 upper_bound: 5 repeats: 0", + "num: 1 lower_bound: -9223372036854775808 upper_bound: 2 repeats: 0\n" + + "num: 2 lower_bound: 2 upper_bound: 4 repeats: 0\n" + + "num: 1 lower_bound: 4 upper_bound: 4 repeats: 1", idxCols: 0, }, { // test index feedback by double read sql: "select * from t use index(idx) where t.b <= 5", hist: "index:1 ndv:2\n" + - "num: 2 lower_bound: -inf upper_bound: 2 repeats: 0\n" + - "num: 2 lower_bound: 3 upper_bound: 6 repeats: 0", + "num: 3 lower_bound: -inf upper_bound: 5 repeats: 0\n" + + "num: 1 lower_bound: 5 upper_bound: 5 repeats: 1", idxCols: 1, }, { // test index feedback by single read sql: "select b from t use index(idx) where t.b <= 5", hist: "index:1 ndv:2\n" + - "num: 2 lower_bound: -inf upper_bound: 2 repeats: 0\n" + - "num: 2 lower_bound: 3 upper_bound: 6 repeats: 0", + "num: 3 lower_bound: -inf upper_bound: 5 repeats: 0\n" + + "num: 1 lower_bound: 5 upper_bound: 5 repeats: 1", idxCols: 1, }, } @@ -773,7 +773,7 @@ func (s *testStatsSuite) TestQueryFeedback(c *C) { c.Assert(h.DumpStatsFeedbackToKV(), IsNil) c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) c.Assert(err, IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) tblInfo := table.Meta() tbl := h.GetTableStats(tblInfo) if t.idxCols == 0 { @@ -785,7 +785,7 @@ func (s *testStatsSuite) TestQueryFeedback(c *C) { // Feedback from limit executor may not be accurate. testKit.MustQuery("select * from t where t.a <= 5 limit 1") - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) feedback := h.GetQueryFeedback() c.Assert(len(feedback), Equals, 0) @@ -793,7 +793,7 @@ func (s *testStatsSuite) TestQueryFeedback(c *C) { statistics.MaxNumberOfRanges = 0 for _, t := range tests { testKit.MustQuery(t.sql) - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) feedback := h.GetQueryFeedback() c.Assert(len(feedback), Equals, 0) } @@ -803,7 +803,7 @@ func (s *testStatsSuite) TestQueryFeedback(c *C) { statistics.MaxNumberOfRanges = oriNumber for _, t := range tests { testKit.MustQuery(t.sql) - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) feedback := h.GetQueryFeedback() c.Assert(len(feedback), Equals, 0) } @@ -855,7 +855,7 @@ func (s *testStatsSuite) TestQueryFeedbackForPartition(c *C) { // test primary key feedback sql: "select * from t where t.a <= 5", hist: "column:1 ndv:2 totColSize:0\n" + - "num: 1 lower_bound: -9223372036854775808 upper_bound: 1 repeats: 0\n" + + "num: 1 lower_bound: -9223372036854775808 upper_bound: 2 repeats: 0\n" + "num: 1 lower_bound: 2 upper_bound: 5 repeats: 0", idxCols: 0, }, @@ -896,7 +896,7 @@ func (s *testStatsSuite) TestQueryFeedbackForPartition(c *C) { c.Assert(h.DumpStatsFeedbackToKV(), IsNil) c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) c.Assert(err, IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl := h.GetPartitionStats(tblInfo, pid) if t.idxCols == 0 { c.Assert(tbl.Columns[tblInfo.Columns[0].ID].ToString(0), Equals, tests[i].hist) @@ -987,8 +987,8 @@ func (s *testStatsSuite) TestUpdateStatsByLocalFeedback(c *C) { c.Assert(tbl.Columns[tblInfo.Columns[0].ID].ToString(0), Equals, "column:1 ndv:3 totColSize:0\n"+ "num: 1 lower_bound: 1 upper_bound: 1 repeats: 1\n"+ - "num: 1 lower_bound: 2 upper_bound: 2 repeats: 1\n"+ - "num: 2 lower_bound: 3 upper_bound: 9223372036854775807 repeats: 0") + "num: 2 lower_bound: 2 upper_bound: 4 repeats: 0\n"+ + "num: 1 lower_bound: 4 upper_bound: 9223372036854775807 repeats: 0") sc := &stmtctx.StatementContext{TimeZone: time.Local} low, err := codec.EncodeKey(sc, nil, types.NewIntDatum(5)) c.Assert(err, IsNil) @@ -996,8 +996,8 @@ func (s *testStatsSuite) TestUpdateStatsByLocalFeedback(c *C) { c.Assert(tbl.Indices[tblInfo.Indices[0].ID].CMSketch.QueryBytes(low), Equals, uint64(2)) c.Assert(tbl.Indices[tblInfo.Indices[0].ID].ToString(1), Equals, "index:1 ndv:2\n"+ - "num: 2 lower_bound: -inf upper_bound: 2 repeats: 0\n"+ - "num: 2 lower_bound: 3 upper_bound: 6 repeats: 0") + "num: 2 lower_bound: -inf upper_bound: 5 repeats: 0\n"+ + "num: 1 lower_bound: 5 upper_bound: 5 repeats: 1") // Test that it won't cause panic after update. testKit.MustQuery("select * from t use index(idx) where b > 0") @@ -1038,8 +1038,8 @@ func (s *testStatsSuite) TestUpdatePartitionStatsByLocalFeedback(c *C) { c.Assert(tbl.Columns[tblInfo.Columns[0].ID].ToString(0), Equals, "column:1 ndv:3 totColSize:0\n"+ "num: 1 lower_bound: 1 upper_bound: 1 repeats: 1\n"+ - "num: 1 lower_bound: 2 upper_bound: 2 repeats: 1\n"+ - "num: 2 lower_bound: 3 upper_bound: 9223372036854775807 repeats: 0") + "num: 2 lower_bound: 2 upper_bound: 4 repeats: 0\n"+ + "num: 1 lower_bound: 4 upper_bound: 9223372036854775807 repeats: 0") } type logHook struct { @@ -1086,18 +1086,18 @@ func (s *testStatsSuite) TestLogDetailedInfo(c *C) { oriMinLogCount := handle.MinLogScanCount oriMinError := handle.MinLogErrorRate oriLevel := log.GetLevel() - oriLease := s.do.StatsHandle().Lease + oriLease := s.do.StatsHandle().Lease() defer func() { statistics.FeedbackProbability = oriProbability handle.MinLogScanCount = oriMinLogCount handle.MinLogErrorRate = oriMinError - s.do.StatsHandle().Lease = oriLease + s.do.StatsHandle().SetLease(oriLease) log.SetLevel(oriLevel) }() statistics.FeedbackProbability.Store(1) handle.MinLogScanCount = 0 handle.MinLogErrorRate = 0 - s.do.StatsHandle().Lease = 1 + s.do.StatsHandle().SetLease(1) testKit := testkit.NewTestKit(c, s.store) testKit.MustExec("use test") @@ -1112,13 +1112,13 @@ func (s *testStatsSuite) TestLogDetailedInfo(c *C) { }{ { sql: "select * from t where t.a <= 15", - result: "[stats-feedback] test.t, column=a, rangeStr=range: [-inf,7), actual: 8, expected: 7, buckets: {num: 8 lower_bound: 0 upper_bound: 7 repeats: 1}" + + result: "[stats-feedback] test.t, column=a, rangeStr=range: [-inf,8), actual: 8, expected: 8, buckets: {num: 8 lower_bound: 0 upper_bound: 7 repeats: 1, num: 8 lower_bound: 8 upper_bound: 15 repeats: 1}" + "[stats-feedback] test.t, column=a, rangeStr=range: [8,15), actual: 8, expected: 7, buckets: {num: 8 lower_bound: 8 upper_bound: 15 repeats: 1}", }, { sql: "select * from t use index(idx) where t.b <= 15", - result: "[stats-feedback] test.t, index=idx, rangeStr=range: [-inf,7), actual: 8, expected: 7, histogram: {num: 8 lower_bound: 0 upper_bound: 7 repeats: 1}" + - "[stats-feedback] test.t, index=idx, rangeStr=range: [8,15), actual: 8, expected: 7, histogram: {num: 8 lower_bound: 8 upper_bound: 15 repeats: 1}", + result: "[stats-feedback] test.t, index=idx, rangeStr=range: [-inf,8), actual: 8, expected: 8, histogram: {num: 8 lower_bound: 0 upper_bound: 7 repeats: 1, num: 8 lower_bound: 8 upper_bound: 15 repeats: 1}" + + "[stats-feedback] test.t, index=idx, rangeStr=range: [8,16), actual: 8, expected: 8, histogram: {num: 8 lower_bound: 8 upper_bound: 15 repeats: 1, num: 4 lower_bound: 16 upper_bound: 19 repeats: 1}", }, { sql: "select b from t use index(idx_ba) where b = 1 and a <= 5", @@ -1267,20 +1267,22 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { statistics.FeedbackProbability.Store(1) testKit.MustExec("use test") - testKit.MustExec("create table t (a bigint(64), b bigint(64), c bigint(64), index idx_ab(a,b), index idx_ac(a,c), index idx_b(b))") + testKit.MustExec("create table t (a bigint(64), b bigint(64), c bigint(64), d float, e double, f decimal(17,2), " + + "g time, h date, index idx_b(b), index idx_ab(a,b), index idx_ac(a,c), index idx_ad(a, d), index idx_ae(a, e), index idx_af(a, f)," + + " index idx_ag(a, g), index idx_ah(a, h))") for i := 0; i < 20; i++ { - testKit.MustExec(fmt.Sprintf("insert into t values (1, %d, %d)", i, i)) + testKit.MustExec(fmt.Sprintf(`insert into t values (1, %d, %d, %d, %d, %d, %d, "%s")`, i, i, i, i, i, i, fmt.Sprintf("1000-01-%02d", i+1))) } h := s.do.StatsHandle() h.HandleDDLEvent(<-h.DDLEventCh()) c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) testKit.MustExec("analyze table t with 3 buckets") for i := 0; i < 20; i++ { - testKit.MustExec(fmt.Sprintf("insert into t values (1, %d, %d)", i, i)) + testKit.MustExec(fmt.Sprintf(`insert into t values (1, %d, %d, %d, %d, %d, %d, "%s")`, i, i, i, i, i, i, fmt.Sprintf("1000-01-%02d", i+1))) } c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) is := s.do.InfoSchema() - h.Update(is) + c.Assert(h.Update(is), IsNil) table, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) tblInfo := table.Meta() @@ -1294,23 +1296,78 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { }{ { sql: "select * from t use index(idx_ab) where a = 1 and b < 21", - hist: "index:3 ndv:20\n" + - "num: 16 lower_bound: -inf upper_bound: 7 repeats: 0\n" + - "num: 16 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + hist: "index:1 ndv:20\n" + + "num: 15 lower_bound: -inf upper_bound: 8 repeats: 0\n" + + "num: 15 lower_bound: 8 upper_bound: 16 repeats: 0\n" + "num: 8 lower_bound: 16 upper_bound: 21 repeats: 0", - rangeID: tblInfo.Indices[2].ID, - idxID: tblInfo.Indices[0].ID, + rangeID: tblInfo.Indices[0].ID, + idxID: tblInfo.Indices[1].ID, idxCols: 1, eqCount: 39, }, { sql: "select * from t use index(idx_ac) where a = 1 and c < 21", hist: "column:3 ndv:20 totColSize:20\n" + - "num: 13 lower_bound: -9223372036854775808 upper_bound: 6 repeats: 0\n" + - "num: 13 lower_bound: 7 upper_bound: 13 repeats: 0\n" + - "num: 12 lower_bound: 14 upper_bound: 21 repeats: 0", + "num: 15 lower_bound: -9223372036854775808 upper_bound: 7 repeats: 0\n" + + "num: 14 lower_bound: 7 upper_bound: 14 repeats: 0\n" + + "num: 13 lower_bound: 14 upper_bound: 21 repeats: 0", rangeID: tblInfo.Columns[2].ID, - idxID: tblInfo.Indices[1].ID, + idxID: tblInfo.Indices[2].ID, + idxCols: 0, + eqCount: 35, + }, + { + sql: "select * from t use index(idx_ad) where a = 1 and d < 21", + hist: "column:4 ndv:20 totColSize:160\n" + + "num: 15 lower_bound: -10000000000000 upper_bound: 7 repeats: 0\n" + + "num: 14 lower_bound: 7 upper_bound: 14 repeats: 0\n" + + "num: 13 lower_bound: 14 upper_bound: 21 repeats: 0", + rangeID: tblInfo.Columns[3].ID, + idxID: tblInfo.Indices[3].ID, + idxCols: 0, + eqCount: 35, + }, + { + sql: "select * from t use index(idx_ae) where a = 1 and e < 21", + hist: "column:5 ndv:20 totColSize:160\n" + + "num: 15 lower_bound: -100000000000000000000000 upper_bound: 7 repeats: 0\n" + + "num: 14 lower_bound: 7 upper_bound: 14 repeats: 0\n" + + "num: 13 lower_bound: 14 upper_bound: 21 repeats: 0", + rangeID: tblInfo.Columns[4].ID, + idxID: tblInfo.Indices[4].ID, + idxCols: 0, + eqCount: 35, + }, + { + sql: "select * from t use index(idx_af) where a = 1 and f < 21", + hist: "column:6 ndv:20 totColSize:200\n" + + "num: 15 lower_bound: -999999999999999.99 upper_bound: 7.00 repeats: 0\n" + + "num: 14 lower_bound: 7.00 upper_bound: 14.00 repeats: 0\n" + + "num: 13 lower_bound: 14.00 upper_bound: 21.00 repeats: 0", + rangeID: tblInfo.Columns[5].ID, + idxID: tblInfo.Indices[5].ID, + idxCols: 0, + eqCount: 35, + }, + { + sql: "select * from t use index(idx_ag) where a = 1 and g < 21", + hist: "column:7 ndv:20 totColSize:98\n" + + "num: 15 lower_bound: -838:59:59 upper_bound: 00:00:07 repeats: 0\n" + + "num: 14 lower_bound: 00:00:07 upper_bound: 00:00:14 repeats: 0\n" + + "num: 13 lower_bound: 00:00:14 upper_bound: 00:00:21 repeats: 0", + rangeID: tblInfo.Columns[6].ID, + idxID: tblInfo.Indices[6].ID, + idxCols: 0, + eqCount: 35, + }, + { + sql: `select * from t use index(idx_ah) where a = 1 and h < "1000-01-21"`, + hist: "column:8 ndv:20 totColSize:180\n" + + "num: 15 lower_bound: 1000-01-01 upper_bound: 1000-01-08 repeats: 0\n" + + "num: 14 lower_bound: 1000-01-08 upper_bound: 1000-01-15 repeats: 0\n" + + "num: 13 lower_bound: 1000-01-15 upper_bound: 1000-01-21 repeats: 0", + rangeID: tblInfo.Columns[7].ID, + idxID: tblInfo.Indices[7].ID, idxCols: 0, eqCount: 35, }, @@ -1320,7 +1377,7 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.DumpStatsFeedbackToKV(), IsNil) c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl := h.GetTableStats(tblInfo) if t.idxCols == 0 { c.Assert(tbl.Columns[t.rangeID].ToString(0), Equals, tests[i].hist) @@ -1333,6 +1390,47 @@ func (s *testStatsSuite) TestIndexQueryFeedback(c *C) { } } +func (s *testStatsSuite) TestIndexQueryFeedback4TopN(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + + oriProbability := statistics.FeedbackProbability + defer func() { + statistics.FeedbackProbability = oriProbability + }() + statistics.FeedbackProbability.Store(1) + + testKit.MustExec("use test") + testKit.MustExec("create table t (a bigint(64), index idx(a))") + for i := 0; i < 20; i++ { + testKit.MustExec(`insert into t values (1)`) + } + h := s.do.StatsHandle() + h.HandleDDLEvent(<-h.DDLEventCh()) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + testKit.MustExec("set @@tidb_enable_fast_analyze = 1") + testKit.MustExec("analyze table t with 3 buckets") + for i := 0; i < 20; i++ { + testKit.MustExec(`insert into t values (1)`) + } + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + is := s.do.InfoSchema() + c.Assert(h.Update(is), IsNil) + table, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + c.Assert(err, IsNil) + tblInfo := table.Meta() + + testKit.MustQuery("select * from t use index(idx) where a = 1") + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) + c.Assert(h.DumpStatsFeedbackToKV(), IsNil) + c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) + c.Assert(h.Update(is), IsNil) + tbl := h.GetTableStats(tblInfo) + val, err := codec.EncodeKey(testKit.Se.GetSessionVars().StmtCtx, nil, types.NewIntDatum(1)) + c.Assert(err, IsNil) + c.Assert(tbl.Indices[1].CMSketch.QueryBytes(val), Equals, uint64(40)) +} + func (s *testStatsSuite) TestAbnormalIndexFeedback(c *C) { defer cleanEnv(c, s.store, s.do) testKit := testkit.NewTestKit(c, s.store) @@ -1367,9 +1465,9 @@ func (s *testStatsSuite) TestAbnormalIndexFeedback(c *C) { // The real count of `a = 1` is 0. sql: "select * from t where a = 1 and b < 21", hist: "column:2 ndv:20 totColSize:20\n" + - "num: 4 lower_bound: -9223372036854775808 upper_bound: 6 repeats: 0\n" + - "num: 3 lower_bound: 7 upper_bound: 13 repeats: 0\n" + - "num: 6 lower_bound: 14 upper_bound: 19 repeats: 1", + "num: 5 lower_bound: -9223372036854775808 upper_bound: 7 repeats: 0\n" + + "num: 4 lower_bound: 7 upper_bound: 14 repeats: 0\n" + + "num: 4 lower_bound: 14 upper_bound: 21 repeats: 0", rangeID: tblInfo.Columns[1].ID, idxID: tblInfo.Indices[0].ID, eqCount: 3, @@ -1378,9 +1476,9 @@ func (s *testStatsSuite) TestAbnormalIndexFeedback(c *C) { // The real count of `b > 10` is 0. sql: "select * from t where a = 2 and b > 10", hist: "column:2 ndv:20 totColSize:20\n" + - "num: 4 lower_bound: -9223372036854775808 upper_bound: 6 repeats: 0\n" + - "num: 2 lower_bound: 7 upper_bound: 13 repeats: 0\n" + - "num: 6 lower_bound: 14 upper_bound: 19 repeats: 1", + "num: 5 lower_bound: -9223372036854775808 upper_bound: 7 repeats: 0\n" + + "num: 6 lower_bound: 7 upper_bound: 14 repeats: 0\n" + + "num: 7 lower_bound: 14 upper_bound: 9223372036854775807 repeats: 0", rangeID: tblInfo.Columns[1].ID, idxID: tblInfo.Indices[0].ID, eqCount: 3, @@ -1391,7 +1489,7 @@ func (s *testStatsSuite) TestAbnormalIndexFeedback(c *C) { c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.DumpStatsFeedbackToKV(), IsNil) c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) tbl := h.GetTableStats(tblInfo) c.Assert(tbl.Columns[t.rangeID].ToString(0), Equals, tests[i].hist) val, err := codec.EncodeKey(testKit.Se.GetSessionVars().StmtCtx, nil, types.NewIntDatum(1)) @@ -1430,27 +1528,27 @@ func (s *testStatsSuite) TestFeedbackRanges(c *C) { colID int64 }{ { - sql: "select * from t where a <= 50 or (a > 130 and a < 140)", + sql: "select * from t use index() where a <= 50 or (a > 130 and a < 140)", hist: "column:1 ndv:30 totColSize:0\n" + - "num: 8 lower_bound: -128 upper_bound: 7 repeats: 0\n" + - "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 8 lower_bound: -128 upper_bound: 8 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 16 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 50 repeats: 0", colID: 1, }, { - sql: "select * from t where a >= 10", + sql: "select * from t use index() where a >= 10", hist: "column:1 ndv:30 totColSize:0\n" + - "num: 8 lower_bound: -128 upper_bound: 7 repeats: 0\n" + - "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 8 lower_bound: -128 upper_bound: 8 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 16 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 127 repeats: 0", colID: 1, }, { sql: "select * from t use index(idx) where a = 1 and (b <= 50 or (b > 130 and b < 140))", hist: "column:2 ndv:20 totColSize:20\n" + - "num: 7 lower_bound: -128 upper_bound: 6 repeats: 0\n" + - "num: 7 lower_bound: 7 upper_bound: 13 repeats: 1\n" + - "num: 6 lower_bound: 14 upper_bound: 19 repeats: 1", + "num: 8 lower_bound: -128 upper_bound: 7 repeats: 0\n" + + "num: 8 lower_bound: 7 upper_bound: 14 repeats: 0\n" + + "num: 7 lower_bound: 14 upper_bound: 51 repeats: 0", colID: 2, }, } @@ -1462,7 +1560,7 @@ func (s *testStatsSuite) TestFeedbackRanges(c *C) { c.Assert(h.DumpStatsFeedbackToKV(), IsNil) c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) c.Assert(err, IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) tblInfo := table.Meta() tbl := h.GetTableStats(tblInfo) c.Assert(tbl.Columns[t.colID].ToString(0), Equals, tests[i].hist) @@ -1505,32 +1603,32 @@ func (s *testStatsSuite) TestUnsignedFeedbackRanges(c *C) { { sql: "select * from t where a <= 50", hist: "column:1 ndv:30 totColSize:0\n" + - "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + - "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 8 lower_bound: 0 upper_bound: 8 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 16 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 50 repeats: 0", tblName: "t", }, { sql: "select count(*) from t", hist: "column:1 ndv:30 totColSize:0\n" + - "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + - "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 8 lower_bound: 0 upper_bound: 8 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 16 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 255 repeats: 0", tblName: "t", }, { sql: "select * from t1 where a <= 50", hist: "column:1 ndv:30 totColSize:0\n" + - "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + - "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 8 lower_bound: 0 upper_bound: 8 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 16 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 50 repeats: 0", tblName: "t1", }, { sql: "select count(*) from t1", hist: "column:1 ndv:30 totColSize:0\n" + - "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + - "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 8 lower_bound: 0 upper_bound: 8 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 16 repeats: 0\n" + "num: 14 lower_bound: 16 upper_bound: 18446744073709551615 repeats: 0", tblName: "t1", }, @@ -1544,7 +1642,7 @@ func (s *testStatsSuite) TestUnsignedFeedbackRanges(c *C) { c.Assert(h.DumpStatsFeedbackToKV(), IsNil) c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) c.Assert(err, IsNil) - h.Update(is) + c.Assert(h.Update(is), IsNil) tblInfo := table.Meta() tbl := h.GetTableStats(tblInfo) c.Assert(tbl.Columns[1].ToString(0), Equals, tests[i].hist) @@ -1555,9 +1653,9 @@ func (s *testStatsSuite) TestLoadHistCorrelation(c *C) { defer cleanEnv(c, s.store, s.do) testKit := testkit.NewTestKit(c, s.store) h := s.do.StatsHandle() - origLease := h.Lease - h.Lease = time.Second - defer func() { h.Lease = origLease }() + origLease := h.Lease() + h.SetLease(time.Second) + defer func() { h.SetLease(origLease) }() testKit.MustExec("use test") testKit.MustExec("create table t(c int)") testKit.MustExec("insert into t values(1),(2),(3),(4),(5)") diff --git a/statistics/histogram.go b/statistics/histogram.go index 73adeee757bac..f529c803d20e0 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -216,9 +216,10 @@ func ValueToString(value *types.Datum, idxCols int) (string, error) { if idxCols == 0 { return value.ToString() } - decodedVals, err := codec.DecodeRange(value.GetBytes(), idxCols) - if err != nil { - return "", errors.Trace(err) + // Treat remaining part that cannot decode successfully as bytes. + decodedVals, remained, err := codec.DecodeRange(value.GetBytes(), idxCols) + if err != nil && len(remained) > 0 { + decodedVals = append(decodedVals, types.NewBytesDatum(remained)) } str, err := types.DatumsToString(decodedVals, true) if err != nil { @@ -427,41 +428,43 @@ func (hg *Histogram) typeMatch(ranges []*ranger.Range) bool { return true } -// SplitRange splits the range according to the histogram upper bound. Note that we treat last bucket's upper bound -// as inf, so all the split Ranges will totally fall in one of the (-inf, u(0)], (u(0), u(1)],...(u(n-3), u(n-2)], -// (u(n-2), +inf), where n is the number of buckets, u(i) is the i-th bucket's upper bound. +// SplitRange splits the range according to the histogram lower bound. Note that we treat first bucket's lower bound +// as -inf and last bucket's upper bound as +inf, so all the split ranges will totally fall in one of the (-inf, l(1)), +// [l(1), l(2)),...[l(n-2), l(n-1)), [l(n-1), +inf), where n is the number of buckets, l(i) is the i-th bucket's lower bound. func (hg *Histogram) SplitRange(sc *stmtctx.StatementContext, oldRanges []*ranger.Range, encoded bool) ([]*ranger.Range, bool) { if !hg.typeMatch(oldRanges) { return oldRanges, false } + // Treat the only buckets as (-inf, +inf), so we do not need split it. + if hg.Len() == 1 { + return oldRanges, true + } ranges := make([]*ranger.Range, 0, len(oldRanges)) for _, ran := range oldRanges { ranges = append(ranges, ran.Clone()) } split := make([]*ranger.Range, 0, len(ranges)) for len(ranges) > 0 { - // Find the last bound that greater or equal to the LowVal. + // Find the first bound that greater than the LowVal. idx := hg.Bounds.UpperBound(0, &ranges[0].LowVal[0]) - if !ranges[0].LowExclude && idx > 0 { - cmp := chunk.Compare(hg.Bounds.GetRow(idx-1), 0, &ranges[0].LowVal[0]) - if cmp == 0 { - idx-- - } - } - // Treat last bucket's upper bound as inf, so we do not need split any more. - if idx >= hg.Bounds.NumRows()-2 { + // Treat last bucket's upper bound as +inf, so we do not need split any more. + if idx >= hg.Bounds.NumRows()-1 { split = append(split, ranges...) break } - // Get the corresponding upper bound. - if idx%2 == 0 { + // Treat first buckets's lower bound as -inf, just increase it to the next lower bound. + if idx == 0 { + idx = 2 + } + // Get the next lower bound. + if idx%2 == 1 { idx++ } - upperBound := hg.Bounds.GetRow(idx) + lowerBound := hg.Bounds.GetRow(idx) var i int - // Find the first range that need to be split by the upper bound. + // Find the first range that need to be split by the lower bound. for ; i < len(ranges); i++ { - if chunk.Compare(upperBound, 0, &ranges[i].HighVal[0]) < 0 { + if chunk.Compare(lowerBound, 0, &ranges[i].HighVal[0]) <= 0 { break } } @@ -470,17 +473,20 @@ func (hg *Histogram) SplitRange(sc *stmtctx.StatementContext, oldRanges []*range if len(ranges) == 0 { break } - // Split according to the upper bound. - cmp := chunk.Compare(upperBound, 0, &ranges[0].LowVal[0]) - if cmp > 0 || (cmp == 0 && !ranges[0].LowExclude) { - upper := upperBound.GetDatum(0, hg.Tp) - split = append(split, &ranger.Range{ + // Split according to the lower bound. + cmp := chunk.Compare(lowerBound, 0, &ranges[0].LowVal[0]) + if cmp > 0 { + lower := lowerBound.GetDatum(0, hg.Tp) + newRange := &ranger.Range{ LowExclude: ranges[0].LowExclude, LowVal: []types.Datum{ranges[0].LowVal[0]}, - HighVal: []types.Datum{upper}, - HighExclude: false}) - ranges[0].LowVal[0] = upper - ranges[0].LowExclude = true + HighVal: []types.Datum{lower}, + HighExclude: true} + if validRange(sc, newRange, encoded) { + split = append(split, newRange) + } + ranges[0].LowVal[0] = lower + ranges[0].LowExclude = false if !validRange(sc, ranges[0], encoded) { ranges = ranges[1:] } diff --git a/statistics/histogram_test.go b/statistics/histogram_test.go index e131b143e7306..9bbe9e35a0137 100644 --- a/statistics/histogram_test.go +++ b/statistics/histogram_test.go @@ -50,11 +50,9 @@ func (s *testStatisticsSuite) TestNewHistogramBySelectivity(c *C) { node.Ranges = append(node.Ranges, &ranger.Range{LowVal: types.MakeDatums(25), HighVal: []types.Datum{types.MaxValueDatum()}}) intColResult := `column:1 ndv:16 totColSize:0 num: 30 lower_bound: 0 upper_bound: 2 repeats: 10 -num: 10 lower_bound: 3 upper_bound: 5 repeats: 10 -num: 20 lower_bound: 6 upper_bound: 8 repeats: 10 -num: 20 lower_bound: 9 upper_bound: 11 repeats: 0 +num: 20 lower_bound: 6 upper_bound: 8 repeats: 0 +num: 30 lower_bound: 9 upper_bound: 11 repeats: 0 num: 10 lower_bound: 12 upper_bound: 14 repeats: 0 -num: 20 lower_bound: 24 upper_bound: 26 repeats: 10 num: 30 lower_bound: 27 upper_bound: 29 repeats: 0` stringCol := &Column{} @@ -85,9 +83,9 @@ num: 30 lower_bound: 27 upper_bound: 29 repeats: 0` node2.Ranges = append(node2.Ranges, &ranger.Range{LowVal: types.MakeDatums("ggg"), HighVal: []types.Datum{types.MaxValueDatum()}}) stringColResult := `column:2 ndv:9 totColSize:0 num: 60 lower_bound: a upper_bound: aaaabbbb repeats: 0 -num: 60 lower_bound: bbbb upper_bound: fdsfdsfds repeats: 20 -num: 60 lower_bound: kkkkk upper_bound: ooooo repeats: 20 -num: 60 lower_bound: oooooo upper_bound: sssss repeats: 20 +num: 52 lower_bound: bbbb upper_bound: fdsfdsfds repeats: 0 +num: 54 lower_bound: kkkkk upper_bound: ooooo repeats: 0 +num: 60 lower_bound: oooooo upper_bound: sssss repeats: 0 num: 60 lower_bound: ssssssu upper_bound: yyyyy repeats: 0` newColl := coll.NewHistCollBySelectivity(sc, []*StatsNode{node, node2}) @@ -120,3 +118,14 @@ num: 30 lower_bound: 12 upper_bound: 14 repeats: 10` c.Assert(err, IsNil, Commentf("Test failed: %v", err)) c.Assert(newIdx.String(), Equals, idxResult) } + +func (s *testStatisticsSuite) TestValueToString4InvalidKey(c *C) { + bytes, err := codec.EncodeKey(nil, nil, types.NewDatum(1), types.NewDatum(0.5)) + c.Assert(err, IsNil) + // Append invalid flag. + bytes = append(bytes, 20) + datum := types.NewDatum(bytes) + res, err := ValueToString(&datum, 3) + c.Assert(err, IsNil) + c.Assert(res, Equals, "(1, 0.5, \x14)") +} diff --git a/statistics/sample.go b/statistics/sample.go index 41d139ebdef2b..8dbf0c4676125 100644 --- a/statistics/sample.go +++ b/statistics/sample.go @@ -91,7 +91,7 @@ func (c *SampleCollector) MergeSampleCollector(sc *stmtctx.StatementContext, rc c.TotalSize += rc.TotalSize c.FMSketch.mergeFMSketch(rc.FMSketch) if rc.CMSketch != nil { - err := c.CMSketch.MergeCMSketch(rc.CMSketch) + err := c.CMSketch.MergeCMSketch(rc.CMSketch, 0) terror.Log(errors.Trace(err)) } for _, item := range rc.Samples { @@ -217,8 +217,8 @@ func (s SampleBuilder) CollectColumnStats() ([]*SampleCollector, *SortedBuilder, } } ctx := context.TODO() - req := s.RecordSet.NewRecordBatch() - it := chunk.NewIterator4Chunk(req.Chunk) + req := s.RecordSet.NewChunk() + it := chunk.NewIterator4Chunk(req) for { err := s.RecordSet.Next(ctx, req) if err != nil { diff --git a/statistics/sample_test.go b/statistics/sample_test.go index dfc7b59df597b..cf07ec799a957 100644 --- a/statistics/sample_test.go +++ b/statistics/sample_test.go @@ -66,7 +66,7 @@ func (s *testSampleSuite) TestCollectColumnStats(c *C) { CMSketchWidth: 2048, CMSketchDepth: 8, } - s.rs.Close() + c.Assert(s.rs.Close(), IsNil) collectors, pkBuilder, err := builder.CollectColumnStats() c.Assert(err, IsNil) c.Assert(collectors[0].NullCount+collectors[0].Count, Equals, int64(s.count)) @@ -87,7 +87,7 @@ func (s *testSampleSuite) TestMergeSampleCollector(c *C) { CMSketchWidth: 2048, CMSketchDepth: 8, } - s.rs.Close() + c.Assert(s.rs.Close(), IsNil) sc := &stmtctx.StatementContext{TimeZone: time.Local} collectors, pkBuilder, err := builder.CollectColumnStats() c.Assert(err, IsNil) @@ -113,7 +113,7 @@ func (s *testSampleSuite) TestCollectorProtoConversion(c *C) { CMSketchWidth: 2048, CMSketchDepth: 8, } - s.rs.Close() + c.Assert(s.rs.Close(), IsNil) collectors, pkBuilder, err := builder.CollectColumnStats() c.Assert(err, IsNil) c.Assert(pkBuilder, IsNil) diff --git a/statistics/selectivity_test.go b/statistics/selectivity_test.go index efe34266f3c5d..21694c09d0809 100644 --- a/statistics/selectivity_test.go +++ b/statistics/selectivity_test.go @@ -14,6 +14,7 @@ package statistics_test import ( + "context" "fmt" "math" "os" @@ -60,13 +61,13 @@ func (s *testStatsSuite) SetUpSuite(c *C) { // Add the hook here to avoid data race. s.registerHook() var err error - s.store, s.do, err = newStoreWithBootstrap(0) + s.store, s.do, err = newStoreWithBootstrap() c.Assert(err, IsNil) } func (s *testStatsSuite) TearDownSuite(c *C) { s.do.Close() - s.store.Close() + c.Assert(s.store.Close(), IsNil) testleak.AfterTest(c)() } @@ -115,13 +116,13 @@ func (h *logHook) Check(e zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.Chec return ce } -func newStoreWithBootstrap(statsLease time.Duration) (kv.Storage, *domain.Domain, error) { +func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) { store, err := mockstore.NewMockTikvStore() if err != nil { return nil, nil, errors.Trace(err) } session.SetSchemaLease(0) - session.SetStatsLease(statsLease) + session.DisableStats4Test() domain.RunAutoAnalyze = false do, err := session.BootstrapSession(store) do.SetStatsUpdating(true) @@ -209,7 +210,8 @@ func (s *testStatsSuite) prepareSelectivity(testKit *testkit.TestKit, c *C) *sta statsTbl := mockStatsTable(tbl, 540) // Set the value of columns' histogram. - colValues, _ := s.generateIntDatum(1, 54) + colValues, err := s.generateIntDatum(1, 54) + c.Assert(err, IsNil) for i := 1; i <= 5; i++ { statsTbl.Columns[int64(i)] = &statistics.Column{Histogram: *mockStatsHistogram(int64(i), colValues, 10, types.NewFieldType(mysql.TypeLonglong)), Info: tbl.Columns[i-1]} } @@ -266,17 +268,19 @@ func (s *testStatsSuite) TestSelectivity(c *C) { selectivity: 0, }, } + + ctx := context.Background() for _, tt := range tests { sql := "select * from t where " + tt.exprs comment := Commentf("for %s", tt.exprs) - ctx := testKit.Se.(sessionctx.Context) - stmts, err := session.Parse(ctx, sql) + sctx := testKit.Se.(sessionctx.Context) + stmts, err := session.Parse(sctx, sql) c.Assert(err, IsNil, Commentf("error %v, for expr %s", err, tt.exprs)) c.Assert(stmts, HasLen, 1) - err = plannercore.Preprocess(ctx, stmts[0], is) + err = plannercore.Preprocess(sctx, stmts[0], is) c.Assert(err, IsNil, comment) - p, err := plannercore.BuildLogicalPlan(ctx, stmts[0], is) + p, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for building plan, expr %s", err, tt.exprs)) sel := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) @@ -284,12 +288,12 @@ func (s *testStatsSuite) TestSelectivity(c *C) { histColl := statsTbl.GenerateHistCollFromColumnInfo(ds.Columns, ds.Schema().Columns) - ratio, _, err := histColl.Selectivity(ctx, sel.Conditions) + ratio, _, err := histColl.Selectivity(sctx, sel.Conditions) c.Assert(err, IsNil, comment) c.Assert(math.Abs(ratio-tt.selectivity) < eps, IsTrue, Commentf("for %s, needed: %v, got: %v", tt.exprs, tt.selectivity, ratio)) histColl.Count *= 10 - ratio, _, err = histColl.Selectivity(ctx, sel.Conditions) + ratio, _, err = histColl.Selectivity(sctx, sel.Conditions) c.Assert(err, IsNil, comment) c.Assert(math.Abs(ratio-tt.selectivity) < eps, IsTrue, Commentf("for %s, needed: %v, got: %v", tt.exprs, tt.selectivity, ratio)) } @@ -347,12 +351,12 @@ func (s *testStatsSuite) TestEstimationForUnknownValues(c *C) { testKit.MustExec(fmt.Sprintf("insert into t values (%d, %d)", i, i)) } h := s.do.StatsHandle() - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) testKit.MustExec("analyze table t") for i := 0; i < 10; i++ { testKit.MustExec(fmt.Sprintf("insert into t values (%d, %d)", i+10, i+10)) } - h.DumpStatsDeltaToKV(handle.DumpAll) + c.Assert(h.DumpStatsDeltaToKV(handle.DumpAll), IsNil) c.Assert(h.Update(s.do.InfoSchema()), IsNil) table, err := s.do.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) c.Assert(err, IsNil) @@ -442,23 +446,24 @@ func BenchmarkSelectivity(b *testing.B) { exprs := "a > 1 and b < 2 and c > 3 and d < 4 and e > 5" sql := "select * from t where " + exprs comment := Commentf("for %s", exprs) - ctx := testKit.Se.(sessionctx.Context) - stmts, err := session.Parse(ctx, sql) + sctx := testKit.Se.(sessionctx.Context) + stmts, err := session.Parse(sctx, sql) c.Assert(err, IsNil, Commentf("error %v, for expr %s", err, exprs)) c.Assert(stmts, HasLen, 1) - err = plannercore.Preprocess(ctx, stmts[0], is) + err = plannercore.Preprocess(sctx, stmts[0], is) c.Assert(err, IsNil, comment) - p, err := plannercore.BuildLogicalPlan(ctx, stmts[0], is) + p, err := plannercore.BuildLogicalPlan(context.Background(), sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for building plan, expr %s", err, exprs)) - file, _ := os.Create("cpu.profile") + file, err := os.Create("cpu.profile") + c.Assert(err, IsNil) defer file.Close() pprof.StartCPUProfile(file) b.Run("Selectivity", func(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, _, err := statsTbl.Selectivity(ctx, p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection).Conditions) + _, _, err := statsTbl.Selectivity(sctx, p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection).Conditions) c.Assert(err, IsNil) } b.ReportAllocs() diff --git a/statistics/statistics_test.go b/statistics/statistics_test.go index 4ee497514d156..009ae8064da92 100644 --- a/statistics/statistics_test.go +++ b/statistics/statistics_test.go @@ -83,7 +83,7 @@ func (r *recordSet) getNext() []types.Datum { return row } -func (r *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (r *recordSet) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() row := r.getNext() if row != nil { @@ -94,12 +94,12 @@ func (r *recordSet) Next(ctx context.Context, req *chunk.RecordBatch) error { return nil } -func (r *recordSet) NewRecordBatch() *chunk.RecordBatch { +func (r *recordSet) NewChunk() *chunk.Chunk { fields := make([]*types.FieldType, 0, len(r.fields)) for _, field := range r.fields { fields = append(fields, &field.Column.FieldType) } - return chunk.NewRecordBatch(chunk.NewChunkWithCapacity(fields, 32)) + return chunk.NewChunkWithCapacity(fields, 32) } func (r *recordSet) Close() error { @@ -177,7 +177,7 @@ func buildPK(sctx sessionctx.Context, numBuckets, id int64, records sqlexec.Reco b := NewSortedBuilder(sctx.GetSessionVars().StmtCtx, numBuckets, id, types.NewFieldType(mysql.TypeLonglong)) ctx := context.Background() for { - req := records.NewRecordBatch() + req := records.NewChunk() err := records.Next(ctx, req) if err != nil { return 0, nil, errors.Trace(err) @@ -185,7 +185,7 @@ func buildPK(sctx sessionctx.Context, numBuckets, id int64, records sqlexec.Reco if req.NumRows() == 0 { break } - it := chunk.NewIterator4Chunk(req.Chunk) + it := chunk.NewIterator4Chunk(req) for row := it.Begin(); row != it.End(); row = it.Next() { datums := RowToDatums(row, records.Fields()) err = b.Iterate(datums[0]) @@ -201,8 +201,8 @@ func buildIndex(sctx sessionctx.Context, numBuckets, id int64, records sqlexec.R b := NewSortedBuilder(sctx.GetSessionVars().StmtCtx, numBuckets, id, types.NewFieldType(mysql.TypeBlob)) cms := NewCMSketch(8, 2048) ctx := context.Background() - req := records.NewRecordBatch() - it := chunk.NewIterator4Chunk(req.Chunk) + req := records.NewChunk() + it := chunk.NewIterator4Chunk(req) for { err := records.Next(ctx, req) if err != nil { @@ -238,7 +238,8 @@ func (s *testStatisticsSuite) TestBuild(c *C) { bucketCount := int64(256) ctx := mock.NewContext() sc := ctx.GetSessionVars().StmtCtx - sketch, _, _ := buildFMSketch(sc, s.rc.(*recordSet).data, 1000) + sketch, _, err := buildFMSketch(sc, s.rc.(*recordSet).data, 1000) + c.Assert(err, IsNil) collector := &SampleCollector{ Count: int64(s.count), @@ -277,7 +278,7 @@ func (s *testStatisticsSuite) TestBuild(c *C) { MaxSampleSize: 1000, MaxFMSketchSize: 1000, } - s.pk.Close() + c.Assert(s.pk.Close(), IsNil) collectors, _, err := builder.CollectColumnStats() c.Assert(err, IsNil) c.Assert(len(collectors), Equals, 1) @@ -336,7 +337,7 @@ func (s *testStatisticsSuite) TestBuild(c *C) { func (s *testStatisticsSuite) TestHistogramProtoConversion(c *C) { ctx := mock.NewContext() - s.rc.Close() + c.Assert(s.rc.Close(), IsNil) tblCount, col, _, err := buildIndex(ctx, 256, 1, sqlexec.RecordSet(s.rc)) c.Check(err, IsNil) c.Check(int(tblCount), Equals, 100000) @@ -441,7 +442,8 @@ func (s *testStatisticsSuite) TestColumnRange(c *C) { bucketCount := int64(256) ctx := mock.NewContext() sc := ctx.GetSessionVars().StmtCtx - sketch, _, _ := buildFMSketch(sc, s.rc.(*recordSet).data, 1000) + sketch, _, err := buildFMSketch(sc, s.rc.(*recordSet).data, 1000) + c.Assert(err, IsNil) collector := &SampleCollector{ Count: int64(s.count), diff --git a/statistics/table.go b/statistics/table.go index 9fc6964020fe3..5805b177e347b 100644 --- a/statistics/table.go +++ b/statistics/table.go @@ -16,6 +16,7 @@ package statistics import ( "fmt" "math" + "sort" "strings" "sync" @@ -98,12 +99,22 @@ func (t *Table) Copy() *Table { func (t *Table) String() string { strs := make([]string, 0, len(t.Columns)+1) strs = append(strs, fmt.Sprintf("Table:%d Count:%d", t.PhysicalID, t.Count)) + cols := make([]*Column, 0, len(t.Columns)) for _, col := range t.Columns { - strs = append(strs, col.String()) + cols = append(cols, col) } - for _, col := range t.Indices { + sort.Slice(cols, func(i, j int) bool { return cols[i].ID < cols[j].ID }) + for _, col := range cols { strs = append(strs, col.String()) } + idxs := make([]*Index, 0, len(t.Indices)) + for _, idx := range t.Indices { + idxs = append(idxs, idx) + } + sort.Slice(idxs, func(i, j int) bool { return idxs[i].ID < idxs[j].ID }) + for _, idx := range idxs { + strs = append(strs, idx.String()) + } return strings.Join(strs, "\n") } @@ -443,12 +454,15 @@ func PseudoTable(tblInfo *model.TableInfo) *Table { PhysicalID: fakePhysicalID, Info: col, IsHandle: tblInfo.PKIsHandle && mysql.HasPriKeyFlag(col.Flag), + Histogram: *NewHistogram(col.ID, 0, 0, 0, &col.FieldType, 0, 0), } } } for _, idx := range tblInfo.Indices { if idx.State == model.StatePublic { - t.Indices[idx.ID] = &Index{Info: idx} + t.Indices[idx.ID] = &Index{ + Info: idx, + Histogram: *NewHistogram(idx.ID, 0, 0, 0, types.NewFieldType(mysql.TypeBlob), 0, 0)} } } return t diff --git a/store/helper/helper.go b/store/helper/helper.go index 6e0e250755778..6c2aac6a0ffa9 100644 --- a/store/helper/helper.go +++ b/store/helper/helper.go @@ -17,8 +17,10 @@ import ( "bytes" "context" "encoding/json" + "io" "math" "net/http" + "strconv" "time" "github.com/pingcap/errors" @@ -99,7 +101,7 @@ type RegionMetric struct { } // ScrapeHotInfo gets the needed hot region information by the url given. -func (h *Helper) ScrapeHotInfo(rw string, allSchemas []*model.DBInfo) (map[TblIndex]RegionMetric, error) { +func (h *Helper) ScrapeHotInfo(rw string, allSchemas []*model.DBInfo) ([]HotTableIndex, error) { regionMetrics, err := h.FetchHotRegion(rw) if err != nil { return nil, err @@ -121,9 +123,7 @@ func (h *Helper) FetchHotRegion(rw string) (map[uint64]RegionMetric, error) { if err != nil { return nil, errors.Trace(err) } - timeout, cancelFunc := context.WithTimeout(context.Background(), 50*time.Millisecond) - resp, err := http.DefaultClient.Do(req.WithContext(timeout)) - cancelFunc() + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, errors.Trace(err) } @@ -175,10 +175,22 @@ type RegionFrameRange struct { region *tikv.KeyLocation // the region } -// FetchRegionTableIndex constructs a map that maps a table to its hot region information by the given raw hot region metrics. -func (h *Helper) FetchRegionTableIndex(metrics map[uint64]RegionMetric, allSchemas []*model.DBInfo) (map[TblIndex]RegionMetric, error) { - idxMetrics := make(map[TblIndex]RegionMetric) +// HotTableIndex contains region and its table/index info. +type HotTableIndex struct { + RegionID uint64 `json:"region_id"` + RegionMetric *RegionMetric `json:"region_metric"` + DbName string `json:"db_name"` + TableName string `json:"table_name"` + TableID int64 `json:"table_id"` + IndexName string `json:"index_name"` + IndexID int64 `json:"index_id"` +} + +// FetchRegionTableIndex constructs a map that maps a table to its hot region information by the given raw hot RegionMetric metrics. +func (h *Helper) FetchRegionTableIndex(metrics map[uint64]RegionMetric, allSchemas []*model.DBInfo) ([]HotTableIndex, error) { + hotTables := make([]HotTableIndex, 0, len(metrics)) for regionID, regionMetric := range metrics { + t := HotTableIndex{RegionID: regionID, RegionMetric: ®ionMetric} region, err := h.RegionCache.LocateRegionByID(tikv.NewBackoffer(context.Background(), 500), regionID) if err != nil { logutil.Logger(context.Background()).Error("locate region failed", zap.Error(err)) @@ -189,32 +201,18 @@ func (h *Helper) FetchRegionTableIndex(metrics map[uint64]RegionMetric, allSchem if err != nil { return nil, err } - f := h.FindTableIndexOfRegion(allSchemas, hotRange) if f != nil { - idx := TblIndex{ - DbName: f.DBName, - TableName: f.TableName, - TableID: f.TableID, - IndexName: f.IndexName, - IndexID: f.IndexID, - } - metric, exists := idxMetrics[idx] - if !exists { - metric = regionMetric - metric.Count++ - idxMetrics[idx] = metric - } else { - metric.FlowBytes += regionMetric.FlowBytes - if metric.MaxHotDegree < regionMetric.MaxHotDegree { - metric.MaxHotDegree = regionMetric.MaxHotDegree - } - metric.Count++ - } + t.DbName = f.DBName + t.TableName = f.TableName + t.TableID = f.TableID + t.IndexName = f.IndexName + t.IndexID = f.IndexID } + hotTables = append(hotTables, t) } - return idxMetrics, nil + return hotTables, nil } // FindTableIndexOfRegion finds what table is involved in this hot region. And constructs the new frame item for future use. @@ -410,36 +408,52 @@ type RegionsInfo struct { // GetRegionsInfo gets the region information of current store by using PD's api. func (h *Helper) GetRegionsInfo() (*RegionsInfo, error) { + var regionsInfo RegionsInfo + err := h.requestPD("GET", pdapi.Regions, nil, ®ionsInfo) + return ®ionsInfo, err +} + +// GetRegionInfoByID gets the region information of the region ID by using PD's api. +func (h *Helper) GetRegionInfoByID(regionID uint64) (*RegionInfo, error) { + var regionInfo RegionInfo + err := h.requestPD("GET", pdapi.RegionByID+strconv.FormatUint(regionID, 10), nil, ®ionInfo) + return ®ionInfo, err +} + +// request PD API, decode the response body into res +func (h *Helper) requestPD(method, uri string, body io.Reader, res interface{}) error { etcd, ok := h.Store.(tikv.EtcdBackend) if !ok { - return nil, errors.WithStack(errors.New("not implemented")) + return errors.WithStack(errors.New("not implemented")) } pdHosts := etcd.EtcdAddrs() if len(pdHosts) == 0 { - return nil, errors.New("pd unavailable") + return errors.New("pd unavailable") } - req, err := http.NewRequest("GET", protocol+pdHosts[0]+pdapi.Regions, nil) + + logutil.Logger(context.Background()).Debug("RequestPD URL", zap.String("url", protocol+pdHosts[0]+uri)) + req, err := http.NewRequest(method, protocol+pdHosts[0]+uri, body) if err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } - timeout, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second) - resp, err := http.DefaultClient.Do(req.WithContext(timeout)) - defer cancelFunc() + resp, err := http.DefaultClient.Do(req) if err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } + defer func() { err = resp.Body.Close() if err != nil { logutil.Logger(context.Background()).Error("close body failed", zap.Error(err)) } }() - var regionsInfo RegionsInfo - err = json.NewDecoder(resp.Body).Decode(®ionsInfo) + + err = json.NewDecoder(resp.Body).Decode(res) if err != nil { - return nil, errors.Trace(err) + return errors.Trace(err) } - return ®ionsInfo, nil + + return nil } // StoresStat stores all information get from PD's api. @@ -479,8 +493,8 @@ type StoreDetailStat struct { LeaderScore int64 `json:"leader_score"` LeaderSize int64 `json:"leader_size"` RegionCount int64 `json:"region_count"` - RegionWeight int64 `json:"region_weight"` - RegionScore int64 `json:"region_score"` + RegionWeight float64 `json:"region_weight"` + RegionScore float64 `json:"region_score"` RegionSize int64 `json:"region_size"` StartTs time.Time `json:"start_ts"` LastHeartbeatTs time.Time `json:"last_heartbeat_ts"` @@ -501,9 +515,7 @@ func (h *Helper) GetStoresStat() (*StoresStat, error) { if err != nil { return nil, errors.Trace(err) } - timeout, cancelFunc := context.WithTimeout(context.Background(), 50*time.Millisecond) - resp, err := http.DefaultClient.Do(req.WithContext(timeout)) - defer cancelFunc() + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, errors.Trace(err) } diff --git a/store/mockstore/mocktikv/analyze.go b/store/mockstore/mocktikv/analyze.go index c50f48e12c558..65a85da0f6bb3 100644 --- a/store/mockstore/mocktikv/analyze.go +++ b/store/mockstore/mocktikv/analyze.go @@ -214,7 +214,7 @@ func (e *analyzeColumnsExec) getNext(ctx context.Context) ([]types.Datum, error) return datumRow, nil } -func (e *analyzeColumnsExec) Next(ctx context.Context, req *chunk.RecordBatch) error { +func (e *analyzeColumnsExec) Next(ctx context.Context, req *chunk.Chunk) error { req.Reset() row, err := e.getNext(ctx) if row == nil || err != nil { @@ -226,12 +226,12 @@ func (e *analyzeColumnsExec) Next(ctx context.Context, req *chunk.RecordBatch) e return nil } -func (e *analyzeColumnsExec) NewRecordBatch() *chunk.RecordBatch { +func (e *analyzeColumnsExec) NewChunk() *chunk.Chunk { fields := make([]*types.FieldType, 0, len(e.fields)) for _, field := range e.fields { fields = append(fields, &field.Column.FieldType) } - return chunk.NewRecordBatch(chunk.NewChunkWithCapacity(fields, 1)) + return chunk.NewChunkWithCapacity(fields, 1) } // Close implements the sqlexec.RecordSet Close interface. diff --git a/store/mockstore/mocktikv/cluster.go b/store/mockstore/mocktikv/cluster.go index 014e61042f5c8..460fffd6d39b8 100644 --- a/store/mockstore/mocktikv/cluster.go +++ b/store/mockstore/mocktikv/cluster.go @@ -18,10 +18,12 @@ import ( "context" "math" "sync" + "time" "github.com/golang/protobuf/proto" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/tablecodec" ) @@ -40,14 +42,24 @@ type Cluster struct { id uint64 stores map[uint64]*Store regions map[uint64]*Region + + // delayEvents is used to control the execution sequence of rpc requests for test. + delayEvents map[delayKey]time.Duration + delayMu sync.Mutex +} + +type delayKey struct { + startTS uint64 + regionID uint64 } // NewCluster creates an empty cluster. It needs to be bootstrapped before // providing service. func NewCluster() *Cluster { return &Cluster{ - stores: make(map[uint64]*Store), - regions: make(map[uint64]*Region), + stores: make(map[uint64]*Store), + regions: make(map[uint64]*Region), + delayEvents: make(map[delayKey]time.Duration), } } @@ -307,12 +319,13 @@ func (c *Cluster) Split(regionID, newRegionID uint64, key []byte, peerIDs []uint } // SplitRaw splits a Region at the key (not encoded) and creates new Region. -func (c *Cluster) SplitRaw(regionID, newRegionID uint64, rawKey []byte, peerIDs []uint64, leaderPeerID uint64) { +func (c *Cluster) SplitRaw(regionID, newRegionID uint64, rawKey []byte, peerIDs []uint64, leaderPeerID uint64) *Region { c.Lock() defer c.Unlock() newRegion := c.regions[regionID].split(newRegionID, rawKey, peerIDs, leaderPeerID) c.regions[newRegionID] = newRegion + return newRegion } // Merge merges 2 regions, their key ranges should be adjacent. @@ -340,6 +353,32 @@ func (c *Cluster) SplitIndex(mvccStore MVCCStore, tableID, indexID int64, count c.splitRange(mvccStore, NewMvccKey(indexStart), NewMvccKey(indexEnd), count) } +// SplitKeys evenly splits the start, end key into "count" regions. +// Only works for single store. +func (c *Cluster) SplitKeys(mvccStore MVCCStore, start, end kv.Key, count int) { + c.splitRange(mvccStore, NewMvccKey(start), NewMvccKey(end), count) +} + +// ScheduleDelay schedules a delay event for a transaction on a region. +func (c *Cluster) ScheduleDelay(startTS, regionID uint64, dur time.Duration) { + c.delayMu.Lock() + c.delayEvents[delayKey{startTS: startTS, regionID: regionID}] = dur + c.delayMu.Unlock() +} + +func (c *Cluster) handleDelay(startTS, regionID uint64) { + key := delayKey{startTS: startTS, regionID: regionID} + c.delayMu.Lock() + dur, ok := c.delayEvents[key] + if ok { + delete(c.delayEvents, key) + } + c.delayMu.Unlock() + if ok { + time.Sleep(dur) + } +} + func (c *Cluster) splitRange(mvccStore MVCCStore, start, end MvccKey, count int) { c.Lock() defer c.Unlock() diff --git a/store/mockstore/mocktikv/errors.go b/store/mockstore/mocktikv/errors.go index eb986eeb1e456..8f15c0f689b16 100644 --- a/store/mockstore/mocktikv/errors.go +++ b/store/mockstore/mocktikv/errors.go @@ -22,6 +22,7 @@ type ErrLocked struct { Primary []byte StartTS uint64 TTL uint64 + TxnSize uint64 } // Error formats the lock to a string. @@ -61,7 +62,7 @@ func (e ErrAlreadyCommitted) Error() string { return "txn already committed" } -// ErrConflict is turned when the commitTS of key in the DB is greater than startTS. +// ErrConflict is returned when the commitTS of key in the DB is greater than startTS. type ErrConflict struct { StartTS uint64 ConflictTS uint64 @@ -71,3 +72,14 @@ type ErrConflict struct { func (e *ErrConflict) Error() string { return "write conflict" } + +// ErrDeadlock is returned when deadlock error is detected. +type ErrDeadlock struct { + LockTS uint64 + LockKey []byte + DealockKeyHash uint64 +} + +func (e *ErrDeadlock) Error() string { + return "deadlock" +} diff --git a/store/mockstore/mocktikv/mock_tikv_test.go b/store/mockstore/mocktikv/mock_tikv_test.go index 091a3aa05bbca..f8222603622d6 100644 --- a/store/mockstore/mocktikv/mock_tikv_test.go +++ b/store/mockstore/mocktikv/mock_tikv_test.go @@ -39,6 +39,7 @@ type testMVCCLevelDB struct { } var ( + _ = Suite(&testMockTiKVSuite{}) _ = Suite(&testMVCCLevelDB{}) _ = Suite(testMarshal{}) ) @@ -94,7 +95,12 @@ func (s *testMockTiKVSuite) mustGetRC(c *C, key string, ts uint64, expect string } func (s *testMockTiKVSuite) mustPutOK(c *C, key, value string, startTS, commitTS uint64) { - errs := s.store.Prewrite(putMutations(key, value), []byte(key), startTS, 0) + req := &kvrpcpb.PrewriteRequest{ + Mutations: putMutations(key, value), + PrimaryLock: []byte(key), + StartVersion: startTS, + } + errs := s.store.Prewrite(req) for _, err := range errs { c.Assert(err, IsNil) } @@ -109,7 +115,12 @@ func (s *testMockTiKVSuite) mustDeleteOK(c *C, key string, startTS, commitTS uin Key: []byte(key), }, } - errs := s.store.Prewrite(mutations, []byte(key), startTS, 0) + req := &kvrpcpb.PrewriteRequest{ + Mutations: mutations, + PrimaryLock: []byte(key), + StartVersion: startTS, + } + errs := s.store.Prewrite(req) for _, err := range errs { c.Assert(err, IsNil) } @@ -146,7 +157,16 @@ func (s *testMockTiKVSuite) mustRangeReverseScanOK(c *C, start, end string, limi } func (s *testMockTiKVSuite) mustPrewriteOK(c *C, mutations []*kvrpcpb.Mutation, primary string, startTS uint64) { - errs := s.store.Prewrite(mutations, []byte(primary), startTS, 0) + s.mustPrewriteWithTTLOK(c, mutations, primary, startTS, 0) +} + +func (s *testMockTiKVSuite) mustPrewriteWithTTLOK(c *C, mutations []*kvrpcpb.Mutation, primary string, startTS uint64, ttl uint64) { + req := &kvrpcpb.PrewriteRequest{ + Mutations: mutations, + PrimaryLock: []byte(primary), + StartVersion: startTS, + } + errs := s.store.Prewrite(req) for _, err := range errs { c.Assert(err, IsNil) } @@ -412,7 +432,12 @@ func (s *testMockTiKVSuite) TestCommitConflict(c *C) { // A prewrite. s.mustPrewriteOK(c, putMutations("x", "A"), "x", 5) // B prewrite and find A's lock. - errs := s.store.Prewrite(putMutations("x", "B"), []byte("x"), 10, 0) + req := &kvrpcpb.PrewriteRequest{ + Mutations: putMutations("x", "B"), + PrimaryLock: []byte("x"), + StartVersion: 10, + } + errs := s.store.Prewrite(req) c.Assert(errs[0], NotNil) // B find rollback A because A exist too long. s.mustRollbackOK(c, [][]byte{[]byte("x")}, 5) @@ -470,17 +495,27 @@ func (s *testMockTiKVSuite) TestBatchResolveLock(c *C) { func (s *testMockTiKVSuite) TestRollbackAndWriteConflict(c *C) { s.mustPutOK(c, "test", "test", 1, 3) - - errs := s.store.Prewrite(putMutations("lock", "lock", "test", "test1"), []byte("test"), 2, 2) + req := &kvrpcpb.PrewriteRequest{ + Mutations: putMutations("lock", "lock", "test", "test1"), + PrimaryLock: []byte("test"), + StartVersion: 2, + LockTtl: 2, + } + errs := s.store.Prewrite(req) s.mustWriteWriteConflict(c, errs, 1) s.mustPutOK(c, "test", "test2", 5, 8) // simulate `getTxnStatus` for txn 2. - err := s.store.Cleanup([]byte("test"), 2) + err := s.store.Cleanup([]byte("test"), 2, math.MaxUint64) c.Assert(err, IsNil) - - errs = s.store.Prewrite(putMutations("test", "test3"), []byte("test"), 6, 1) + req = &kvrpcpb.PrewriteRequest{ + Mutations: putMutations("test", "test3"), + PrimaryLock: []byte("test"), + StartVersion: 6, + LockTtl: 1, + } + errs = s.store.Prewrite(req) s.mustWriteWriteConflict(c, errs, 0) } @@ -564,3 +599,22 @@ func (s testMarshal) TestMarshalmvccValue(c *C) { c.Assert(v.commitTS, Equals, v1.commitTS) c.Assert(string(v.value), Equals, string(v.value)) } + +func (s *testMVCCLevelDB) TestTxnHeartBeat(c *C) { + s.mustPrewriteWithTTLOK(c, putMutations("pk", "val"), "pk", 5, 666) + + // Update the ttl + ttl, err := s.store.TxnHeartBeat([]byte("pk"), 5, 888) + c.Assert(err, IsNil) + c.Assert(ttl, Greater, uint64(666)) + + // Advise ttl is small + ttl, err = s.store.TxnHeartBeat([]byte("pk"), 5, 300) + c.Assert(err, IsNil) + c.Assert(ttl, Greater, uint64(300)) + + // The lock has already been clean up + c.Assert(s.store.Cleanup([]byte("pk"), 5, 0), IsNil) + _, err = s.store.TxnHeartBeat([]byte("pk"), 5, 1000) + c.Assert(err, NotNil) +} diff --git a/store/mockstore/mocktikv/mvcc.go b/store/mockstore/mocktikv/mvcc.go index 9607819f018e9..543e6f4f92218 100644 --- a/store/mockstore/mocktikv/mvcc.go +++ b/store/mockstore/mocktikv/mvcc.go @@ -42,11 +42,13 @@ type mvccValue struct { } type mvccLock struct { - startTS uint64 - primary []byte - value []byte - op kvrpcpb.Op - ttl uint64 + startTS uint64 + primary []byte + value []byte + op kvrpcpb.Op + ttl uint64 + forUpdateTS uint64 + txnSize uint64 } type mvccEntry struct { @@ -66,6 +68,8 @@ func (l *mvccLock) MarshalBinary() ([]byte, error) { mh.WriteSlice(&buf, l.value) mh.WriteNumber(&buf, l.op) mh.WriteNumber(&buf, l.ttl) + mh.WriteNumber(&buf, l.forUpdateTS) + mh.WriteNumber(&buf, l.txnSize) return buf.Bytes(), errors.Trace(mh.err) } @@ -78,6 +82,8 @@ func (l *mvccLock) UnmarshalBinary(data []byte) error { mh.ReadSlice(buf, &l.value) mh.ReadNumber(buf, &l.op) mh.ReadNumber(buf, &l.ttl) + mh.ReadNumber(buf, &l.forUpdateTS) + mh.ReadNumber(buf, &l.txnSize) return errors.Trace(mh.err) } @@ -194,6 +200,7 @@ func (l *mvccLock) lockErr(key []byte) error { Primary: l.primary, StartTS: l.startTS, TTL: l.ttl, + TxnSize: l.txnSize, } } @@ -429,11 +436,13 @@ type MVCCStore interface { ReverseScan(startKey, endKey []byte, limit int, startTS uint64, isoLevel kvrpcpb.IsolationLevel) []Pair BatchGet(ks [][]byte, startTS uint64, isoLevel kvrpcpb.IsolationLevel) []Pair PessimisticLock(mutations []*kvrpcpb.Mutation, primary []byte, startTS, forUpdateTS uint64, ttl uint64) []error - Prewrite(mutations []*kvrpcpb.Mutation, primary []byte, startTS uint64, ttl uint64) []error + PessimisticRollback(keys [][]byte, startTS, forUpdateTS uint64) []error + Prewrite(req *kvrpcpb.PrewriteRequest) []error Commit(keys [][]byte, startTS, commitTS uint64) error Rollback(keys [][]byte, startTS uint64) error - Cleanup(key []byte, startTS uint64) error + Cleanup(key []byte, startTS, currentTS uint64) error ScanLock(startKey, endKey []byte, maxTS uint64) ([]*kvrpcpb.LockInfo, error) + TxnHeartBeat(primaryKey []byte, startTS uint64, adviseTTL uint64) (uint64, error) ResolveLock(startKey, endKey []byte, startTS, commitTS uint64) error BatchResolveLock(startKey, endKey []byte, txnInfos map[uint64]uint64) error DeleteRange(startKey, endKey []byte) error diff --git a/store/mockstore/mocktikv/mvcc_leveldb.go b/store/mockstore/mocktikv/mvcc_leveldb.go index d9c0e2ba2317c..3a30ba637d627 100644 --- a/store/mockstore/mocktikv/mvcc_leveldb.go +++ b/store/mockstore/mocktikv/mvcc_leveldb.go @@ -19,6 +19,7 @@ import ( "math" "sync" + "github.com/dgryski/go-farm" "github.com/pingcap/errors" "github.com/pingcap/goleveldb/leveldb" "github.com/pingcap/goleveldb/leveldb/iterator" @@ -27,7 +28,9 @@ import ( "github.com/pingcap/goleveldb/leveldb/util" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/util/codec" + "github.com/pingcap/tidb/util/deadlock" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -60,7 +63,8 @@ type MVCCLevelDB struct { // mu used for lock // leveldb can not guarantee multiple operations to be atomic, for example, read // then write, another write may happen during it, so this lock is necessory. - mu sync.RWMutex + mu sync.RWMutex + deadlockDetector *deadlock.Detector } const lockVer uint64 = math.MaxUint64 @@ -121,7 +125,7 @@ func NewMVCCLevelDB(path string) (*MVCCLevelDB, error) { d, err = leveldb.OpenFile(path, &opt.Options{BlockCacheCapacity: 600 * 1024 * 1024}) } - return &MVCCLevelDB{db: d}, errors.Trace(err) + return &MVCCLevelDB{db: d, deadlockDetector: deadlock.NewDetector()}, errors.Trace(err) } // Iterator wraps iterator.Iterator to provide Valid() method. @@ -507,7 +511,7 @@ func (mvcc *MVCCLevelDB) PessimisticLock(mutations []*kvrpcpb.Mutation, primary batch := &leveldb.Batch{} errs := make([]error, 0, len(mutations)) for _, m := range mutations { - err := pessimisticLockMutation(mvcc.db, batch, m, startTS, forUpdateTS, primary, ttl) + err := mvcc.pessimisticLockMutation(batch, m, startTS, forUpdateTS, primary, ttl) errs = append(errs, err) if err != nil { anyError = true @@ -517,15 +521,15 @@ func (mvcc *MVCCLevelDB) PessimisticLock(mutations []*kvrpcpb.Mutation, primary return errs } if err := mvcc.db.Write(batch, nil); err != nil { - return nil + return []error{err} } return errs } -func pessimisticLockMutation(db *leveldb.DB, batch *leveldb.Batch, mutation *kvrpcpb.Mutation, startTS, forUpdateTS uint64, primary []byte, ttl uint64) error { +func (mvcc *MVCCLevelDB) pessimisticLockMutation(batch *leveldb.Batch, mutation *kvrpcpb.Mutation, startTS, forUpdateTS uint64, primary []byte, ttl uint64) error { startKey := mvccEncode(mutation.Key, lockVer) - iter := newIterator(db, &util.Range{ + iter := newIterator(mvcc.db, &util.Range{ Start: startKey, }) defer iter.Release() @@ -539,19 +543,28 @@ func pessimisticLockMutation(db *leveldb.DB, batch *leveldb.Batch, mutation *kvr } if ok { if dec.lock.startTS != startTS { + errDeadlock := mvcc.deadlockDetector.Detect(startTS, dec.lock.startTS, farm.Fingerprint64(mutation.Key)) + if errDeadlock != nil { + return &ErrDeadlock{ + LockKey: mutation.Key, + LockTS: dec.lock.startTS, + DealockKeyHash: errDeadlock.KeyHash, + } + } return dec.lock.lockErr(mutation.Key) } return nil } - if err = checkConflictValue(iter, mutation.Key, forUpdateTS); err != nil { + if err = checkConflictValue(iter, mutation, forUpdateTS); err != nil { return err } lock := mvccLock{ - startTS: startTS, - primary: primary, - op: kvrpcpb.Op_PessimisticLock, - ttl: ttl, + startTS: startTS, + primary: primary, + op: kvrpcpb.Op_PessimisticLock, + ttl: ttl, + forUpdateTS: forUpdateTS, } writeKey := mvccEncode(mutation.Key, lockVer) writeValue, err := lock.MarshalBinary() @@ -563,15 +576,67 @@ func pessimisticLockMutation(db *leveldb.DB, batch *leveldb.Batch, mutation *kvr return nil } +// PessimisticRollback implements the MVCCStore interface. +func (mvcc *MVCCLevelDB) PessimisticRollback(keys [][]byte, startTS, forUpdateTS uint64) []error { + mvcc.mu.Lock() + defer mvcc.mu.Unlock() + + anyError := false + batch := &leveldb.Batch{} + errs := make([]error, 0, len(keys)) + for _, key := range keys { + err := pessimisticRollbackKey(mvcc.db, batch, key, startTS, forUpdateTS) + errs = append(errs, err) + if err != nil { + anyError = true + } + } + if anyError { + return errs + } + if err := mvcc.db.Write(batch, nil); err != nil { + return []error{err} + } + return errs +} + +func pessimisticRollbackKey(db *leveldb.DB, batch *leveldb.Batch, key []byte, startTS, forUpdateTS uint64) error { + startKey := mvccEncode(key, lockVer) + iter := newIterator(db, &util.Range{ + Start: startKey, + }) + defer iter.Release() + + dec := lockDecoder{ + expectKey: key, + } + ok, err := dec.Decode(iter) + if err != nil { + return errors.Trace(err) + } + if ok { + lock := dec.lock + if lock.op == kvrpcpb.Op_PessimisticLock && lock.startTS == startTS && lock.forUpdateTS <= forUpdateTS { + batch.Delete(startKey) + } + } + return nil +} + // Prewrite implements the MVCCStore interface. -func (mvcc *MVCCLevelDB) Prewrite(mutations []*kvrpcpb.Mutation, primary []byte, startTS uint64, ttl uint64) []error { +func (mvcc *MVCCLevelDB) Prewrite(req *kvrpcpb.PrewriteRequest) []error { + mutations := req.Mutations + primary := req.PrimaryLock + startTS := req.StartVersion + ttl := req.LockTtl mvcc.mu.Lock() defer mvcc.mu.Unlock() anyError := false batch := &leveldb.Batch{} errs := make([]error, 0, len(mutations)) - for _, m := range mutations { + txnSize := req.TxnSize + for i, m := range mutations { // If the operation is Insert, check if key is exists at first. var err error if m.GetOp() == kvrpcpb.Op_Insert { @@ -590,7 +655,8 @@ func (mvcc *MVCCLevelDB) Prewrite(mutations []*kvrpcpb.Mutation, primary []byte, continue } } - err = prewriteMutation(mvcc.db, batch, m, startTS, primary, ttl) + isPessimisticLock := len(req.IsPessimisticLock) > 0 && req.IsPessimisticLock[i] + err = prewriteMutation(mvcc.db, batch, m, startTS, primary, ttl, txnSize, isPessimisticLock) errs = append(errs, err) if err != nil { anyError = true @@ -600,32 +666,53 @@ func (mvcc *MVCCLevelDB) Prewrite(mutations []*kvrpcpb.Mutation, primary []byte, return errs } if err := mvcc.db.Write(batch, nil); err != nil { - return nil + return []error{err} } return errs } -func checkConflictValue(iter *Iterator, key []byte, startTS uint64) error { +func checkConflictValue(iter *Iterator, m *kvrpcpb.Mutation, startTS uint64) error { dec := valueDecoder{ - expectKey: key, + expectKey: m.Key, } ok, err := dec.Decode(iter) if err != nil { return errors.Trace(err) } + if !ok { + return nil + } // Note that it's a write conflict here, even if the value is a rollback one. - if ok && dec.value.commitTS >= startTS { + if dec.value.commitTS >= startTS { return &ErrConflict{ StartTS: startTS, ConflictTS: dec.value.commitTS, - Key: key, + Key: m.Key, + } + } + if m.Op == kvrpcpb.Op_PessimisticLock && m.Assertion == kvrpcpb.Assertion_NotExist { + // Skip rollback keys. + for dec.value.valueType == typeRollback { + ok, err = dec.Decode(iter) + if err != nil { + return errors.Trace(err) + } + if !ok { + return nil + } + } + if dec.value.valueType == typeDelete { + return nil + } + return &ErrKeyAlreadyExist{ + Key: m.Key, } } return nil } -func prewriteMutation(db *leveldb.DB, batch *leveldb.Batch, mutation *kvrpcpb.Mutation, startTS uint64, primary []byte, ttl uint64) error { +func prewriteMutation(db *leveldb.DB, batch *leveldb.Batch, mutation *kvrpcpb.Mutation, startTS uint64, primary []byte, ttl uint64, txnSize uint64, isPessimisticLock bool) error { startKey := mvccEncode(mutation.Key, lockVer) iter := newIterator(db, &util.Range{ Start: startKey, @@ -641,6 +728,12 @@ func prewriteMutation(db *leveldb.DB, batch *leveldb.Batch, mutation *kvrpcpb.Mu } if ok { if dec.lock.startTS != startTS { + if isPessimisticLock { + // NOTE: A special handling. + // When pessimistic txn prewrite meets lock, set the TTL = 0 means + // telling TiDB to rollback the transaction **unconditionly**. + dec.lock.ttl = 0 + } return dec.lock.lockErr(mutation.Key) } if dec.lock.op != kvrpcpb.Op_PessimisticLock { @@ -648,7 +741,10 @@ func prewriteMutation(db *leveldb.DB, batch *leveldb.Batch, mutation *kvrpcpb.Mu } // Overwrite the pessimistic lock. } else { - err = checkConflictValue(iter, mutation.Key, startTS) + if isPessimisticLock { + return ErrAbort("pessimistic lock not found") + } + err = checkConflictValue(iter, mutation, startTS) if err != nil { return err } @@ -664,6 +760,7 @@ func prewriteMutation(db *leveldb.DB, batch *leveldb.Batch, mutation *kvrpcpb.Mu value: mutation.Value, op: op, ttl: ttl, + txnSize: txnSize, } writeKey := mvccEncode(mutation.Key, lockVer) writeValue, err := lock.MarshalBinary() @@ -684,7 +781,10 @@ func prewriteMutation(db *leveldb.DB, batch *leveldb.Batch, mutation *kvrpcpb.Mu // Commit implements the MVCCStore interface. func (mvcc *MVCCLevelDB) Commit(keys [][]byte, startTS, commitTS uint64) error { mvcc.mu.Lock() - defer mvcc.mu.Unlock() + defer func() { + mvcc.mu.Unlock() + mvcc.deadlockDetector.CleanUp(startTS) + }() batch := &leveldb.Batch{} for _, k := range keys { @@ -758,7 +858,10 @@ func commitLock(batch *leveldb.Batch, lock mvccLock, key []byte, startTS, commit // Rollback implements the MVCCStore interface. func (mvcc *MVCCLevelDB) Rollback(keys [][]byte, startTS uint64) error { mvcc.mu.Lock() - defer mvcc.mu.Unlock() + defer func() { + mvcc.mu.Unlock() + mvcc.deadlockDetector.CleanUp(startTS) + }() batch := &leveldb.Batch{} for _, k := range keys { @@ -858,16 +961,117 @@ func getTxnCommitInfo(iter *Iterator, expectKey []byte, startTS uint64) (mvccVal } // Cleanup implements the MVCCStore interface. -func (mvcc *MVCCLevelDB) Cleanup(key []byte, startTS uint64) error { +// Cleanup API is deprecated, use CheckTxnStatus instead. +func (mvcc *MVCCLevelDB) Cleanup(key []byte, startTS, currentTS uint64) error { mvcc.mu.Lock() - defer mvcc.mu.Unlock() + defer func() { + mvcc.mu.Unlock() + mvcc.deadlockDetector.CleanUp(startTS) + }() batch := &leveldb.Batch{} - err := rollbackKey(mvcc.db, batch, key, startTS) + startKey := mvccEncode(key, lockVer) + iter := newIterator(mvcc.db, &util.Range{ + Start: startKey, + }) + defer iter.Release() + + if iter.Valid() { + dec := lockDecoder{ + expectKey: key, + } + ok, err := dec.Decode(iter) + if err != nil { + return err + } + // If current transaction's lock exists. + if ok && dec.lock.startTS == startTS { + // If the lock has already outdated, clean up it. + if currentTS == 0 || uint64(oracle.ExtractPhysical(dec.lock.startTS))+dec.lock.ttl < uint64(oracle.ExtractPhysical(currentTS)) { + if err = rollbackLock(batch, dec.lock, key, startTS); err != nil { + return err + } + return mvcc.db.Write(batch, nil) + } + + // Otherwise, return a locked error with the TTL information. + return dec.lock.lockErr(key) + } + + // If current transaction's lock does not exist. + // If the commit information of the current transaction exist. + c, ok, err := getTxnCommitInfo(iter, key, startTS) + if err != nil { + return errors.Trace(err) + } + if ok { + // If the current transaction has already committed. + if c.valueType != typeRollback { + return ErrAlreadyCommitted(c.commitTS) + } + // If the current transaction has already rollbacked. + return nil + } + } + + // If current transaction is not prewritted before. + value := mvccValue{ + valueType: typeRollback, + startTS: startTS, + commitTS: startTS, + } + writeKey := mvccEncode(key, startTS) + writeValue, err := value.MarshalBinary() if err != nil { return errors.Trace(err) } - return mvcc.db.Write(batch, nil) + batch.Put(writeKey, writeValue) + return nil +} + +// TxnHeartBeat implements the MVCCStore interface. +func (mvcc *MVCCLevelDB) TxnHeartBeat(key []byte, startTS uint64, adviseTTL uint64) (uint64, error) { + mvcc.mu.Lock() + defer mvcc.mu.Unlock() + + startKey := mvccEncode(key, lockVer) + iter := newIterator(mvcc.db, &util.Range{ + Start: startKey, + }) + defer iter.Release() + + if iter.Valid() { + dec := lockDecoder{ + expectKey: key, + } + ok, err := dec.Decode(iter) + if err != nil { + return 0, errors.Trace(err) + } + if ok && dec.lock.startTS == startTS { + if !bytes.Equal(dec.lock.primary, key) { + return 0, errors.New("txnHeartBeat on non-primary key, the code should not run here") + } + + lock := dec.lock + batch := &leveldb.Batch{} + // Increase the ttl of this transaction. + if adviseTTL > lock.ttl { + lock.ttl = adviseTTL + writeKey := mvccEncode(key, lockVer) + writeValue, err := lock.MarshalBinary() + if err != nil { + return 0, errors.Trace(err) + } + batch.Put(writeKey, writeValue) + if err = mvcc.db.Write(batch, nil); err != nil { + return 0, errors.Trace(err) + } + } + return lock.ttl, nil + } + } + return 0, errors.New("lock doesn't exist") } // ScanLock implements the MVCCStore interface. diff --git a/store/mockstore/mocktikv/pd.go b/store/mockstore/mocktikv/pd.go index e79f1a471164a..5e69823fedd52 100644 --- a/store/mockstore/mocktikv/pd.go +++ b/store/mockstore/mocktikv/pd.go @@ -103,7 +103,7 @@ func (c *pdClient) GetAllStores(ctx context.Context, opts ...pd.GetStoreOption) } func (c *pdClient) UpdateGCSafePoint(ctx context.Context, safePoint uint64) (uint64, error) { - panic("unimplemented") + return 0, nil } func (c *pdClient) Close() { @@ -113,6 +113,10 @@ func (c *pdClient) ScatterRegion(ctx context.Context, regionID uint64) error { return nil } +func (c *pdClient) ScanRegions(ctx context.Context, key []byte, limit int) ([]*metapb.Region, []*metapb.Peer, error) { + return nil, nil, nil +} + func (c *pdClient) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) { return &pdpb.GetOperatorResponse{Status: pdpb.OperatorStatus_SUCCESS}, nil } diff --git a/store/mockstore/mocktikv/rpc.go b/store/mockstore/mocktikv/rpc.go old mode 100644 new mode 100755 index eb1afed758b3b..f566cae797e05 --- a/store/mockstore/mocktikv/rpc.go +++ b/store/mockstore/mocktikv/rpc.go @@ -18,10 +18,12 @@ import ( "context" "fmt" "io" + "math" "strconv" "time" "github.com/golang/protobuf/proto" + "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/coprocessor" @@ -56,6 +58,7 @@ func convertToKeyError(err error) *kvrpcpb.KeyError { PrimaryLock: locked.Primary, LockVersion: locked.StartTS, LockTtl: locked.TTL, + TxnSize: locked.TxnSize, }, } } @@ -75,6 +78,15 @@ func convertToKeyError(err error) *kvrpcpb.KeyError { }, } } + if dead, ok := errors.Cause(err).(*ErrDeadlock); ok { + return &kvrpcpb.KeyError{ + Deadlock: &kvrpcpb.Deadlock{ + LockTs: dead.LockTS, + LockKey: dead.LockKey, + DeadlockKeyHash: dead.DealockKeyHash, + }, + } + } if retryable, ok := errors.Cause(err).(ErrRetryable); ok { return &kvrpcpb.KeyError{ Retryable: retryable.Error(), @@ -245,14 +257,32 @@ func (h *rpcHandler) handleKvGet(req *kvrpcpb.GetRequest) *kvrpcpb.GetResponse { } func (h *rpcHandler) handleKvScan(req *kvrpcpb.ScanRequest) *kvrpcpb.ScanResponse { - if !h.checkKeyInRegion(req.GetStartKey()) { - panic("KvScan: startKey not in region") - } - endKey := h.endKey - if len(req.EndKey) > 0 && (len(endKey) == 0 || bytes.Compare(req.EndKey, endKey) < 0) { - endKey = req.EndKey + endKey := MvccKey(h.endKey).Raw() + var pairs []Pair + if !req.Reverse { + if !h.checkKeyInRegion(req.GetStartKey()) { + panic("KvScan: startKey not in region") + } + if len(req.EndKey) > 0 && (len(endKey) == 0 || bytes.Compare(NewMvccKey(req.EndKey), h.endKey) < 0) { + endKey = req.EndKey + } + pairs = h.mvccStore.Scan(req.GetStartKey(), endKey, int(req.GetLimit()), req.GetVersion(), h.isolationLevel) + } else { + // TiKV use range [end_key, start_key) for reverse scan. + // Should use the req.EndKey to check in region. + if !h.checkKeyInRegion(req.GetEndKey()) { + panic("KvScan: startKey not in region") + } + + // TiKV use range [end_key, start_key) for reverse scan. + // So the req.StartKey actually is the end_key. + if len(req.StartKey) > 0 && (len(endKey) == 0 || bytes.Compare(NewMvccKey(req.StartKey), h.endKey) < 0) { + endKey = req.StartKey + } + + pairs = h.mvccStore.ReverseScan(req.EndKey, endKey, int(req.GetLimit()), req.GetVersion(), h.isolationLevel) } - pairs := h.mvccStore.Scan(req.GetStartKey(), endKey, int(req.GetLimit()), req.GetVersion(), h.isolationLevel) + return &kvrpcpb.ScanResponse{ Pairs: convertToPbPairs(pairs), } @@ -264,7 +294,7 @@ func (h *rpcHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.Pre panic("KvPrewrite: key not in region") } } - errs := h.mvccStore.Prewrite(req.Mutations, req.PrimaryLock, req.GetStartVersion(), req.GetLockTtl()) + errs := h.mvccStore.Prewrite(req) return &kvrpcpb.PrewriteResponse{ Errors: convertToKeyErrors(errs), } @@ -273,15 +303,42 @@ func (h *rpcHandler) handleKvPrewrite(req *kvrpcpb.PrewriteRequest) *kvrpcpb.Pre func (h *rpcHandler) handleKvPessimisticLock(req *kvrpcpb.PessimisticLockRequest) *kvrpcpb.PessimisticLockResponse { for _, m := range req.Mutations { if !h.checkKeyInRegion(m.Key) { - panic("KvPrewrite: key not in region") + panic("KvPessimisticLock: key not in region") } } + startTS := req.StartVersion + regionID := req.Context.RegionId + h.cluster.handleDelay(startTS, regionID) errs := h.mvccStore.PessimisticLock(req.Mutations, req.PrimaryLock, req.GetStartVersion(), req.GetForUpdateTs(), req.GetLockTtl()) + + // TODO: remove this when implement sever side wait. + h.simulateServerSideWaitLock(errs) return &kvrpcpb.PessimisticLockResponse{ Errors: convertToKeyErrors(errs), } } +func (h *rpcHandler) simulateServerSideWaitLock(errs []error) { + for _, err := range errs { + if _, ok := err.(*ErrLocked); ok { + time.Sleep(time.Millisecond * 5) + break + } + } +} + +func (h *rpcHandler) handleKvPessimisticRollback(req *kvrpcpb.PessimisticRollbackRequest) *kvrpcpb.PessimisticRollbackResponse { + for _, key := range req.Keys { + if !h.checkKeyInRegion(key) { + panic("KvPessimisticRollback: key not in region") + } + } + errs := h.mvccStore.PessimisticRollback(req.Keys, req.StartVersion, req.ForUpdateTs) + return &kvrpcpb.PessimisticRollbackResponse{ + Errors: convertToKeyErrors(errs), + } +} + func (h *rpcHandler) handleKvCommit(req *kvrpcpb.CommitRequest) *kvrpcpb.CommitResponse { for _, k := range req.Keys { if !h.checkKeyInRegion(k) { @@ -301,7 +358,7 @@ func (h *rpcHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.Clean panic("KvCleanup: key not in region") } var resp kvrpcpb.CleanupResponse - err := h.mvccStore.Cleanup(req.Key, req.GetStartVersion()) + err := h.mvccStore.Cleanup(req.Key, req.GetStartVersion(), req.GetCurrentTs()) if err != nil { if commitTS, ok := errors.Cause(err).(ErrAlreadyCommitted); ok { resp.CommitVersion = uint64(commitTS) @@ -312,6 +369,19 @@ func (h *rpcHandler) handleKvCleanup(req *kvrpcpb.CleanupRequest) *kvrpcpb.Clean return &resp } +func (h *rpcHandler) handleTxnHeartBeat(req *kvrpcpb.TxnHeartBeatRequest) *kvrpcpb.TxnHeartBeatResponse { + if !h.checkKeyInRegion(req.PrimaryLock) { + panic("KvTxnHeartBeat: key not in region") + } + var resp kvrpcpb.TxnHeartBeatResponse + ttl, err := h.mvccStore.TxnHeartBeat(req.PrimaryLock, req.StartVersion, req.AdviseLockTtl) + if err != nil { + resp.Error = convertToKeyError(err) + } + resp.LockTtl = ttl + return &resp +} + func (h *rpcHandler) handleKvBatchGet(req *kvrpcpb.BatchGetRequest) *kvrpcpb.BatchGetResponse { for _, k := range req.Keys { if !h.checkKeyInRegion(k) { @@ -536,14 +606,25 @@ func (h *rpcHandler) handleKvRawScan(req *kvrpcpb.RawScanRequest) *kvrpcpb.RawSc } func (h *rpcHandler) handleSplitRegion(req *kvrpcpb.SplitRegionRequest) *kvrpcpb.SplitRegionResponse { - key := NewMvccKey(req.GetSplitKey()) - region, _ := h.cluster.GetRegionByKey(key) - if bytes.Equal(region.GetStartKey(), key) { - return &kvrpcpb.SplitRegionResponse{} + keys := req.GetSplitKeys() + resp := &kvrpcpb.SplitRegionResponse{Regions: make([]*metapb.Region, 0, len(keys)+1)} + for i, key := range keys { + k := NewMvccKey(key) + region, _ := h.cluster.GetRegionByKey(k) + if bytes.Equal(region.GetStartKey(), key) { + continue + } + if i == 0 { + // Set the leftmost region. + resp.Regions = append(resp.Regions, region) + } + newRegionID, newPeerIDs := h.cluster.AllocID(), h.cluster.AllocIDs(len(region.Peers)) + newRegion := h.cluster.SplitRaw(region.GetId(), newRegionID, k, newPeerIDs, newPeerIDs[0]) + // The mocktikv should return a deep copy of meta info to avoid data race + metaCloned := proto.Clone(newRegion.Meta) + resp.Regions = append(resp.Regions, metaCloned.(*metapb.Region)) } - newRegionID, newPeerIDs := h.cluster.AllocID(), h.cluster.AllocIDs(len(region.Peers)) - h.cluster.SplitRaw(region.GetId(), newRegionID, key, newPeerIDs, newPeerIDs[0]) - return &kvrpcpb.SplitRegionResponse{} + return resp } // RPCClient sends kv RPC calls to mock cluster. RPCClient mocks the behavior of @@ -601,6 +682,12 @@ func (c *RPCClient) checkArgs(ctx context.Context, addr string) (*rpcHandler, er // SendRequest sends a request to mock cluster. func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("RPCClient.SendRequest", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + failpoint.Inject("rpcServerBusy", func(val failpoint.Value) { if val.(bool) { failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{ServerIsBusy: &errorpb.ServerIsBusy{}})) @@ -631,6 +718,16 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R resp.Scan = handler.handleKvScan(r) case tikvrpc.CmdPrewrite: + failpoint.Inject("rpcPrewriteResult", func(val failpoint.Value) { + switch val.(string) { + case "notLeader": + failpoint.Return(&tikvrpc.Response{ + Type: tikvrpc.CmdPrewrite, + Prewrite: &kvrpcpb.PrewriteResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, + }, nil) + } + }) + r := req.Prewrite if err := handler.checkRequest(reqCtx, r.Size()); err != nil { resp.Prewrite = &kvrpcpb.PrewriteResponse{RegionError: err} @@ -644,6 +741,13 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return resp, nil } resp.PessimisticLock = handler.handleKvPessimisticLock(r) + case tikvrpc.CmdPessimisticRollback: + r := req.PessimisticRollback + if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + resp.PessimisticRollback = &kvrpcpb.PessimisticRollbackResponse{RegionError: err} + return resp, nil + } + resp.PessimisticRollback = handler.handleKvPessimisticRollback(r) case tikvrpc.CmdCommit: failpoint.Inject("rpcCommitResult", func(val failpoint.Value) { switch val.(string) { @@ -680,6 +784,13 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R return resp, nil } resp.Cleanup = handler.handleKvCleanup(r) + case tikvrpc.CmdTxnHeartBeat: + r := req.TxnHeartBeat + if err := handler.checkRequest(reqCtx, r.Size()); err != nil { + resp.TxnHeartBeat = &kvrpcpb.TxnHeartBeatResponse{RegionError: err} + return resp, nil + } + resp.TxnHeartBeat = handler.handleTxnHeartBeat(r) case tikvrpc.CmdBatchGet: r := req.BatchGet if err := handler.checkRequest(reqCtx, r.Size()); err != nil { @@ -856,8 +967,8 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R // DebugGetRegionProperties is for fast analyze in mock tikv. case tikvrpc.CmdDebugGetRegionProperties: r := req.DebugGetRegionProperties - region, _ := c.Cluster.GetRegionByID(r.RegionId) - scanResp := handler.handleKvScan(&kvrpcpb.ScanRequest{StartKey: region.StartKey, EndKey: region.EndKey}) + region, _ := c.Cluster.GetRegion(r.RegionId) + scanResp := handler.handleKvScan(&kvrpcpb.ScanRequest{StartKey: MvccKey(region.StartKey).Raw(), EndKey: MvccKey(region.EndKey).Raw(), Version: math.MaxUint64, Limit: math.MaxUint32}) resp.DebugGetRegionProperties = &debugpb.GetRegionPropertiesResponse{ Props: []*debugpb.Property{{ Name: "mvcc.num_rows", diff --git a/store/tikv/2pc.go b/store/tikv/2pc.go index 7f62263b1bfba..1e8b710c7c504 100644 --- a/store/tikv/2pc.go +++ b/store/tikv/2pc.go @@ -21,7 +21,9 @@ import ( "sync" "sync/atomic" "time" + "unsafe" + "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" "github.com/pingcap/failpoint" pb "github.com/pingcap/kvproto/pkg/kvrpcpb" @@ -30,6 +32,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx/binloginfo" + "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/store/tikv/tikvrpc" "github.com/pingcap/tidb/tablecodec" "github.com/pingcap/tidb/util/execdetails" @@ -38,41 +41,59 @@ import ( "go.uber.org/zap" ) -type twoPhaseCommitAction int +type twoPhaseCommitAction interface { + handleSingleBatch(*twoPhaseCommitter, *Backoffer, batchKeys) error + String() string +} -const ( - actionPrewrite twoPhaseCommitAction = 1 + iota - actionCommit - actionCleanup - actionPessimisticLock +type actionPrewrite struct{} +type actionCommit struct{} +type actionCleanup struct{} +type actionPessimisticLock struct{ killed *uint32 } +type actionPessimisticRollback struct{} + +var ( + _ twoPhaseCommitAction = actionPrewrite{} + _ twoPhaseCommitAction = actionCommit{} + _ twoPhaseCommitAction = actionCleanup{} + _ twoPhaseCommitAction = actionPessimisticLock{} + _ twoPhaseCommitAction = actionPessimisticRollback{} ) var ( tikvSecondaryLockCleanupFailureCounterCommit = metrics.TiKVSecondaryLockCleanupFailureCounter.WithLabelValues("commit") tikvSecondaryLockCleanupFailureCounterRollback = metrics.TiKVSecondaryLockCleanupFailureCounter.WithLabelValues("rollback") + tiKVTxnHeartBeatHistogramOK = metrics.TiKVTxnHeartBeatHistogram.WithLabelValues("ok") + tiKVTxnHeartBeatHistogramError = metrics.TiKVTxnHeartBeatHistogram.WithLabelValues("err") ) // Global variable set by config file. var ( - PessimisticLockTTL uint64 + PessimisticLockTTL uint64 = 15000 // 15s ~ 40s ) -func (ca twoPhaseCommitAction) String() string { - switch ca { - case actionPrewrite: - return "prewrite" - case actionCommit: - return "commit" - case actionCleanup: - return "cleanup" - case actionPessimisticLock: - return "pessimistic_lock" - } - return "unknown" +func (actionPrewrite) String() string { + return "prewrite" +} + +func (actionCommit) String() string { + return "commit" +} + +func (actionCleanup) String() string { + return "cleanup" +} + +func (actionPessimisticLock) String() string { + return "pessimistic_lock" +} + +func (actionPessimisticRollback) String() string { + return "pessimistic_rollback" } -// MetricsTag returns detail tag for metrics. -func (ca twoPhaseCommitAction) MetricsTag() string { +// metricsTag returns detail tag for metrics. +func metricsTag(ca twoPhaseCommitAction) string { return "2pc_" + ca.String() } @@ -85,25 +106,31 @@ type twoPhaseCommitter struct { mutations map[string]*mutationEx lockTTL uint64 commitTS uint64 - mu struct { - sync.RWMutex - committed bool - undeterminedErr error // undeterminedErr saves the rpc error we encounter when commit primary key. - } - priority pb.CommandPri - syncLog bool - connID uint64 // connID is used for log. - cleanWg sync.WaitGroup + priority pb.CommandPri + connID uint64 // connID is used for log. + cleanWg sync.WaitGroup // maxTxnTimeUse represents max time a Txn may use (in ms) from its startTS to commitTS. // We use it to guarantee GC worker will not influence any active txn. The value // should be less than GC life time. - maxTxnTimeUse uint64 - detail *execdetails.CommitDetails + maxTxnTimeUse uint64 + detail unsafe.Pointer + primaryKey []byte + forUpdateTS uint64 + pessimisticTTL uint64 + + mu struct { + sync.RWMutex + undeterminedErr error // undeterminedErr saves the rpc error we encounter when commit primary key. + committed bool + } + syncLog bool // For pessimistic transaction isPessimistic bool - primaryKey []byte - forUpdateTS uint64 isFirstLock bool + // regionTxnSize stores the number of keys involved in each region + regionTxnSize map[uint64]int + // Used by pessimistic transaction and large transaction. + ttlManager } type mutationEx struct { @@ -115,13 +142,57 @@ type mutationEx struct { // newTwoPhaseCommitter creates a twoPhaseCommitter. func newTwoPhaseCommitter(txn *tikvTxn, connID uint64) (*twoPhaseCommitter, error) { return &twoPhaseCommitter{ - store: txn.store, - txn: txn, - startTS: txn.StartTS(), - connID: connID, + store: txn.store, + txn: txn, + startTS: txn.StartTS(), + connID: connID, + regionTxnSize: map[uint64]int{}, + ttlManager: ttlManager{ + ch: make(chan struct{}), + }, }, nil } +func sendTxnHeartBeat(bo *Backoffer, store *tikvStore, primary []byte, startTS, ttl uint64) (uint64, error) { + req := &tikvrpc.Request{ + Type: tikvrpc.CmdTxnHeartBeat, + TxnHeartBeat: &pb.TxnHeartBeatRequest{ + PrimaryLock: primary, + StartVersion: startTS, + AdviseLockTtl: ttl, + }, + } + for { + loc, err := store.GetRegionCache().LocateKey(bo, primary) + if err != nil { + return 0, errors.Trace(err) + } + resp, err := store.SendReq(bo, req, loc.Region, readTimeoutShort) + if err != nil { + return 0, errors.Trace(err) + } + regionErr, err := resp.GetRegionError() + if err != nil { + return 0, errors.Trace(err) + } + if regionErr != nil { + err = bo.Backoff(BoRegionMiss, errors.New(regionErr.String())) + if err != nil { + return 0, errors.Trace(err) + } + continue + } + if resp.TxnHeartBeat == nil { + return 0, errors.Trace(ErrBodyMissing) + } + cmdResp := resp.TxnHeartBeat + if keyErr := cmdResp.GetError(); keyErr != nil { + return 0, errors.Errorf("txn %d heartbeat fail, primary key = %v, err = %s", startTS, primary, keyErr.Abort) + } + return cmdResp.GetLockTtl(), nil + } +} + func (c *twoPhaseCommitter) initKeysAndMutations() error { var ( keys [][]byte @@ -265,7 +336,7 @@ func (c *twoPhaseCommitter) initKeysAndMutations() error { c.lockTTL = txnLockTTL(txn.startTime, size) c.priority = getTxnPriority(txn) c.syncLog = getTxnSyncLog(txn) - c.detail = commitDetail + c.setDetail(commitDetail) return nil } @@ -309,18 +380,24 @@ func (c *twoPhaseCommitter) doActionOnKeys(bo *Backoffer, action twoPhaseCommitA if len(keys) == 0 { return nil } - groups, firstRegion, err := c.store.regionCache.GroupKeysByRegion(bo, keys) + groups, firstRegion, err := c.store.regionCache.GroupKeysByRegion(bo, keys, nil) if err != nil { return errors.Trace(err) } - metrics.TiKVTxnRegionsNumHistogram.WithLabelValues(action.MetricsTag()).Observe(float64(len(groups))) + metrics.TiKVTxnRegionsNumHistogram.WithLabelValues(metricsTag(action)).Observe(float64(len(groups))) var batches []batchKeys var sizeFunc = c.keySize - if action == actionPrewrite { + if _, ok := action.(actionPrewrite); ok { + // Do not update regionTxnSize on retries. They are not used when building a PrewriteRequest. + if len(bo.errors) == 0 { + for region, keys := range groups { + c.regionTxnSize[region.id] = len(keys) + } + } sizeFunc = c.keyValueSize - atomic.AddInt32(&c.detail.PrewriteRegionNum, int32(len(groups))) + atomic.AddInt32(&c.getDetail().PrewriteRegionNum, int32(len(groups))) } // Make sure the group that contains primary key goes first. batches = appendBatchBySize(batches, firstRegion, groups[firstRegion], sizeFunc, txnCommitBatchSize) @@ -330,7 +407,9 @@ func (c *twoPhaseCommitter) doActionOnKeys(bo *Backoffer, action twoPhaseCommitA } firstIsPrimary := bytes.Equal(keys[0], c.primary()) - if firstIsPrimary && (action == actionCommit || action == actionCleanup) { + _, actionIsCommit := action.(actionCommit) + _, actionIsCleanup := action.(actionCleanup) + if firstIsPrimary && (actionIsCommit || actionIsCleanup) { // primary should be committed/cleanup first err = c.doActionOnBatches(bo, action, batches[:1]) if err != nil { @@ -338,7 +417,7 @@ func (c *twoPhaseCommitter) doActionOnKeys(bo *Backoffer, action twoPhaseCommitA } batches = batches[1:] } - if action == actionCommit { + if actionIsCommit { // Commit secondary batches in background goroutine to reduce latency. // The backoffer instance is created outside of the goroutine to avoid // potencial data race in unit test since `CommitMaxBackoff` will be updated @@ -365,21 +444,10 @@ func (c *twoPhaseCommitter) doActionOnBatches(bo *Backoffer, action twoPhaseComm if len(batches) == 0 { return nil } - var singleBatchActionFunc func(bo *Backoffer, batch batchKeys) error - switch action { - case actionPrewrite: - singleBatchActionFunc = c.prewriteSingleBatch - case actionCommit: - singleBatchActionFunc = c.commitSingleBatch - case actionCleanup: - singleBatchActionFunc = c.cleanupSingleBatch - case actionPessimisticLock: - singleBatchActionFunc = c.pessimisticLockSingleBatch - } if len(batches) == 1 { - e := singleBatchActionFunc(bo, batches[0]) + e := action.handleSingleBatch(c, bo, batches[0]) if e != nil { - logutil.Logger(context.Background()).Debug("2PC doActionOnBatches failed", + logutil.Logger(bo.ctx).Debug("2PC doActionOnBatches failed", zap.Uint64("conn", c.connID), zap.Stringer("action type", action), zap.Error(e), @@ -391,7 +459,7 @@ func (c *twoPhaseCommitter) doActionOnBatches(bo *Backoffer, action twoPhaseComm // For prewrite, stop sending other requests after receiving first error. backoffer := bo var cancel context.CancelFunc - if action == actionPrewrite { + if _, ok := action.(actionPrewrite); ok { backoffer, cancel = bo.Fork() defer cancel() } @@ -402,7 +470,8 @@ func (c *twoPhaseCommitter) doActionOnBatches(bo *Backoffer, action twoPhaseComm batch := batch1 go func() { - if action == actionCommit { + var singleBatchBackoffer *Backoffer + if _, ok := action.(actionCommit); ok { // Because the secondary batches of the commit actions are implemented to be // committed asynchronously in background goroutines, we should not // fork a child context and call cancel() while the foreground goroutine exits. @@ -410,12 +479,22 @@ func (c *twoPhaseCommitter) doActionOnBatches(bo *Backoffer, action twoPhaseComm // Here we makes a new clone of the original backoffer for this goroutine // exclusively to avoid the data race when using the same backoffer // in concurrent goroutines. - singleBatchBackoffer := backoffer.Clone() - ch <- singleBatchActionFunc(singleBatchBackoffer, batch) + singleBatchBackoffer = backoffer.Clone() } else { - singleBatchBackoffer, singleBatchCancel := backoffer.Fork() + var singleBatchCancel context.CancelFunc + singleBatchBackoffer, singleBatchCancel = backoffer.Fork() defer singleBatchCancel() - ch <- singleBatchActionFunc(singleBatchBackoffer, batch) + } + beforeSleep := singleBatchBackoffer.totalSleep + ch <- action.handleSingleBatch(c, singleBatchBackoffer, batch) + commitDetail := c.getDetail() + if commitDetail != nil { // lock operations of pessimistic-txn will let commitDetail be nil + if delta := singleBatchBackoffer.totalSleep - beforeSleep; delta > 0 { + atomic.AddInt64(&commitDetail.CommitBackoffTime, int64(singleBatchBackoffer.totalSleep-beforeSleep)*int64(time.Millisecond)) + commitDetail.Mu.Lock() + commitDetail.Mu.BackoffTypes = append(commitDetail.Mu.BackoffTypes, singleBatchBackoffer.types...) + commitDetail.Mu.Unlock() + } } }() } @@ -455,7 +534,7 @@ func (c *twoPhaseCommitter) keySize(key []byte) int { return len(key) } -func (c *twoPhaseCommitter) buildPrewriteRequest(batch batchKeys) *tikvrpc.Request { +func (c *twoPhaseCommitter) buildPrewriteRequest(batch batchKeys, txnSize uint64) *tikvrpc.Request { mutations := make([]*pb.Mutation, len(batch.keys)) var isPessimisticLock []bool if c.isPessimistic { @@ -476,6 +555,8 @@ func (c *twoPhaseCommitter) buildPrewriteRequest(batch batchKeys) *tikvrpc.Reque StartVersion: c.startTS, LockTtl: c.lockTTL, IsPessimisticLock: isPessimisticLock, + ForUpdateTs: c.forUpdateTS, + TxnSize: txnSize, }, Context: pb.Context{ Priority: c.priority, @@ -484,8 +565,15 @@ func (c *twoPhaseCommitter) buildPrewriteRequest(batch batchKeys) *tikvrpc.Reque } } -func (c *twoPhaseCommitter) prewriteSingleBatch(bo *Backoffer, batch batchKeys) error { - req := c.buildPrewriteRequest(batch) +func (actionPrewrite) handleSingleBatch(c *twoPhaseCommitter, bo *Backoffer, batch batchKeys) error { + txnSize := uint64(c.regionTxnSize[batch.region.id]) + // When we retry because of a region miss, we don't know the transaction size. We set the transaction size here + // to MaxUint64 to avoid unexpected "resolve lock lite". + if len(bo.errors) > 0 { + txnSize = math.MaxUint64 + } + + req := c.buildPrewriteRequest(batch, txnSize) for { resp, err := c.store.SendReq(bo, req, batch.region, readTimeoutShort) if err != nil { @@ -541,7 +629,7 @@ func (c *twoPhaseCommitter) prewriteSingleBatch(bo *Backoffer, batch batchKeys) if err != nil { return errors.Trace(err) } - atomic.AddInt64(&c.detail.ResolveLockTime, int64(time.Since(start))) + atomic.AddInt64(&c.getDetail().ResolveLockTime, int64(time.Since(start))) if msBeforeExpired > 0 { err = bo.BackoffWithMaxSleep(BoTxnLock, int(msBeforeExpired), errors.Errorf("2PC prewrite lockedKeys: %d", len(locks))) if err != nil { @@ -551,7 +639,81 @@ func (c *twoPhaseCommitter) prewriteSingleBatch(bo *Backoffer, batch batchKeys) } } -func (c *twoPhaseCommitter) pessimisticLockSingleBatch(bo *Backoffer, batch batchKeys) error { +type ttlManagerState uint32 + +const ( + stateUninitialized ttlManagerState = iota + stateRunning + stateClosed +) + +type ttlManager struct { + state ttlManagerState + ch chan struct{} +} + +func (tm *ttlManager) run(c *twoPhaseCommitter) { + // Run only once. + if !atomic.CompareAndSwapUint32((*uint32)(&tm.state), uint32(stateUninitialized), uint32(stateRunning)) { + return + } + go tm.keepAlive(c) +} + +func (tm *ttlManager) close() { + if !atomic.CompareAndSwapUint32((*uint32)(&tm.state), uint32(stateRunning), uint32(stateClosed)) { + return + } + close(tm.ch) +} + +func (tm *ttlManager) keepAlive(c *twoPhaseCommitter) { + // Ticker is set to 1/3 of the PessimisticLockTTL. + ticker := time.NewTicker(time.Duration(PessimisticLockTTL) * time.Millisecond / 3) + defer ticker.Stop() + for { + select { + case <-tm.ch: + return + case <-ticker.C: + bo := NewBackoffer(context.Background(), pessimisticLockMaxBackoff) + now, err := c.store.GetOracle().GetTimestamp(bo.ctx) + if err != nil { + err1 := bo.Backoff(BoPDRPC, err) + if err1 != nil { + logutil.Logger(context.Background()).Warn("keepAlive get tso fail", + zap.Error(err)) + return + } + continue + } + + uptime := uint64(oracle.ExtractPhysical(now) - oracle.ExtractPhysical(c.startTS)) + const c10min = 10 * 60 * 1000 + if uptime > c10min { + // Set a 10min maximum lifetime for the ttlManager, so when something goes wrong + // the key will not be locked forever. + logutil.Logger(context.Background()).Info("ttlManager live up to its lifetime", + zap.Uint64("txnStartTS", c.startTS)) + return + } + + newTTL := uptime + PessimisticLockTTL + startTime := time.Now() + _, err = sendTxnHeartBeat(bo, c.store, c.primary(), c.startTS, newTTL) + if err != nil { + tiKVTxnHeartBeatHistogramError.Observe(time.Since(startTime).Seconds()) + logutil.Logger(context.Background()).Warn("send TxnHeartBeat failed", + zap.Error(err), + zap.Uint64("txnStartTS", c.startTS)) + return + } + tiKVTxnHeartBeatHistogramOK.Observe(time.Since(startTime).Seconds()) + } + } +} + +func (action actionPessimisticLock) handleSingleBatch(c *twoPhaseCommitter, bo *Backoffer, batch batchKeys) error { mutations := make([]*pb.Mutation, len(batch.keys)) for i, k := range batch.keys { mut := &pb.Mutation{ @@ -572,7 +734,7 @@ func (c *twoPhaseCommitter) pessimisticLockSingleBatch(bo *Backoffer, batch batc PrimaryLock: c.primary(), StartVersion: c.startTS, ForUpdateTs: c.forUpdateTS, - LockTtl: PessimisticLockTTL, + LockTtl: c.pessimisticTTL, IsFirstLock: c.isFirstLock, }, Context: pb.Context{ @@ -594,7 +756,7 @@ func (c *twoPhaseCommitter) pessimisticLockSingleBatch(bo *Backoffer, batch batc if err != nil { return errors.Trace(err) } - err = c.pessimisticLockKeys(bo, batch.keys) + err = c.pessimisticLockKeys(bo, action.killed, batch.keys) return errors.Trace(err) } lockResp := resp.PessimisticLock @@ -616,6 +778,9 @@ func (c *twoPhaseCommitter) pessimisticLockSingleBatch(bo *Backoffer, batch batc } return errors.Trace(conditionPair.Err()) } + if deadlock := keyErr.Deadlock; deadlock != nil { + return &ErrDeadlock{Deadlock: deadlock} + } // Extract lock from key error lock, err1 := extractLockFromKeyErr(keyErr) @@ -624,16 +789,53 @@ func (c *twoPhaseCommitter) pessimisticLockSingleBatch(bo *Backoffer, batch batc } locks = append(locks, lock) } - msBeforeExpired, err := c.store.lockResolver.ResolveLocks(bo, locks) + // Because we already waited on tikv, no need to Backoff here. + _, err = c.store.lockResolver.ResolveLocks(bo, locks) if err != nil { return errors.Trace(err) } - if msBeforeExpired > 0 { - err = bo.BackoffWithMaxSleep(BoTxnLock, int(msBeforeExpired), errors.Errorf("2PC prewrite lockedKeys: %d", len(locks))) + + // Handle the killed flag when waiting for the pessimistic lock. + // When a txn runs into LockKeys() and backoff here, it has no chance to call + // executor.Next() and check the killed flag. + if action.killed != nil { + // Do not reset the killed flag here! + // actionPessimisticLock runs on each region parallelly, we have to consider that + // the error may be dropped. + if atomic.LoadUint32(action.killed) == 1 { + return ErrQueryInterrupted + } + } + } +} + +func (actionPessimisticRollback) handleSingleBatch(c *twoPhaseCommitter, bo *Backoffer, batch batchKeys) error { + req := &tikvrpc.Request{ + Type: tikvrpc.CmdPessimisticRollback, + PessimisticRollback: &pb.PessimisticRollbackRequest{ + StartVersion: c.startTS, + ForUpdateTs: c.forUpdateTS, + Keys: batch.keys, + }, + } + for { + resp, err := c.store.SendReq(bo, req, batch.region, readTimeoutShort) + if err != nil { + return errors.Trace(err) + } + regionErr, err := resp.GetRegionError() + if err != nil { + return errors.Trace(err) + } + if regionErr != nil { + err = bo.Backoff(BoRegionMiss, errors.New(regionErr.String())) if err != nil { return errors.Trace(err) } + err = c.pessimisticRollbackKeys(bo, batch.keys) + return errors.Trace(err) } + return nil } } @@ -661,6 +863,14 @@ func kvPriorityToCommandPri(pri int) pb.CommandPri { return pb.CommandPri_Normal } +func (c *twoPhaseCommitter) setDetail(d *execdetails.CommitDetails) { + atomic.StorePointer(&c.detail, unsafe.Pointer(d)) +} + +func (c *twoPhaseCommitter) getDetail() *execdetails.CommitDetails { + return (*execdetails.CommitDetails)(atomic.LoadPointer(&c.detail)) +} + func (c *twoPhaseCommitter) setUndeterminedErr(err error) { c.mu.Lock() defer c.mu.Unlock() @@ -673,7 +883,7 @@ func (c *twoPhaseCommitter) getUndeterminedErr() error { return c.mu.undeterminedErr } -func (c *twoPhaseCommitter) commitSingleBatch(bo *Backoffer, batch batchKeys) error { +func (actionCommit) handleSingleBatch(c *twoPhaseCommitter, bo *Backoffer, batch batchKeys) error { req := &tikvrpc.Request{ Type: tikvrpc.CmdCommit, Commit: &pb.CommitRequest{ @@ -686,7 +896,6 @@ func (c *twoPhaseCommitter) commitSingleBatch(bo *Backoffer, batch batchKeys) er SyncLog: c.syncLog, }, } - req.Context.Priority = c.priority sender := NewRegionRequestSender(c.store.regionCache, c.store.client) resp, err := sender.SendReq(bo, req, batch.region, readTimeoutShort) @@ -753,7 +962,7 @@ func (c *twoPhaseCommitter) commitSingleBatch(bo *Backoffer, batch batchKeys) er return nil } -func (c *twoPhaseCommitter) cleanupSingleBatch(bo *Backoffer, batch batchKeys) error { +func (actionCleanup) handleSingleBatch(c *twoPhaseCommitter, bo *Backoffer, batch batchKeys) error { req := &tikvrpc.Request{ Type: tikvrpc.CmdBatchRollback, BatchRollback: &pb.BatchRollbackRequest{ @@ -792,19 +1001,35 @@ func (c *twoPhaseCommitter) cleanupSingleBatch(bo *Backoffer, batch batchKeys) e } func (c *twoPhaseCommitter) prewriteKeys(bo *Backoffer, keys [][]byte) error { - return c.doActionOnKeys(bo, actionPrewrite, keys) + if span := opentracing.SpanFromContext(bo.ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("twoPhaseCommitter.prewriteKeys", opentracing.ChildOf(span.Context())) + defer span1.Finish() + bo.ctx = opentracing.ContextWithSpan(bo.ctx, span1) + } + + return c.doActionOnKeys(bo, actionPrewrite{}, keys) } func (c *twoPhaseCommitter) commitKeys(bo *Backoffer, keys [][]byte) error { - return c.doActionOnKeys(bo, actionCommit, keys) + if span := opentracing.SpanFromContext(bo.ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("twoPhaseCommitter.commitKeys", opentracing.ChildOf(span.Context())) + defer span1.Finish() + bo.ctx = opentracing.ContextWithSpan(bo.ctx, span1) + } + + return c.doActionOnKeys(bo, actionCommit{}, keys) } func (c *twoPhaseCommitter) cleanupKeys(bo *Backoffer, keys [][]byte) error { - return c.doActionOnKeys(bo, actionCleanup, keys) + return c.doActionOnKeys(bo, actionCleanup{}, keys) } -func (c *twoPhaseCommitter) pessimisticLockKeys(bo *Backoffer, keys [][]byte) error { - return c.doActionOnKeys(bo, actionPessimisticLock, keys) +func (c *twoPhaseCommitter) pessimisticLockKeys(bo *Backoffer, killed *uint32, keys [][]byte) error { + return c.doActionOnKeys(bo, actionPessimisticLock{killed}, keys) +} + +func (c *twoPhaseCommitter) pessimisticRollbackKeys(bo *Backoffer, keys [][]byte) error { + return c.doActionOnKeys(bo, actionPessimisticRollback{}, keys) } func (c *twoPhaseCommitter) executeAndWriteFinishBinlog(ctx context.Context) error { @@ -849,8 +1074,14 @@ func (c *twoPhaseCommitter) execute(ctx context.Context) error { prewriteBo := NewBackoffer(ctx, prewriteMaxBackoff).WithVars(c.txn.vars) start := time.Now() err := c.prewriteKeys(prewriteBo, c.keys) - c.detail.PrewriteTime = time.Since(start) - c.detail.TotalBackoffTime += time.Duration(prewriteBo.totalSleep) * time.Millisecond + commitDetail := c.getDetail() + commitDetail.PrewriteTime = time.Since(start) + if prewriteBo.totalSleep > 0 { + atomic.AddInt64(&commitDetail.CommitBackoffTime, int64(prewriteBo.totalSleep)*int64(time.Millisecond)) + commitDetail.Mu.Lock() + commitDetail.Mu.BackoffTypes = append(commitDetail.Mu.BackoffTypes, prewriteBo.types...) + commitDetail.Mu.Unlock() + } if binlogChan != nil { binlogErr := <-binlogChan if binlogErr != nil { @@ -872,7 +1103,7 @@ func (c *twoPhaseCommitter) execute(ctx context.Context) error { zap.Uint64("txnStartTS", c.startTS)) return errors.Trace(err) } - c.detail.GetCommitTsTime = time.Since(start) + commitDetail.GetCommitTsTime = time.Since(start) // check commitTS if commitTS <= c.startTS { @@ -901,8 +1132,13 @@ func (c *twoPhaseCommitter) execute(ctx context.Context) error { start = time.Now() commitBo := NewBackoffer(ctx, CommitMaxBackoff).WithVars(c.txn.vars) err = c.commitKeys(commitBo, c.keys) - c.detail.CommitTime = time.Since(start) - c.detail.TotalBackoffTime += time.Duration(commitBo.totalSleep) * time.Millisecond + commitDetail.CommitTime = time.Since(start) + if commitBo.totalSleep > 0 { + atomic.AddInt64(&commitDetail.CommitBackoffTime, int64(commitBo.totalSleep)*int64(time.Millisecond)) + commitDetail.Mu.Lock() + commitDetail.Mu.BackoffTypes = append(commitDetail.Mu.BackoffTypes, commitBo.types...) + commitDetail.Mu.Unlock() + } if err != nil { if undeterminedErr := c.getUndeterminedErr(); undeterminedErr != nil { logutil.Logger(ctx).Error("2PC commit result undetermined", diff --git a/store/tikv/2pc_fail_test.go b/store/tikv/2pc_fail_test.go index 49624f5be5197..b76968f9ffc6b 100644 --- a/store/tikv/2pc_fail_test.go +++ b/store/tikv/2pc_fail_test.go @@ -116,3 +116,27 @@ func (s *testCommitterSuite) TestFailCommitTimeout(c *C) { c.Assert(err, IsNil) c.Assert(len(value), Greater, 0) } + +// TestFailPrewriteRegionError tests data race does not happen on retries +func (s *testCommitterSuite) TestFailPrewriteRegionError(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/mockstore/mocktikv/rpcPrewriteResult", `return("notLeader")`), IsNil) + defer func() { + c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/mockstore/mocktikv/rpcPrewriteResult"), IsNil) + }() + + txn := s.begin(c) + + // Set the value big enough to create many batches. This increases the chance of data races. + var bigVal [18000]byte + for i := 0; i < 1000; i++ { + err := txn.Set([]byte{byte(i)}, bigVal[:]) + c.Assert(err, IsNil) + } + + committer, err := newTwoPhaseCommitterWithInit(txn, 1) + c.Assert(err, IsNil) + + ctx := context.Background() + err = committer.prewriteKeys(NewBackoffer(ctx, 1000), committer.keys) + c.Assert(err, NotNil) +} diff --git a/store/tikv/2pc_test.go b/store/tikv/2pc_test.go index 42f4c5f0cfdd0..7ecb6bb1994ba 100644 --- a/store/tikv/2pc_test.go +++ b/store/tikv/2pc_test.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/mockstore/mocktikv" + "github.com/pingcap/tidb/store/tikv/oracle" "github.com/pingcap/tidb/store/tikv/tikvrpc" ) @@ -442,6 +443,40 @@ func (s *testCommitterSuite) TestWrittenKeysOnConflict(c *C) { c.Assert(totalTime, Less, time.Millisecond*200) } +func (s *testCommitterSuite) TestPrewriteTxnSize(c *C) { + // Prepare two regions first: (, 100) and [100, ) + region, _ := s.cluster.GetRegionByKey([]byte{50}) + newRegionID := s.cluster.AllocID() + newPeerID := s.cluster.AllocID() + s.cluster.Split(region.Id, newRegionID, []byte{100}, []uint64{newPeerID}, newPeerID) + + txn := s.begin(c) + var val [1024]byte + for i := byte(50); i < 120; i++ { + err := txn.Set([]byte{i}, val[:]) + c.Assert(err, IsNil) + } + + commiter, err := newTwoPhaseCommitterWithInit(txn, 1) + c.Assert(err, IsNil) + + ctx := context.Background() + err = commiter.prewriteKeys(NewBackoffer(ctx, prewriteMaxBackoff), commiter.keys) + c.Assert(err, IsNil) + + // Check the written locks in the first region (50 keys) + for i := byte(50); i < 100; i++ { + lock := s.getLockInfo(c, []byte{i}) + c.Assert(int(lock.TxnSize), Equals, 50) + } + + // Check the written locks in the second region (20 keys) + for i := byte(100); i < 120; i++ { + lock := s.getLockInfo(c, []byte{i}) + c.Assert(int(lock.TxnSize), Equals, 20) + } +} + func (s *testCommitterSuite) TestPessimisticPrewriteRequest(c *C) { // This test checks that the isPessimisticLock field is set in the request even when no keys are pessimistic lock. txn := s.begin(c) @@ -450,9 +485,103 @@ func (s *testCommitterSuite) TestPessimisticPrewriteRequest(c *C) { c.Assert(err, IsNil) commiter, err := newTwoPhaseCommitterWithInit(txn, 0) c.Assert(err, IsNil) + commiter.forUpdateTS = 100 var batch batchKeys batch.keys = append(batch.keys, []byte("t1")) batch.region = RegionVerID{1, 1, 1} - req := commiter.buildPrewriteRequest(batch) + req := commiter.buildPrewriteRequest(batch, 1) c.Assert(len(req.Prewrite.IsPessimisticLock), Greater, 0) + c.Assert(req.Prewrite.ForUpdateTs, Equals, uint64(100)) +} + +func (s *testCommitterSuite) TestUnsetPrimaryKey(c *C) { + // This test checks that the isPessimisticLock field is set in the request even when no keys are pessimistic lock. + key := kv.Key("key") + txn := s.begin(c) + c.Assert(txn.Set(key, key), IsNil) + c.Assert(txn.Commit(context.Background()), IsNil) + + txn = s.begin(c) + txn.SetOption(kv.Pessimistic, true) + txn.SetOption(kv.PresumeKeyNotExists, nil) + _, _ = txn.us.Get(key) + c.Assert(txn.Set(key, key), IsNil) + txn.DelOption(kv.PresumeKeyNotExists) + err := txn.LockKeys(context.Background(), nil, txn.startTS, key) + c.Assert(err, NotNil) + c.Assert(txn.Delete(key), IsNil) + key2 := kv.Key("key2") + c.Assert(txn.Set(key2, key2), IsNil) + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) +} + +func (s *testCommitterSuite) TestPessimisticLockedKeysDedup(c *C) { + txn := s.begin(c) + txn.SetOption(kv.Pessimistic, true) + err := txn.LockKeys(context.Background(), nil, 100, kv.Key("abc"), kv.Key("def")) + c.Assert(err, IsNil) + err = txn.LockKeys(context.Background(), nil, 100, kv.Key("abc"), kv.Key("def")) + c.Assert(err, IsNil) + c.Assert(txn.lockKeys, HasLen, 2) +} + +func (s *testCommitterSuite) TestPessimisticTTL(c *C) { + key := kv.Key("key") + txn := s.begin(c) + txn.SetOption(kv.Pessimistic, true) + time.Sleep(time.Millisecond * 100) + err := txn.LockKeys(context.Background(), nil, txn.startTS, key) + c.Assert(err, IsNil) + time.Sleep(time.Millisecond * 100) + key2 := kv.Key("key2") + err = txn.LockKeys(context.Background(), nil, txn.startTS, key2) + c.Assert(err, IsNil) + lockInfo := s.getLockInfo(c, key) + elapsedTTL := lockInfo.LockTtl - PessimisticLockTTL + c.Assert(elapsedTTL, GreaterEqual, uint64(100)) + + lr := newLockResolver(s.store) + bo := NewBackoffer(context.Background(), getMaxBackoff) + status, err := lr.getTxnStatus(bo, txn.startTS, key2, txn.startTS) + c.Assert(err, IsNil) + c.Assert(status.ttl, Equals, lockInfo.LockTtl) + + // Check primary lock TTL is auto increasing while the pessimistic txn is ongoing. + for i := 0; i < 50; i++ { + lockInfoNew := s.getLockInfo(c, key) + if lockInfoNew.LockTtl > lockInfo.LockTtl { + currentTS, err := lr.store.GetOracle().GetTimestamp(bo.ctx) + c.Assert(err, IsNil) + // Check that the TTL is update to a reasonable range. + expire := oracle.ExtractPhysical(txn.startTS) + int64(lockInfoNew.LockTtl) + now := oracle.ExtractPhysical(currentTS) + c.Assert(expire > now, IsTrue) + c.Assert(uint64(expire-now) <= PessimisticLockTTL, IsTrue) + return + } + time.Sleep(100 * time.Millisecond) + } + c.Assert(false, IsTrue, Commentf("update pessimistic ttl fail")) +} + +func (s *testCommitterSuite) getLockInfo(c *C, key []byte) *kvrpcpb.LockInfo { + txn := s.begin(c) + err := txn.Set(key, key) + c.Assert(err, IsNil) + commiter, err := newTwoPhaseCommitterWithInit(txn, 1) + c.Assert(err, IsNil) + bo := NewBackoffer(context.Background(), getMaxBackoff) + loc, err := s.store.regionCache.LocateKey(bo, key) + c.Assert(err, IsNil) + batch := batchKeys{region: loc.Region, keys: [][]byte{key}} + req := commiter.buildPrewriteRequest(batch, 1) + resp, err := s.store.SendReq(bo, req, loc.Region, readTimeoutShort) + c.Assert(err, IsNil) + c.Assert(resp.Prewrite, NotNil) + keyErrs := resp.Prewrite.Errors + c.Assert(keyErrs, HasLen, 1) + locked := keyErrs[0].Locked + c.Assert(locked, NotNil) + return locked } diff --git a/store/tikv/backoff.go b/store/tikv/backoff.go index c34d5eb65643a..0d31e814e275d 100644 --- a/store/tikv/backoff.go +++ b/store/tikv/backoff.go @@ -45,34 +45,42 @@ const ( ) var ( - tikvBackoffCounterRPC = metrics.TiKVBackoffCounter.WithLabelValues("tikvRPC") - tikvBackoffCounterLock = metrics.TiKVBackoffCounter.WithLabelValues("txnLock") - tikvBackoffCounterLockFast = metrics.TiKVBackoffCounter.WithLabelValues("tikvLockFast") - tikvBackoffCounterPD = metrics.TiKVBackoffCounter.WithLabelValues("pdRPC") - tikvBackoffCounterRegionMiss = metrics.TiKVBackoffCounter.WithLabelValues("regionMiss") - tikvBackoffCounterUpdateLeader = metrics.TiKVBackoffCounter.WithLabelValues("updateLeader") - tikvBackoffCounterServerBusy = metrics.TiKVBackoffCounter.WithLabelValues("serverBusy") - tikvBackoffCounterEmpty = metrics.TiKVBackoffCounter.WithLabelValues("") + tikvBackoffCounterRPC = metrics.TiKVBackoffCounter.WithLabelValues("tikvRPC") + tikvBackoffCounterLock = metrics.TiKVBackoffCounter.WithLabelValues("txnLock") + tikvBackoffCounterLockFast = metrics.TiKVBackoffCounter.WithLabelValues("tikvLockFast") + tikvBackoffCounterPD = metrics.TiKVBackoffCounter.WithLabelValues("pdRPC") + tikvBackoffCounterRegionMiss = metrics.TiKVBackoffCounter.WithLabelValues("regionMiss") + tikvBackoffCounterUpdateLeader = metrics.TiKVBackoffCounter.WithLabelValues("updateLeader") + tikvBackoffCounterServerBusy = metrics.TiKVBackoffCounter.WithLabelValues("serverBusy") + tikvBackoffCounterEmpty = metrics.TiKVBackoffCounter.WithLabelValues("") + tikvBackoffHistogramRPC = metrics.TiKVBackoffHistogram.WithLabelValues("tikvRPC") + tikvBackoffHistogramLock = metrics.TiKVBackoffHistogram.WithLabelValues("txnLock") + tikvBackoffHistogramLockFast = metrics.TiKVBackoffHistogram.WithLabelValues("tikvLockFast") + tikvBackoffHistogramPD = metrics.TiKVBackoffHistogram.WithLabelValues("pdRPC") + tikvBackoffHistogramRegionMiss = metrics.TiKVBackoffHistogram.WithLabelValues("regionMiss") + tikvBackoffHistogramUpdateLeader = metrics.TiKVBackoffHistogram.WithLabelValues("updateLeader") + tikvBackoffHistogramServerBusy = metrics.TiKVBackoffHistogram.WithLabelValues("serverBusy") + tikvBackoffHistogramEmpty = metrics.TiKVBackoffHistogram.WithLabelValues("") ) -func (t backoffType) Counter() prometheus.Counter { +func (t backoffType) metric() (prometheus.Counter, prometheus.Observer) { switch t { case boTiKVRPC: - return tikvBackoffCounterRPC + return tikvBackoffCounterRPC, tikvBackoffHistogramRPC case BoTxnLock: - return tikvBackoffCounterLock + return tikvBackoffCounterLock, tikvBackoffHistogramLock case boTxnLockFast: - return tikvBackoffCounterLockFast + return tikvBackoffCounterLockFast, tikvBackoffHistogramLockFast case BoPDRPC: - return tikvBackoffCounterPD + return tikvBackoffCounterPD, tikvBackoffHistogramPD case BoRegionMiss: - return tikvBackoffCounterRegionMiss + return tikvBackoffCounterRegionMiss, tikvBackoffHistogramRegionMiss case BoUpdateLeader: - return tikvBackoffCounterUpdateLeader + return tikvBackoffCounterUpdateLeader, tikvBackoffHistogramUpdateLeader case boServerBusy: - return tikvBackoffCounterServerBusy + return tikvBackoffCounterServerBusy, tikvBackoffHistogramServerBusy } - return tikvBackoffCounterEmpty + return tikvBackoffCounterEmpty, tikvBackoffHistogramEmpty } // NewBackoffFn creates a backoff func which implements exponential backoff with @@ -110,12 +118,12 @@ func NewBackoffFn(base, cap, jitter int) func(ctx context.Context, maxSleepMs in } select { case <-time.After(time.Duration(realSleep) * time.Millisecond): + attempts++ + lastSleep = sleep + return realSleep case <-ctx.Done(): + return 0 } - - attempts++ - lastSleep = sleep - return lastSleep } } @@ -211,9 +219,12 @@ const ( deleteRangeOneRegionMaxBackoff = 100000 rawkvMaxBackoff = 20000 splitRegionBackoff = 20000 + maxSplitRegionsBackoff = 120000 scatterRegionBackoff = 20000 waitScatterRegionFinishBackoff = 120000 locateRegionMaxBackoff = 20000 + pessimisticLockMaxBackoff = 10000 + pessimisticRollbackMaxBackoff = 10000 ) // CommitMaxBackoff is max sleep time of the 'commit' command @@ -227,8 +238,9 @@ type Backoffer struct { maxSleep int totalSleep int errors []error - types []backoffType + types []fmt.Stringer vars *kv.Variables + noop bool } // txnStartKey is a key for transaction start_ts info in context.Context. @@ -243,6 +255,11 @@ func NewBackoffer(ctx context.Context, maxSleep int) *Backoffer { } } +// NewNoopBackoff create a Backoffer do nothing just return error directly +func NewNoopBackoff(ctx context.Context) *Backoffer { + return &Backoffer{ctx: ctx, noop: true} +} + // WithVars sets the kv.Variables to the Backoffer and return it. func (b *Backoffer) WithVars(vars *kv.Variables) *Backoffer { if vars != nil { @@ -274,7 +291,23 @@ func (b *Backoffer) BackoffWithMaxSleep(typ backoffType, maxSleepMs int, err err default: } - typ.Counter().Inc() + b.errors = append(b.errors, errors.Errorf("%s at %s", err.Error(), time.Now().Format(time.RFC3339Nano))) + b.types = append(b.types, typ) + if b.noop || (b.maxSleep > 0 && b.totalSleep >= b.maxSleep) { + errMsg := fmt.Sprintf("%s backoffer.maxSleep %dms is exceeded, errors:", typ.String(), b.maxSleep) + for i, err := range b.errors { + // Print only last 3 errors for non-DEBUG log levels. + if log.GetLevel() == zapcore.DebugLevel || i >= len(b.errors)-3 { + errMsg += "\n" + err.Error() + } + } + logutil.Logger(context.Background()).Warn(errMsg) + // Use the first backoff type to generate a MySQL error. + return b.types[0].(backoffType).TError() + } + + backoffCounter, backoffDuration := typ.metric() + backoffCounter.Inc() // Lazy initialize. if b.fn == nil { b.fn = make(map[backoffType]func(context.Context, int) int) @@ -285,8 +318,9 @@ func (b *Backoffer) BackoffWithMaxSleep(typ backoffType, maxSleepMs int, err err b.fn[typ] = f } - b.totalSleep += f(b.ctx, maxSleepMs) - b.types = append(b.types, typ) + realSleep := f(b.ctx, maxSleepMs) + backoffDuration.Observe(float64(realSleep) / 1000) + b.totalSleep += realSleep var startTs interface{} if ts := b.ctx.Value(txnStartKey); ts != nil { @@ -298,20 +332,6 @@ func (b *Backoffer) BackoffWithMaxSleep(typ backoffType, maxSleepMs int, err err zap.Int("maxSleep", b.maxSleep), zap.Stringer("type", typ), zap.Reflect("txnStartTS", startTs)) - - b.errors = append(b.errors, errors.Errorf("%s at %s", err.Error(), time.Now().Format(time.RFC3339Nano))) - if b.maxSleep > 0 && b.totalSleep >= b.maxSleep { - errMsg := fmt.Sprintf("%s backoffer.maxSleep %dms is exceeded, errors:", typ.String(), b.maxSleep) - for i, err := range b.errors { - // Print only last 3 errors for non-DEBUG log levels. - if log.GetLevel() == zapcore.DebugLevel || i >= len(b.errors)-3 { - errMsg += "\n" + err.Error() - } - } - logutil.Logger(context.Background()).Warn(errMsg) - // Use the first backoff type to generate a MySQL error. - return b.types[0].TError() - } return nil } diff --git a/store/tikv/backoff_test.go b/store/tikv/backoff_test.go new file mode 100644 index 0000000000000..ddf7d1fcf86f6 --- /dev/null +++ b/store/tikv/backoff_test.go @@ -0,0 +1,42 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package tikv + +import ( + "context" + "errors" + . "github.com/pingcap/check" +) + +type testBackoffSuite struct { + OneByOneSuite + store *tikvStore +} + +var _ = Suite(&testBackoffSuite{}) + +func (s *testBackoffSuite) SetUpTest(c *C) { + s.store = NewTestStore(c).(*tikvStore) +} + +func (s *testBackoffSuite) TearDownTest(c *C) { + s.store.Close() +} + +func (s *testBackoffSuite) TestBackoffWithMax(c *C) { + b := NewBackoffer(context.TODO(), 2000) + err := b.BackoffWithMaxSleep(boTxnLockFast, 30, errors.New("test")) + c.Assert(err, IsNil) + c.Assert(b.totalSleep, Equals, 30) +} diff --git a/store/tikv/client.go b/store/tikv/client.go index ebc5ef266fa2c..b32dd0a9848e8 100644 --- a/store/tikv/client.go +++ b/store/tikv/client.go @@ -17,6 +17,7 @@ package tikv import ( "context" "io" + "math" "strconv" "sync" "sync/atomic" @@ -32,21 +33,15 @@ import ( "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/store/tikv/tikvrpc" "github.com/pingcap/tidb/util/logutil" - "go.uber.org/zap" "google.golang.org/grpc" - gcodes "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" - gstatus "google.golang.org/grpc/status" ) -// MaxSendMsgSize set max gRPC request message size sent to server. If any request message size is larger than +// MaxRecvMsgSize set max gRPC receive message size received from server. If any message size is larger than // current value, an error will be reported from gRPC. -var MaxSendMsgSize = 1<<31 - 1 - -// MaxCallMsgSize set max gRPC receive message size received from server. If any message size is larger than -// current value, an error will be reported from gRPC. -var MaxCallMsgSize = 1<<31 - 1 +var MaxRecvMsgSize = math.MaxInt64 // Timeout durations. const ( @@ -78,144 +73,23 @@ type connArray struct { v []*grpc.ClientConn // streamTimeout binds with a background goroutine to process coprocessor streaming timeout. streamTimeout chan *tikvrpc.Lease - - // batchCommandsCh used for batch commands. - batchCommandsCh chan *batchCommandsEntry - batchCommandsClients []*batchCommandsClient - tikvTransportLayerLoad uint64 -} - -type batchCommandsClient struct { - // The target host. - target string - - conn *grpc.ClientConn - client tikvpb.Tikv_BatchCommandsClient - batched sync.Map - idAlloc uint64 - tikvTransportLayerLoad *uint64 - - // closed indicates the batch client is closed explicitly or not. - closed int32 - // clientLock protects client when re-create the streaming. - clientLock sync.Mutex -} - -func (c *batchCommandsClient) isStopped() bool { - return atomic.LoadInt32(&c.closed) != 0 + // batchConn is not null when batch is enabled. + *batchConn } -func (c *batchCommandsClient) failPendingRequests(err error) { - c.batched.Range(func(key, value interface{}) bool { - id, _ := key.(uint64) - entry, _ := value.(*batchCommandsEntry) - entry.err = err - close(entry.res) - c.batched.Delete(id) - return true - }) -} - -func (c *batchCommandsClient) batchRecvLoop(cfg config.TiKVClient) { - defer func() { - if r := recover(); r != nil { - metrics.PanicCounter.WithLabelValues(metrics.LabelBatchRecvLoop).Inc() - logutil.Logger(context.Background()).Error("batchRecvLoop", - zap.Reflect("r", r), - zap.Stack("stack")) - logutil.Logger(context.Background()).Info("restart batchRecvLoop") - go c.batchRecvLoop(cfg) - } - }() - - for { - // When `conn.Close()` is called, `client.Recv()` will return an error. - resp, err := c.client.Recv() - if err != nil { - - now := time.Now() - for { // try to re-create the streaming in the loop. - if c.isStopped() { - return - } - logutil.Logger(context.Background()).Error( - "batchRecvLoop error when receive", - zap.String("target", c.target), - zap.Error(err), - ) - - // Hold the lock to forbid batchSendLoop using the old client. - c.clientLock.Lock() - c.failPendingRequests(err) // fail all pending requests. - - // Re-establish a application layer stream. TCP layer is handled by gRPC. - tikvClient := tikvpb.NewTikvClient(c.conn) - streamClient, err := tikvClient.BatchCommands(context.TODO()) - c.clientLock.Unlock() - - if err == nil { - logutil.Logger(context.Background()).Info( - "batchRecvLoop re-create streaming success", - zap.String("target", c.target), - ) - c.client = streamClient - break - } - logutil.Logger(context.Background()).Error( - "batchRecvLoop re-create streaming fail", - zap.String("target", c.target), - zap.Error(err), - ) - // TODO: Use a more smart backoff strategy. - time.Sleep(time.Second) - } - metrics.TiKVBatchClientUnavailable.Observe(time.Since(now).Seconds()) - continue - } - - responses := resp.GetResponses() - for i, requestID := range resp.GetRequestIds() { - value, ok := c.batched.Load(requestID) - if !ok { - // There shouldn't be any unknown responses because if the old entries - // are cleaned by `failPendingRequests`, the stream must be re-created - // so that old responses will be never received. - panic("batchRecvLoop receives a unknown response") - } - entry := value.(*batchCommandsEntry) - if atomic.LoadInt32(&entry.canceled) == 0 { - // Put the response only if the request is not canceled. - entry.res <- responses[i] - } - c.batched.Delete(requestID) - } - - tikvTransportLayerLoad := resp.GetTransportLayerLoad() - if tikvTransportLayerLoad > 0.0 && cfg.MaxBatchWaitTime > 0 { - // We need to consider TiKV load only if batch-wait strategy is enabled. - atomic.StoreUint64(c.tikvTransportLayerLoad, tikvTransportLayerLoad) - } - } -} - -func newConnArray(maxSize uint, addr string, security config.Security) (*connArray, error) { - cfg := config.GetGlobalConfig() +func newConnArray(maxSize uint, addr string, security config.Security, idleNotify *uint32) (*connArray, error) { a := &connArray{ index: 0, v: make([]*grpc.ClientConn, maxSize), streamTimeout: make(chan *tikvrpc.Lease, 1024), - - batchCommandsCh: make(chan *batchCommandsEntry, cfg.TiKVClient.MaxBatchSize), - batchCommandsClients: make([]*batchCommandsClient, 0, maxSize), - tikvTransportLayerLoad: 0, } - if err := a.Init(addr, security); err != nil { + if err := a.Init(addr, security, idleNotify); err != nil { return nil, err } return a, nil } -func (a *connArray) Init(addr string, security config.Security) error { +func (a *connArray) Init(addr string, security config.Security, idleNotify *uint32) error { a.target = addr opt := grpc.WithInsecure() @@ -238,6 +112,10 @@ func (a *connArray) Init(addr string, security config.Security) error { } allowBatch := cfg.TiKVClient.MaxBatchSize > 0 + if allowBatch { + a.batchConn = newBatchConn(uint(len(a.v)), cfg.TiKVClient.MaxBatchSize, idleNotify) + a.pendingRequests = metrics.TiKVPendingBatchRequests.WithLabelValues(a.target) + } keepAlive := cfg.TiKVClient.GrpcKeepAliveTime keepAliveTimeout := cfg.TiKVClient.GrpcKeepAliveTimeout for i := range a.v { @@ -250,8 +128,7 @@ func (a *connArray) Init(addr string, security config.Security) error { grpc.WithInitialConnWindowSize(grpcInitialConnWindowSize), grpc.WithUnaryInterceptor(unaryInterceptor), grpc.WithStreamInterceptor(streamInterceptor), - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxCallMsgSize)), - grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(MaxSendMsgSize)), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(MaxRecvMsgSize)), grpc.WithBackoffMaxDelay(time.Second*3), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: time.Duration(keepAlive) * time.Second, @@ -268,24 +145,16 @@ func (a *connArray) Init(addr string, security config.Security) error { a.v[i] = conn if allowBatch { - // Initialize batch streaming clients. - tikvClient := tikvpb.NewTikvClient(conn) - streamClient, err := tikvClient.BatchCommands(context.TODO()) - if err != nil { - a.Close() - return errors.Trace(err) - } batchClient := &batchCommandsClient{ - target: a.target, - conn: conn, - client: streamClient, - batched: sync.Map{}, - idAlloc: 0, - tikvTransportLayerLoad: &a.tikvTransportLayerLoad, - closed: 0, + target: a.target, + conn: conn, + batched: sync.Map{}, + idAlloc: 0, + closed: 0, + tikvClientCfg: cfg.TiKVClient, + tikvLoad: &a.tikvTransportLayerLoad, } a.batchCommandsClients = append(a.batchCommandsClients, batchClient) - go batchClient.batchRecvLoop(cfg.TiKVClient) } } go tikvrpc.CheckStreamTimeoutLoop(a.streamTimeout) @@ -302,12 +171,9 @@ func (a *connArray) Get() *grpc.ClientConn { } func (a *connArray) Close() { - // Close all batchRecvLoop. - for _, c := range a.batchCommandsClients { - // After connections are closed, `batchRecvLoop`s will check the flag. - atomic.StoreInt32(&c.closed, 1) + if a.batchConn != nil { + a.batchConn.Close() } - close(a.batchCommandsCh) for i, c := range a.v { if c != nil { @@ -319,182 +185,19 @@ func (a *connArray) Close() { close(a.streamTimeout) } -type batchCommandsEntry struct { - req *tikvpb.BatchCommandsRequest_Request - res chan *tikvpb.BatchCommandsResponse_Response - - // canceled indicated the request is canceled or not. - canceled int32 - err error -} - -// fetchAllPendingRequests fetches all pending requests from the channel. -func fetchAllPendingRequests( - ch chan *batchCommandsEntry, - maxBatchSize int, - entries *[]*batchCommandsEntry, - requests *[]*tikvpb.BatchCommandsRequest_Request, -) { - // Block on the first element. - headEntry := <-ch - if headEntry == nil { - return - } - *entries = append(*entries, headEntry) - *requests = append(*requests, headEntry.req) - - // This loop is for trying best to collect more requests. - for len(*entries) < maxBatchSize { - select { - case entry := <-ch: - if entry == nil { - return - } - *entries = append(*entries, entry) - *requests = append(*requests, entry.req) - default: - return - } - } -} - -// fetchMorePendingRequests fetches more pending requests from the channel. -func fetchMorePendingRequests( - ch chan *batchCommandsEntry, - maxBatchSize int, - batchWaitSize int, - maxWaitTime time.Duration, - entries *[]*batchCommandsEntry, - requests *[]*tikvpb.BatchCommandsRequest_Request, -) { - waitStart := time.Now() - - // Try to collect `batchWaitSize` requests, or wait `maxWaitTime`. - after := time.NewTimer(maxWaitTime) - for len(*entries) < batchWaitSize { - select { - case entry := <-ch: - if entry == nil { - return - } - *entries = append(*entries, entry) - *requests = append(*requests, entry.req) - case waitEnd := <-after.C: - metrics.TiKVBatchWaitDuration.Observe(float64(waitEnd.Sub(waitStart))) - return - } - } - after.Stop() - - // Do an additional non-block try. Here we test the lengh with `maxBatchSize` instead - // of `batchWaitSize` because trying best to fetch more requests is necessary so that - // we can adjust the `batchWaitSize` dynamically. - for len(*entries) < maxBatchSize { - select { - case entry := <-ch: - if entry == nil { - return - } - *entries = append(*entries, entry) - *requests = append(*requests, entry.req) - default: - metrics.TiKVBatchWaitDuration.Observe(float64(time.Since(waitStart))) - return - } - } -} - -func (a *connArray) batchSendLoop(cfg config.TiKVClient) { - defer func() { - if r := recover(); r != nil { - metrics.PanicCounter.WithLabelValues(metrics.LabelBatchSendLoop).Inc() - logutil.Logger(context.Background()).Error("batchSendLoop", - zap.Reflect("r", r), - zap.Stack("stack")) - logutil.Logger(context.Background()).Info("restart batchSendLoop") - go a.batchSendLoop(cfg) - } - }() - - entries := make([]*batchCommandsEntry, 0, cfg.MaxBatchSize) - requests := make([]*tikvpb.BatchCommandsRequest_Request, 0, cfg.MaxBatchSize) - requestIDs := make([]uint64, 0, cfg.MaxBatchSize) - - var bestBatchWaitSize = cfg.BatchWaitSize - for { - // Choose a connection by round-robbin. - next := atomic.AddUint32(&a.index, 1) % uint32(len(a.v)) - batchCommandsClient := a.batchCommandsClients[next] - - entries = entries[:0] - requests = requests[:0] - requestIDs = requestIDs[:0] - - metrics.TiKVPendingBatchRequests.Set(float64(len(a.batchCommandsCh))) - fetchAllPendingRequests(a.batchCommandsCh, int(cfg.MaxBatchSize), &entries, &requests) - - if len(entries) < int(cfg.MaxBatchSize) && cfg.MaxBatchWaitTime > 0 { - tikvTransportLayerLoad := atomic.LoadUint64(batchCommandsClient.tikvTransportLayerLoad) - // If the target TiKV is overload, wait a while to collect more requests. - if uint(tikvTransportLayerLoad) >= cfg.OverloadThreshold { - fetchMorePendingRequests( - a.batchCommandsCh, int(cfg.MaxBatchSize), int(bestBatchWaitSize), - cfg.MaxBatchWaitTime, &entries, &requests, - ) - } - } - length := len(requests) - if uint(length) == 0 { - // The batch command channel is closed. - return - } else if uint(length) < bestBatchWaitSize && bestBatchWaitSize > 1 { - // Waits too long to collect requests, reduce the target batch size. - bestBatchWaitSize -= 1 - } else if uint(length) > bestBatchWaitSize+4 && bestBatchWaitSize < cfg.MaxBatchSize { - bestBatchWaitSize += 1 - } - - maxBatchID := atomic.AddUint64(&batchCommandsClient.idAlloc, uint64(length)) - for i := 0; i < length; i++ { - requestID := uint64(i) + maxBatchID - uint64(length) - requestIDs = append(requestIDs, requestID) - } - - request := &tikvpb.BatchCommandsRequest{ - Requests: requests, - RequestIds: requestIDs, - } - - // Use the lock to protect the stream client won't be replaced by RecvLoop, - // and new added request won't be removed by `failPendingRequests`. - batchCommandsClient.clientLock.Lock() - for i, requestID := range request.RequestIds { - batchCommandsClient.batched.Store(requestID, entries[i]) - } - err := batchCommandsClient.client.Send(request) - batchCommandsClient.clientLock.Unlock() - if err != nil { - logutil.Logger(context.Background()).Error( - "batch commands send error", - zap.String("target", a.target), - zap.Error(err), - ) - batchCommandsClient.failPendingRequests(err) - } - } -} - // rpcClient is RPC client struct. // TODO: Add flow control between RPC clients in TiDB ond RPC servers in TiKV. // Since we use shared client connection to communicate to the same TiKV, it's possible // that there are too many concurrent requests which overload the service of TiKV. -// TODO: Implement background cleanup. It adds a background goroutine to periodically check -// whether there is any connection is idle and then close and remove these idle connections. type rpcClient struct { sync.RWMutex isClosed bool conns map[string]*connArray security config.Security + + // Implement background cleanup. + // Periodically check whether there is any connection that is idle and then close and remove these idle connections. + idleNotify uint32 } func newRPCClient(security config.Security) *rpcClient { @@ -504,6 +207,11 @@ func newRPCClient(security config.Security) *rpcClient { } } +// NewTestRPCClient is for some external tests. +func NewTestRPCClient() Client { + return newRPCClient(config.Security{}) +} + func (c *rpcClient) getConnArray(addr string) (*connArray, error) { c.RLock() if c.isClosed { @@ -529,7 +237,7 @@ func (c *rpcClient) createConnArray(addr string) (*connArray, error) { if !ok { var err error connCount := config.GetGlobalConfig().TiKVClient.GrpcConnectionCount - array, err = newConnArray(connCount, addr, c.security) + array, err = newConnArray(connCount, addr, c.security, &c.idleNotify) if err != nil { return nil, err } @@ -550,41 +258,6 @@ func (c *rpcClient) closeConns() { c.Unlock() } -func sendBatchRequest( - ctx context.Context, - addr string, - connArray *connArray, - req *tikvpb.BatchCommandsRequest_Request, - timeout time.Duration, -) (*tikvrpc.Response, error) { - entry := &batchCommandsEntry{ - req: req, - res: make(chan *tikvpb.BatchCommandsResponse_Response, 1), - canceled: 0, - err: nil, - } - ctx1, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - select { - case connArray.batchCommandsCh <- entry: - case <-ctx1.Done(): - logutil.Logger(context.Background()).Warn("send request is timeout", zap.String("to", addr)) - return nil, errors.Trace(gstatus.Error(gcodes.DeadlineExceeded, "Canceled or timeout")) - } - - select { - case res, ok := <-entry.res: - if !ok { - return nil, errors.Trace(entry.err) - } - return tikvrpc.FromBatchCommandsResponse(res), nil - case <-ctx1.Done(): - atomic.StoreInt32(&entry.canceled, 1) - logutil.Logger(context.Background()).Warn("send request is canceled", zap.String("to", addr)) - return nil, errors.Trace(gstatus.Error(gcodes.DeadlineExceeded, "Canceled or timeout")) - } -} - // SendRequest sends a Request to server and receives Response. func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { start := time.Now() @@ -594,6 +267,10 @@ func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R metrics.TiKVSendReqHistogram.WithLabelValues(reqType, storeID).Observe(time.Since(start).Seconds()) }() + if atomic.CompareAndSwapUint32(&c.idleNotify, 1, 0) { + c.recycleIdleConnArray() + } + connArray, err := c.getConnArray(addr) if err != nil { return nil, errors.Trace(err) @@ -601,18 +278,23 @@ func (c *rpcClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R if config.GetGlobalConfig().TiKVClient.MaxBatchSize > 0 { if batchReq := req.ToBatchCommandsRequest(); batchReq != nil { - return sendBatchRequest(ctx, addr, connArray, batchReq, timeout) + return sendBatchRequest(ctx, addr, connArray.batchConn, batchReq, timeout) } } + clientConn := connArray.Get() + if state := clientConn.GetState(); state == connectivity.TransientFailure { + metrics.GRPCConnTransientFailureCounter.WithLabelValues(addr, storeID).Inc() + } + if req.IsDebugReq() { - client := debugpb.NewDebugClient(connArray.Get()) + client := debugpb.NewDebugClient(clientConn) ctx1, cancel := context.WithTimeout(ctx, timeout) defer cancel() return tikvrpc.CallDebugRPC(ctx1, client, req) } - client := tikvpb.NewTikvClient(connArray.Get()) + client := tikvpb.NewTikvClient(clientConn) if req.Type != tikvrpc.CmdCopStream { ctx1, cancel := context.WithTimeout(ctx, timeout) diff --git a/store/tikv/client_batch.go b/store/tikv/client_batch.go new file mode 100644 index 0000000000000..1f349c70b9a8f --- /dev/null +++ b/store/tikv/client_batch.go @@ -0,0 +1,599 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tikv provides tcp connection to kvserver. +package tikv + +import ( + "context" + "math" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/tikvpb" + "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/metrics" + "github.com/pingcap/tidb/store/tikv/tikvrpc" + "github.com/pingcap/tidb/util/logutil" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" +) + +type batchConn struct { + // An atomic flag indicates whether the batch is idle or not. + // 0 for busy, others for idle. + idle uint32 + + // batchCommandsCh used for batch commands. + batchCommandsCh chan *batchCommandsEntry + batchCommandsClients []*batchCommandsClient + tikvTransportLayerLoad uint64 + closed chan struct{} + + // Notify rpcClient to check the idle flag + idleNotify *uint32 + idleDetect *time.Timer + + pendingRequests prometheus.Gauge + + index uint32 +} + +func newBatchConn(connCount, maxBatchSize uint, idleNotify *uint32) *batchConn { + return &batchConn{ + batchCommandsCh: make(chan *batchCommandsEntry, maxBatchSize), + batchCommandsClients: make([]*batchCommandsClient, 0, connCount), + tikvTransportLayerLoad: 0, + closed: make(chan struct{}), + + idleNotify: idleNotify, + idleDetect: time.NewTimer(idleTimeout), + } +} + +func (a *batchConn) isIdle() bool { + return atomic.LoadUint32(&a.idle) != 0 +} + +// fetchAllPendingRequests fetches all pending requests from the channel. +func (a *batchConn) fetchAllPendingRequests( + maxBatchSize int, + entries *[]*batchCommandsEntry, + requests *[]*tikvpb.BatchCommandsRequest_Request, +) { + // Block on the first element. + var headEntry *batchCommandsEntry + select { + case headEntry = <-a.batchCommandsCh: + if !a.idleDetect.Stop() { + <-a.idleDetect.C + } + a.idleDetect.Reset(idleTimeout) + case <-a.idleDetect.C: + a.idleDetect.Reset(idleTimeout) + atomic.AddUint32(&a.idle, 1) + atomic.CompareAndSwapUint32(a.idleNotify, 0, 1) + // This batchConn to be recycled + return + case <-a.closed: + return + } + if headEntry == nil { + return + } + *entries = append(*entries, headEntry) + *requests = append(*requests, headEntry.req) + + // This loop is for trying best to collect more requests. + for len(*entries) < maxBatchSize { + select { + case entry := <-a.batchCommandsCh: + if entry == nil { + return + } + *entries = append(*entries, entry) + *requests = append(*requests, entry.req) + default: + return + } + } +} + +// fetchMorePendingRequests fetches more pending requests from the channel. +func fetchMorePendingRequests( + ch chan *batchCommandsEntry, + maxBatchSize int, + batchWaitSize int, + maxWaitTime time.Duration, + entries *[]*batchCommandsEntry, + requests *[]*tikvpb.BatchCommandsRequest_Request, +) { + waitStart := time.Now() + + // Try to collect `batchWaitSize` requests, or wait `maxWaitTime`. + after := time.NewTimer(maxWaitTime) + for len(*entries) < batchWaitSize { + select { + case entry := <-ch: + if entry == nil { + return + } + *entries = append(*entries, entry) + *requests = append(*requests, entry.req) + case waitEnd := <-after.C: + metrics.TiKVBatchWaitDuration.Observe(float64(waitEnd.Sub(waitStart))) + return + } + } + after.Stop() + + // Do an additional non-block try. Here we test the lengh with `maxBatchSize` instead + // of `batchWaitSize` because trying best to fetch more requests is necessary so that + // we can adjust the `batchWaitSize` dynamically. + for len(*entries) < maxBatchSize { + select { + case entry := <-ch: + if entry == nil { + return + } + *entries = append(*entries, entry) + *requests = append(*requests, entry.req) + default: + metrics.TiKVBatchWaitDuration.Observe(float64(time.Since(waitStart))) + return + } + } +} + +type tryLock struct { + sync.RWMutex + reCreating bool +} + +func (l *tryLock) tryLockForSend() bool { + l.RLock() + if l.reCreating { + l.RUnlock() + return false + } + return true +} + +func (l *tryLock) unlockForSend() { + l.RUnlock() +} + +func (l *tryLock) lockForRecreate() { + l.Lock() + l.reCreating = true + l.Unlock() + +} + +func (l *tryLock) unlockForRecreate() { + l.Lock() + l.reCreating = false + l.Unlock() +} + +type batchCommandsClient struct { + // The target host. + target string + + conn *grpc.ClientConn + client tikvpb.Tikv_BatchCommandsClient + batched sync.Map + idAlloc uint64 + + tikvClientCfg config.TiKVClient + tikvLoad *uint64 + + // closed indicates the batch client is closed explicitly or not. + closed int32 + // tryLock protects client when re-create the streaming. + tryLock +} + +func (c *batchCommandsClient) isStopped() bool { + return atomic.LoadInt32(&c.closed) != 0 +} + +func (c *batchCommandsClient) send(request *tikvpb.BatchCommandsRequest, entries []*batchCommandsEntry) { + for i, requestID := range request.RequestIds { + c.batched.Store(requestID, entries[i]) + } + + if err := c.initBatchClient(); err != nil { + logutil.Logger(context.Background()).Warn( + "init create streaming fail", + zap.String("target", c.target), + zap.Error(err), + ) + c.failPendingRequests(err) + return + } + + if err := c.client.Send(request); err != nil { + logutil.Logger(context.Background()).Info( + "sending batch commands meets error", + zap.String("target", c.target), + zap.Error(err), + ) + c.failPendingRequests(err) + } +} + +func (c *batchCommandsClient) recv() (*tikvpb.BatchCommandsResponse, error) { + failpoint.Inject("gotErrorInRecvLoop", func(_ failpoint.Value) (*tikvpb.BatchCommandsResponse, error) { + return nil, errors.New("injected error in batchRecvLoop") + }) + // When `conn.Close()` is called, `client.Recv()` will return an error. + return c.client.Recv() +} + +// `failPendingRequests` must be called in locked contexts in order to avoid double closing channels. +func (c *batchCommandsClient) failPendingRequests(err error) { + failpoint.Inject("panicInFailPendingRequests", nil) + c.batched.Range(func(key, value interface{}) bool { + id, _ := key.(uint64) + entry, _ := value.(*batchCommandsEntry) + entry.err = err + c.batched.Delete(id) + close(entry.res) + return true + }) +} + +func (c *batchCommandsClient) waitConnReady() (err error) { + dialCtx, cancel := context.WithTimeout(context.Background(), dialTimeout) + for { + s := c.conn.GetState() + if s == connectivity.Ready { + cancel() + break + } + if !c.conn.WaitForStateChange(dialCtx, s) { + cancel() + err = dialCtx.Err() + return + } + } + return +} + +func (c *batchCommandsClient) reCreateStreamingClientOnce(perr error) error { + c.failPendingRequests(perr) // fail all pending requests. + + err := c.waitConnReady() + // Re-establish a application layer stream. TCP layer is handled by gRPC. + if err == nil { + tikvClient := tikvpb.NewTikvClient(c.conn) + var streamClient tikvpb.Tikv_BatchCommandsClient + streamClient, err = tikvClient.BatchCommands(context.TODO()) + if err == nil { + logutil.Logger(context.Background()).Info( + "batchRecvLoop re-create streaming success", + zap.String("target", c.target), + ) + c.client = streamClient + return nil + } + } + logutil.Logger(context.Background()).Info( + "batchRecvLoop re-create streaming fail", + zap.String("target", c.target), + zap.Error(err), + ) + return err +} + +func (c *batchCommandsClient) batchRecvLoop(cfg config.TiKVClient, tikvTransportLayerLoad *uint64) { + defer func() { + if r := recover(); r != nil { + metrics.PanicCounter.WithLabelValues(metrics.LabelBatchRecvLoop).Inc() + logutil.Logger(context.Background()).Error("batchRecvLoop", + zap.Reflect("r", r), + zap.Stack("stack")) + logutil.Logger(context.Background()).Info("restart batchRecvLoop") + go c.batchRecvLoop(cfg, tikvTransportLayerLoad) + } + }() + + for { + resp, err := c.recv() + if err != nil { + if c.isStopped() { + return + } + logutil.Logger(context.Background()).Info( + "batchRecvLoop fails when receiving, needs to reconnect", + zap.String("target", c.target), + zap.Error(err), + ) + + now := time.Now() + if stopped := c.reCreateStreamingClient(err); stopped { + return + } + metrics.TiKVBatchClientUnavailable.Observe(time.Since(now).Seconds()) + continue + } + + responses := resp.GetResponses() + for i, requestID := range resp.GetRequestIds() { + value, ok := c.batched.Load(requestID) + if !ok { + // There shouldn't be any unknown responses because if the old entries + // are cleaned by `failPendingRequests`, the stream must be re-created + // so that old responses will be never received. + panic("batchRecvLoop receives a unknown response") + } + entry := value.(*batchCommandsEntry) + if atomic.LoadInt32(&entry.canceled) == 0 { + // Put the response only if the request is not canceled. + entry.res <- responses[i] + } + c.batched.Delete(requestID) + } + + transportLayerLoad := resp.GetTransportLayerLoad() + if transportLayerLoad > 0.0 && cfg.MaxBatchWaitTime > 0 { + // We need to consider TiKV load only if batch-wait strategy is enabled. + atomic.StoreUint64(tikvTransportLayerLoad, transportLayerLoad) + } + } +} + +func (c *batchCommandsClient) reCreateStreamingClient(err error) (stopped bool) { + // Forbids the batchSendLoop using the old client. + c.lockForRecreate() + defer c.unlockForRecreate() + + b := NewBackoffer(context.Background(), math.MaxInt32) + for { // try to re-create the streaming in the loop. + if c.isStopped() { + return true + } + err1 := c.reCreateStreamingClientOnce(err) + if err1 == nil { + break + } + + err2 := b.Backoff(boTiKVRPC, err1) + // As timeout is set to math.MaxUint32, err2 should always be nil. + // This line is added to make the 'make errcheck' pass. + terror.Log(err2) + } + return false +} + +type batchCommandsEntry struct { + req *tikvpb.BatchCommandsRequest_Request + res chan *tikvpb.BatchCommandsResponse_Response + + // canceled indicated the request is canceled or not. + canceled int32 + err error +} + +func (b *batchCommandsEntry) isCanceled() bool { + return atomic.LoadInt32(&b.canceled) == 1 +} + +const idleTimeout = 3 * time.Minute + +func (a *batchConn) batchSendLoop(cfg config.TiKVClient) { + defer func() { + if r := recover(); r != nil { + metrics.PanicCounter.WithLabelValues(metrics.LabelBatchSendLoop).Inc() + logutil.Logger(context.Background()).Error("batchSendLoop", + zap.Reflect("r", r), + zap.Stack("stack")) + logutil.Logger(context.Background()).Info("restart batchSendLoop") + go a.batchSendLoop(cfg) + } + }() + + entries := make([]*batchCommandsEntry, 0, cfg.MaxBatchSize) + requests := make([]*tikvpb.BatchCommandsRequest_Request, 0, cfg.MaxBatchSize) + requestIDs := make([]uint64, 0, cfg.MaxBatchSize) + + var bestBatchWaitSize = cfg.BatchWaitSize + for { + entries = entries[:0] + requests = requests[:0] + requestIDs = requestIDs[:0] + + a.pendingRequests.Set(float64(len(a.batchCommandsCh))) + a.fetchAllPendingRequests(int(cfg.MaxBatchSize), &entries, &requests) + + if len(entries) < int(cfg.MaxBatchSize) && cfg.MaxBatchWaitTime > 0 { + // If the target TiKV is overload, wait a while to collect more requests. + if atomic.LoadUint64(&a.tikvTransportLayerLoad) >= uint64(cfg.OverloadThreshold) { + fetchMorePendingRequests( + a.batchCommandsCh, int(cfg.MaxBatchSize), int(bestBatchWaitSize), + cfg.MaxBatchWaitTime, &entries, &requests, + ) + } + } + length := len(requests) + if uint(length) == 0 { + // The batch command channel is closed. + return + } else if uint(length) < bestBatchWaitSize && bestBatchWaitSize > 1 { + // Waits too long to collect requests, reduce the target batch size. + bestBatchWaitSize -= 1 + } else if uint(length) > bestBatchWaitSize+4 && bestBatchWaitSize < cfg.MaxBatchSize { + bestBatchWaitSize += 1 + } + + entries, requests = removeCanceledRequests(entries, requests) + if len(entries) == 0 { + continue // All requests are canceled. + } + + a.getClientAndSend(entries, requests, requestIDs) + } +} + +func (a *batchConn) getClientAndSend(entries []*batchCommandsEntry, requests []*tikvpb.BatchCommandsRequest_Request, requestIDs []uint64) { + // Choose a connection by round-robbin. + var cli *batchCommandsClient = nil + var target string = "" + for i := 0; i < len(a.batchCommandsClients); i++ { + a.index = (a.index + 1) % uint32(len(a.batchCommandsClients)) + target = a.batchCommandsClients[a.index].target + // The lock protects the batchCommandsClient from been closed while it's inuse. + if a.batchCommandsClients[a.index].tryLockForSend() { + cli = a.batchCommandsClients[a.index] + break + } + } + if cli == nil { + logutil.Logger(context.Background()).Warn("no available connections", zap.String("target", target)) + for _, entry := range entries { + // Please ensure the error is handled in region cache correctly. + entry.err = errors.New("no available connections") + close(entry.res) + } + return + } + defer cli.unlockForSend() + + maxBatchID := atomic.AddUint64(&cli.idAlloc, uint64(len(requests))) + for i := 0; i < len(requests); i++ { + requestID := uint64(i) + maxBatchID - uint64(len(requests)) + requestIDs = append(requestIDs, requestID) + } + req := &tikvpb.BatchCommandsRequest{ + Requests: requests, + RequestIds: requestIDs, + } + + cli.send(req, entries) + return +} + +func (c *batchCommandsClient) initBatchClient() error { + if c.client != nil { + return nil + } + + if err := c.waitConnReady(); err != nil { + return err + } + + // Initialize batch streaming clients. + tikvClient := tikvpb.NewTikvClient(c.conn) + streamClient, err := tikvClient.BatchCommands(context.TODO()) + if err != nil { + return errors.Trace(err) + } + c.client = streamClient + go c.batchRecvLoop(c.tikvClientCfg, c.tikvLoad) + return nil +} + +func (a *batchConn) Close() { + // Close all batchRecvLoop. + for _, c := range a.batchCommandsClients { + // After connections are closed, `batchRecvLoop`s will check the flag. + atomic.StoreInt32(&c.closed, 1) + } + // Don't close(batchCommandsCh) because when Close() is called, someone maybe + // calling SendRequest and writing batchCommandsCh, if we close it here the + // writing goroutine will panic. + close(a.closed) +} + +// removeCanceledRequests removes canceled requests before sending. +func removeCanceledRequests(entries []*batchCommandsEntry, + requests []*tikvpb.BatchCommandsRequest_Request) ([]*batchCommandsEntry, []*tikvpb.BatchCommandsRequest_Request) { + validEntries := entries[:0] + validRequets := requests[:0] + for _, e := range entries { + if !e.isCanceled() { + validEntries = append(validEntries, e) + validRequets = append(validRequets, e.req) + } + } + return validEntries, validRequets +} + +func sendBatchRequest( + ctx context.Context, + addr string, + batchConn *batchConn, + req *tikvpb.BatchCommandsRequest_Request, + timeout time.Duration, +) (*tikvrpc.Response, error) { + entry := &batchCommandsEntry{ + req: req, + res: make(chan *tikvpb.BatchCommandsResponse_Response, 1), + canceled: 0, + err: nil, + } + ctx1, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + select { + case batchConn.batchCommandsCh <- entry: + case <-ctx1.Done(): + logutil.Logger(context.Background()).Warn("send request is cancelled", + zap.String("to", addr), zap.String("cause", ctx1.Err().Error())) + return nil, errors.Trace(ctx1.Err()) + } + + select { + case res, ok := <-entry.res: + if !ok { + return nil, errors.Trace(entry.err) + } + return tikvrpc.FromBatchCommandsResponse(res), nil + case <-ctx1.Done(): + atomic.StoreInt32(&entry.canceled, 1) + logutil.Logger(context.Background()).Warn("wait response is cancelled", + zap.String("to", addr), zap.String("cause", ctx1.Err().Error())) + return nil, errors.Trace(ctx1.Err()) + } +} + +func (c *rpcClient) recycleIdleConnArray() { + var addrs []string + c.RLock() + for _, conn := range c.conns { + if conn.isIdle() { + addrs = append(addrs, conn.target) + } + } + c.RUnlock() + + for _, addr := range addrs { + c.Lock() + conn, ok := c.conns[addr] + if ok { + delete(c.conns, addr) + logutil.Logger(context.Background()).Info("recycle idle connection", + zap.String("target", addr)) + } + c.Unlock() + if conn != nil { + conn.Close() + } + } +} diff --git a/store/tikv/client_fail_test.go b/store/tikv/client_fail_test.go new file mode 100644 index 0000000000000..ad49b5040da1b --- /dev/null +++ b/store/tikv/client_fail_test.go @@ -0,0 +1,74 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package tikv + +import ( + "context" + "fmt" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/tikvpb" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/store/tikv/tikvrpc" +) + +type testClientFailSuite struct { + OneByOneSuite +} + +func (s *testClientFailSuite) SetUpSuite(c *C) { + // This lock make testClientFailSuite runs exclusively. + withTiKVGlobalLock.Lock() +} + +func (s testClientFailSuite) TearDownSuite(c *C) { + withTiKVGlobalLock.Unlock() +} + +func setGrpcConnectionCount(count uint) { + config.GetGlobalConfig().TiKVClient.GrpcConnectionCount = count +} + +func (s *testClientFailSuite) TestPanicInRecvLoop(c *C) { + c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/panicInFailPendingRequests", `panic`), IsNil) + c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/gotErrorInRecvLoop", `return("0")`), IsNil) + + server, port := startMockTikvService() + c.Assert(port > 0, IsTrue) + + grpcConnectionCount := config.GetGlobalConfig().TiKVClient.GrpcConnectionCount + setGrpcConnectionCount(1) + addr := fmt.Sprintf("%s:%d", "127.0.0.1", port) + rpcClient := newRPCClient(config.Security{}) + + // Start batchRecvLoop, and it should panic in `failPendingRequests`. + _, err := rpcClient.getConnArray(addr) + c.Assert(err, IsNil) + + time.Sleep(time.Second) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/gotErrorInRecvLoop"), IsNil) + c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/panicInFailPendingRequests"), IsNil) + time.Sleep(time.Second) + + req := &tikvrpc.Request{ + Type: tikvrpc.CmdEmpty, + Empty: &tikvpb.BatchCommandsEmptyRequest{}, + } + _, err = rpcClient.SendRequest(context.Background(), addr, req, time.Second) + c.Assert(err, IsNil) + server.Stop() + setGrpcConnectionCount(grpcConnectionCount) +} diff --git a/store/tikv/client_test.go b/store/tikv/client_test.go index 5063acb77eeee..5dd9567bcc721 100644 --- a/store/tikv/client_test.go +++ b/store/tikv/client_test.go @@ -14,10 +14,16 @@ package tikv import ( + "context" + "fmt" "testing" + "time" . "github.com/pingcap/check" + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/tikvpb" "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/store/tikv/tikvrpc" ) func TestT(t *testing.T) { @@ -30,10 +36,15 @@ type testClientSuite struct { } var _ = Suite(&testClientSuite{}) +var _ = Suite(&testClientFailSuite{}) + +func setMaxBatchSize(size uint) { + config.GetGlobalConfig().TiKVClient.MaxBatchSize = size +} func (s *testClientSuite) TestConn(c *C) { - globalConfig := config.GetGlobalConfig() - globalConfig.TiKVClient.MaxBatchSize = 0 // Disable batch. + maxBatchSize := config.GetGlobalConfig().TiKVClient.MaxBatchSize + setMaxBatchSize(0) client := newRPCClient(config.Security{}) @@ -49,4 +60,63 @@ func (s *testClientSuite) TestConn(c *C) { conn3, err := client.getConnArray(addr) c.Assert(err, NotNil) c.Assert(conn3, IsNil) + setMaxBatchSize(maxBatchSize) +} + +func (s *testClientSuite) TestRemoveCanceledRequests(c *C) { + req := new(tikvpb.BatchCommandsRequest_Request) + entries := []*batchCommandsEntry{ + {canceled: 1, req: req}, + {canceled: 0, req: req}, + {canceled: 1, req: req}, + {canceled: 1, req: req}, + {canceled: 0, req: req}, + } + entryPtr := &entries[0] + requests := make([]*tikvpb.BatchCommandsRequest_Request, len(entries)) + for i := range entries { + requests[i] = entries[i].req + } + entries, requests = removeCanceledRequests(entries, requests) + c.Assert(len(entries), Equals, 2) + for _, e := range entries { + c.Assert(e.isCanceled(), IsFalse) + } + c.Assert(len(requests), Equals, 2) + newEntryPtr := &entries[0] + c.Assert(entryPtr, Equals, newEntryPtr) +} + +func (s *testClientSuite) TestCancelTimeoutRetErr(c *C) { + req := new(tikvpb.BatchCommandsRequest_Request) + a := newBatchConn(1, 1, nil) + + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + _, err := sendBatchRequest(ctx, "", a, req, 2*time.Second) + c.Assert(errors.Cause(err), Equals, context.Canceled) + + _, err = sendBatchRequest(context.Background(), "", a, req, 0) + c.Assert(errors.Cause(err), Equals, context.DeadlineExceeded) +} + +func (s *testClientSuite) TestSendWhenReconnect(c *C) { + server, port := startMockTikvService() + c.Assert(port > 0, IsTrue) + + rpcClient := newRPCClient(config.Security{}) + addr := fmt.Sprintf("%s:%d", "127.0.0.1", port) + conn, err := rpcClient.getConnArray(addr) + c.Assert(err, IsNil) + + // Suppose all connections are re-establishing. + for _, client := range conn.batchConn.batchCommandsClients { + client.lockForRecreate() + } + + req := &tikvrpc.Request{Type: tikvrpc.CmdEmpty, Empty: &tikvpb.BatchCommandsEmptyRequest{}} + _, err = rpcClient.SendRequest(context.Background(), addr, req, 100*time.Second) + c.Assert(err.Error() == "no available connections", IsTrue) + conn.Close() + server.Stop() } diff --git a/store/tikv/coprocessor.go b/store/tikv/coprocessor.go index 6f3976582ceff..712de3f1b7b30 100644 --- a/store/tikv/coprocessor.go +++ b/store/tikv/coprocessor.go @@ -108,7 +108,9 @@ func (c *CopClient) Send(ctx context.Context, req *kv.Request, vars *kv.Variable // Make sure that there is at least one worker. it.concurrency = 1 } - if !it.req.KeepOrder { + if it.req.KeepOrder { + it.sendRate = newRateLimit(2 * it.concurrency) + } else { it.respChan = make(chan *copResponse, it.concurrency) } it.open(ctx) @@ -259,9 +261,11 @@ func buildCopTasks(bo *Backoffer, cache *RegionCache, ranges *copRanges, desc bo for i := 0; i < rLen; { nextI := mathutil.Min(i+rangesPerTask, rLen) tasks = append(tasks, &copTask{ - region: region, - ranges: ranges.slice(i, nextI), - respChan: make(chan *copResponse, 1), + region: region, + ranges: ranges.slice(i, nextI), + // Channel buffer is 2 for handling region split. + // In a common case, two region split tasks will not be blocked. + respChan: make(chan *copResponse, 2), cmdType: cmdType, }) i = nextI @@ -371,6 +375,9 @@ type copIterator struct { // If keepOrder, results are stored in copTask.respChan, read them out one by one. tasks []*copTask curr int + // sendRate controls the sending rate of copIteratorTaskSender, if keepOrder, + // to prevent all tasks being done (aka. all of the responses are buffered) + sendRate *rateLimit // Otherwise, results are stored in respChan. respChan chan *copResponse @@ -401,11 +408,12 @@ type copIteratorTaskSender struct { tasks []*copTask finishCh <-chan struct{} respChan chan<- *copResponse + sendRate *rateLimit } type copResponse struct { - pbResp *coprocessor.Response - execdetails.ExecDetails + pbResp *coprocessor.Response + detail *execdetails.ExecDetails startKey kv.Key err error respSize int64 @@ -427,7 +435,7 @@ func (rs *copResponse) GetStartKey() kv.Key { } func (rs *copResponse) GetExecDetails() *execdetails.ExecDetails { - return &rs.ExecDetails + return rs.detail } // MemSize returns how many bytes of memory this response use @@ -438,9 +446,11 @@ func (rs *copResponse) MemSize() int64 { // ignore rs.err rs.respSize += int64(cap(rs.startKey)) - rs.respSize += int64(sizeofExecDetails) - if rs.CommitDetail != nil { - rs.respSize += int64(sizeofCommitDetails) + if rs.detail != nil { + rs.respSize += int64(sizeofExecDetails) + if rs.detail.CommitDetail != nil { + rs.respSize += int64(sizeofCommitDetails) + } } if rs.pbResp != nil { // Using a approximate size since it's hard to get a accurate value. @@ -463,9 +473,6 @@ func (worker *copIteratorWorker) run(ctx context.Context) { bo := NewBackoffer(ctx, copNextMaxBackoff).WithVars(worker.vars) worker.handleTask(bo, task, respCh) - if bo.totalSleep > 0 { - metrics.TiKVBackoffHistogram.Observe(float64(bo.totalSleep) / 1000) - } close(task.respChan) select { case <-worker.finishCh: @@ -499,6 +506,7 @@ func (it *copIterator) open(ctx context.Context) { wg: &it.wg, tasks: it.tasks, finishCh: it.finishCh, + sendRate: it.sendRate, } taskSender.respChan = it.respChan go taskSender.run() @@ -507,6 +515,16 @@ func (it *copIterator) open(ctx context.Context) { func (sender *copIteratorTaskSender) run() { // Send tasks to feed the worker goroutines. for _, t := range sender.tasks { + // If keepOrder, we must control the sending rate to prevent all tasks + // being done (aka. all of the responses are buffered) by copIteratorWorker. + // We keep the number of inflight tasks within the number of concurrency * 2. + // It sends one more task if a task has been finished in copIterator.Next. + if sender.sendRate != nil { + exit := sender.sendRate.getToken(sender.finishCh) + if exit { + break + } + } exit := sender.sendToTaskCh(t) if exit { break @@ -594,6 +612,7 @@ func (it *copIterator) Next(ctx context.Context) (kv.ResultSubset, error) { // Switch to next task. it.tasks[it.curr] = nil it.curr++ + it.sendRate.putToken() } } @@ -618,7 +637,7 @@ func (worker *copIteratorWorker) handleTask(bo *Backoffer, task *copTask, respCh zap.Stack("stack trace")) resp := &copResponse{err: errors.Errorf("%v", r)} // if panic has happened, set checkOOM to false to avoid another panic. - worker.sendToRespCh(resp, task.respChan, false) + worker.sendToRespCh(resp, respCh, false) } }() remainTasks := []*copTask{task} @@ -805,19 +824,22 @@ func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *RPCCon } else { resp.startKey = task.ranges.at(0).StartKey } - resp.BackoffTime = time.Duration(bo.totalSleep) * time.Millisecond + if resp.detail == nil { + resp.detail = new(execdetails.ExecDetails) + } + resp.detail.BackoffTime = time.Duration(bo.totalSleep) * time.Millisecond if rpcCtx != nil { - resp.CalleeAddress = rpcCtx.Addr + resp.detail.CalleeAddress = rpcCtx.Addr } if pbDetails := resp.pbResp.ExecDetails; pbDetails != nil { if handleTime := pbDetails.HandleTime; handleTime != nil { - resp.WaitTime = time.Duration(handleTime.WaitMs) * time.Millisecond - resp.ProcessTime = time.Duration(handleTime.ProcessMs) * time.Millisecond + resp.detail.WaitTime = time.Duration(handleTime.WaitMs) * time.Millisecond + resp.detail.ProcessTime = time.Duration(handleTime.ProcessMs) * time.Millisecond } if scanDetail := pbDetails.ScanDetail; scanDetail != nil { if scanDetail.Write != nil { - resp.TotalKeys += scanDetail.Write.Total - resp.ProcessedKeys += scanDetail.Write.Processed + resp.detail.TotalKeys += scanDetail.Write.Total + resp.detail.ProcessedKeys += scanDetail.Write.Processed } } } @@ -857,6 +879,33 @@ func (it *copIterator) Close() error { return nil } +type rateLimit struct { + token chan struct{} +} + +func newRateLimit(n int) *rateLimit { + return &rateLimit{ + token: make(chan struct{}, n), + } +} + +func (r *rateLimit) getToken(done <-chan struct{}) (exit bool) { + select { + case <-done: + return true + case r.token <- struct{}{}: + return false + } +} + +func (r *rateLimit) putToken() { + select { + case <-r.token: + default: + panic("put a redundant token") + } +} + // copErrorResponse returns error when calling Next() type copErrorResponse struct{ error } diff --git a/store/tikv/coprocessor_test.go b/store/tikv/coprocessor_test.go index 6daec69cf5c0c..404434894036f 100644 --- a/store/tikv/coprocessor_test.go +++ b/store/tikv/coprocessor_test.go @@ -15,6 +15,7 @@ package tikv import ( "context" + "time" . "github.com/pingcap/check" "github.com/pingcap/tidb/kv" @@ -294,6 +295,32 @@ func (s *testCoprocessorSuite) TestCopRangeSplit(c *C) { ) } +func (s *testCoprocessorSuite) TestRateLimit(c *C) { + done := make(chan struct{}, 1) + rl := newRateLimit(1) + c.Assert(rl.putToken, PanicMatches, "put a redundant token") + exit := rl.getToken(done) + c.Assert(exit, Equals, false) + rl.putToken() + c.Assert(rl.putToken, PanicMatches, "put a redundant token") + + exit = rl.getToken(done) + c.Assert(exit, Equals, false) + done <- struct{}{} + exit = rl.getToken(done) // blocked but exit + c.Assert(exit, Equals, true) + + sig := make(chan int, 1) + go func() { + exit = rl.getToken(done) // blocked + c.Assert(exit, Equals, false) + close(sig) + }() + time.Sleep(200 * time.Millisecond) + rl.putToken() + <-sig +} + type splitCase struct { key string *copRanges diff --git a/store/tikv/delete_range.go b/store/tikv/delete_range.go index 5d75fc5230a09..0154035f70211 100644 --- a/store/tikv/delete_range.go +++ b/store/tikv/delete_range.go @@ -16,9 +16,9 @@ package tikv import ( "bytes" "context" - "github.com/pingcap/errors" "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv/tikvrpc" ) @@ -27,40 +27,74 @@ import ( // if the task was canceled or not. type DeleteRangeTask struct { completedRegions int - canceled bool store Storage - ctx context.Context startKey []byte endKey []byte + notifyOnly bool + concurrency int } -// NewDeleteRangeTask creates a DeleteRangeTask. Deleting will not be performed right away. -// WARNING: Currently, this API may leave some waste key-value pairs uncleaned in TiKV. Be careful while using it. -func NewDeleteRangeTask(ctx context.Context, store Storage, startKey []byte, endKey []byte) *DeleteRangeTask { +// NewDeleteRangeTask creates a DeleteRangeTask. Deleting will be performed when `Execute` method is invoked. +// Be careful while using this API. This API doesn't keep recent MVCC versions, but will delete all versions of all keys +// in the range immediately. Also notice that frequent invocation to this API may cause performance problems to TiKV. +func NewDeleteRangeTask(store Storage, startKey []byte, endKey []byte, concurrency int) *DeleteRangeTask { return &DeleteRangeTask{ completedRegions: 0, - canceled: false, store: store, - ctx: ctx, startKey: startKey, endKey: endKey, + notifyOnly: false, + concurrency: concurrency, + } +} + +// NewNotifyDeleteRangeTask creates a task that sends delete range requests to all regions in the range, but with the +// flag `notifyOnly` set. TiKV will not actually delete the range after receiving request, but it will be replicated via +// raft. This is used to notify the involved regions before sending UnsafeDestroyRange requests. +func NewNotifyDeleteRangeTask(store Storage, startKey []byte, endKey []byte, concurrency int) *DeleteRangeTask { + task := NewDeleteRangeTask(store, startKey, endKey, concurrency) + task.notifyOnly = true + return task +} + +// getRunnerName returns a name for RangeTaskRunner. +func (t *DeleteRangeTask) getRunnerName() string { + if t.notifyOnly { + return "delete-range-notify" } + return "delete-range" } // Execute performs the delete range operation. -func (t *DeleteRangeTask) Execute() error { - startKey, rangeEndKey := t.startKey, t.endKey +func (t *DeleteRangeTask) Execute(ctx context.Context) error { + runnerName := t.getRunnerName() + + runner := NewRangeTaskRunner(runnerName, t.store, t.concurrency, t.sendReqOnRange) + err := runner.RunOnRange(ctx, t.startKey, t.endKey) + t.completedRegions = int(runner.CompletedRegions()) + + return err +} + +// Execute performs the delete range operation. +func (t *DeleteRangeTask) sendReqOnRange(ctx context.Context, r kv.KeyRange) (int, error) { + startKey, rangeEndKey := r.StartKey, r.EndKey + completedRegions := 0 for { select { - case <-t.ctx.Done(): - t.canceled = true - return nil + case <-ctx.Done(): + return completedRegions, errors.Trace(ctx.Err()) default: } - bo := NewBackoffer(t.ctx, deleteRangeOneRegionMaxBackoff) + + if bytes.Compare(startKey, rangeEndKey) >= 0 { + break + } + + bo := NewBackoffer(ctx, deleteRangeOneRegionMaxBackoff) loc, err := t.store.GetRegionCache().LocateKey(bo, startKey) if err != nil { - return errors.Trace(err) + return completedRegions, errors.Trace(err) } // Delete to the end of the region, except if it's the last region overlapping the range @@ -73,49 +107,42 @@ func (t *DeleteRangeTask) Execute() error { req := &tikvrpc.Request{ Type: tikvrpc.CmdDeleteRange, DeleteRange: &kvrpcpb.DeleteRangeRequest{ - StartKey: startKey, - EndKey: endKey, + StartKey: startKey, + EndKey: endKey, + NotifyOnly: t.notifyOnly, }, } resp, err := t.store.SendReq(bo, req, loc.Region, ReadTimeoutMedium) if err != nil { - return errors.Trace(err) + return completedRegions, errors.Trace(err) } regionErr, err := resp.GetRegionError() if err != nil { - return errors.Trace(err) + return completedRegions, errors.Trace(err) } if regionErr != nil { err = bo.Backoff(BoRegionMiss, errors.New(regionErr.String())) if err != nil { - return errors.Trace(err) + return completedRegions, errors.Trace(err) } continue } deleteRangeResp := resp.DeleteRange if deleteRangeResp == nil { - return errors.Trace(ErrBodyMissing) + return completedRegions, errors.Trace(ErrBodyMissing) } if err := deleteRangeResp.GetError(); err != "" { - return errors.Errorf("unexpected delete range err: %v", err) - } - t.completedRegions++ - if bytes.Equal(endKey, rangeEndKey) { - break + return completedRegions, errors.Errorf("unexpected delete range err: %v", err) } + completedRegions++ startKey = endKey } - return nil + return completedRegions, nil } // CompletedRegions returns the number of regions that are affected by this delete range task func (t *DeleteRangeTask) CompletedRegions() int { return t.completedRegions } - -// IsCanceled returns true if the delete range operation was canceled on the half way -func (t *DeleteRangeTask) IsCanceled() bool { - return t.canceled -} diff --git a/store/tikv/delete_range_test.go b/store/tikv/delete_range_test.go index 1feb88c7eaefc..cbb9206917b9e 100644 --- a/store/tikv/delete_range_test.go +++ b/store/tikv/delete_range_test.go @@ -33,7 +33,7 @@ var _ = Suite(&testDeleteRangeSuite{}) func (s *testDeleteRangeSuite) SetUpTest(c *C) { s.cluster = mocktikv.NewCluster() - mocktikv.BootstrapWithMultiRegions(s.cluster, []byte("a"), []byte("b"), []byte("c")) + mocktikv.BootstrapWithMultiRegions(s.cluster, []byte("b"), []byte("c"), []byte("d")) client, pdClient, err := mocktikv.NewTiKVAndPDClient(s.cluster, nil, "") c.Assert(err, IsNil) @@ -81,12 +81,13 @@ func (s *testDeleteRangeSuite) checkData(c *C, expectedData map[string]string) { c.Assert(data, DeepEquals, expectedData) } -func (s *testDeleteRangeSuite) deleteRange(c *C, startKey []byte, endKey []byte) { - ctx := context.Background() - task := NewDeleteRangeTask(ctx, s.store, startKey, endKey) +func (s *testDeleteRangeSuite) deleteRange(c *C, startKey []byte, endKey []byte) int { + task := NewDeleteRangeTask(s.store, startKey, endKey, 1) - err := task.Execute() + err := task.Execute(context.Background()) c.Assert(err, IsNil) + + return task.CompletedRegions() } // deleteRangeFromMap deletes all keys in a given range from a map @@ -100,10 +101,11 @@ func deleteRangeFromMap(m map[string]string, startKey []byte, endKey []byte) { } // mustDeleteRange does delete range on both the map and the storage, and assert they are equal after deleting -func (s *testDeleteRangeSuite) mustDeleteRange(c *C, startKey []byte, endKey []byte, expected map[string]string) { - s.deleteRange(c, startKey, endKey) +func (s *testDeleteRangeSuite) mustDeleteRange(c *C, startKey []byte, endKey []byte, expected map[string]string, regions int) { + completedRegions := s.deleteRange(c, startKey, endKey) deleteRangeFromMap(expected, startKey, endKey) s.checkData(c, expected) + c.Assert(completedRegions, Equals, regions) } func (s *testDeleteRangeSuite) TestDeleteRange(c *C) { @@ -119,7 +121,8 @@ func (s *testDeleteRangeSuite) TestDeleteRange(c *C) { key := []byte{byte(i), byte(j)} value := []byte{byte(rand.Intn(256)), byte(rand.Intn(256))} testData[string(key)] = string(value) - txn.Set(key, value) + err := txn.Set(key, value) + c.Assert(err, IsNil) } } @@ -128,10 +131,10 @@ func (s *testDeleteRangeSuite) TestDeleteRange(c *C) { s.checkData(c, testData) - s.mustDeleteRange(c, []byte("b"), []byte("c0"), testData) - s.mustDeleteRange(c, []byte("c11"), []byte("c12"), testData) - s.mustDeleteRange(c, []byte("d0"), []byte("d0"), testData) - s.mustDeleteRange(c, []byte("d0\x00"), []byte("d1\x00"), testData) - s.mustDeleteRange(c, []byte("c5"), []byte("d5"), testData) - s.mustDeleteRange(c, []byte("a"), []byte("z"), testData) + s.mustDeleteRange(c, []byte("b"), []byte("c0"), testData, 2) + s.mustDeleteRange(c, []byte("c11"), []byte("c12"), testData, 1) + s.mustDeleteRange(c, []byte("d0"), []byte("d0"), testData, 0) + s.mustDeleteRange(c, []byte("d0\x00"), []byte("d1\x00"), testData, 1) + s.mustDeleteRange(c, []byte("c5"), []byte("d5"), testData, 2) + s.mustDeleteRange(c, []byte("a"), []byte("z"), testData, 4) } diff --git a/store/tikv/error.go b/store/tikv/error.go index 357aa0f76cb07..574e460454912 100644 --- a/store/tikv/error.go +++ b/store/tikv/error.go @@ -15,6 +15,7 @@ package tikv import ( "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" ) @@ -37,8 +38,20 @@ var ( ErrRegionUnavailable = terror.ClassTiKV.New(mysql.ErrRegionUnavailable, mysql.MySQLErrName[mysql.ErrRegionUnavailable]) ErrTiKVServerBusy = terror.ClassTiKV.New(mysql.ErrTiKVServerBusy, mysql.MySQLErrName[mysql.ErrTiKVServerBusy]) ErrGCTooEarly = terror.ClassTiKV.New(mysql.ErrGCTooEarly, mysql.MySQLErrName[mysql.ErrGCTooEarly]) + ErrQueryInterrupted = terror.ClassTiKV.New(mysql.ErrQueryInterrupted, mysql.MySQLErrName[mysql.ErrQueryInterrupted]) ) +// ErrDeadlock wraps *kvrpcpb.Deadlock to implement the error interface. +// It also marks if the deadlock is retryable. +type ErrDeadlock struct { + *kvrpcpb.Deadlock + IsRetryable bool +} + +func (d *ErrDeadlock) Error() string { + return d.Deadlock.String() +} + func init() { tikvMySQLErrCodes := map[terror.ErrCode]uint16{ mysql.ErrTiKVServerTimeout: mysql.ErrTiKVServerTimeout, @@ -48,6 +61,7 @@ func init() { mysql.ErrTiKVServerBusy: mysql.ErrTiKVServerBusy, mysql.ErrGCTooEarly: mysql.ErrGCTooEarly, mysql.ErrTruncatedWrongValue: mysql.ErrTruncatedWrongValue, + mysql.ErrQueryInterrupted: mysql.ErrQueryInterrupted, } terror.ErrClassToMySQLCodes[terror.ClassTiKV] = tikvMySQLErrCodes } diff --git a/store/tikv/gcworker/gc_worker.go b/store/tikv/gcworker/gc_worker.go index 5fa52281438be..24050a6aa6e29 100644 --- a/store/tikv/gcworker/gc_worker.go +++ b/store/tikv/gcworker/gc_worker.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/parser/terror" "github.com/pingcap/pd/client" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl/util" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" @@ -96,6 +97,9 @@ func (w *GCWorker) Close() { } const ( + booleanTrue = "true" + booleanFalse = "false" + gcWorkerTickInterval = time.Minute gcJobLogTickInterval = time.Minute * 10 gcWorkerLease = time.Minute * 2 @@ -111,6 +115,7 @@ const ( gcLifeTimeKey = "tikv_gc_life_time" gcDefaultLifeTime = time.Minute * 10 + gcMinLifeTime = time.Minute * 10 gcSafePointKey = "tikv_gc_safe_point" gcConcurrencyKey = "tikv_gc_concurrency" gcDefaultConcurrency = 2 @@ -120,29 +125,31 @@ const ( gcScanLockLimit = tikv.ResolvedCacheSize / 2 gcEnableKey = "tikv_gc_enable" - gcEnableValue = "true" - gcDisableValue = "false" gcDefaultEnableValue = true gcModeKey = "tikv_gc_mode" gcModeCentral = "central" gcModeDistributed = "distributed" gcModeDefault = gcModeDistributed + + gcAutoConcurrencyKey = "tikv_gc_auto_concurrency" + gcDefaultAutoConcurrency = true ) var gcSafePointCacheInterval = tikv.GcSafePointCacheInterval var gcVariableComments = map[string]string{ - gcLeaderUUIDKey: "Current GC worker leader UUID. (DO NOT EDIT)", - gcLeaderDescKey: "Host name and pid of current GC leader. (DO NOT EDIT)", - gcLeaderLeaseKey: "Current GC worker leader lease. (DO NOT EDIT)", - gcLastRunTimeKey: "The time when last GC starts. (DO NOT EDIT)", - gcRunIntervalKey: "GC run interval, at least 10m, in Go format.", - gcLifeTimeKey: "All versions within life time will not be collected by GC, at least 10m, in Go format.", - gcSafePointKey: "All versions after safe point can be accessed. (DO NOT EDIT)", - gcConcurrencyKey: "[DEPRECATED] How many goroutines used to do GC parallel, [1, 128], default 2", - gcEnableKey: "Current GC enable status", - gcModeKey: "Mode of GC, \"central\" or \"distributed\"", + gcLeaderUUIDKey: "Current GC worker leader UUID. (DO NOT EDIT)", + gcLeaderDescKey: "Host name and pid of current GC leader. (DO NOT EDIT)", + gcLeaderLeaseKey: "Current GC worker leader lease. (DO NOT EDIT)", + gcLastRunTimeKey: "The time when last GC starts. (DO NOT EDIT)", + gcRunIntervalKey: "GC run interval, at least 10m, in Go format.", + gcLifeTimeKey: "All versions within life time will not be collected by GC, at least 10m, in Go format.", + gcSafePointKey: "All versions after safe point can be accessed. (DO NOT EDIT)", + gcConcurrencyKey: "How many goroutines used to do GC parallel, [1, 128], default 2", + gcEnableKey: "Current GC enable status", + gcModeKey: "Mode of GC, \"central\" or \"distributed\"", + gcAutoConcurrencyKey: "Let TiDB pick the concurrency automatically. If set false, tikv_gc_concurrency will be used", } func (w *GCWorker) start(ctx context.Context, wg *sync.WaitGroup) { @@ -258,26 +265,12 @@ func (w *GCWorker) leaderTick(ctx context.Context) error { return nil } - stores, err := w.getUpStores(ctx) - concurrency := len(stores) + concurrency, err := w.getGCConcurrency(ctx) if err != nil { - logutil.Logger(ctx).Error("[gc worker] failed to get up stores to calculate concurrency.", + logutil.Logger(ctx).Info("[gc worker] failed to get gc concurrency.", zap.String("uuid", w.uuid), zap.Error(err)) - - concurrency, err = w.loadGCConcurrencyWithDefault() - if err != nil { - logutil.Logger(ctx).Error("[gc worker] failed to load gc concurrency. use default value.", - zap.String("uuid", w.uuid), - zap.Error(err)) - concurrency = gcDefaultConcurrency - } - } - - if concurrency == 0 { - logutil.Logger(ctx).Error("[gc worker] no store is up", - zap.String("uuid", w.uuid)) - return errors.New("[gc worker] no store is up") + return errors.Trace(err) } w.gcIsRunning = true @@ -359,19 +352,68 @@ func (w *GCWorker) getOracleTime() (time.Time, error) { } func (w *GCWorker) checkGCEnable() (bool, error) { - str, err := w.loadValueFromSysTable(gcEnableKey) + return w.loadBooleanWithDefault(gcEnableKey, gcDefaultEnableValue) +} + +func (w *GCWorker) checkUseAutoConcurrency() (bool, error) { + return w.loadBooleanWithDefault(gcAutoConcurrencyKey, gcDefaultAutoConcurrency) +} + +func (w *GCWorker) loadBooleanWithDefault(key string, defaultValue bool) (bool, error) { + str, err := w.loadValueFromSysTable(key) if err != nil { return false, errors.Trace(err) } if str == "" { // Save default value for gc enable key. The default value is always true. - err = w.saveValueToSysTable(gcEnableKey, gcEnableValue) + defaultValueStr := booleanFalse + if defaultValue { + defaultValueStr = booleanTrue + } + err = w.saveValueToSysTable(key, defaultValueStr) if err != nil { - return gcDefaultEnableValue, errors.Trace(err) + return defaultValue, errors.Trace(err) } - return gcDefaultEnableValue, nil + return defaultValue, nil } - return strings.EqualFold(str, gcEnableValue), nil + return strings.EqualFold(str, booleanTrue), nil +} + +func (w *GCWorker) getGCConcurrency(ctx context.Context) (int, error) { + useAutoConcurrency, err := w.checkUseAutoConcurrency() + if err != nil { + logutil.Logger(ctx).Error("[gc worker] failed to load config gc_auto_concurrency. use default value.", + zap.String("uuid", w.uuid), + zap.Error(err)) + useAutoConcurrency = gcDefaultAutoConcurrency + } + if !useAutoConcurrency { + return w.loadGCConcurrencyWithDefault() + } + + stores, err := w.getUpStores(ctx) + concurrency := len(stores) + if err != nil { + logutil.Logger(ctx).Error("[gc worker] failed to get up stores to calculate concurrency. use config.", + zap.String("uuid", w.uuid), + zap.Error(err)) + + concurrency, err = w.loadGCConcurrencyWithDefault() + if err != nil { + logutil.Logger(ctx).Error("[gc worker] failed to load gc concurrency from config. use default value.", + zap.String("uuid", w.uuid), + zap.Error(err)) + concurrency = gcDefaultConcurrency + } + } + + if concurrency == 0 { + logutil.Logger(ctx).Error("[gc worker] no store is up", + zap.String("uuid", w.uuid)) + return 0, errors.New("[gc worker] no store is up") + } + + return concurrency, nil } func (w *GCWorker) checkGCInterval(now time.Time) (bool, error) { @@ -396,11 +438,36 @@ func (w *GCWorker) checkGCInterval(now time.Time) (bool, error) { return true, nil } +// validateGCLiftTime checks whether life time is small than min gc life time. +func (w *GCWorker) validateGCLiftTime(lifeTime time.Duration) (time.Duration, error) { + minLifeTime := gcMinLifeTime + // max-txn-time-use value is less than gc_life_time - 10s. + maxTxnTime := time.Duration(config.GetGlobalConfig().TiKVClient.MaxTxnTimeUse+10) * time.Second + if minLifeTime < maxTxnTime { + minLifeTime = maxTxnTime + } + + if lifeTime >= minLifeTime { + return lifeTime, nil + } + + logutil.Logger(context.Background()).Info("[gc worker] invalid gc life time", + zap.Duration("get gc life time", lifeTime), + zap.Duration("min gc life time", minLifeTime)) + + err := w.saveDuration(gcLifeTimeKey, minLifeTime) + return minLifeTime, err +} + func (w *GCWorker) calculateNewSafePoint(now time.Time) (*time.Time, error) { lifeTime, err := w.loadDurationWithDefault(gcLifeTimeKey, gcDefaultLifeTime) if err != nil { return nil, errors.Trace(err) } + *lifeTime, err = w.validateGCLiftTime(*lifeTime) + if err != nil { + return nil, err + } metrics.GCConfigGauge.WithLabelValues(gcLifeTimeKey).Set(lifeTime.Seconds()) lastSafePoint, err := w.loadTime(gcSafePointKey) if err != nil { @@ -431,7 +498,21 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency i w.done <- errors.Trace(err) return } - err = w.deleteRanges(ctx, safePoint) + // Save safe point to pd. + err = w.saveSafePoint(w.store.GetSafePointKV(), tikv.GcSavedSafePoint, safePoint) + if err != nil { + logutil.Logger(ctx).Error("[gc worker] failed to save safe point to PD", + zap.String("uuid", w.uuid), + zap.Error(err)) + w.gcIsRunning = false + metrics.GCJobFailureCounter.WithLabelValues("save_safe_point").Inc() + w.done <- errors.Trace(err) + return + } + // Sleep to wait for all other tidb instances update their safepoint cache. + time.Sleep(gcSafePointCacheInterval) + + err = w.deleteRanges(ctx, safePoint, concurrency) if err != nil { logutil.Logger(ctx).Error("[gc worker] delete range returns an error", zap.String("uuid", w.uuid), @@ -440,7 +521,7 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency i w.done <- errors.Trace(err) return } - err = w.redoDeleteRanges(ctx, safePoint) + err = w.redoDeleteRanges(ctx, safePoint, concurrency) if err != nil { logutil.Logger(ctx).Error("[gc worker] redo-delete range returns an error", zap.String("uuid", w.uuid), @@ -487,7 +568,8 @@ func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency i } // deleteRanges processes all delete range records whose ts < safePoint in table `gc_delete_range` -func (w *GCWorker) deleteRanges(ctx context.Context, safePoint uint64) error { +// `concurrency` specifies the concurrency to send NotifyDeleteRange. +func (w *GCWorker) deleteRanges(ctx context.Context, safePoint uint64, concurrency int) error { metrics.GCWorkerCounter.WithLabelValues("delete_range").Inc() se := createSession(w.store) @@ -504,7 +586,7 @@ func (w *GCWorker) deleteRanges(ctx context.Context, safePoint uint64) error { for _, r := range ranges { startKey, endKey := r.Range() - err = w.sendUnsafeDestroyRangeRequest(ctx, startKey, endKey) + err = w.doUnsafeDestroyRangeRequest(ctx, startKey, endKey, concurrency) if err != nil { return errors.Trace(err) } @@ -525,7 +607,8 @@ func (w *GCWorker) deleteRanges(ctx context.Context, safePoint uint64) error { } // redoDeleteRanges checks all deleted ranges whose ts is at least `lifetime + 24h` ago. See TiKV RFC #2. -func (w *GCWorker) redoDeleteRanges(ctx context.Context, safePoint uint64) error { +// `concurrency` specifies the concurrency to send NotifyDeleteRange. +func (w *GCWorker) redoDeleteRanges(ctx context.Context, safePoint uint64, concurrency int) error { metrics.GCWorkerCounter.WithLabelValues("redo_delete_range").Inc() // We check delete range records that are deleted about 24 hours ago. @@ -545,7 +628,7 @@ func (w *GCWorker) redoDeleteRanges(ctx context.Context, safePoint uint64) error for _, r := range ranges { startKey, endKey := r.Range() - err = w.sendUnsafeDestroyRangeRequest(ctx, startKey, endKey) + err = w.doUnsafeDestroyRangeRequest(ctx, startKey, endKey, concurrency) if err != nil { return errors.Trace(err) } @@ -565,7 +648,7 @@ func (w *GCWorker) redoDeleteRanges(ctx context.Context, safePoint uint64) error return nil } -func (w *GCWorker) sendUnsafeDestroyRangeRequest(ctx context.Context, startKey []byte, endKey []byte) error { +func (w *GCWorker) doUnsafeDestroyRangeRequest(ctx context.Context, startKey []byte, endKey []byte, concurrency int) error { // Get all stores every time deleting a region. So the store list is less probably to be stale. stores, err := w.getUpStores(ctx) if err != nil { @@ -604,6 +687,17 @@ func (w *GCWorker) sendUnsafeDestroyRangeRequest(ctx context.Context, startKey [ wg.Wait() + // Notify all affected regions in the range that UnsafeDestroyRange occurs. + notifyTask := tikv.NewNotifyDeleteRangeTask(w.store, startKey, endKey, concurrency) + err = notifyTask.Execute(ctx) + if err != nil { + logutil.Logger(ctx).Error("[gc worker] failed notifying regions affected by UnsafeDestroyRange", + zap.String("uuid", w.uuid), + zap.Binary("startKey", startKey), + zap.Binary("endKey", endKey), + zap.Error(err)) + } + return errors.Trace(err) } @@ -969,15 +1063,6 @@ func (w *GCWorker) genNextGCTask(bo *tikv.Backoffer, safePoint uint64, key kv.Ke func (w *GCWorker) doGC(ctx context.Context, safePoint uint64, concurrency int) error { metrics.GCWorkerCounter.WithLabelValues("do_gc").Inc() - - err := w.saveSafePoint(w.store.GetSafePointKV(), tikv.GcSavedSafePoint, safePoint) - if err != nil { - return errors.Trace(err) - } - - // Sleep to wait for all other tidb instances update their safepoint cache. - time.Sleep(gcSafePointCacheInterval) - logutil.Logger(ctx).Info("[gc worker]", zap.String("uuid", w.uuid), zap.Int("concurrency", concurrency), @@ -1193,7 +1278,7 @@ func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { if err != nil { return "", errors.Trace(err) } - req := rs[0].NewRecordBatch() + req := rs[0].NewChunk() err = rs[0].Next(ctx, req) if err != nil { return "", errors.Trace(err) @@ -1241,6 +1326,14 @@ func RunGCJob(ctx context.Context, s tikv.Storage, safePoint uint64, identifier if concurrency <= 0 { return errors.Errorf("[gc worker] gc concurrency should greater than 0, current concurrency: %v", concurrency) } + + err = gcWorker.saveSafePoint(gcWorker.store.GetSafePointKV(), tikv.GcSavedSafePoint, safePoint) + if err != nil { + return errors.Trace(err) + } + // Sleep to wait for all other tidb instances update their safepoint cache. + time.Sleep(gcSafePointCacheInterval) + err = gcWorker.doGC(ctx, safePoint, concurrency) if err != nil { return errors.Trace(err) @@ -1270,6 +1363,14 @@ func RunDistributedGCJob( return errors.Trace(err) } + // Save safe point to pd. + err = gcWorker.saveSafePoint(gcWorker.store.GetSafePointKV(), tikv.GcSavedSafePoint, safePoint) + if err != nil { + return errors.Trace(err) + } + // Sleep to wait for all other tidb instances update their safepoint cache. + time.Sleep(gcSafePointCacheInterval) + err = gcWorker.uploadSafePointToPD(ctx, safePoint) if err != nil { return errors.Trace(err) @@ -1313,5 +1414,5 @@ func NewMockGCWorker(store tikv.Storage) (*MockGCWorker, error) { // DeleteRanges calls deleteRanges internally, just for test. func (w *MockGCWorker) DeleteRanges(ctx context.Context, safePoint uint64) error { logutil.Logger(ctx).Error("deleteRanges is called") - return w.worker.deleteRanges(ctx, safePoint) + return w.worker.deleteRanges(ctx, safePoint, 1) } diff --git a/store/tikv/gcworker/gc_worker_test.go b/store/tikv/gcworker/gc_worker_test.go index 99e5de1d83d54..e0c3664bc45a3 100644 --- a/store/tikv/gcworker/gc_worker_test.go +++ b/store/tikv/gcworker/gc_worker_test.go @@ -23,6 +23,8 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/errorpb" + pd "github.com/pingcap/pd/client" + "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/store/mockoracle" @@ -37,9 +39,11 @@ func TestT(t *testing.T) { type testGCWorkerSuite struct { store tikv.Storage + cluster *mocktikv.Cluster oracle *mockoracle.MockOracle gcWorker *GCWorker dom *domain.Domain + pdClient pd.Client } var _ = Suite(&testGCWorkerSuite{}) @@ -47,9 +51,9 @@ var _ = Suite(&testGCWorkerSuite{}) func (s *testGCWorkerSuite) SetUpTest(c *C) { tikv.NewGCHandlerFunc = NewGCWorker - cluster := mocktikv.NewCluster() - mocktikv.BootstrapWithSingleStore(cluster) - store, err := mockstore.NewMockTikvStore(mockstore.WithCluster(cluster)) + s.cluster = mocktikv.NewCluster() + mocktikv.BootstrapWithSingleStore(s.cluster) + store, err := mockstore.NewMockTikvStore(mockstore.WithCluster(s.cluster)) s.store = store.(tikv.Storage) c.Assert(err, IsNil) @@ -58,7 +62,8 @@ func (s *testGCWorkerSuite) SetUpTest(c *C) { s.dom, err = session.BootstrapSession(s.store) c.Assert(err, IsNil) - gcWorker, err := NewGCWorker(s.store, mocktikv.NewPDClient(cluster)) + s.pdClient = mocktikv.NewPDClient(s.cluster) + gcWorker, err := NewGCWorker(s.store, s.pdClient) c.Assert(err, IsNil) gcWorker.Start() gcWorker.Close() @@ -154,16 +159,62 @@ func (s *testGCWorkerSuite) TestPrepareGC(c *C) { // Change GC enable status. s.oracle.AddOffset(time.Minute * 40) - err = s.gcWorker.saveValueToSysTable(gcEnableKey, gcDisableValue) + err = s.gcWorker.saveValueToSysTable(gcEnableKey, booleanFalse) c.Assert(err, IsNil) ok, _, err = s.gcWorker.prepare() c.Assert(err, IsNil) c.Assert(ok, IsFalse) - err = s.gcWorker.saveValueToSysTable(gcEnableKey, gcEnableValue) + err = s.gcWorker.saveValueToSysTable(gcEnableKey, booleanTrue) c.Assert(err, IsNil) ok, _, err = s.gcWorker.prepare() c.Assert(err, IsNil) c.Assert(ok, IsTrue) + + // Check gc life time small than min. + s.oracle.AddOffset(time.Minute * 40) + err = s.gcWorker.saveDuration(gcLifeTimeKey, time.Minute) + c.Assert(err, IsNil) + ok, _, err = s.gcWorker.prepare() + c.Assert(err, IsNil) + c.Assert(ok, IsTrue) + lifeTime, err := s.gcWorker.loadDuration(gcLifeTimeKey) + c.Assert(err, IsNil) + c.Assert(*lifeTime, Equals, gcMinLifeTime) + + // Check gc life time small than config.max-txn-use-time + s.oracle.AddOffset(time.Minute * 40) + config.GetGlobalConfig().TiKVClient.MaxTxnTimeUse = 20*60 - 10 // 20min - 10s + err = s.gcWorker.saveDuration(gcLifeTimeKey, time.Minute) + c.Assert(err, IsNil) + ok, _, err = s.gcWorker.prepare() + c.Assert(err, IsNil) + c.Assert(ok, IsTrue) + lifeTime, err = s.gcWorker.loadDuration(gcLifeTimeKey) + c.Assert(err, IsNil) + c.Assert(*lifeTime, Equals, 20*time.Minute) + + // check the tikv_gc_life_time more than config.max-txn-use-time situation. + s.oracle.AddOffset(time.Minute * 40) + err = s.gcWorker.saveDuration(gcLifeTimeKey, time.Minute*30) + c.Assert(err, IsNil) + ok, _, err = s.gcWorker.prepare() + c.Assert(err, IsNil) + c.Assert(ok, IsTrue) + lifeTime, err = s.gcWorker.loadDuration(gcLifeTimeKey) + c.Assert(err, IsNil) + c.Assert(*lifeTime, Equals, 30*time.Minute) + + // Change auto concurrency + err = s.gcWorker.saveValueToSysTable(gcAutoConcurrencyKey, booleanFalse) + c.Assert(err, IsNil) + useAutoConcurrency, err := s.gcWorker.checkUseAutoConcurrency() + c.Assert(err, IsNil) + c.Assert(useAutoConcurrency, IsFalse) + err = s.gcWorker.saveValueToSysTable(gcAutoConcurrencyKey, booleanTrue) + c.Assert(err, IsNil) + useAutoConcurrency, err = s.gcWorker.checkUseAutoConcurrency() + c.Assert(err, IsNil) + c.Assert(useAutoConcurrency, IsTrue) } func (s *testGCWorkerSuite) TestDoGCForOneRegion(c *C) { @@ -199,6 +250,28 @@ func (s *testGCWorkerSuite) TestDoGCForOneRegion(c *C) { c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/tikvStoreSendReqResult"), IsNil) } +func (s *testGCWorkerSuite) TestGetGCConcurrency(c *C) { + // Pick a concurrency that doesn't equal to the number of stores. + concurrencyConfig := 25 + c.Assert(concurrencyConfig, Not(Equals), len(s.cluster.GetAllStores())) + err := s.gcWorker.saveValueToSysTable(gcConcurrencyKey, strconv.Itoa(concurrencyConfig)) + c.Assert(err, IsNil) + + ctx := context.Background() + + err = s.gcWorker.saveValueToSysTable(gcAutoConcurrencyKey, booleanFalse) + c.Assert(err, IsNil) + concurrency, err := s.gcWorker.getGCConcurrency(ctx) + c.Assert(err, IsNil) + c.Assert(concurrency, Equals, concurrencyConfig) + + err = s.gcWorker.saveValueToSysTable(gcAutoConcurrencyKey, booleanTrue) + c.Assert(err, IsNil) + concurrency, err = s.gcWorker.getGCConcurrency(ctx) + c.Assert(err, IsNil) + c.Assert(concurrency, Equals, len(s.cluster.GetAllStores())) +} + func (s *testGCWorkerSuite) TestDoGC(c *C) { var err error ctx := context.Background() @@ -236,18 +309,47 @@ func (s *testGCWorkerSuite) TestCheckGCMode(c *C) { c.Assert(err, IsNil) c.Assert(str, Equals, gcModeDistributed) - s.gcWorker.saveValueToSysTable(gcModeKey, gcModeCentral) + err = s.gcWorker.saveValueToSysTable(gcModeKey, gcModeCentral) + c.Assert(err, IsNil) useDistributedGC, err = s.gcWorker.checkUseDistributedGC() c.Assert(err, IsNil) c.Assert(useDistributedGC, Equals, false) - s.gcWorker.saveValueToSysTable(gcModeKey, gcModeDistributed) + err = s.gcWorker.saveValueToSysTable(gcModeKey, gcModeDistributed) + c.Assert(err, IsNil) useDistributedGC, err = s.gcWorker.checkUseDistributedGC() c.Assert(err, IsNil) c.Assert(useDistributedGC, Equals, true) - s.gcWorker.saveValueToSysTable(gcModeKey, "invalid_mode") + err = s.gcWorker.saveValueToSysTable(gcModeKey, "invalid_mode") + c.Assert(err, IsNil) useDistributedGC, err = s.gcWorker.checkUseDistributedGC() c.Assert(err, IsNil) c.Assert(useDistributedGC, Equals, true) } + +func (s *testGCWorkerSuite) TestRunGCJob(c *C) { + gcSafePointCacheInterval = 0 + err := RunGCJob(context.Background(), s.store, 0, "mock", 1) + c.Assert(err, IsNil) + gcWorker, err := NewGCWorker(s.store, s.pdClient) + c.Assert(err, IsNil) + gcWorker.Start() + useDistributedGC, err := gcWorker.(*GCWorker).checkUseDistributedGC() + c.Assert(useDistributedGC, IsTrue) + c.Assert(err, IsNil) + safePoint := uint64(time.Now().Unix()) + gcWorker.(*GCWorker).runGCJob(context.Background(), safePoint, 1) + getSafePoint, err := loadSafePoint(gcWorker.(*GCWorker).store.GetSafePointKV()) + c.Assert(err, IsNil) + c.Assert(getSafePoint, Equals, safePoint) + gcWorker.Close() +} + +func loadSafePoint(kv tikv.SafePointKV) (uint64, error) { + val, err := kv.Get(tikv.GcSavedSafePoint) + if err != nil { + return 0, err + } + return strconv.ParseUint(val, 10, 64) +} diff --git a/store/tikv/lock_resolver.go b/store/tikv/lock_resolver.go index 59c4fc97bb23d..935060ff72d88 100644 --- a/store/tikv/lock_resolver.go +++ b/store/tikv/lock_resolver.go @@ -33,6 +33,9 @@ import ( // ResolvedCacheSize is max number of cached txn status. const ResolvedCacheSize = 2048 +// bigTxnThreshold : transaction involves keys exceed this threshold can be treated as `big transaction`. +const bigTxnThreshold = 16 + var ( tikvLockResolverCountWithBatchResolve = metrics.TiKVLockResolverCounter.WithLabelValues("batch_resolve") tikvLockResolverCountWithExpired = metrics.TiKVLockResolverCounter.WithLabelValues("expired") @@ -43,6 +46,7 @@ var ( tikvLockResolverCountWithQueryTxnStatusCommitted = metrics.TiKVLockResolverCounter.WithLabelValues("query_txn_status_committed") tikvLockResolverCountWithQueryTxnStatusRolledBack = metrics.TiKVLockResolverCounter.WithLabelValues("query_txn_status_rolled_back") tikvLockResolverCountWithResolveLocks = metrics.TiKVLockResolverCounter.WithLabelValues("query_resolve_locks") + tikvLockResolverCountWithResolveLockLite = metrics.TiKVLockResolverCounter.WithLabelValues("query_resolve_lock_lite") ) // LockResolver resolves locks and also caches resolved txn status. @@ -99,14 +103,17 @@ func NewLockResolver(etcdAddrs []string, security config.Security) (*LockResolve return s.lockResolver, nil } -// TxnStatus represents a txn's final status. It should be Commit or Rollback. -type TxnStatus uint64 +// TxnStatus represents a txn's final status. It should be Lock or Commit or Rollback. +type TxnStatus struct { + ttl uint64 + commitTS uint64 +} // IsCommitted returns true if the txn's final status is Commit. -func (s TxnStatus) IsCommitted() bool { return s > 0 } +func (s TxnStatus) IsCommitted() bool { return s.ttl == 0 && s.commitTS > 0 } // CommitTS returns the txn's commitTS. It is valid iff `IsCommitted` is true. -func (s TxnStatus) CommitTS() uint64 { return uint64(s) } +func (s TxnStatus) CommitTS() uint64 { return uint64(s.commitTS) } // By default, locks after 3000ms is considered unusual (the client created the // lock might be dead). Other client may cleanup this kind of lock. @@ -125,6 +132,7 @@ type Lock struct { Primary []byte TxnID uint64 TTL uint64 + TxnSize uint64 } func (l *Lock) String() string { @@ -133,15 +141,12 @@ func (l *Lock) String() string { // NewLock creates a new *Lock. func NewLock(l *kvrpcpb.LockInfo) *Lock { - ttl := l.GetLockTtl() - if ttl == 0 { - ttl = defaultLockTTL - } return &Lock{ Key: l.GetKey(), Primary: l.GetPrimaryLock(), TxnID: l.GetLockVersion(), - TTL: ttl, + TTL: l.GetLockTtl(), + TxnSize: l.GetTxnSize(), } } @@ -169,7 +174,8 @@ func (lr *LockResolver) getResolved(txnID uint64) (TxnStatus, bool) { return s, ok } -// BatchResolveLocks resolve locks in a batch +// BatchResolveLocks resolve locks in a batch. +// Used it in gcworker only! func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc RegionVerID) (bool, error) { if len(locks) == 0 { return true, nil @@ -177,7 +183,7 @@ func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc Regi tikvLockResolverCountWithBatchResolve.Inc() - var expiredLocks []*Lock + expiredLocks := make([]*Lock, 0, len(locks)) for _, l := range locks { if lr.store.GetOracle().IsExpired(l.TxnID, l.TTL) { tikvLockResolverCountWithExpired.Inc() @@ -200,11 +206,11 @@ func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc Regi continue } - status, err := lr.getTxnStatus(bo, l.TxnID, l.Primary) + status, err := lr.getTxnStatus(bo, l.TxnID, l.Primary, 0) if err != nil { return false, errors.Trace(err) } - txnInfos[l.TxnID] = uint64(status) + txnInfos[l.TxnID] = uint64(status.commitTS) } logutil.Logger(context.Background()).Info("BatchResolveLocks: lookup txn status", zap.Duration("cost time", time.Since(startTime)), @@ -266,9 +272,10 @@ func (lr *LockResolver) BatchResolveLocks(bo *Backoffer, locks []*Lock, loc Regi // commit status. // 3) Send `ResolveLock` cmd to the lock's region to resolve all locks belong to // the same transaction. -func (lr *LockResolver) ResolveLocks(bo *Backoffer, locks []*Lock) (msBeforeTxnExpired int64, err error) { +func (lr *LockResolver) ResolveLocks(bo *Backoffer, locks []*Lock) (int64, error) { + var msBeforeTxnExpired txnExpireTime if len(locks) == 0 { - return + return msBeforeTxnExpired.value(), nil } tikvLockResolverCountWithResolve.Inc() @@ -277,61 +284,111 @@ func (lr *LockResolver) ResolveLocks(bo *Backoffer, locks []*Lock) (msBeforeTxnE for _, l := range locks { msBeforeLockExpired := lr.store.GetOracle().UntilExpired(l.TxnID, l.TTL) if msBeforeLockExpired <= 0 { - tikvLockResolverCountWithExpired.Inc() expiredLocks = append(expiredLocks, l) } else { - if msBeforeTxnExpired == 0 || msBeforeLockExpired < msBeforeTxnExpired { - msBeforeTxnExpired = msBeforeLockExpired - } + msBeforeTxnExpired.update(int64(l.TTL)) tikvLockResolverCountWithNotExpired.Inc() } } - if len(expiredLocks) == 0 { - if msBeforeTxnExpired > 0 { - tikvLockResolverCountWithWaitExpired.Inc() - } - return - } - // TxnID -> []Region, record resolved Regions. // TODO: Maybe put it in LockResolver and share by all txns. cleanTxns := make(map[uint64]map[RegionVerID]struct{}) for _, l := range expiredLocks { - var status TxnStatus - status, err = lr.getTxnStatus(bo, l.TxnID, l.Primary) + status, err := lr.getTxnStatusFromLock(bo, l) if err != nil { - msBeforeTxnExpired = 0 + msBeforeTxnExpired.update(0) err = errors.Trace(err) - return + return msBeforeTxnExpired.value(), err } - cleanRegions, exists := cleanTxns[l.TxnID] - if !exists { - cleanRegions = make(map[RegionVerID]struct{}) - cleanTxns[l.TxnID] = cleanRegions - } + if status.ttl == 0 { + tikvLockResolverCountWithExpired.Inc() + // If the lock is committed or rollbacked, resolve lock. + cleanRegions, exists := cleanTxns[l.TxnID] + if !exists { + cleanRegions = make(map[RegionVerID]struct{}) + cleanTxns[l.TxnID] = cleanRegions + } - err = lr.resolveLock(bo, l, status, cleanRegions) - if err != nil { - msBeforeTxnExpired = 0 - err = errors.Trace(err) - return + err = lr.resolveLock(bo, l, status, cleanRegions) + if err != nil { + msBeforeTxnExpired.update(0) + err = errors.Trace(err) + return msBeforeTxnExpired.value(), err + } + } else { + tikvLockResolverCountWithNotExpired.Inc() + // If the lock is valid, the txn may be a pessimistic transaction. + // Update the txn expire time. + msBeforeLockExpired := lr.store.GetOracle().UntilExpired(l.TxnID, status.ttl) + msBeforeTxnExpired.update(msBeforeLockExpired) } } + + if msBeforeTxnExpired.value() > 0 { + tikvLockResolverCountWithWaitExpired.Inc() + } + return msBeforeTxnExpired.value(), nil +} + +type txnExpireTime struct { + initialized bool + txnExpire int64 +} + +func (t *txnExpireTime) update(lockExpire int64) { + if lockExpire <= 0 { + lockExpire = 0 + } + if !t.initialized { + t.txnExpire = lockExpire + t.initialized = true + return + } + if lockExpire < t.txnExpire { + t.txnExpire = lockExpire + } return } +func (t *txnExpireTime) value() int64 { + if !t.initialized { + return 0 + } + return t.txnExpire +} + // GetTxnStatus queries tikv-server for a txn's status (commit/rollback). // If the primary key is still locked, it will launch a Rollback to abort it. // To avoid unnecessarily aborting too many txns, it is wiser to wait a few // seconds before calling it after Prewrite. func (lr *LockResolver) GetTxnStatus(txnID uint64, primary []byte) (TxnStatus, error) { + var status TxnStatus bo := NewBackoffer(context.Background(), cleanupMaxBackoff) - status, err := lr.getTxnStatus(bo, txnID, primary) - return status, errors.Trace(err) + currentTS, err := lr.store.GetOracle().GetLowResolutionTimestamp(bo.ctx) + if err != nil { + return status, err + } + return lr.getTxnStatus(bo, txnID, primary, currentTS) } -func (lr *LockResolver) getTxnStatus(bo *Backoffer, txnID uint64, primary []byte) (TxnStatus, error) { +func (lr *LockResolver) getTxnStatusFromLock(bo *Backoffer, l *Lock) (TxnStatus, error) { + // NOTE: l.TTL = 0 is a special protocol!!! + // When the pessimistic txn prewrite meets locks of a txn, it should rollback that txn **unconditionally**. + // In this case, TiKV set the lock TTL = 0, and TiDB use currentTS = 0 to call + // getTxnStatus, and getTxnStatus with currentTS = 0 would rollback the transaction. + if l.TTL == 0 { + return lr.getTxnStatus(bo, l.TxnID, l.Primary, 0) + } + + currentTS, err := lr.store.GetOracle().GetLowResolutionTimestamp(bo.ctx) + if err != nil { + return TxnStatus{}, err + } + return lr.getTxnStatus(bo, l.TxnID, l.Primary, currentTS) +} + +func (lr *LockResolver) getTxnStatus(bo *Backoffer, txnID uint64, primary []byte, currentTS uint64) (TxnStatus, error) { if s, ok := lr.getResolved(txnID); ok { return s, nil } @@ -344,6 +401,7 @@ func (lr *LockResolver) getTxnStatus(bo *Backoffer, txnID uint64, primary []byte Cleanup: &kvrpcpb.CleanupRequest{ Key: primary, StartVersion: txnID, + CurrentTs: currentTS, }, } for { @@ -371,12 +429,18 @@ func (lr *LockResolver) getTxnStatus(bo *Backoffer, txnID uint64, primary []byte return status, errors.Trace(ErrBodyMissing) } if keyErr := cmdResp.GetError(); keyErr != nil { + // If the TTL of the primary lock is not outdated, the proto returns a ErrLocked contains the TTL. + if lockInfo := keyErr.GetLocked(); lockInfo != nil { + status.ttl = lockInfo.LockTtl + status.commitTS = 0 + return status, nil + } err = errors.Errorf("unexpected cleanup err: %s, tid: %v", keyErr, txnID) logutil.Logger(context.Background()).Error("getTxnStatus error", zap.Error(err)) return status, err } if cmdResp.CommitVersion != 0 { - status = TxnStatus(cmdResp.GetCommitVersion()) + status = TxnStatus{0, cmdResp.GetCommitVersion()} tikvLockResolverCountWithQueryTxnStatusCommitted.Inc() } else { tikvLockResolverCountWithQueryTxnStatusRolledBack.Inc() @@ -388,6 +452,7 @@ func (lr *LockResolver) getTxnStatus(bo *Backoffer, txnID uint64, primary []byte func (lr *LockResolver) resolveLock(bo *Backoffer, l *Lock, status TxnStatus, cleanRegions map[RegionVerID]struct{}) error { tikvLockResolverCountWithResolveLocks.Inc() + cleanWholeRegion := l.TxnSize >= bigTxnThreshold for { loc, err := lr.store.GetRegionCache().LocateKey(bo, l.Key) if err != nil { @@ -405,6 +470,12 @@ func (lr *LockResolver) resolveLock(bo *Backoffer, l *Lock, status TxnStatus, cl if status.IsCommitted() { req.ResolveLock.CommitVersion = status.CommitTS() } + if l.TxnSize < bigTxnThreshold { + // Only resolve specified keys when it is a small transaction, + // prevent from scanning the whole region in this case. + tikvLockResolverCountWithResolveLockLite.Inc() + req.ResolveLock.Keys = [][]byte{l.Key} + } resp, err := lr.store.SendReq(bo, req, loc.Region, readTimeoutShort) if err != nil { return errors.Trace(err) @@ -429,7 +500,9 @@ func (lr *LockResolver) resolveLock(bo *Backoffer, l *Lock, status TxnStatus, cl logutil.Logger(context.Background()).Error("resolveLock error", zap.Error(err)) return err } - cleanRegions[loc.Region] = struct{}{} + if cleanWholeRegion { + cleanRegions[loc.Region] = struct{}{} + } return nil } } diff --git a/store/tikv/lock_test.go b/store/tikv/lock_test.go index d507ecaa800da..fbab00265ff37 100644 --- a/store/tikv/lock_test.go +++ b/store/tikv/lock_test.go @@ -200,6 +200,69 @@ func (s *testLockSuite) TestGetTxnStatus(c *C) { status, err = s.store.lockResolver.GetTxnStatus(startTS, []byte("a")) c.Assert(err, IsNil) c.Assert(status.IsCommitted(), IsFalse) + c.Assert(status.ttl, Greater, uint64(0)) +} + +func (s *testLockSuite) TestCheckTxnStatusTTL(c *C) { + txn, err := s.store.Begin() + c.Assert(err, IsNil) + txn.Set(kv.Key("key"), []byte("value")) + s.prewriteTxn(c, txn.(*tikvTxn)) + + // Check the lock TTL of a transaction. + bo := NewBackoffer(context.Background(), prewriteMaxBackoff) + lr := newLockResolver(s.store) + status, err := lr.GetTxnStatus(txn.StartTS(), []byte("key")) + c.Assert(err, IsNil) + c.Assert(status.IsCommitted(), IsFalse) + c.Assert(status.ttl, Greater, uint64(0)) + c.Assert(status.CommitTS(), Equals, uint64(0)) + + // Rollback the txn. + lock := s.mustGetLock(c, []byte("key")) + status = TxnStatus{} + cleanRegions := make(map[RegionVerID]struct{}) + err = newLockResolver(s.store).resolveLock(bo, lock, status, cleanRegions) + c.Assert(err, IsNil) + + // Check its status is rollbacked. + status, err = lr.GetTxnStatus(txn.StartTS(), []byte("key")) + c.Assert(err, IsNil) + c.Assert(status.ttl, Equals, uint64(0)) + c.Assert(status.commitTS, Equals, uint64(0)) + + // Check a committed txn. + startTS, commitTS := s.putKV(c, []byte("a"), []byte("a")) + status, err = lr.GetTxnStatus(startTS, []byte("a")) + c.Assert(err, IsNil) + c.Assert(status.ttl, Equals, uint64(0)) + c.Assert(status.commitTS, Equals, commitTS) +} + +func (s *testLockSuite) TestTxnHeartBeat(c *C) { + txn, err := s.store.Begin() + c.Assert(err, IsNil) + txn.Set(kv.Key("key"), []byte("value")) + s.prewriteTxn(c, txn.(*tikvTxn)) + + bo := NewBackoffer(context.Background(), prewriteMaxBackoff) + newTTL, err := sendTxnHeartBeat(bo, s.store, []byte("key"), txn.StartTS(), 666) + c.Assert(err, IsNil) + c.Assert(newTTL, Equals, uint64(666)) + + newTTL, err = sendTxnHeartBeat(bo, s.store, []byte("key"), txn.StartTS(), 555) + c.Assert(err, IsNil) + c.Assert(newTTL, Equals, uint64(666)) + + // The getTxnStatus API is confusing, it really means rollback! + status, err := newLockResolver(s.store).getTxnStatus(bo, txn.StartTS(), []byte("key"), 0) + c.Assert(err, IsNil) + c.Assert(status.ttl, Equals, uint64(0)) + c.Assert(status.commitTS, Equals, uint64(0)) + + newTTL, err = sendTxnHeartBeat(bo, s.store, []byte("key"), txn.StartTS(), 666) + c.Assert(err, NotNil) + c.Assert(newTTL, Equals, uint64(0)) } func (s *testLockSuite) prewriteTxn(c *C, txn *tikvTxn) { @@ -277,6 +340,11 @@ func (s *testLockSuite) TestLockTTL(c *C) { s.ttlEquals(c, l.TTL, defaultLockTTL+uint64(time.Since(start)/time.Millisecond)) } +func (s *testLockSuite) TestNewLockZeroTTL(c *C) { + l := NewLock(&kvrpcpb.LockInfo{}) + c.Assert(l.TTL, Equals, uint64(0)) +} + func init() { // Speed up tests. defaultLockTTL = 3 diff --git a/store/tikv/mock_tikv_service.go b/store/tikv/mock_tikv_service.go new file mode 100644 index 0000000000000..7bebb40c51f0c --- /dev/null +++ b/store/tikv/mock_tikv_service.go @@ -0,0 +1,69 @@ +package tikv + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/pingcap/kvproto/pkg/tikvpb" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +type server struct { + tikvpb.TikvServer +} + +func (s *server) BatchCommands(ss tikvpb.Tikv_BatchCommandsServer) error { + for { + req, err := ss.Recv() + if err != nil { + logutil.Logger(context.Background()).Error("batch commands receive fail", zap.Error(err)) + return err + } + + responses := make([]*tikvpb.BatchCommandsResponse_Response, 0, len(req.GetRequestIds())) + for i := 0; i < len(req.GetRequestIds()); i++ { + responses = append(responses, &tikvpb.BatchCommandsResponse_Response{ + Cmd: &tikvpb.BatchCommandsResponse_Response_Empty{ + Empty: &tikvpb.BatchCommandsEmptyResponse{}, + }, + }) + } + + err = ss.Send(&tikvpb.BatchCommandsResponse{ + Responses: responses, + RequestIds: req.GetRequestIds(), + }) + if err != nil { + logutil.Logger(context.Background()).Error("batch commands send fail", zap.Error(err)) + return err + } + } +} + +// Try to start a gRPC server and retrun the server instance and binded port. +func startMockTikvService() (*grpc.Server, int) { + for port := 40000; port < 50000; port++ { + lis, err := net.Listen("tcp", fmt.Sprintf("%s:%d", "127.0.0.1", port)) + if err != nil { + logutil.Logger(context.Background()).Error("can't listen", zap.Error(err)) + continue + } + s := grpc.NewServer(grpc.ConnectionTimeout(time.Minute)) + tikvpb.RegisterTikvServer(s, &server{}) + go func() { + if err = s.Serve(lis); err != nil { + logutil.Logger(context.Background()).Error( + "can't serve gRPC requests", + zap.Error(err), + ) + } + }() + return s, port + } + logutil.Logger(context.Background()).Error("can't start mock tikv service because no available ports") + return nil, -1 +} diff --git a/store/tikv/range_task.go b/store/tikv/range_task.go index 1fac00a3a2588..dcf3d72069219 100644 --- a/store/tikv/range_task.go +++ b/store/tikv/range_task.go @@ -129,8 +129,6 @@ func (s *RangeTaskRunner) RunOnRange(ctx context.Context, startKey []byte, endKe key := startKey for { select { - case <-ctx.Done(): - return errors.Trace(ctx.Err()) case <-statLogTicker.C: logutil.Logger(ctx).Info("range task in progress", zap.String("name", s.name), @@ -168,7 +166,12 @@ func (s *RangeTaskRunner) RunOnRange(ctx context.Context, startKey []byte, endKe } pushTaskStartTime := time.Now() - taskCh <- task + + select { + case taskCh <- task: + case <-ctx.Done(): + break + } metrics.TiKVRangeTaskPushDuration.WithLabelValues(s.name).Observe(time.Since(pushTaskStartTime).Seconds()) if isLast { @@ -247,8 +250,6 @@ func (w *rangeTaskWorker) run(ctx context.Context, cancel context.CancelFunc) { } completedRegions, err := w.handler(ctx, *r) - atomic.AddInt32(w.completedRegions, int32(completedRegions)) - if err != nil { logutil.Logger(ctx).Info("canceling range task because of error", zap.String("name", w.name), @@ -259,5 +260,6 @@ func (w *rangeTaskWorker) run(ctx context.Context, cancel context.CancelFunc) { cancel() break } + atomic.AddInt32(w.completedRegions, int32(completedRegions)) } } diff --git a/store/tikv/rawkv.go b/store/tikv/rawkv.go index aea8c9abf52dd..b3fbfaf064026 100644 --- a/store/tikv/rawkv.go +++ b/store/tikv/rawkv.go @@ -412,7 +412,7 @@ func (c *RawKVClient) sendReq(key []byte, req *tikvrpc.Request, reverse bool) (* } func (c *RawKVClient) sendBatchReq(bo *Backoffer, keys [][]byte, cmdType tikvrpc.CmdType) (*tikvrpc.Response, error) { // split the keys - groups, _, err := c.regionCache.GroupKeysByRegion(bo, keys) + groups, _, err := c.regionCache.GroupKeysByRegion(bo, keys, nil) if err != nil { return nil, errors.Trace(err) } @@ -570,7 +570,7 @@ func (c *RawKVClient) sendBatchPut(bo *Backoffer, keys, values [][]byte) error { for i, key := range keys { keyToValue[string(key)] = values[i] } - groups, _, err := c.regionCache.GroupKeysByRegion(bo, keys) + groups, _, err := c.regionCache.GroupKeysByRegion(bo, keys, nil) if err != nil { return errors.Trace(err) } diff --git a/store/tikv/region_cache.go b/store/tikv/region_cache.go index 64005740fbc8a..ccca1ebc6ca31 100644 --- a/store/tikv/region_cache.go +++ b/store/tikv/region_cache.go @@ -22,8 +22,10 @@ import ( "time" "unsafe" + "github.com/gogo/protobuf/proto" "github.com/google/btree" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/pd/client" "github.com/pingcap/tidb/metrics" @@ -32,19 +34,23 @@ import ( ) const ( - btreeDegree = 32 - rcDefaultRegionCacheTTLSec = 600 - invalidatedLastAccessTime = -1 + btreeDegree = 32 + invalidatedLastAccessTime = -1 ) +// RegionCacheTTLSec is the max idle time for regions in the region cache. +var RegionCacheTTLSec int64 = 600 + var ( - tikvRegionCacheCounterWithDropRegionFromCacheOK = metrics.TiKVRegionCacheCounter.WithLabelValues("drop_region_from_cache", "ok") - tikvRegionCacheCounterWithGetRegionByIDOK = metrics.TiKVRegionCacheCounter.WithLabelValues("get_region_by_id", "ok") - tikvRegionCacheCounterWithGetRegionByIDError = metrics.TiKVRegionCacheCounter.WithLabelValues("get_region_by_id", "err") - tikvRegionCacheCounterWithGetRegionOK = metrics.TiKVRegionCacheCounter.WithLabelValues("get_region", "ok") - tikvRegionCacheCounterWithGetRegionError = metrics.TiKVRegionCacheCounter.WithLabelValues("get_region", "err") - tikvRegionCacheCounterWithGetStoreOK = metrics.TiKVRegionCacheCounter.WithLabelValues("get_store", "ok") - tikvRegionCacheCounterWithGetStoreError = metrics.TiKVRegionCacheCounter.WithLabelValues("get_store", "err") + tikvRegionCacheCounterWithInvalidateRegionFromCacheOK = metrics.TiKVRegionCacheCounter.WithLabelValues("invalidate_region_from_cache", "ok") + tikvRegionCacheCounterWithSendFail = metrics.TiKVRegionCacheCounter.WithLabelValues("send_fail", "ok") + tikvRegionCacheCounterWithGetRegionByIDOK = metrics.TiKVRegionCacheCounter.WithLabelValues("get_region_by_id", "ok") + tikvRegionCacheCounterWithGetRegionByIDError = metrics.TiKVRegionCacheCounter.WithLabelValues("get_region_by_id", "err") + tikvRegionCacheCounterWithGetRegionOK = metrics.TiKVRegionCacheCounter.WithLabelValues("get_region", "ok") + tikvRegionCacheCounterWithGetRegionError = metrics.TiKVRegionCacheCounter.WithLabelValues("get_region", "err") + tikvRegionCacheCounterWithGetStoreOK = metrics.TiKVRegionCacheCounter.WithLabelValues("get_store", "ok") + tikvRegionCacheCounterWithGetStoreError = metrics.TiKVRegionCacheCounter.WithLabelValues("get_store", "err") + tikvRegionCacheCounterWithInvalidateStoreRegionsOK = metrics.TiKVRegionCacheCounter.WithLabelValues("invalidate_store_regions", "ok") ) const ( @@ -65,13 +71,19 @@ type Region struct { type RegionStore struct { workStoreIdx int32 // point to current work peer in meta.Peers and work store in stores(same idx) stores []*Store // stores in this region + storeFails []uint32 // snapshots of store's fail, need reload when `storeFails[curr] != stores[cur].fail` } // clone clones region store struct. func (r *RegionStore) clone() *RegionStore { + storeFails := make([]uint32, len(r.stores)) + for i, e := range r.storeFails { + storeFails[i] = e + } return &RegionStore{ workStoreIdx: r.workStoreIdx, stores: r.stores, + storeFails: storeFails, } } @@ -82,6 +94,7 @@ func (r *Region) init(c *RegionCache) { rs := &RegionStore{ workStoreIdx: 0, stores: make([]*Store, 0, len(r.meta.Peers)), + storeFails: make([]uint32, 0, len(r.meta.Peers)), } for _, p := range r.meta.Peers { c.storeMu.RLock() @@ -91,6 +104,7 @@ func (r *Region) init(c *RegionCache) { store = c.getStoreByStoreID(p.StoreId) } rs.stores = append(rs.stores, store) + rs.storeFails = append(rs.storeFails, atomic.LoadUint32(&store.fail)) } atomic.StorePointer(&r.store, unsafe.Pointer(rs)) @@ -110,7 +124,7 @@ func (r *Region) compareAndSwapStore(oldStore, newStore *RegionStore) bool { func (r *Region) checkRegionCacheTTL(ts int64) bool { for { lastAccess := atomic.LoadInt64(&r.lastAccess) - if ts-lastAccess > rcDefaultRegionCacheTTLSec { + if ts-lastAccess > RegionCacheTTLSec { return false } if atomic.CompareAndSwapInt64(&r.lastAccess, lastAccess, ts) { @@ -121,6 +135,7 @@ func (r *Region) checkRegionCacheTTL(ts int64) bool { // invalidate invalidates a region, next time it will got null result. func (r *Region) invalidate() { + tikvRegionCacheCounterWithInvalidateRegionFromCacheOK.Inc() atomic.StoreInt64(&r.lastAccess, invalidatedLastAccessTime) } @@ -261,12 +276,27 @@ func (c *RegionCache) GetRPCContext(bo *Backoffer, id RegionVerID) (*RPCContext, if err != nil { return nil, err } + // enable by `curl -XPUT -d '1*return("[some-addr]")->return("")' http://host:port/github.com/pingcap/tidb/store/tikv/injectWrongStoreAddr` + failpoint.Inject("injectWrongStoreAddr", func(val failpoint.Value) { + if a, ok := val.(string); ok && len(a) > 0 { + addr = a + } + }) if store == nil || len(addr) == 0 { // Store not found, region must be out of date. cachedRegion.invalidate() return nil, nil } + storeFailEpoch := atomic.LoadUint32(&store.fail) + if storeFailEpoch != regionStore.storeFails[regionStore.workStoreIdx] { + cachedRegion.invalidate() + logutil.Logger(context.Background()).Info("invalidate current region, because others failed on same store", + zap.Uint64("region", id.GetID()), + zap.String("store", store.addr)) + return nil, nil + } + return &RPCContext{ Region: id, Meta: cachedRegion.meta, @@ -359,13 +389,18 @@ func (c *RegionCache) findRegionByKey(bo *Backoffer, key []byte, isEndKey bool) } // OnSendFail handles send request fail logic. -func (c *RegionCache) OnSendFail(bo *Backoffer, ctx *RPCContext, scheduleReload bool) { +func (c *RegionCache) OnSendFail(bo *Backoffer, ctx *RPCContext, scheduleReload bool, err error) { + tikvRegionCacheCounterWithSendFail.Inc() r := c.getCachedRegionWithRLock(ctx.Region) if r != nil { - c.switchNextPeer(r, ctx.PeerIdx) + c.switchNextPeer(r, ctx.PeerIdx, err) if scheduleReload { r.scheduleReload() } + logutil.Logger(bo.ctx).Info("switch region peer to next due to send request fail", + zap.Stringer("current", ctx), + zap.Bool("needReload", scheduleReload), + zap.Error(err)) } } @@ -414,7 +449,8 @@ func (c *RegionCache) LocateRegionByID(bo *Backoffer, regionID uint64) (*KeyLoca // GroupKeysByRegion separates keys into groups by their belonging Regions. // Specially it also returns the first key's region which may be used as the // 'PrimaryLockKey' and should be committed ahead of others. -func (c *RegionCache) GroupKeysByRegion(bo *Backoffer, keys [][]byte) (map[RegionVerID][][]byte, RegionVerID, error) { +// filter is used to filter some unwanted keys. +func (c *RegionCache) GroupKeysByRegion(bo *Backoffer, keys [][]byte, filter func(key, regionStartKey []byte) bool) (map[RegionVerID][][]byte, RegionVerID, error) { groups := make(map[RegionVerID][][]byte) var first RegionVerID var lastLoc *KeyLocation @@ -425,6 +461,9 @@ func (c *RegionCache) GroupKeysByRegion(bo *Backoffer, keys [][]byte) (map[Regio if err != nil { return nil, first, errors.Trace(err) } + if filter != nil && filter(k, lastLoc.StartKey) { + continue + } } id := lastLoc.Region if i == 0 { @@ -451,13 +490,32 @@ func (c *RegionCache) ListRegionIDsInKeyRange(bo *Backoffer, startKey, endKey [] return regionIDs, nil } +// LoadRegionsInKeyRange lists ids of regions in [start_key,end_key]. +func (c *RegionCache) LoadRegionsInKeyRange(bo *Backoffer, startKey, endKey []byte) (regions []*Region, err error) { + for { + curRegion, err := c.loadRegion(bo, startKey, false) + if err != nil { + return nil, errors.Trace(err) + } + c.mu.Lock() + c.insertRegionToCache(curRegion) + c.mu.Unlock() + + regions = append(regions, curRegion) + if curRegion.Contains(endKey) { + break + } + startKey = curRegion.EndKey() + } + return regions, nil +} + // InvalidateCachedRegion removes a cached Region. func (c *RegionCache) InvalidateCachedRegion(id RegionVerID) { cachedRegion := c.getCachedRegionWithRLock(id) if cachedRegion == nil { return } - tikvRegionCacheCounterWithDropRegionFromCacheOK.Inc() cachedRegion.invalidate() } @@ -472,15 +530,24 @@ func (c *RegionCache) UpdateLeader(regionID RegionVerID, leaderStoreID uint64, c } if leaderStoreID == 0 { - c.switchNextPeer(r, currentPeerIdx) + c.switchNextPeer(r, currentPeerIdx, nil) + logutil.Logger(context.Background()).Info("switch region peer to next due to NotLeader with NULL leader", + zap.Int("currIdx", currentPeerIdx), + zap.Uint64("regionID", regionID.GetID())) return } if !c.switchToPeer(r, leaderStoreID) { - logutil.Logger(context.Background()).Debug("regionCache: cannot find peer when updating leader", + logutil.Logger(context.Background()).Info("invalidate region cache due to cannot find peer when updating leader", zap.Uint64("regionID", regionID.GetID()), + zap.Int("currIdx", currentPeerIdx), zap.Uint64("leaderStoreID", leaderStoreID)) r.invalidate() + } else { + logutil.Logger(context.Background()).Info("switch region leader to specific leader due to kv return NotLeader", + zap.Uint64("regionID", regionID.GetID()), + zap.Int("currIdx", currentPeerIdx), + zap.Uint64("leaderStoreID", leaderStoreID)) } } @@ -712,7 +779,6 @@ func (c *RegionCache) OnRegionEpochNotMatch(bo *Backoffer, ctx *RPCContext, curr if needInvalidateOld { cachedRegion, ok := c.mu.regions[ctx.Region] if ok { - tikvRegionCacheCounterWithDropRegionFromCacheOK.Inc() cachedRegion.invalidate() } } @@ -752,6 +818,29 @@ func (r *Region) GetID() uint64 { return r.meta.GetId() } +// GetMeta returns region meta. +func (r *Region) GetMeta() *metapb.Region { + return proto.Clone(r.meta).(*metapb.Region) +} + +// GetLeaderID returns leader region ID. +func (r *Region) GetLeaderID() uint64 { + store := r.getStore() + if int(store.workStoreIdx) >= len(r.meta.Peers) { + return 0 + } + return r.meta.Peers[int(r.getStore().workStoreIdx)].Id +} + +// GetLeaderStoreID returns the store ID of the leader region. +func (r *Region) GetLeaderStoreID() uint64 { + store := r.getStore() + if int(store.workStoreIdx) >= len(r.meta.Peers) { + return 0 + } + return r.meta.Peers[int(r.getStore().workStoreIdx)].StoreId +} + // WorkStorePeer returns current work store with work peer. func (r *Region) WorkStorePeer(rs *RegionStore) (store *Store, peer *metapb.Peer, idx int) { idx = int(rs.workStoreIdx) @@ -799,15 +888,25 @@ func (c *RegionCache) switchToPeer(r *Region, targetStoreID uint64) (found bool) return } -func (c *RegionCache) switchNextPeer(r *Region, currentPeerIdx int) { - regionStore := r.getStore() - if int(regionStore.workStoreIdx) != currentPeerIdx { +func (c *RegionCache) switchNextPeer(r *Region, currentPeerIdx int, err error) { + rs := r.getStore() + if int(rs.workStoreIdx) != currentPeerIdx { return } - nextIdx := (currentPeerIdx + 1) % len(regionStore.stores) - newRegionStore := regionStore.clone() + + if err != nil { // TODO: refine err, only do this for some errors. + s := rs.stores[rs.workStoreIdx] + epoch := rs.storeFails[rs.workStoreIdx] + if atomic.CompareAndSwapUint32(&s.fail, epoch, epoch+1) { + logutil.Logger(context.Background()).Info("mark store's regions need be refill", zap.String("store", s.addr)) + tikvRegionCacheCounterWithInvalidateStoreRegionsOK.Inc() + } + } + + nextIdx := (currentPeerIdx + 1) % len(rs.stores) + newRegionStore := rs.clone() newRegionStore.workStoreIdx = int32(nextIdx) - r.compareAndSwapStore(regionStore, newRegionStore) + r.compareAndSwapStore(rs, newRegionStore) } func (c *RegionCache) getPeerStoreIndex(r *Region, id uint64) (idx int, found bool) { @@ -860,6 +959,7 @@ type Store struct { storeID uint64 // store's id state uint64 // unsafe store storeState resolveMutex sync.Mutex // protect pd from concurrent init requests + fail uint32 // store fail count, see RegionStore.storeFails } type resolveState uint64 @@ -932,6 +1032,11 @@ func (s *Store) reResolve(c *RegionCache) { return } if store == nil { + // store has be removed in PD, we should invalidate all regions using those store. + logutil.Logger(context.Background()).Info("invalidate regions in removed store", + zap.Uint64("store", s.storeID), zap.String("add", s.addr)) + atomic.AddUint32(&s.fail, 1) + tikvRegionCacheCounterWithInvalidateStoreRegionsOK.Inc() return } diff --git a/store/tikv/region_cache_test.go b/store/tikv/region_cache_test.go index dc78d01413eb9..c1261dadb3d17 100644 --- a/store/tikv/region_cache_test.go +++ b/store/tikv/region_cache_test.go @@ -15,6 +15,7 @@ package tikv import ( "context" + "errors" "fmt" "testing" "time" @@ -121,6 +122,8 @@ func (s *testRegionCacheSuite) TestSimple(c *C) { c.Assert(r.GetID(), Equals, s.region1) c.Assert(s.getAddr(c, []byte("a")), Equals, s.storeAddr(s.store1)) s.checkCache(c, 1) + c.Assert(r.GetMeta(), DeepEquals, r.meta) + c.Assert(r.GetLeaderID(), Equals, r.meta.Peers[r.getStore().workStoreIdx].Id) s.cache.mu.regions[r.VerID()].lastAccess = 0 r = s.cache.searchCachedRegion([]byte("a"), true) c.Assert(r, IsNil) @@ -239,7 +242,7 @@ func (s *testRegionCacheSuite) TestSendFailedButLeaderNotChange(c *C) { c.Assert(len(ctx.Meta.Peers), Equals, 3) // send fail leader switch to 2 - s.cache.OnSendFail(s.bo, ctx, false) + s.cache.OnSendFail(s.bo, ctx, false, nil) ctx, err = s.cache.GetRPCContext(s.bo, loc.Region) c.Assert(err, IsNil) c.Assert(ctx.Peer.Id, Equals, s.peer2) @@ -267,7 +270,7 @@ func (s *testRegionCacheSuite) TestSendFailedInHibernateRegion(c *C) { c.Assert(len(ctx.Meta.Peers), Equals, 3) // send fail leader switch to 2 - s.cache.OnSendFail(s.bo, ctx, false) + s.cache.OnSendFail(s.bo, ctx, false, nil) ctx, err = s.cache.GetRPCContext(s.bo, loc.Region) c.Assert(err, IsNil) c.Assert(ctx.Peer.Id, Equals, s.peer2) @@ -287,6 +290,31 @@ func (s *testRegionCacheSuite) TestSendFailedInHibernateRegion(c *C) { c.Assert(ctx.Peer.Id, Equals, s.peer1) } +func (s *testRegionCacheSuite) TestSendFailInvalidateRegionsInSameStore(c *C) { + // key range: ['' - 'm' - 'z'] + region2 := s.cluster.AllocID() + newPeers := s.cluster.AllocIDs(2) + s.cluster.Split(s.region1, region2, []byte("m"), newPeers, newPeers[0]) + + // Check the two regions. + loc1, err := s.cache.LocateKey(s.bo, []byte("a")) + c.Assert(err, IsNil) + c.Assert(loc1.Region.id, Equals, s.region1) + loc2, err := s.cache.LocateKey(s.bo, []byte("x")) + c.Assert(err, IsNil) + c.Assert(loc2.Region.id, Equals, region2) + + // Send fail on region1 + ctx, _ := s.cache.GetRPCContext(s.bo, loc1.Region) + s.checkCache(c, 2) + s.cache.OnSendFail(s.bo, ctx, false, errors.New("test error")) + + // Get region2 cache will get nil then reload. + ctx2, err := s.cache.GetRPCContext(s.bo, loc2.Region) + c.Assert(ctx2, IsNil) + c.Assert(err, IsNil) +} + func (s *testRegionCacheSuite) TestSendFailedInMultipleNode(c *C) { // 3 nodes and no.1 is leader. store3 := s.cluster.AllocID() @@ -303,13 +331,13 @@ func (s *testRegionCacheSuite) TestSendFailedInMultipleNode(c *C) { c.Assert(len(ctx.Meta.Peers), Equals, 3) // send fail leader switch to 2 - s.cache.OnSendFail(s.bo, ctx, false) + s.cache.OnSendFail(s.bo, ctx, false, nil) ctx, err = s.cache.GetRPCContext(s.bo, loc.Region) c.Assert(err, IsNil) c.Assert(ctx.Peer.Id, Equals, s.peer2) // send 2 fail leader switch to 3 - s.cache.OnSendFail(s.bo, ctx, false) + s.cache.OnSendFail(s.bo, ctx, false, nil) ctx, err = s.cache.GetRPCContext(s.bo, loc.Region) c.Assert(err, IsNil) c.Assert(ctx.Peer.Id, Equals, peer3) @@ -455,6 +483,35 @@ func (s *testRegionCacheSuite) TestUpdateStoreAddr(c *C) { c.Assert(getVal, BytesEquals, testValue) } +func (s *testRegionCacheSuite) TestReplaceAddrWithNewStore(c *C) { + mvccStore := mocktikv.MustNewMVCCStore() + defer mvccStore.Close() + + client := &RawKVClient{ + clusterID: 0, + regionCache: NewRegionCache(mocktikv.NewPDClient(s.cluster)), + rpcClient: mocktikv.NewRPCClient(s.cluster, mvccStore), + } + defer client.Close() + testKey := []byte("test_key") + testValue := []byte("test_value") + err := client.Put(testKey, testValue) + c.Assert(err, IsNil) + + // make store2 using store1's addr and store1 offline + store1Addr := s.storeAddr(s.store1) + s.cluster.UpdateStoreAddr(s.store1, s.storeAddr(s.store2)) + s.cluster.UpdateStoreAddr(s.store2, store1Addr) + s.cluster.RemoveStore(s.store1) + s.cluster.ChangeLeader(s.region1, s.peer2) + s.cluster.RemovePeer(s.region1, s.store1) + + getVal, err := client.Get(testKey) + + c.Assert(err, IsNil) + c.Assert(getVal, BytesEquals, testValue) +} + func (s *testRegionCacheSuite) TestListRegionIDsInCache(c *C) { // ['' - 'm' - 'z'] region2 := s.cluster.AllocID() @@ -541,7 +598,7 @@ func BenchmarkOnRequestFail(b *testing.B) { } r := cache.getCachedRegionWithRLock(rpcCtx.Region) if r == nil { - cache.switchNextPeer(r, rpcCtx.PeerIdx) + cache.switchNextPeer(r, rpcCtx.PeerIdx, nil) } } }) diff --git a/store/tikv/region_request.go b/store/tikv/region_request.go index 78fde3854f853..c1a82fb344a44 100644 --- a/store/tikv/region_request.go +++ b/store/tikv/region_request.go @@ -172,7 +172,7 @@ func (s *RegionRequestSender) onSendFail(bo *Backoffer, ctx *RPCContext, err err } } - s.regionCache.OnSendFail(bo, ctx, s.needReloadRegion(ctx)) + s.regionCache.OnSendFail(bo, ctx, s.needReloadRegion(ctx), err) // Retry on send request failure when it's not canceled. // When a store is not available, the leader of related region should be elected quickly. diff --git a/store/tikv/region_request_test.go b/store/tikv/region_request_test.go index 52cc1636bef7e..3533d59a71a7f 100644 --- a/store/tikv/region_request_test.go +++ b/store/tikv/region_request_test.go @@ -49,7 +49,7 @@ func (s *testRegionRequestSuite) SetUpTest(c *C) { s.store, s.peer, s.region = mocktikv.BootstrapWithSingleStore(s.cluster) pdCli := &codecPDClient{mocktikv.NewPDClient(s.cluster)} s.cache = NewRegionCache(pdCli) - s.bo = NewBackoffer(context.Background(), 1) + s.bo = NewNoopBackoff(context.Background()) s.mvccStore = mocktikv.MustNewMVCCStore() client := mocktikv.NewRPCClient(s.cluster, s.mvccStore) s.regionRequestSender = NewRegionRequestSender(s.cache, client) @@ -100,6 +100,14 @@ func (s *testRegionRequestSuite) TestOnSendFailedWithCloseKnownStoreThenUseNewOn Value: []byte("value"), }, } + + // add new store2 and make store2 as leader. + store2 := s.cluster.AllocID() + peer2 := s.cluster.AllocID() + s.cluster.AddStore(store2, fmt.Sprintf("store%d", store2)) + s.cluster.AddPeer(s.region, store2, peer2) + s.cluster.ChangeLeader(s.region, peer2) + region, err := s.cache.LocateRegionByID(s.bo, s.region) c.Assert(err, IsNil) c.Assert(region, NotNil) @@ -107,27 +115,18 @@ func (s *testRegionRequestSuite) TestOnSendFailedWithCloseKnownStoreThenUseNewOn c.Assert(err, IsNil) c.Assert(resp.RawPut, NotNil) - // add new unknown region - store2 := s.cluster.AllocID() - peer2 := s.cluster.AllocID() - s.cluster.AddStore(store2, fmt.Sprintf("store%d", store2)) - s.cluster.AddPeer(region.Region.id, store2, peer2) - - // stop known region - s.cluster.StopStore(s.store) - - // send to failed store - resp, err = s.regionRequestSender.SendReq(NewBackoffer(context.Background(), 100), req, region.Region, time.Second) - c.Assert(err, NotNil) + // stop store2 and make store1 as new leader. + s.cluster.StopStore(store2) + s.cluster.ChangeLeader(s.region, s.peer) - // retry to send store by old region info - region, err = s.cache.LocateRegionByID(s.bo, s.region) - c.Assert(region, NotNil) + // send to store2 fail and send to new leader store1. + bo2 := NewBackoffer(context.Background(), 100) + resp, err = s.regionRequestSender.SendReq(bo2, req, region.Region, time.Second) c.Assert(err, IsNil) - - // retry again, reload region info and send to new store. - resp, err = s.regionRequestSender.SendReq(NewBackoffer(context.Background(), 100), req, region.Region, time.Second) - c.Assert(err, NotNil) + regionErr, err := resp.GetRegionError() + c.Assert(err, IsNil) + c.Assert(regionErr, IsNil) + c.Assert(resp.RawPut, NotNil) } func (s *testRegionRequestSuite) TestSendReqCtx(c *C) { @@ -311,18 +310,18 @@ func (s *mockTikvGrpcServer) MvccGetByStartTs(context.Context, *kvrpcpb.MvccGetB func (s *mockTikvGrpcServer) SplitRegion(context.Context, *kvrpcpb.SplitRegionRequest) (*kvrpcpb.SplitRegionResponse, error) { return nil, errors.New("unreachable") } - func (s *mockTikvGrpcServer) CoprocessorStream(*coprocessor.Request, tikvpb.Tikv_CoprocessorStreamServer) error { return errors.New("unreachable") } - func (s *mockTikvGrpcServer) BatchCommands(tikvpb.Tikv_BatchCommandsServer) error { return errors.New("unreachable") } - func (s *mockTikvGrpcServer) ReadIndex(context.Context, *kvrpcpb.ReadIndexRequest) (*kvrpcpb.ReadIndexResponse, error) { return nil, errors.New("unreachable") } +func (s *mockTikvGrpcServer) KvTxnHeartBeat(ctx context.Context, in *kvrpcpb.TxnHeartBeatRequest) (*kvrpcpb.TxnHeartBeatResponse, error) { + return nil, errors.New("unreachable") +} func (s *testRegionRequestSuite) TestNoReloadRegionForGrpcWhenCtxCanceled(c *C) { // prepare a mock tikv grpc server diff --git a/store/tikv/scan.go b/store/tikv/scan.go index dd56a1f7ed5a4..4ce5979270c02 100644 --- a/store/tikv/scan.go +++ b/store/tikv/scan.go @@ -35,9 +35,13 @@ type Scanner struct { nextStartKey []byte endKey []byte eof bool + + // Use for reverse scan. + reverse bool + nextEndKey []byte } -func newScanner(snapshot *tikvSnapshot, startKey []byte, endKey []byte, batchSize int) (*Scanner, error) { +func newScanner(snapshot *tikvSnapshot, startKey []byte, endKey []byte, batchSize int, reverse bool) (*Scanner, error) { // It must be > 1. Otherwise scanner won't skipFirst. if batchSize <= 1 { batchSize = scanBatchSize @@ -48,6 +52,8 @@ func newScanner(snapshot *tikvSnapshot, startKey []byte, endKey []byte, batchSiz valid: true, nextStartKey: startKey, endKey: endKey, + reverse: reverse, + nextEndKey: endKey, } err := scanner.Next() if kv.IsErrNotFound(err) { @@ -83,6 +89,7 @@ func (s *Scanner) Next() error { if !s.valid { return errors.New("scanner iterator is invalid") } + var err error for { s.idx++ if s.idx >= len(s.cache) { @@ -90,7 +97,7 @@ func (s *Scanner) Next() error { s.Close() return nil } - err := s.getData(bo) + err = s.getData(bo) if err != nil { s.Close() return errors.Trace(err) @@ -101,7 +108,8 @@ func (s *Scanner) Next() error { } current := s.cache[s.idx] - if len(s.endKey) > 0 && kv.Key(current.Key).Cmp(kv.Key(s.endKey)) >= 0 { + if (!s.reverse && (len(s.endKey) > 0 && kv.Key(current.Key).Cmp(kv.Key(s.endKey)) >= 0)) || + (s.reverse && len(s.nextStartKey) > 0 && kv.Key(current.Key).Cmp(kv.Key(s.nextStartKey)) < 0) { s.eof = true s.Close() return nil @@ -147,18 +155,34 @@ func (s *Scanner) resolveCurrentLock(bo *Backoffer, current *pb.KvPair) error { func (s *Scanner) getData(bo *Backoffer) error { logutil.Logger(context.Background()).Debug("txn getData", zap.Binary("nextStartKey", s.nextStartKey), + zap.Binary("nextEndKey", s.nextEndKey), + zap.Bool("reverse", s.reverse), zap.Uint64("txnStartTS", s.startTS())) sender := NewRegionRequestSender(s.snapshot.store.regionCache, s.snapshot.store.client) - + var reqEndKey, reqStartKey []byte + var loc *KeyLocation + var err error for { - loc, err := s.snapshot.store.regionCache.LocateKey(bo, s.nextStartKey) + if !s.reverse { + loc, err = s.snapshot.store.regionCache.LocateKey(bo, s.nextStartKey) + } else { + loc, err = s.snapshot.store.regionCache.LocateEndKey(bo, s.nextEndKey) + } if err != nil { return errors.Trace(err) } - reqEndKey := s.endKey - if len(reqEndKey) > 0 && len(loc.EndKey) > 0 && bytes.Compare(loc.EndKey, reqEndKey) < 0 { - reqEndKey = loc.EndKey + if !s.reverse { + reqEndKey = s.endKey + if len(reqEndKey) > 0 && len(loc.EndKey) > 0 && bytes.Compare(loc.EndKey, reqEndKey) < 0 { + reqEndKey = loc.EndKey + } + } else { + reqStartKey = s.nextStartKey + if len(reqStartKey) == 0 || + (len(loc.StartKey) > 0 && bytes.Compare(loc.StartKey, reqStartKey) > 0) { + reqStartKey = loc.StartKey + } } req := &tikvrpc.Request{ @@ -175,6 +199,11 @@ func (s *Scanner) getData(bo *Backoffer) error { NotFillCache: s.snapshot.notFillCache, }, } + if s.reverse { + req.Scan.StartKey = s.nextEndKey + req.Scan.EndKey = reqStartKey + req.Scan.Reverse = true + } resp, err := sender.SendReq(bo, req, loc.Region, ReadTimeoutMedium) if err != nil { return errors.Trace(err) @@ -218,8 +247,13 @@ func (s *Scanner) getData(bo *Backoffer) error { if len(kvPairs) < s.batchSize { // No more data in current Region. Next getData() starts // from current Region's endKey. - s.nextStartKey = loc.EndKey - if len(loc.EndKey) == 0 || (len(s.endKey) > 0 && kv.Key(s.nextStartKey).Cmp(kv.Key(s.endKey)) >= 0) { + if !s.reverse { + s.nextStartKey = loc.EndKey + } else { + s.nextEndKey = reqStartKey + } + if (!s.reverse && (len(loc.EndKey) == 0 || (len(s.endKey) > 0 && kv.Key(s.nextStartKey).Cmp(kv.Key(s.endKey)) >= 0))) || + (s.reverse && (len(loc.StartKey) == 0 || (len(s.nextStartKey) > 0 && kv.Key(s.nextStartKey).Cmp(kv.Key(s.nextEndKey)) >= 0))) { // Current Region is the last one. s.eof = true } @@ -230,7 +264,11 @@ func (s *Scanner) getData(bo *Backoffer) error { // may get an empty response if the Region in fact does not have // more data. lastKey := kvPairs[len(kvPairs)-1].GetKey() - s.nextStartKey = kv.Key(lastKey).Next() + if !s.reverse { + s.nextStartKey = kv.Key(lastKey).Next() + } else { + s.nextEndKey = kv.Key(lastKey) + } return nil } } diff --git a/store/tikv/scan_mock_test.go b/store/tikv/scan_mock_test.go index 4cf09c50c6abb..204bcc95783d9 100644 --- a/store/tikv/scan_mock_test.go +++ b/store/tikv/scan_mock_test.go @@ -42,7 +42,7 @@ func (s *testScanMockSuite) TestScanMultipleRegions(c *C) { txn, err = store.Begin() c.Assert(err, IsNil) snapshot := newTiKVSnapshot(store, kv.Version{Ver: txn.StartTS()}) - scanner, err := newScanner(snapshot, []byte("a"), nil, 10) + scanner, err := newScanner(snapshot, []byte("a"), nil, 10, false) c.Assert(err, IsNil) for ch := byte('a'); ch <= byte('z'); ch++ { c.Assert([]byte{ch}, BytesEquals, []byte(scanner.Key())) @@ -50,7 +50,7 @@ func (s *testScanMockSuite) TestScanMultipleRegions(c *C) { } c.Assert(scanner.Valid(), IsFalse) - scanner, err = newScanner(snapshot, []byte("a"), []byte("i"), 10) + scanner, err = newScanner(snapshot, []byte("a"), []byte("i"), 10, false) c.Assert(err, IsNil) for ch := byte('a'); ch <= byte('h'); ch++ { c.Assert([]byte{ch}, BytesEquals, []byte(scanner.Key())) @@ -58,3 +58,36 @@ func (s *testScanMockSuite) TestScanMultipleRegions(c *C) { } c.Assert(scanner.Valid(), IsFalse) } + +func (s *testScanMockSuite) TestReverseScan(c *C) { + store := NewTestStore(c).(*tikvStore) + defer store.Close() + + txn, err := store.Begin() + c.Assert(err, IsNil) + for ch := byte('a'); ch <= byte('z'); ch++ { + err = txn.Set([]byte{ch}, []byte{ch}) + c.Assert(err, IsNil) + } + err = txn.Commit(context.Background()) + c.Assert(err, IsNil) + + txn, err = store.Begin() + c.Assert(err, IsNil) + snapshot := newTiKVSnapshot(store, kv.Version{Ver: txn.StartTS()}) + scanner, err := newScanner(snapshot, nil, []byte("z"), 10, true) + c.Assert(err, IsNil) + for ch := byte('y'); ch >= byte('a'); ch-- { + c.Assert(string([]byte{ch}), Equals, string([]byte(scanner.Key()))) + c.Assert(scanner.Next(), IsNil) + } + c.Assert(scanner.Valid(), IsFalse) + + scanner, err = newScanner(snapshot, []byte("a"), []byte("i"), 10, true) + c.Assert(err, IsNil) + for ch := byte('h'); ch >= byte('a'); ch-- { + c.Assert(string([]byte{ch}), Equals, string([]byte(scanner.Key()))) + c.Assert(scanner.Next(), IsNil) + } + c.Assert(scanner.Valid(), IsFalse) +} diff --git a/store/tikv/scan_test.go b/store/tikv/scan_test.go index 893836fa5448b..41d0280935a97 100644 --- a/store/tikv/scan_test.go +++ b/store/tikv/scan_test.go @@ -91,12 +91,12 @@ func (s *testScanSuite) TestScan(c *C) { c.Assert(err, IsNil) if rowNum > 123 { - err = s.store.SplitRegion(encodeKey(s.prefix, s08d("key", 123))) + _, err = s.store.SplitRegions(context.Background(), [][]byte{encodeKey(s.prefix, s08d("key", 123))}, false) c.Assert(err, IsNil) } if rowNum > 456 { - err = s.store.SplitRegion(encodeKey(s.prefix, s08d("key", 456))) + _, err = s.store.SplitRegions(context.Background(), [][]byte{encodeKey(s.prefix, s08d("key", 456))}, false) c.Assert(err, IsNil) } diff --git a/store/tikv/snapshot.go b/store/tikv/snapshot.go index a73b866255850..3f86eb3cf6000 100644 --- a/store/tikv/snapshot.go +++ b/store/tikv/snapshot.go @@ -56,6 +56,15 @@ type tikvSnapshot struct { syncLog bool keyOnly bool vars *kv.Variables + + // Cache the result of BatchGet. + // The invariance is that calling BatchGet multiple times using the same start ts, + // the result should not change. + // NOTE: This representation here is different from the BatchGet API. + // cached use len(value)=0 to represent a key-value entry doesn't exist (a reliable truth from TiKV). + // In the BatchGet API, it use no key-value entry to represent non-exist. + // It's OK as long as there are no zero-byte values in the protocol. + cached map[string][]byte } // newTiKVSnapshot creates a snapshot of an TiKV store. @@ -68,6 +77,12 @@ func newTiKVSnapshot(store *tikvStore, ver kv.Version) *tikvSnapshot { } } +func (s *tikvSnapshot) setSnapshotTS(ts uint64) { + // Invalidate cache if the snapshotTS change! + s.version.Ver = ts + s.cached = nil +} + func (s *tikvSnapshot) SetPriority(priority int) { s.priority = pb.CommandPri(priority) } @@ -75,7 +90,22 @@ func (s *tikvSnapshot) SetPriority(priority int) { // BatchGet gets all the keys' value from kv-server and returns a map contains key/value pairs. // The map will not contain nonexistent keys. func (s *tikvSnapshot) BatchGet(keys []kv.Key) (map[string][]byte, error) { + // Check the cached value first. m := make(map[string][]byte) + if s.cached != nil { + tmp := keys[:0] + for _, key := range keys { + if val, ok := s.cached[string(key)]; ok { + if len(val) > 0 { + m[string(key)] = val + } + } else { + tmp = append(tmp, key) + } + } + keys = tmp + } + if len(keys) == 0 { return m, nil } @@ -94,6 +124,7 @@ func (s *tikvSnapshot) BatchGet(keys []kv.Key) (map[string][]byte, error) { if len(v) == 0 { return } + mu.Lock() m[string(k)] = v mu.Unlock() @@ -107,11 +138,19 @@ func (s *tikvSnapshot) BatchGet(keys []kv.Key) (map[string][]byte, error) { return nil, errors.Trace(err) } + // Update the cache. + if s.cached == nil { + s.cached = make(map[string][]byte, len(m)) + } + for _, key := range keys { + s.cached[string(key)] = m[string(key)] + } + return m, nil } func (s *tikvSnapshot) batchGetKeysByRegions(bo *Backoffer, keys [][]byte, collectF func(k, v []byte)) error { - groups, _, err := s.store.regionCache.GroupKeysByRegion(bo, keys) + groups, _, err := s.store.regionCache.GroupKeysByRegion(bo, keys, nil) if err != nil { return errors.Trace(err) } @@ -234,6 +273,17 @@ func (s *tikvSnapshot) Get(k kv.Key) ([]byte, error) { } func (s *tikvSnapshot) get(bo *Backoffer, k kv.Key) ([]byte, error) { + // Check the cached values first. + if s.cached != nil { + if value, ok := s.cached[string(k)]; ok { + return value, nil + } + } + + failpoint.Inject("snapshot-get-cache-fail", func(_ failpoint.Value) { + panic("cache miss") + }) + sender := NewRegionRequestSender(s.store.regionCache, s.store.client) req := &tikvrpc.Request{ @@ -295,13 +345,14 @@ func (s *tikvSnapshot) get(bo *Backoffer, k kv.Key) ([]byte, error) { // Iter return a list of key-value pair after `k`. func (s *tikvSnapshot) Iter(k kv.Key, upperBound kv.Key) (kv.Iterator, error) { - scanner, err := newScanner(s, k, upperBound, scanBatchSize) + scanner, err := newScanner(s, k, upperBound, scanBatchSize, false) return scanner, errors.Trace(err) } // IterReverse creates a reversed Iterator positioned on the first entry which key is less than k. func (s *tikvSnapshot) IterReverse(k kv.Key) (kv.Iterator, error) { - return nil, kv.ErrNotImplemented + scanner, err := newScanner(s, nil, k, scanBatchSize, true) + return scanner, errors.Trace(err) } func extractLockFromKeyErr(keyErr *pb.KeyError) (*Lock, error) { diff --git a/store/tikv/snapshot_test.go b/store/tikv/snapshot_test.go index 4b6e1af441e9f..de412fec23aab 100644 --- a/store/tikv/snapshot_test.go +++ b/store/tikv/snapshot_test.go @@ -19,6 +19,7 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/failpoint" pb "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/util/logutil" @@ -117,6 +118,26 @@ func (s *testSnapshotSuite) TestBatchGet(c *C) { } } +func (s *testSnapshotSuite) TestSnapshotCache(c *C) { + txn := s.beginTxn(c) + c.Assert(txn.Set(kv.Key("x"), []byte("x")), IsNil) + c.Assert(txn.Commit(context.Background()), IsNil) + + txn = s.beginTxn(c) + snapshot := newTiKVSnapshot(s.store, kv.Version{Ver: txn.StartTS()}) + _, err := snapshot.BatchGet([]kv.Key{kv.Key("x"), kv.Key("y")}) + c.Assert(err, IsNil) + + c.Assert(failpoint.Enable("github.com/pingcap/tidb/store/tikv/snapshot-get-cache-fail", `return(true)`), IsNil) + _, err = snapshot.Get(kv.Key("x")) + c.Assert(err, IsNil) + + _, err = snapshot.Get(kv.Key("y")) + c.Assert(kv.IsErrNotFound(err), IsTrue) + + c.Assert(failpoint.Disable("github.com/pingcap/tidb/store/tikv/snapshot-get-cache-fail"), IsNil) +} + func (s *testSnapshotSuite) TestBatchGetNotExist(c *C) { for _, rowNum := range s.rowNums { logutil.Logger(context.Background()).Debug("test BatchGetNotExist", diff --git a/store/tikv/split_region.go b/store/tikv/split_region.go index b9c8a6648b0dc..51b5c4a533c49 100644 --- a/store/tikv/split_region.go +++ b/store/tikv/split_region.go @@ -16,93 +16,235 @@ package tikv import ( "bytes" "context" + "math" + "time" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/kvproto/pkg/kvrpcpb" "github.com/pingcap/kvproto/pkg/metapb" "github.com/pingcap/kvproto/pkg/pdpb" - "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/store/tikv/tikvrpc" + "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/stringutil" "go.uber.org/zap" ) -// SplitRegion splits the region contains splitKey into 2 regions: [start, -// splitKey) and [splitKey, end). -func (s *tikvStore) SplitRegion(splitKey kv.Key) error { - _, err := s.splitRegion(splitKey) - return err +func equalRegionStartKey(key, regionStartKey []byte) bool { + if bytes.Equal(key, regionStartKey) { + return true + } + return false } -func (s *tikvStore) splitRegion(splitKey kv.Key) (*metapb.Region, error) { - logutil.Logger(context.Background()).Info("start split region", - zap.Binary("at", splitKey)) - bo := NewBackoffer(context.Background(), splitRegionBackoff) - sender := NewRegionRequestSender(s.regionCache, s.client) - req := &tikvrpc.Request{ - Type: tikvrpc.CmdSplitRegion, - SplitRegion: &kvrpcpb.SplitRegionRequest{ - SplitKey: splitKey, - }, +func (s *tikvStore) splitBatchRegionsReq(bo *Backoffer, keys [][]byte, scatter bool) (*tikvrpc.Response, error) { + // equalRegionStartKey is used to filter split keys. + // If the split key is equal to the start key of the region, then the key has been split, we need to skip the split key. + groups, _, err := s.regionCache.GroupKeysByRegion(bo, keys, equalRegionStartKey) + if err != nil { + return nil, errors.Trace(err) } - req.Context.Priority = kvrpcpb.CommandPri_Normal - for { - loc, err := s.regionCache.LocateKey(bo, splitKey) - if err != nil { - return nil, errors.Trace(err) + + var batches []batch + for regionID, groupKeys := range groups { + batches = appendKeyBatches(batches, regionID, groupKeys, rawBatchPutSize) + } + + if len(batches) == 0 { + return nil, nil + } + // The first time it enters this function. + if bo.totalSleep == 0 { + logutil.Logger(context.Background()).Info("split batch regions request", + zap.Int("split key count", len(keys)), + zap.Int("batch count", len(batches)), + zap.Uint64("first batch, region ID", batches[0].regionID.id), + zap.Binary("first split key", batches[0].keys[0])) + } + if len(batches) == 1 { + resp := s.batchSendSingleRegion(bo, batches[0], scatter) + return resp.resp, errors.Trace(resp.err) + } + ch := make(chan singleBatchResp, len(batches)) + for _, batch1 := range batches { + go func(b batch) { + backoffer, cancel := bo.Fork() + defer cancel() + + util.WithRecovery(func() { + select { + case ch <- s.batchSendSingleRegion(backoffer, b, scatter): + case <-bo.ctx.Done(): + ch <- singleBatchResp{err: bo.ctx.Err()} + } + }, func(r interface{}) { + if r != nil { + ch <- singleBatchResp{err: errors.Errorf("%v", r)} + } + }) + }(batch1) + } + + srResp := &kvrpcpb.SplitRegionResponse{Regions: make([]*metapb.Region, 0, len(keys)*2)} + for i := 0; i < len(batches); i++ { + batchResp := <-ch + if batchResp.err != nil { + logutil.Logger(context.Background()).Debug("batch split regions failed", zap.Error(batchResp.err)) + if err == nil { + err = batchResp.err + } } - if bytes.Equal(splitKey, loc.StartKey) { - logutil.Logger(context.Background()).Info("skip split region", - zap.Binary("at", splitKey)) - return nil, nil + + // If the split succeeds and the scatter fails, we also need to add the region IDs. + if batchResp.resp != nil { + spResp := batchResp.resp.SplitRegion + regions := spResp.GetRegions() + srResp.Regions = append(srResp.Regions, regions...) } - res, err := sender.SendReq(bo, req, loc.Region, readTimeoutShort) - if err != nil { - return nil, errors.Trace(err) + } + return &tikvrpc.Response{SplitRegion: srResp}, errors.Trace(err) +} + +func (s *tikvStore) batchSendSingleRegion(bo *Backoffer, batch batch, scatter bool) singleBatchResp { + failpoint.Inject("MockSplitRegionTimeout", func(val failpoint.Value) { + if val.(bool) { + time.Sleep(time.Second*1 + time.Millisecond*10) } - regionErr, err := res.GetRegionError() + }) + + req := &tikvrpc.Request{ + Type: tikvrpc.CmdSplitRegion, + SplitRegion: &kvrpcpb.SplitRegionRequest{SplitKeys: batch.keys}, + Context: kvrpcpb.Context{Priority: kvrpcpb.CommandPri_Normal}, + } + + sender := NewRegionRequestSender(s.regionCache, s.client) + resp, err := sender.SendReq(bo, req, batch.regionID, readTimeoutShort) + + batchResp := singleBatchResp{resp: resp} + if err != nil { + batchResp.err = errors.Trace(err) + return batchResp + } + regionErr, err := resp.GetRegionError() + if err != nil { + batchResp.err = errors.Trace(err) + return batchResp + } + if regionErr != nil { + err := bo.Backoff(BoRegionMiss, errors.New(regionErr.String())) if err != nil { - return nil, errors.Trace(err) + batchResp.err = errors.Trace(err) + return batchResp } - if regionErr != nil { - err := bo.Backoff(BoRegionMiss, errors.New(regionErr.String())) - if err != nil { - return nil, errors.Trace(err) + resp, err = s.splitBatchRegionsReq(bo, batch.keys, scatter) + batchResp.resp = resp + batchResp.err = err + return batchResp + } + + spResp := resp.SplitRegion + regions := spResp.GetRegions() + if len(regions) > 0 { + // Divide a region into n, one of them may not need to be scattered, + // so n-1 needs to be scattered to other stores. + spResp.Regions = regions[:len(regions)-1] + } + logutil.Logger(context.Background()).Info("batch split regions complete", + zap.Uint64("batch region ID", batch.regionID.id), + zap.Binary("first at", batch.keys[0]), + zap.Stringer("first new region left", stringutil.MemoizeStr(func() string { + if len(spResp.Regions) == 0 { + return "" } + return spResp.Regions[0].String() + })), + zap.Int("new region count", len(spResp.Regions))) + + if !scatter { + if len(spResp.Regions) == 0 { + return batchResp + } + return batchResp + } + + for i, r := range spResp.Regions { + if err = s.scatterRegion(r.Id); err == nil { + logutil.Logger(context.Background()).Info("batch split regions, scatter region complete", + zap.Uint64("batch region ID", batch.regionID.id), + zap.Binary("at", batch.keys[i]), + zap.String("new region left", r.String())) continue } - logutil.Logger(context.Background()).Info("split region complete", - zap.Binary("at", splitKey), - zap.Stringer("new region left", res.SplitRegion.GetLeft()), - zap.Stringer("new region right", res.SplitRegion.GetRight())) - return res.SplitRegion.GetLeft(), nil + + logutil.Logger(context.Background()).Info("batch split regions, scatter region failed", + zap.Uint64("batch region ID", batch.regionID.id), + zap.Binary("at", batch.keys[i]), + zap.Stringer("new region left", r), + zap.Error(err)) + if batchResp.err == nil { + batchResp.err = err + } + if ErrPDServerTimeout.Equal(err) { + break + } + } + return batchResp +} + +// SplitRegions splits regions by splitKeys. +func (s *tikvStore) SplitRegions(ctx context.Context, splitKeys [][]byte, scatter bool) (regionIDs []uint64, err error) { + bo := NewBackoffer(ctx, int(math.Min(float64(len(splitKeys))*splitRegionBackoff, maxSplitRegionsBackoff))) + resp, err := s.splitBatchRegionsReq(bo, splitKeys, scatter) + regionIDs = make([]uint64, 0, len(splitKeys)) + if resp != nil && resp.SplitRegion != nil { + spResp := resp.SplitRegion + for _, r := range spResp.Regions { + regionIDs = append(regionIDs, r.Id) + } + logutil.Logger(context.Background()).Info("split regions complete", + zap.Int("region count", len(regionIDs)), zap.Uint64s("region IDs", regionIDs)) } + return regionIDs, errors.Trace(err) } func (s *tikvStore) scatterRegion(regionID uint64) error { + failpoint.Inject("MockScatterRegionTimeout", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(ErrPDServerTimeout) + } + }) + logutil.Logger(context.Background()).Info("start scatter region", zap.Uint64("regionID", regionID)) bo := NewBackoffer(context.Background(), scatterRegionBackoff) for { err := s.pdClient.ScatterRegion(context.Background(), regionID) + if err == nil { + break + } + err = bo.Backoff(BoPDRPC, errors.New(err.Error())) if err != nil { - err = bo.Backoff(BoRegionMiss, errors.New(err.Error())) - if err != nil { - return errors.Trace(err) - } - continue + return errors.Trace(err) } - break } - logutil.Logger(context.Background()).Info("scatter region complete", + logutil.Logger(context.Background()).Debug("scatter region complete", zap.Uint64("regionID", regionID)) return nil } -func (s *tikvStore) WaitScatterRegionFinish(regionID uint64) error { +// WaitScatterRegionFinish implements SplitableStore interface. +// backOff is the back off time of the wait scatter region.(Milliseconds) +// if backOff <= 0, the default wait scatter back off time will be used. +func (s *tikvStore) WaitScatterRegionFinish(regionID uint64, backOff int) error { + if backOff <= 0 { + backOff = waitScatterRegionFinishBackoff + } logutil.Logger(context.Background()).Info("wait scatter region", - zap.Uint64("regionID", regionID)) - bo := NewBackoffer(context.Background(), waitScatterRegionFinishBackoff) + zap.Uint64("regionID", regionID), zap.Int("backoff(ms)", backOff)) + + bo := NewBackoffer(context.Background(), backOff) logFreq := 0 for { resp, err := s.pdClient.GetOperator(context.Background(), regionID) @@ -115,7 +257,7 @@ func (s *tikvStore) WaitScatterRegionFinish(regionID uint64) error { if logFreq%10 == 0 { logutil.Logger(context.Background()).Info("wait scatter region", zap.Uint64("regionID", regionID), - zap.String("desc", string(resp.Desc)), + zap.String("reverse", string(resp.Desc)), zap.String("status", pdpb.OperatorStatus_name[int32(resp.Status)])) } logFreq++ @@ -129,20 +271,25 @@ func (s *tikvStore) WaitScatterRegionFinish(regionID uint64) error { return errors.Trace(err) } } - } -func (s *tikvStore) SplitRegionAndScatter(splitKey kv.Key) (uint64, error) { - left, err := s.splitRegion(splitKey) - if err != nil { - return 0, err - } - if left == nil { - return 0, nil - } - err = s.scatterRegion(left.Id) - if err != nil { - return 0, err +// CheckRegionInScattering uses to check whether scatter region finished. +func (s *tikvStore) CheckRegionInScattering(regionID uint64) (bool, error) { + bo := NewBackoffer(context.Background(), locateRegionMaxBackoff) + for { + resp, err := s.pdClient.GetOperator(context.Background(), regionID) + if err == nil && resp != nil { + if !bytes.Equal(resp.Desc, []byte("scatter-region")) || resp.Status != pdpb.OperatorStatus_RUNNING { + return false, nil + } + } + if err != nil { + err = bo.Backoff(BoRegionMiss, errors.New(err.Error())) + } else { + return true, nil + } + if err != nil { + return true, errors.Trace(err) + } } - return left.Id, nil } diff --git a/store/tikv/split_test.go b/store/tikv/split_test.go index 3a40c844b14e4..ff0c7066b3924 100644 --- a/store/tikv/split_test.go +++ b/store/tikv/split_test.go @@ -61,7 +61,7 @@ func (s *testSplitSuite) TestSplitBatchGet(c *C) { snapshot := newTiKVSnapshot(s.store, kv.Version{Ver: txn.StartTS()}) keys := [][]byte{{'a'}, {'b'}, {'c'}} - _, region, err := s.store.regionCache.GroupKeysByRegion(s.bo, keys) + _, region, err := s.store.regionCache.GroupKeysByRegion(s.bo, keys, nil) c.Assert(err, IsNil) batch := batchKeys{ region: region, diff --git a/store/tikv/sql_fail_test.go b/store/tikv/sql_fail_test.go index 28fd144853c39..7e58706453cd7 100644 --- a/store/tikv/sql_fail_test.go +++ b/store/tikv/sql_fail_test.go @@ -89,7 +89,7 @@ func (s *testSQLSuite) TestFailBusyServerCop(c *C) { defer terror.Call(rs[0].Close) } c.Assert(err, IsNil) - req := rs[0].NewRecordBatch() + req := rs[0].NewChunk() err = rs[0].Next(context.Background(), req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -124,7 +124,7 @@ func (s *testSQLSuite) TestCoprocessorStreamRecvTimeout(c *C) { res, err := tk.Se.Execute(ctx, "select * from t") c.Assert(err, IsNil) - req := res[0].NewRecordBatch() + req := res[0].NewChunk() for { err := res[0].Next(ctx, req) c.Assert(err, IsNil) diff --git a/store/tikv/store_test.go b/store/tikv/store_test.go index 66bd99c286e43..3524dd9dd0a5b 100644 --- a/store/tikv/store_test.go +++ b/store/tikv/store_test.go @@ -93,6 +93,8 @@ func (s *testStoreSuite) TestOracle(c *C) { wg.Wait() } +var _ pd.Client = &mockPDClient{} + type mockPDClient struct { sync.RWMutex client pd.Client @@ -189,6 +191,10 @@ func (c *mockPDClient) ScatterRegion(ctx context.Context, regionID uint64) error return nil } +func (c *mockPDClient) ScanRegions(ctx context.Context, key []byte, limit int) ([]*metapb.Region, []*metapb.Peer, error) { + return nil, nil, nil +} + func (c *mockPDClient) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) { return &pdpb.GetOperatorResponse{Status: pdpb.OperatorStatus_SUCCESS}, nil } diff --git a/store/tikv/test_util.go b/store/tikv/test_util.go index ffa34637ba0e5..e63abc405309e 100644 --- a/store/tikv/test_util.go +++ b/store/tikv/test_util.go @@ -14,10 +14,10 @@ package tikv import ( + "github.com/google/uuid" "github.com/pingcap/errors" "github.com/pingcap/pd/client" "github.com/pingcap/tidb/kv" - "github.com/twinj/uuid" ) // NewTestTiKVStore creates a test store with Option @@ -32,7 +32,7 @@ func NewTestTiKVStore(client Client, pdClient pd.Client, clientHijack func(Clien } // Make sure the uuid is unique. - uid := uuid.NewV4().String() + uid := uuid.New().String() spkv := NewMockSafePointKV() tikvStore, err := newTikvStore(uid, pdCli, spkv, client, false) diff --git a/store/tikv/ticlient_test.go b/store/tikv/ticlient_test.go index 6662bb9a8ad20..ddbbe3e610f0b 100644 --- a/store/tikv/ticlient_test.go +++ b/store/tikv/ticlient_test.go @@ -28,7 +28,7 @@ import ( ) var ( - withTiKVGlobalLock sync.Mutex + withTiKVGlobalLock sync.RWMutex withTiKV = flag.Bool("with-tikv", false, "run tests with TiKV cluster started. (not use the mock server)") pdAddrs = flag.String("pd-addrs", "127.0.0.1:2379", "pd addrs") ) @@ -119,7 +119,7 @@ func (s *testTiclientSuite) TestSingleKey(c *C) { txn := s.beginTxn(c) err := txn.Set(encodeKey(s.prefix, "key"), []byte("value")) c.Assert(err, IsNil) - err = txn.LockKeys(context.Background(), 0, encodeKey(s.prefix, "key")) + err = txn.LockKeys(context.Background(), nil, 0, encodeKey(s.prefix, "key")) c.Assert(err, IsNil) err = txn.Commit(context.Background()) c.Assert(err, IsNil) diff --git a/store/tikv/tikv_test.go b/store/tikv/tikv_test.go index a4db0b7df60ec..d4a5bfe1d6a6c 100644 --- a/store/tikv/tikv_test.go +++ b/store/tikv/tikv_test.go @@ -18,18 +18,21 @@ import ( ) // OneByOneSuite is a suite, When with-tikv flag is true, there is only one storage, so the test suite have to run one by one. -type OneByOneSuite struct { -} +type OneByOneSuite struct{} -func (s OneByOneSuite) SetUpSuite(c *C) { +func (s *OneByOneSuite) SetUpSuite(c *C) { if *withTiKV { withTiKVGlobalLock.Lock() + } else { + withTiKVGlobalLock.RLock() } } -func (s OneByOneSuite) TearDownSuite(c *C) { +func (s *OneByOneSuite) TearDownSuite(c *C) { if *withTiKV { withTiKVGlobalLock.Unlock() + } else { + withTiKVGlobalLock.RUnlock() } } diff --git a/store/tikv/tikvrpc/tikvrpc.go b/store/tikv/tikvrpc/tikvrpc.go index e62a153410a61..1d88e7b25bcb3 100644 --- a/store/tikv/tikvrpc/tikvrpc.go +++ b/store/tikv/tikvrpc/tikvrpc.go @@ -45,6 +45,8 @@ const ( CmdGC CmdDeleteRange CmdPessimisticLock + CmdPessimisticRollback + CmdTxnHeartBeat CmdRawGet CmdType = 256 + iota CmdRawBatchGet @@ -65,6 +67,8 @@ const ( CmdSplitRegion CmdDebugGetRegionProperties CmdType = 2048 + iota + + CmdEmpty CmdType = 3072 + iota ) func (t CmdType) String() string { @@ -77,6 +81,8 @@ func (t CmdType) String() string { return "Prewrite" case CmdPessimisticLock: return "PessimisticLock" + case CmdPessimisticRollback: + return "PessimisticRollback" case CmdCommit: return "Commit" case CmdCleanup: @@ -123,6 +129,8 @@ func (t CmdType) String() string { return "SplitRegion" case CmdDebugGetRegionProperties: return "DebugGetRegionProperties" + case CmdTxnHeartBeat: + return "TxnHeartBeat" } return "Unknown" } @@ -134,7 +142,6 @@ type Request struct { Get *kvrpcpb.GetRequest Scan *kvrpcpb.ScanRequest Prewrite *kvrpcpb.PrewriteRequest - PessimisticLock *kvrpcpb.PessimisticLockRequest Commit *kvrpcpb.CommitRequest Cleanup *kvrpcpb.CleanupRequest BatchGet *kvrpcpb.BatchGetRequest @@ -157,7 +164,13 @@ type Request struct { MvccGetByStartTs *kvrpcpb.MvccGetByStartTsRequest SplitRegion *kvrpcpb.SplitRegionRequest + PessimisticLock *kvrpcpb.PessimisticLockRequest + PessimisticRollback *kvrpcpb.PessimisticRollbackRequest + DebugGetRegionProperties *debugpb.GetRegionPropertiesRequest + + Empty *tikvpb.BatchCommandsEmptyRequest + TxnHeartBeat *kvrpcpb.TxnHeartBeatRequest } // ToBatchCommandsRequest converts the request to an entry in BatchCommands request. @@ -205,6 +218,12 @@ func (req *Request) ToBatchCommandsRequest() *tikvpb.BatchCommandsRequest_Reques return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Coprocessor{Coprocessor: req.Cop}} case CmdPessimisticLock: return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_PessimisticLock{PessimisticLock: req.PessimisticLock}} + case CmdPessimisticRollback: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_PessimisticRollback{PessimisticRollback: req.PessimisticRollback}} + case CmdEmpty: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Empty{Empty: req.Empty}} + case CmdTxnHeartBeat: + return &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_TxnHeartBeat{TxnHeartBeat: req.TxnHeartBeat}} } return nil } @@ -224,7 +243,6 @@ type Response struct { Get *kvrpcpb.GetResponse Scan *kvrpcpb.ScanResponse Prewrite *kvrpcpb.PrewriteResponse - PessimisticLock *kvrpcpb.PessimisticLockResponse Commit *kvrpcpb.CommitResponse Cleanup *kvrpcpb.CleanupResponse BatchGet *kvrpcpb.BatchGetResponse @@ -248,7 +266,13 @@ type Response struct { MvccGetByStartTS *kvrpcpb.MvccGetByStartTsResponse SplitRegion *kvrpcpb.SplitRegionResponse + PessimisticLock *kvrpcpb.PessimisticLockResponse + PessimisticRollback *kvrpcpb.PessimisticRollbackResponse + DebugGetRegionProperties *debugpb.GetRegionPropertiesResponse + + Empty *tikvpb.BatchCommandsEmptyResponse + TxnHeartBeat *kvrpcpb.TxnHeartBeatResponse } // FromBatchCommandsResponse converts a BatchCommands response to Response. @@ -296,6 +320,12 @@ func FromBatchCommandsResponse(res *tikvpb.BatchCommandsResponse_Response) *Resp return &Response{Type: CmdCop, Cop: res.Coprocessor} case *tikvpb.BatchCommandsResponse_Response_PessimisticLock: return &Response{Type: CmdPessimisticLock, PessimisticLock: res.PessimisticLock} + case *tikvpb.BatchCommandsResponse_Response_PessimisticRollback: + return &Response{Type: CmdPessimisticRollback, PessimisticRollback: res.PessimisticRollback} + case *tikvpb.BatchCommandsResponse_Response_Empty: + return &Response{Type: CmdEmpty, Empty: res.Empty} + case *tikvpb.BatchCommandsResponse_Response_TxnHeartBeat: + return &Response{Type: CmdTxnHeartBeat, TxnHeartBeat: res.TxnHeartBeat} } return nil } @@ -326,6 +356,8 @@ func SetContext(req *Request, region *metapb.Region, peer *metapb.Peer) error { req.Prewrite.Context = ctx case CmdPessimisticLock: req.PessimisticLock.Context = ctx + case CmdPessimisticRollback: + req.PessimisticRollback.Context = ctx case CmdCommit: req.Commit.Context = ctx case CmdCleanup: @@ -370,6 +402,9 @@ func SetContext(req *Request, region *metapb.Region, peer *metapb.Peer) error { req.MvccGetByStartTs.Context = ctx case CmdSplitRegion: req.SplitRegion.Context = ctx + case CmdEmpty: + case CmdTxnHeartBeat: + req.TxnHeartBeat.Context = ctx default: return fmt.Errorf("invalid request type %v", req.Type) } @@ -398,6 +433,10 @@ func GenRegionErrorResp(req *Request, e *errorpb.Error) (*Response, error) { resp.PessimisticLock = &kvrpcpb.PessimisticLockResponse{ RegionError: e, } + case CmdPessimisticRollback: + resp.PessimisticRollback = &kvrpcpb.PessimisticRollbackResponse{ + RegionError: e, + } case CmdCommit: resp.Commit = &kvrpcpb.CommitResponse{ RegionError: e, @@ -488,6 +527,11 @@ func GenRegionErrorResp(req *Request, e *errorpb.Error) (*Response, error) { resp.SplitRegion = &kvrpcpb.SplitRegionResponse{ RegionError: e, } + case CmdEmpty: + case CmdTxnHeartBeat: + resp.TxnHeartBeat = &kvrpcpb.TxnHeartBeatResponse{ + RegionError: e, + } default: return nil, fmt.Errorf("invalid request type %v", req.Type) } @@ -504,6 +548,8 @@ func (resp *Response) GetRegionError() (*errorpb.Error, error) { e = resp.Scan.GetRegionError() case CmdPessimisticLock: e = resp.PessimisticLock.GetRegionError() + case CmdPessimisticRollback: + e = resp.PessimisticRollback.GetRegionError() case CmdPrewrite: e = resp.Prewrite.GetRegionError() case CmdCommit: @@ -550,6 +596,9 @@ func (resp *Response) GetRegionError() (*errorpb.Error, error) { e = resp.MvccGetByStartTS.GetRegionError() case CmdSplitRegion: e = resp.SplitRegion.GetRegionError() + case CmdEmpty: + case CmdTxnHeartBeat: + e = resp.TxnHeartBeat.GetRegionError() default: return nil, fmt.Errorf("invalid response type %v", resp.Type) } @@ -572,6 +621,8 @@ func CallRPC(ctx context.Context, client tikvpb.TikvClient, req *Request) (*Resp resp.Prewrite, err = client.KvPrewrite(ctx, req.Prewrite) case CmdPessimisticLock: resp.PessimisticLock, err = client.KvPessimisticLock(ctx, req.PessimisticLock) + case CmdPessimisticRollback: + resp.PessimisticRollback, err = client.KVPessimisticRollback(ctx, req.PessimisticRollback) case CmdCommit: resp.Commit, err = client.KvCommit(ctx, req.Commit) case CmdCleanup: @@ -620,6 +671,10 @@ func CallRPC(ctx context.Context, client tikvpb.TikvClient, req *Request) (*Resp resp.MvccGetByStartTS, err = client.MvccGetByStartTs(ctx, req.MvccGetByStartTs) case CmdSplitRegion: resp.SplitRegion, err = client.SplitRegion(ctx, req.SplitRegion) + case CmdEmpty: + resp.Empty, err = &tikvpb.BatchCommandsEmptyResponse{}, nil + case CmdTxnHeartBeat: + resp.TxnHeartBeat, err = client.KvTxnHeartBeat(ctx, req.TxnHeartBeat) default: return nil, errors.Errorf("invalid request type: %v", req.Type) } diff --git a/store/tikv/txn.go b/store/tikv/txn.go index 3f2faccaef183..cd87396cb867c 100644 --- a/store/tikv/txn.go +++ b/store/tikv/txn.go @@ -17,10 +17,13 @@ import ( "context" "fmt" "sync" + "sync/atomic" "time" + "github.com/dgryski/go-farm" "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx" @@ -58,6 +61,7 @@ type tikvTxn struct { commitTS uint64 valid bool lockKeys [][]byte + lockedMap map[string]struct{} mu sync.Mutex // For thread-safe LockKeys function. dirty bool setCnt int64 @@ -84,6 +88,7 @@ func newTikvTxnWithStartTS(store *tikvStore, startTS uint64) (*tikvTxn, error) { return &tikvTxn{ snapshot: snapshot, us: kv.NewUnionStore(snapshot), + lockedMap: map[string]struct{}{}, store: store, startTS: startTS, startTime: time.Now(), @@ -223,6 +228,8 @@ func (txn *tikvTxn) SetOption(opt kv.Option, val interface{}) { txn.snapshot.syncLog = val.(bool) case kv.KeyOnly: txn.snapshot.keyOnly = val.(bool) + case kv.SnapshotTS: + txn.snapshot.setSnapshotTS(val.(uint64)) } } @@ -268,6 +275,7 @@ func (txn *tikvTxn) Commit(ctx context.Context) error { return errors.Trace(err) } } + defer committer.ttlManager.close() if err := committer.initKeysAndMutations(); err != nil { return errors.Trace(err) } @@ -282,7 +290,7 @@ func (txn *tikvTxn) Commit(ctx context.Context) error { if *commitDetail != nil { (*commitDetail).TxnRetry += 1 } else { - *commitDetail = committer.detail + *commitDetail = committer.getDetail() } } }() @@ -298,9 +306,10 @@ func (txn *tikvTxn) Commit(ctx context.Context) error { // for transactions which need to acquire latches start = time.Now() lock := txn.store.txnLatches.Lock(committer.startTS, committer.keys) - committer.detail.LocalLatchTime = time.Since(start) - if committer.detail.LocalLatchTime > 0 { - metrics.TiKVLocalLatchWaitTimeHistogram.Observe(committer.detail.LocalLatchTime.Seconds()) + commitDetail := committer.getDetail() + commitDetail.LocalLatchTime = time.Since(start) + if commitDetail.LocalLatchTime > 0 { + metrics.TiKVLocalLatchWaitTimeHistogram.Observe(commitDetail.LocalLatchTime.Seconds()) } defer txn.store.txnLatches.UnLock(lock) if lock.IsStale() { @@ -319,17 +328,17 @@ func (txn *tikvTxn) close() { } func (txn *tikvTxn) Rollback() error { + if !txn.valid { + return kv.ErrInvalidTxn + } // Clean up pessimistic lock. if txn.IsPessimistic() && txn.committer != nil { - err := txn.rollbackPessimisticLock() + err := txn.rollbackPessimisticLocks() + txn.committer.ttlManager.close() if err != nil { logutil.Logger(context.Background()).Error(err.Error()) } } - - if !txn.valid { - return kv.ErrInvalidTxn - } txn.close() logutil.Logger(context.Background()).Debug("[kv] rollback txn", zap.Uint64("txnStartTS", txn.StartTS())) tikvTxnCmdCountWithRollback.Inc() @@ -337,19 +346,23 @@ func (txn *tikvTxn) Rollback() error { return nil } -func (txn *tikvTxn) rollbackPessimisticLock() error { - c := txn.committer - if err := c.initKeysAndMutations(); err != nil { - return errors.Trace(err) - } - if len(c.keys) == 0 { +func (txn *tikvTxn) rollbackPessimisticLocks() error { + if len(txn.lockKeys) == 0 { return nil } - - return c.cleanupKeys(NewBackoffer(context.Background(), cleanupMaxBackoff), c.keys) + return txn.committer.pessimisticRollbackKeys(NewBackoffer(context.Background(), cleanupMaxBackoff), txn.lockKeys) } -func (txn *tikvTxn) LockKeys(ctx context.Context, forUpdateTS uint64, keys ...kv.Key) error { +func (txn *tikvTxn) LockKeys(ctx context.Context, killed *uint32, forUpdateTS uint64, keysInput ...kv.Key) error { + // Exclude keys that are already locked. + keys := make([][]byte, 0, len(keysInput)) + txn.mu.Lock() + for _, key := range keysInput { + if _, ok := txn.lockedMap[string(key)]; !ok { + keys = append(keys, key) + } + } + txn.mu.Unlock() if len(keys) == 0 { return nil } @@ -367,32 +380,96 @@ func (txn *tikvTxn) LockKeys(ctx context.Context, forUpdateTS uint64, keys ...kv if err != nil { return err } + } + if txn.committer.pessimisticTTL == 0 { + // add elapsed time to pessimistic TTL on the first LockKeys request. + elapsed := uint64(time.Since(txn.startTime) / time.Millisecond) + txn.committer.pessimisticTTL = PessimisticLockTTL + elapsed + } + var assignedPrimaryKey bool + if txn.committer.primaryKey == nil { txn.committer.primaryKey = keys[0] + assignedPrimaryKey = true } - bo := NewBackoffer(ctx, prewriteMaxBackoff).WithVars(txn.vars) - keys1 := make([][]byte, len(keys)) - for i, key := range keys { - keys1[i] = key - } + bo := NewBackoffer(ctx, pessimisticLockMaxBackoff).WithVars(txn.vars) txn.committer.forUpdateTS = forUpdateTS - // If the number of keys1 greater than 1, it can be on different region, + // If the number of keys greater than 1, it can be on different region, // concurrently execute on multiple regions may lead to deadlock. - txn.committer.isFirstLock = len(txn.lockKeys) == 0 && len(keys1) == 1 - err := txn.committer.pessimisticLockKeys(bo, keys1) + txn.committer.isFirstLock = len(txn.lockKeys) == 0 && len(keys) == 1 + err := txn.committer.pessimisticLockKeys(bo, killed, keys) + if killed != nil { + // If the kill signal is received during waiting for pessimisticLock, + // pessimisticLockKeys would handle the error but it doesn't reset the flag. + // We need to reset the killed flag here. + atomic.CompareAndSwapUint32(killed, 1, 0) + } if err != nil { + for _, key := range keys { + txn.us.DeleteConditionPair(key) + } + keyMayBeLocked := terror.ErrorNotEqual(kv.ErrWriteConflict, err) && terror.ErrorNotEqual(kv.ErrKeyExists, err) + // If there is only 1 key and lock fails, no need to do pessimistic rollback. + if len(keys) > 1 || keyMayBeLocked { + wg := txn.asyncPessimisticRollback(ctx, keys) + if dl, ok := errors.Cause(err).(*ErrDeadlock); ok && hashInKeys(dl.DeadlockKeyHash, keys) { + dl.IsRetryable = true + // Wait for the pessimistic rollback to finish before we retry the statement. + wg.Wait() + // Sleep a little, wait for the other transaction that blocked by this transaction to acquire the lock. + time.Sleep(time.Millisecond * 5) + } + } + if assignedPrimaryKey { + // unset the primary key if we assigned primary key when failed to lock it. + txn.committer.primaryKey = nil + } return err } + if assignedPrimaryKey { + txn.committer.ttlManager.run(txn.committer) + } } txn.mu.Lock() + txn.lockKeys = append(txn.lockKeys, keys...) for _, key := range keys { - txn.lockKeys = append(txn.lockKeys, key) + txn.lockedMap[string(key)] = struct{}{} } txn.dirty = true txn.mu.Unlock() return nil } +func (txn *tikvTxn) asyncPessimisticRollback(ctx context.Context, keys [][]byte) *sync.WaitGroup { + // Clone a new committer for execute in background. + committer := &twoPhaseCommitter{ + store: txn.committer.store, + connID: txn.committer.connID, + startTS: txn.committer.startTS, + forUpdateTS: txn.committer.forUpdateTS, + primaryKey: txn.committer.primaryKey, + } + wg := new(sync.WaitGroup) + wg.Add(1) + go func() { + err := committer.pessimisticRollbackKeys(NewBackoffer(ctx, pessimisticRollbackMaxBackoff), keys) + if err != nil { + logutil.Logger(ctx).Warn("[kv] pessimisticRollback failed.", zap.Error(err)) + } + wg.Done() + }() + return wg +} + +func hashInKeys(deadlockKeyHash uint64, keys [][]byte) bool { + for _, key := range keys { + if farm.Fingerprint64(key) == deadlockKeyHash { + return true + } + } + return false +} + func (txn *tikvTxn) IsReadOnly() bool { return !txn.dirty } diff --git a/structure/hash.go b/structure/hash.go index 3249884a0b6fc..ddad8d69d0344 100644 --- a/structure/hash.go +++ b/structure/hash.go @@ -216,6 +216,23 @@ func (t *TxStructure) HGetAll(key []byte) ([]HashPair, error) { return res, errors.Trace(err) } +// HGetLastN gets latest N fields and values in hash. +func (t *TxStructure) HGetLastN(key []byte, num int) ([]HashPair, error) { + res := make([]HashPair, 0, num) + err := t.iterReverseHash(key, func(field []byte, value []byte) (bool, error) { + pair := HashPair{ + Field: append([]byte{}, field...), + Value: append([]byte{}, value...), + } + res = append(res, pair) + if len(res) >= num { + return false, nil + } + return true, nil + }) + return res, errors.Trace(err) +} + // HClear removes the hash value of the key. func (t *TxStructure) HClear(key []byte) error { metaKey := t.encodeHashMetaKey(key) @@ -268,6 +285,37 @@ func (t *TxStructure) iterateHash(key []byte, fn func(k []byte, v []byte) error) return nil } +func (t *TxStructure) iterReverseHash(key []byte, fn func(k []byte, v []byte) (bool, error)) error { + dataPrefix := t.hashDataKeyPrefix(key) + it, err := t.reader.IterReverse(dataPrefix.PrefixNext()) + if err != nil { + return errors.Trace(err) + } + + var field []byte + for it.Valid() { + if !it.Key().HasPrefix(dataPrefix) { + break + } + + _, field, err = t.decodeHashDataKey(it.Key()) + if err != nil { + return errors.Trace(err) + } + + more, err := fn(field, it.Value()) + if !more || err != nil { + return errors.Trace(err) + } + + err = it.Next() + if err != nil { + return errors.Trace(err) + } + } + return nil +} + func (t *TxStructure) loadHashMeta(metaKey []byte) (hashMeta, error) { v, err := t.reader.Get(metaKey) if kv.ErrNotExist.Equal(err) { diff --git a/structure/structure_test.go b/structure/structure_test.go index 5ecab9f75c3a9..e6e55fcf5dfef 100644 --- a/structure/structure_test.go +++ b/structure/structure_test.go @@ -244,6 +244,17 @@ func (s *testTxStructureSuite) TestHash(c *C) { {Field: []byte("1"), Value: []byte("1")}, {Field: []byte("2"), Value: []byte("2")}}) + res, err = tx.HGetLastN(key, 1) + c.Assert(err, IsNil) + c.Assert(res, DeepEquals, []structure.HashPair{ + {Field: []byte("2"), Value: []byte("2")}}) + + res, err = tx.HGetLastN(key, 2) + c.Assert(err, IsNil) + c.Assert(res, DeepEquals, []structure.HashPair{ + {Field: []byte("2"), Value: []byte("2")}, + {Field: []byte("1"), Value: []byte("1")}}) + err = tx.HDel(key, []byte("1")) c.Assert(err, IsNil) diff --git a/structure/type.go b/structure/type.go index 89759269871c9..7096d70e86984 100644 --- a/structure/type.go +++ b/structure/type.go @@ -63,6 +63,11 @@ func (t *TxStructure) encodeHashDataKey(key []byte, field []byte) kv.Key { return codec.EncodeBytes(ek, field) } +// EncodeHashDataKey exports for tests. +func (t *TxStructure) EncodeHashDataKey(key []byte, field []byte) kv.Key { + return t.encodeHashDataKey(key, field) +} + func (t *TxStructure) decodeHashDataKey(ek kv.Key) ([]byte, []byte, error) { var ( key []byte diff --git a/table/column.go b/table/column.go index ce9fb65454673..40351928747dc 100644 --- a/table/column.go +++ b/table/column.go @@ -19,6 +19,8 @@ package table import ( "context" + "fmt" + "strconv" "strings" "time" "unicode/utf8" @@ -256,6 +258,13 @@ func NewColDesc(col *Column) *ColDesc { var defaultValue interface{} if !mysql.HasNoDefaultValueFlag(col.Flag) { defaultValue = col.GetDefaultValue() + if defaultValStr, ok := defaultValue.(string); ok { + if (col.Tp == mysql.TypeTimestamp || col.Tp == mysql.TypeDatetime) && + strings.ToUpper(defaultValStr) == strings.ToUpper(ast.CurrentTimestamp) && + col.Decimal > 0 { + defaultValue = fmt.Sprintf("%s(%d)", defaultValStr, col.Decimal) + } + } } extra := "" @@ -264,7 +273,7 @@ func NewColDesc(col *Column) *ColDesc { } else if mysql.HasOnUpdateNowFlag(col.Flag) { //in order to match the rules of mysql 8.0.16 version //see https://github.com/pingcap/tidb/issues/10337 - extra = "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" + extra = "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" + OptionalFsp(&col.FieldType) } else if col.IsGenerated() { if col.GeneratedStored { extra = "STORED GENERATED" @@ -354,6 +363,12 @@ func CheckNotNull(cols []*Column, row []types.Datum) error { // GetColOriginDefaultValue gets default value of the column from original default value. func GetColOriginDefaultValue(ctx sessionctx.Context, col *model.ColumnInfo) (types.Datum, error) { + // If the column type is BIT, both `OriginDefaultValue` and `DefaultValue` of ColumnInfo are corrupted, because + // after JSON marshaling and unmarshaling against the field with type `interface{}`, the content with actual type `[]byte` is changed. + // We need `DefaultValueBit` to restore OriginDefaultValue before reading it. + if col.Tp == mysql.TypeBit && col.DefaultValueBit != nil && col.OriginDefaultValue != nil { + col.OriginDefaultValue = col.DefaultValueBit + } return getColDefaultValue(ctx, col, col.OriginDefaultValue) } @@ -486,3 +501,12 @@ func GetZeroValue(col *model.ColumnInfo) types.Datum { } return d } + +// OptionalFsp convert a FieldType.Decimal to string. +func OptionalFsp(fieldType *types.FieldType) string { + fsp := fieldType.Decimal + if fsp == 0 { + return "" + } + return "(" + strconv.Itoa(fsp) + ")" +} diff --git a/table/table.go b/table/table.go index b0952a6c46096..c06fa60186333 100644 --- a/table/table.go +++ b/table/table.go @@ -18,6 +18,9 @@ package table import ( + "context" + + "github.com/opentracing/opentracing-go" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" @@ -143,8 +146,8 @@ type Table interface { // RemoveRecord removes a row in the table. RemoveRecord(ctx sessionctx.Context, h int64, r []types.Datum) error - // AllocAutoID allocates an auto_increment ID for a new row. - AllocAutoID(ctx sessionctx.Context) (int64, error) + // AllocHandle allocates a handle for a new row. + AllocHandle(ctx sessionctx.Context) (int64, error) // Allocator returns Allocator. Allocator(ctx sessionctx.Context) autoid.Allocator @@ -164,6 +167,28 @@ type Table interface { Type() Type } +// AllocAutoIncrementValue allocates an auto_increment value for a new row. +func AllocAutoIncrementValue(ctx context.Context, t Table, sctx sessionctx.Context) (int64, error) { + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("table.AllocAutoIncrementValue", opentracing.ChildOf(span.Context())) + defer span1.Finish() + } + _, max, err := t.Allocator(sctx).Alloc(t.Meta().ID, uint64(1)) + if err != nil { + return 0, err + } + return max, err +} + +// AllocBatchAutoIncrementValue allocates batch auto_increment value (min and max] for rows. +func AllocBatchAutoIncrementValue(ctx context.Context, t Table, sctx sessionctx.Context, N int) (int64, int64, error) { + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("table.AllocBatchAutoIncrementValue", opentracing.ChildOf(span.Context())) + defer span1.Finish() + } + return t.Allocator(sctx).Alloc(t.Meta().ID, uint64(N)) +} + // PhysicalTable is an abstraction for two kinds of table representation: partition or non-partitioned table. // PhysicalID is a ID that can be used to construct a key ranges, all the data in the key range belongs to the corresponding PhysicalTable. // For a non-partitioned table, its PhysicalID equals to its TableID; For a partition of a partitioned table, its PhysicalID is the partition's ID. diff --git a/table/tables/tables.go b/table/tables/tables.go index 3b1497345ead1..3778ed5453d14 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -443,7 +443,7 @@ func (t *tableCommon) AddRecord(ctx sessionctx.Context, r []types.Datum, opts .. } } if !hasRecordID { - recordID, err = t.AllocAutoID(ctx) + recordID, err = t.AllocHandle(ctx) if err != nil { return 0, err } @@ -665,11 +665,7 @@ func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h int64, co // Row implements table.Table Row interface. func (t *tableCommon) Row(ctx sessionctx.Context, h int64) ([]types.Datum, error) { - r, err := t.RowWithCols(ctx, h, t.Cols()) - if err != nil { - return nil, err - } - return r, nil + return t.RowWithCols(ctx, h, t.Cols()) } // RemoveRecord implements table.Table RemoveRecord interface. @@ -914,9 +910,9 @@ func GetColDefaultValue(ctx sessionctx.Context, col *table.Column, defaultVals [ return colVal, nil } -// AllocAutoID implements table.Table AllocAutoID interface. -func (t *tableCommon) AllocAutoID(ctx sessionctx.Context) (int64, error) { - rowID, err := t.Allocator(ctx).Alloc(t.tableID) +// AllocHandle implements table.Table AllocHandle interface. +func (t *tableCommon) AllocHandle(ctx sessionctx.Context) (int64, error) { + _, rowID, err := t.Allocator(ctx).Alloc(t.tableID, 1) if err != nil { return 0, err } @@ -1042,7 +1038,7 @@ var ( recordPrefixSep = []byte("_r") ) -// FindIndexByColName implements table.Table FindIndexByColName interface. +// FindIndexByColName returns a public table index containing only one column named `name`. func FindIndexByColName(t table.Table, name string) table.Index { for _, idx := range t.Indices() { // only public index can be read. diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index 30e6eac0bb526..2afca28c3607c 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -95,10 +95,14 @@ func (ts *testSuite) TestBasic(c *C) { c.Assert(string(tb.RecordPrefix()), Not(Equals), "") c.Assert(tables.FindIndexByColName(tb, "b"), NotNil) - autoid, err := tb.AllocAutoID(nil) + autoid, err := table.AllocAutoIncrementValue(context.Background(), tb, nil) c.Assert(err, IsNil) c.Assert(autoid, Greater, int64(0)) + handle, err := tb.AllocHandle(nil) + c.Assert(err, IsNil) + c.Assert(handle, Greater, int64(0)) + ctx := ts.se rid, err := tb.AddRecord(ctx, types.MakeDatums(1, "abc")) c.Assert(err, IsNil) @@ -182,7 +186,7 @@ func (ts *testSuite) TestTypes(c *C) { c.Assert(err, IsNil) rs, err := ts.se.Execute(ctx, "select * from test.t where c1 = 1") c.Assert(err, IsNil) - req := rs[0].NewRecordBatch() + req := rs[0].NewChunk() err = rs[0].Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -196,7 +200,7 @@ func (ts *testSuite) TestTypes(c *C) { c.Assert(err, IsNil) rs, err = ts.se.Execute(ctx, "select * from test.t where c1 = 1") c.Assert(err, IsNil) - req = rs[0].NewRecordBatch() + req = rs[0].NewChunk() err = rs[0].Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -212,7 +216,7 @@ func (ts *testSuite) TestTypes(c *C) { c.Assert(err, IsNil) rs, err = ts.se.Execute(ctx, "select c1 + 1 from test.t where c1 = 1") c.Assert(err, IsNil) - req = rs[0].NewRecordBatch() + req = rs[0].NewChunk() err = rs[0].Next(ctx, req) c.Assert(err, IsNil) c.Assert(req.NumRows() == 0, IsFalse) @@ -239,10 +243,15 @@ func (ts *testSuite) TestUniqueIndexMultipleNullEntries(c *C) { c.Assert(string(tb.RecordPrefix()), Not(Equals), "") c.Assert(tables.FindIndexByColName(tb, "b"), NotNil) - autoid, err := tb.AllocAutoID(nil) - sctx := ts.se + handle, err := tb.AllocHandle(nil) + c.Assert(err, IsNil) + c.Assert(handle, Greater, int64(0)) + + autoid, err := table.AllocAutoIncrementValue(context.Background(), tb, nil) c.Assert(err, IsNil) c.Assert(autoid, Greater, int64(0)) + + sctx := ts.se c.Assert(sctx.NewTxn(ctx), IsNil) _, err = tb.AddRecord(sctx, types.MakeDatums(1, nil)) c.Assert(err, IsNil) @@ -373,13 +382,13 @@ func (ts *testSuite) TestTableFromMeta(c *C) { tk.MustExec("create table t_meta (a int) shard_row_id_bits = 15") tb, err = domain.GetDomain(tk.Se).InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t_meta")) c.Assert(err, IsNil) - _, err = tb.AllocAutoID(tk.Se) + _, err = tb.AllocHandle(tk.Se) c.Assert(err, IsNil) maxID := 1<<(64-15-1) - 1 err = tb.RebaseAutoID(tk.Se, int64(maxID), false) c.Assert(err, IsNil) - _, err = tb.AllocAutoID(tk.Se) + _, err = tb.AllocHandle(tk.Se) c.Assert(err, NotNil) } diff --git a/tidb-server/main.go b/tidb-server/main.go index 19d95a04e9702..3396e42fe568b 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -28,8 +28,9 @@ import ( "github.com/pingcap/log" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" - "github.com/pingcap/pd/client" + pd "github.com/pingcap/pd/client" pumpcli "github.com/pingcap/tidb-tools/tidb-binlog/pump_client" + "github.com/pingcap/tidb/bindinfo" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/domain" @@ -172,15 +173,19 @@ func main() { signal.SetupSignalHandler(serverShutdown) runServer() cleanup() - exit() + syncLog() } func exit() { + syncLog() + os.Exit(0) +} + +func syncLog() { if err := log.Sync(); err != nil { fmt.Fprintln(os.Stderr, "sync log err:", err) os.Exit(1) } - os.Exit(0) } func registerStores() { @@ -315,6 +320,12 @@ func loadConfig() string { return err.Error() } terror.MustNil(err) + } else { + // configCheck should have the config file specified. + if *configCheck { + fmt.Fprintln(os.Stderr, "config check failed", errors.New("no config file specified for config-check")) + os.Exit(1) + } } return "" } @@ -322,7 +333,7 @@ func loadConfig() string { // hotReloadConfigItems lists all config items which support hot-reload. var hotReloadConfigItems = []string{"Performance.MaxProcs", "Performance.MaxMemory", "Performance.CrossJoin", "Performance.FeedbackProbability", "Performance.QueryFeedbackLimit", "Performance.PseudoEstimateRatio", - "OOMAction", "MemQuotaQuery"} + "OOMAction", "MemQuotaQuery", "StmtSummary.MaxStmtCount", "StmtSummary.MaxSQLLength"} func reloadConfig(nc, c *config.Config) { // Just a part of config items need to be reload explicitly. @@ -448,6 +459,7 @@ func setGlobalVars() { runtime.GOMAXPROCS(int(cfg.Performance.MaxProcs)) statsLeaseDuration := parseDuration(cfg.Performance.StatsLease) session.SetStatsLease(statsLeaseDuration) + bindinfo.Lease = parseDuration(cfg.Performance.BindInfoLease) domain.RunAutoAnalyze = cfg.Performance.RunAutoAnalyze statistics.FeedbackProbability.Store(cfg.Performance.FeedbackProbability) handle.MaxQueryFeedbackCount.Store(int64(cfg.Performance.QueryFeedbackLimit)) @@ -490,6 +502,7 @@ func setGlobalVars() { tikv.CommitMaxBackoff = int(parseDuration(cfg.TiKVClient.CommitTimeout).Seconds() * 1000) tikv.PessimisticLockTTL = uint64(parseDuration(cfg.PessimisticTxn.TTL).Seconds() * 1000) + tikv.RegionCacheTTLSec = int64(cfg.TiKVClient.RegionCacheTTL) } func setupLog() { @@ -514,6 +527,7 @@ func createServer() { svr, err = server.NewServer(cfg, driver) // Both domain and storage have started, so we have to clean them before exiting. terror.MustNil(err, closeDomainAndStorage) + go dom.ExpensiveQueryHandle().SetSessionManager(svr).Run() } func serverShutdown(isgraceful bool) { diff --git a/util/chunk/recordbatch.go b/tidb-server/main_test.go similarity index 63% rename from util/chunk/recordbatch.go rename to tidb-server/main_test.go index 3adebec898f8c..bd34b09481887 100644 --- a/util/chunk/recordbatch.go +++ b/tidb-server/main_test.go @@ -11,15 +11,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package chunk +package main -// RecordBatch is input parameter of Executor.Next` method. -// TODO: remove RecordBatch after finishing chunk size control. -type RecordBatch struct { - *Chunk -} +import "testing" + +var isCoverageServer = "0" -// NewRecordBatch is used to construct a RecordBatch. -func NewRecordBatch(chk *Chunk) *RecordBatch { - return &RecordBatch{chk} +// TestRunMain is a dummy test case, which contains only the main function of tidb-server, +// and it is used to generate coverage_server. +func TestRunMain(t *testing.T) { + if isCoverageServer == "1" { + main() + } } diff --git a/tools/check/check_testSuite.sh b/tools/check/check_testSuite.sh new file mode 100755 index 0000000000000..dd743df3830bf --- /dev/null +++ b/tools/check/check_testSuite.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +set -euo pipefail + +exitCode=0 +for testSuite in $(find . -name "*_test.go" -print0 | xargs -0 grep -P "type test(.*)Suite" | awk '{print $2}'); do + # TODO: ugly regex + # TODO: check code comment + if ! find . -name "*_test.go" -print0 | xargs -0 grep -P "_ = (check\.)?(Suite|SerialSuites)\((&?${testSuite}{|new\(${testSuite}\))" > /dev/null + then + if find . -name "*_test.go" -print0 | xargs -0 grep -P "func \((.* )?\*?${testSuite}\) Test" > /dev/null + then + echo "${testSuite} is not enabled" && exitCode=1 + fi + fi +done +exit ${exitCode} diff --git a/tools/check/go.mod b/tools/check/go.mod index ca5d580f6d6a4..6dfc12cecadbd 100644 --- a/tools/check/go.mod +++ b/tools/check/go.mod @@ -20,3 +20,5 @@ require ( gopkg.in/yaml.v2 v2.2.2 // indirect honnef.co/go/tools v0.0.0-20180920025451-e3ad64cb4ed3 ) + +go 1.13 diff --git a/types/const_test.go b/types/const_test.go index 3f6573f733785..b9a2c64fb37c8 100644 --- a/types/const_test.go +++ b/types/const_test.go @@ -56,7 +56,7 @@ func (s *testMySQLConstSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() } var err error s.dom, err = session.BootstrapSession(s.store) diff --git a/types/convert.go b/types/convert.go index 5e16ad43ac2d1..2dffeb6a44a2d 100644 --- a/types/convert.go +++ b/types/convert.go @@ -98,6 +98,7 @@ func IntergerSignedLowerBound(intType byte) int64 { } // ConvertFloatToInt converts a float64 value to a int value. +// `tp` is used in err msg, if there is overflow, this func will report err according to `tp` func ConvertFloatToInt(fval float64, lowerBound, upperBound int64, tp byte) (int64, error) { val := RoundFloat(fval) if val < float64(lowerBound) { @@ -362,30 +363,63 @@ func NumberToDuration(number int64, fsp int) (Duration, error) { // getValidIntPrefix gets prefix of the string which can be successfully parsed as int. func getValidIntPrefix(sc *stmtctx.StatementContext, str string) (string, error) { - floatPrefix, err := getValidFloatPrefix(sc, str) - if err != nil { - return floatPrefix, errors.Trace(err) + if !sc.CastStrToIntStrict { + floatPrefix, err := getValidFloatPrefix(sc, str) + if err != nil { + return floatPrefix, errors.Trace(err) + } + return floatStrToIntStr(sc, floatPrefix, str) } - return floatStrToIntStr(sc, floatPrefix, str) + + validLen := 0 + + for i := 0; i < len(str); i++ { + c := str[i] + if (c == '+' || c == '-') && i == 0 { + continue + } + + if c >= '0' && c <= '9' { + validLen = i + 1 + continue + } + + break + } + valid := str[:validLen] + if valid == "" { + valid = "0" + } + if validLen == 0 || validLen != len(str) { + return valid, errors.Trace(handleTruncateError(sc, ErrTruncatedWrongVal.GenWithStackByArgs("INTEGER", str))) + } + return valid, nil } -// roundIntStr is to round int string base on the number following dot. +// roundIntStr is to round a **valid int string** base on the number following dot. func roundIntStr(numNextDot byte, intStr string) string { if numNextDot < '5' { return intStr } retStr := []byte(intStr) - for i := len(intStr) - 1; i >= 0; i-- { - if retStr[i] != '9' { - retStr[i]++ + idx := len(intStr) - 1 + for ; idx >= 1; idx-- { + if retStr[idx] != '9' { + retStr[idx]++ break } - if i == 0 { - retStr[i] = '1' + retStr[idx] = '0' + } + if idx == 0 { + if intStr[0] == '9' { + retStr[0] = '1' + retStr = append(retStr, '0') + } else if isDigit(intStr[0]) { + retStr[0]++ + } else { + retStr[1] = '1' retStr = append(retStr, '0') - break } - retStr[i] = '0' } return string(retStr) } @@ -429,6 +463,7 @@ func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr st } return intStr, nil } + // intCnt and digits contain the prefix `+/-` if validFloat[0] is `+/-` var intCnt int digits := make([]byte, 0, len(validFloat)) if dotIdx == -1 { @@ -451,7 +486,7 @@ func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr st intCnt += exp if intCnt <= 0 { intStr = "0" - if intCnt == 0 && len(digits) > 0 { + if intCnt == 0 && len(digits) > 0 && isDigit(digits[0]) { intStr = roundIntStr(digits[0], intStr) } return intStr, nil @@ -525,14 +560,18 @@ func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned if !unsigned { lBound := IntergerSignedLowerBound(mysql.TypeLonglong) uBound := IntergerSignedUpperBound(mysql.TypeLonglong) - return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble) + return ConvertFloatToInt(f, lBound, uBound, mysql.TypeLonglong) } bound := IntergerUnsignedUpperBound(mysql.TypeLonglong) - u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble) + u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeLonglong) return int64(u), errors.Trace(err) case json.TypeCodeString: str := string(hack.String(j.GetString())) - return StrToInt(sc, str) + if !unsigned { + return StrToInt(sc, str) + } + u, err := StrToUint(sc, str) + return int64(u), errors.Trace(err) } return 0, errors.New("Unknown type code in JSON") } @@ -552,8 +591,7 @@ func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float6 case json.TypeCodeInt64: return float64(j.GetInt64()), nil case json.TypeCodeUint64: - u, err := ConvertIntToUint(sc, j.GetInt64(), IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong) - return float64(u), errors.Trace(err) + return float64(j.GetUint64()), nil case json.TypeCodeFloat64: return j.GetFloat64(), nil case json.TypeCodeString: @@ -580,6 +618,9 @@ func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyD // getValidFloatPrefix gets prefix of string which can be successfully parsed as float. func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) { + if (sc.InDeleteStmt || sc.InSelectStmt || sc.InUpdateStmt) && s == "" { + return "0", nil + } var ( sawDot bool sawDigit bool @@ -620,7 +661,7 @@ func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, valid = "0" } if validLen == 0 || validLen != len(s) { - err = errors.Trace(handleTruncateError(sc)) + err = errors.Trace(handleTruncateError(sc, ErrTruncated)) } return valid, err } diff --git a/types/convert_test.go b/types/convert_test.go index 4e39c9dbe0acd..845e5fc4dfde4 100644 --- a/types/convert_test.go +++ b/types/convert_test.go @@ -461,6 +461,42 @@ func (s *testTypeConvertSuite) TestStrToNum(c *C) { testStrToFloat(c, "1e649", math.MaxFloat64, false, nil) testStrToFloat(c, "-1e649", -math.MaxFloat64, true, ErrTruncatedWrongVal) testStrToFloat(c, "-1e649", -math.MaxFloat64, false, nil) + + // for issue #10806, #11179 + testSelectUpdateDeleteEmptyStringError(c) +} + +func testSelectUpdateDeleteEmptyStringError(c *C) { + testCases := []struct { + inSelect bool + inUpdate bool + inDelete bool + }{ + {true, false, false}, + {false, true, false}, + {false, false, true}, + } + sc := new(stmtctx.StatementContext) + for _, tc := range testCases { + sc.InSelectStmt = tc.inSelect + sc.InUpdateStmt = tc.inUpdate + sc.InDeleteStmt = tc.inDelete + + str := "" + expect := 0 + + val, err := StrToInt(sc, str) + c.Assert(err, IsNil) + c.Assert(val, Equals, int64(expect)) + + val1, err := StrToUint(sc, str) + c.Assert(err, IsNil) + c.Assert(val1, Equals, uint64(expect)) + + val2, err := StrToFloat(sc, str) + c.Assert(err, IsNil) + c.Assert(val2, Equals, float64(expect)) + } } func (s *testTypeConvertSuite) TestFieldTypeToStr(c *C) { @@ -618,12 +654,26 @@ func (s *testTypeConvertSuite) TestConvert(c *C) { signedAccept(c, mysql.TypeDouble, "1e+1", "10") // year - signedDeny(c, mysql.TypeYear, 123, "") - signedDeny(c, mysql.TypeYear, 3000, "") + signedDeny(c, mysql.TypeYear, 123, "0") + signedDeny(c, mysql.TypeYear, 3000, "0") signedAccept(c, mysql.TypeYear, "2000", "2000") signedAccept(c, mysql.TypeYear, "abc", "0") signedAccept(c, mysql.TypeYear, "00abc", "2000") signedAccept(c, mysql.TypeYear, "0019", "2019") + signedAccept(c, mysql.TypeYear, 2155, "2155") + signedAccept(c, mysql.TypeYear, 2155.123, "2155") + signedDeny(c, mysql.TypeYear, 2156, "0") + signedDeny(c, mysql.TypeYear, 123.123, "0") + signedDeny(c, mysql.TypeYear, 1900, "0") + signedAccept(c, mysql.TypeYear, 1901, "1901") + signedAccept(c, mysql.TypeYear, 1900.567, "1901") + signedDeny(c, mysql.TypeYear, 1900.456, "0") + signedAccept(c, mysql.TypeYear, 1, "2001") + signedAccept(c, mysql.TypeYear, 69, "2069") + signedAccept(c, mysql.TypeYear, 70, "1970") + signedAccept(c, mysql.TypeYear, 99, "1999") + signedDeny(c, mysql.TypeYear, 100, "0") + signedDeny(c, mysql.TypeYear, "99999999999999999999999999999999999", "0") // time from string signedAccept(c, mysql.TypeDate, "2012-08-23", "2012-08-23") @@ -666,6 +716,101 @@ func (s *testTypeConvertSuite) TestConvert(c *C) { signedAccept(c, mysql.TypeNewDecimal, dec, "-0.00123") } +func (s *testTypeConvertSuite) TestGetValidInt(c *C) { + tests := []struct { + origin string + valid string + signed bool + warning bool + }{ + {"100", "100", true, false}, + {"-100", "-100", true, false}, + {"9223372036854775808", "9223372036854775808", false, false}, + {"1abc", "1", true, true}, + {"-1-1", "-1", true, true}, + {"+1+1", "+1", true, true}, + {"123..34", "123", true, true}, + {"123.23E-10", "123", true, true}, + {"1.1e1.3", "1", true, true}, + {"11e1.3", "11", true, true}, + {"1.", "1", true, true}, + {".1", "0", true, true}, + {"", "0", true, true}, + {"123e+", "123", true, true}, + {"123de", "123", true, true}, + } + sc := new(stmtctx.StatementContext) + sc.TruncateAsWarning = true + sc.CastStrToIntStrict = true + warningCount := 0 + for _, tt := range tests { + prefix, err := getValidIntPrefix(sc, tt.origin) + c.Assert(err, IsNil) + c.Assert(prefix, Equals, tt.valid) + if tt.signed { + _, err = strconv.ParseInt(prefix, 10, 64) + } else { + _, err = strconv.ParseUint(prefix, 10, 64) + } + c.Assert(err, IsNil) + warnings := sc.GetWarnings() + if tt.warning { + c.Assert(warnings, HasLen, warningCount+1) + c.Assert(terror.ErrorEqual(warnings[len(warnings)-1].Err, ErrTruncatedWrongVal), IsTrue) + warningCount += 1 + } else { + c.Assert(warnings, HasLen, warningCount) + } + } + + tests2 := []struct { + origin string + valid string + warning bool + }{ + {"100", "100", false}, + {"-100", "-100", false}, + {"1abc", "1", true}, + {"-1-1", "-1", true}, + {"+1+1", "+1", true}, + {"123..34", "123.", true}, + {"123.23E-10", "0", false}, + {"1.1e1.3", "1.1e1", true}, + {"11e1.3", "11e1", true}, + {"1.", "1", false}, + {".1", "0", false}, + {"", "0", true}, + {"123e+", "123", true}, + {"123de", "123", true}, + } + sc.TruncateAsWarning = false + sc.CastStrToIntStrict = false + for _, tt := range tests2 { + prefix, err := getValidIntPrefix(sc, tt.origin) + if tt.warning { + c.Assert(terror.ErrorEqual(err, ErrTruncated), IsTrue) + } else { + c.Assert(err, IsNil) + } + c.Assert(prefix, Equals, tt.valid) + } +} + +func (s *testTypeConvertSuite) TestRoundIntStr(c *C) { + cases := []struct { + a string + b byte + c string + }{ + {"+999", '5', "+1000"}, + {"999", '5', "1000"}, + {"-999", '5', "-1000"}, + } + for _, cc := range cases { + c.Assert(roundIntStr(cc.b, cc.a), Equals, cc.c) + } +} + func (s *testTypeConvertSuite) TestGetValidFloat(c *C) { tests := []struct { origin string @@ -693,15 +838,31 @@ func (s *testTypeConvertSuite) TestGetValidFloat(c *C) { _, err := strconv.ParseFloat(prefix, 64) c.Assert(err, IsNil) } - floatStr, err := floatStrToIntStr(sc, "1e9223372036854775807", "1e9223372036854775807") - c.Assert(err, IsNil) - c.Assert(floatStr, Equals, "1") - floatStr, err = floatStrToIntStr(sc, "125e342", "125e342.83") - c.Assert(err, IsNil) - c.Assert(floatStr, Equals, "125") - floatStr, err = floatStrToIntStr(sc, "1e21", "1e21") - c.Assert(err, IsNil) - c.Assert(floatStr, Equals, "1") + + tests2 := []struct { + origin string + expected string + }{ + {"1e9223372036854775807", "1"}, + {"125e342", "125"}, + {"1e21", "1"}, + {"1e5", "100000"}, + {"-123.45678e5", "-12345678"}, + {"+0.5", "1"}, + {"-0.5", "-1"}, + {".5e0", "1"}, + {"+.5e0", "+1"}, + {"-.5e0", "-1"}, + {".5", "1"}, + {"123.456789e5", "12345679"}, + {"123.456784e5", "12345678"}, + {"+999.9999e2", "+100000"}, + } + for _, t := range tests2 { + str, err := floatStrToIntStr(sc, t.origin, t.origin) + c.Assert(err, IsNil) + c.Assert(str, Equals, t.expected, Commentf("%v, %v", t.origin, t.expected)) + } } // TestConvertTime tests time related conversion. @@ -790,24 +951,26 @@ func (s *testTypeConvertSuite) TestConvertJSONToInt(c *C) { func (s *testTypeConvertSuite) TestConvertJSONToFloat(c *C) { var tests = []struct { - In string + In interface{} Out float64 + ty json.TypeCode }{ - {`{}`, 0}, - {`[]`, 0}, - {`3`, 3}, - {`-3`, -3}, - {`4.5`, 4.5}, - {`true`, 1}, - {`false`, 0}, - {`null`, 0}, - {`"hello"`, 0}, - {`"123.456hello"`, 123.456}, - {`"1234"`, 1234}, + {make(map[string]interface{}, 0), 0, json.TypeCodeObject}, + {make([]interface{}, 0), 0, json.TypeCodeArray}, + {int64(3), 3, json.TypeCodeInt64}, + {int64(-3), -3, json.TypeCodeInt64}, + {uint64(1 << 63), 1 << 63, json.TypeCodeUint64}, + {float64(4.5), 4.5, json.TypeCodeFloat64}, + {true, 1, json.TypeCodeLiteral}, + {false, 0, json.TypeCodeLiteral}, + {nil, 0, json.TypeCodeLiteral}, + {"hello", 0, json.TypeCodeString}, + {"123.456hello", 123.456, json.TypeCodeString}, + {"1234", 1234, json.TypeCodeString}, } for _, tt := range tests { - j, err := json.ParseBinaryFromString(tt.In) - c.Assert(err, IsNil) + j := json.CreateBinary(tt.In) + c.Assert(j.TypeCode, Equals, tt.ty) casted, _ := ConvertJSONToFloat(new(stmtctx.StatementContext), j) c.Assert(casted, Equals, tt.Out) } diff --git a/types/datum.go b/types/datum.go index 9819a59a1f442..d6421cce35730 100644 --- a/types/datum.go +++ b/types/datum.go @@ -1184,6 +1184,7 @@ func (d *Datum) convertToMysqlYear(sc *stmtctx.StatementContext, target *FieldTy s := d.GetString() y, err = StrToInt(sc, s) if err != nil { + ret.SetInt64(0) return ret, errors.Trace(err) } if len(s) != 4 && len(s) > 0 && s[0:1] == "0" { @@ -1196,16 +1197,18 @@ func (d *Datum) convertToMysqlYear(sc *stmtctx.StatementContext, target *FieldTy default: ret, err = d.convertToInt(sc, NewFieldType(mysql.TypeLonglong)) if err != nil { - return invalidConv(d, target.Tp) + _, err = invalidConv(d, target.Tp) + ret.SetInt64(0) + return ret, err } y = ret.GetInt64() } y, err = AdjustYear(y, adjust) if err != nil { - return invalidConv(d, target.Tp) + _, err = invalidConv(d, target.Tp) } ret.SetInt64(y) - return ret, nil + return ret, err } func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) { @@ -1215,12 +1218,16 @@ func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldTyp switch d.k { case KindString, KindBytes: uintValue, err = BinaryLiteral(d.b).ToInt(sc) + case KindInt64: + // if input kind is int64 (signed), when trans to bit, we need to treat it as unsigned + d.k = KindUint64 + fallthrough default: uintDatum, err1 := d.convertToUint(sc, target) uintValue, err = uintDatum.GetUint64(), err1 } if target.Flen < 64 && uintValue >= 1<<(uint64(target.Flen)) { - return Datum{}, errors.Trace(ErrOverflow.GenWithStackByArgs("BIT", fmt.Sprintf("(%d)", target.Flen))) + return Datum{}, errors.Trace(ErrDataTooLong.GenWithStack("Data Too Long, field len %d", target.Flen)) } byteSize := (target.Flen + 7) >> 3 ret.SetMysqlBit(NewBinaryLiteralFromUint(uintValue, byteSize)) @@ -1781,14 +1788,14 @@ func (ds *datumsSorter) Swap(i, j int) { ds.datums[i], ds.datums[j] = ds.datums[j], ds.datums[i] } -func handleTruncateError(sc *stmtctx.StatementContext) error { +func handleTruncateError(sc *stmtctx.StatementContext, err error) error { if sc.IgnoreTruncate { return nil } if !sc.TruncateAsWarning { - return ErrTruncated + return err } - sc.AppendWarning(ErrTruncated) + sc.AppendWarning(err) return nil } diff --git a/types/datum_test.go b/types/datum_test.go index dc389f3f16a71..a1e8eb468983b 100644 --- a/types/datum_test.go +++ b/types/datum_test.go @@ -250,6 +250,7 @@ func (ts *testDatumSuite) TestToJSON(c *C) { {NewStringDatum("[1, 2, 3]"), `[1, 2, 3]`, true}, {NewStringDatum("{}"), `{}`, true}, {mustParseTimeIntoDatum("2011-11-10 11:11:11.111111", mysql.TypeTimestamp, 6), `"2011-11-10 11:11:11.111111"`, true}, + {NewStringDatum(`{"a": "9223372036854775809"}`), `{"a": "9223372036854775809"}`, true}, // can not parse JSON from this string, so error occurs. {NewStringDatum("hello, 世界"), "", false}, diff --git a/types/etc.go b/types/etc.go index b8b5af64f38d5..e29c91171e5a7 100644 --- a/types/etc.go +++ b/types/etc.go @@ -83,10 +83,7 @@ func IsTemporalWithDate(tp byte) bool { // IsBinaryStr returns a boolean indicating // whether the field type is a binary string type. func IsBinaryStr(ft *FieldType) bool { - if ft.Collate == charset.CollationBin && IsString(ft.Tp) { - return true - } - return false + return ft.Collate == charset.CollationBin && IsString(ft.Tp) } // IsNonBinaryStr returns a boolean indicating diff --git a/types/field_type.go b/types/field_type.go index ddb836f3b3752..d33d83a2d2c2d 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -201,7 +201,7 @@ func DefaultTypeForValue(value interface{}, tp *FieldType) { SetBinChsClnFlag(tp) case HexLiteral: tp.Tp = mysql.TypeVarString - tp.Flen = len(x) + tp.Flen = len(x) * 3 tp.Decimal = 0 tp.Flag |= mysql.UnsignedFlag SetBinChsClnFlag(tp) diff --git a/types/fsp.go b/types/fsp.go index fe5a656cd87bb..c9709822c6453 100644 --- a/types/fsp.go +++ b/types/fsp.go @@ -86,9 +86,12 @@ func ParseFrac(s string, fsp int) (v int, overflow bool, err error) { return } -// alignFrac is used to generate alignment frac, like `100` -> `100000` +// alignFrac is used to generate alignment frac, like `100` -> `100000` ,`-100` -> `-100000` func alignFrac(s string, fsp int) string { sl := len(s) + if sl > 0 && s[0] == '-' { + sl = sl - 1 + } if sl < fsp { return s + strings.Repeat("0", fsp-sl) } diff --git a/types/fsp_test.go b/types/fsp_test.go index 8802e87d5b3e4..b8f29cd4077d7 100644 --- a/types/fsp_test.go +++ b/types/fsp_test.go @@ -115,4 +115,8 @@ func (s *FspTest) TestAlignFrac(c *C) { c.Assert(obtained, Equals, "100000") obtained = alignFrac("10000000000", 6) c.Assert(obtained, Equals, "10000000000") + obtained = alignFrac("-100", 6) + c.Assert(obtained, Equals, "-100000") + obtained = alignFrac("-10000000000", 6) + c.Assert(obtained, Equals, "-10000000000") } diff --git a/types/json/binary_functions.go b/types/json/binary_functions.go index 9cc87569e3542..4ab16f7999887 100644 --- a/types/json/binary_functions.go +++ b/types/json/binary_functions.go @@ -66,19 +66,17 @@ func (bj BinaryJSON) Unquote() (string, error) { switch bj.TypeCode { case TypeCodeString: tmp := string(hack.String(bj.GetString())) - s, err := unquoteString(tmp) - if err != nil { - return "", errors.Trace(err) - } - // Remove prefix and suffix '"'. - slen := len(s) - if slen > 1 { - head, tail := s[0], s[slen-1] - if head == '"' && tail == '"' { - return s[1 : slen-1], nil - } + tlen := len(tmp) + if tlen < 2 { + return tmp, nil + } + head, tail := tmp[0], tmp[tlen-1] + if head == '"' && tail == '"' { + // Remove prefix and suffix '"' before unquoting + return unquoteString(tmp[1 : tlen-1]) } - return s, nil + // if value is not double quoted, do nothing + return tmp, nil default: return bj.String(), nil } diff --git a/types/json/binary_test.go b/types/json/binary_test.go index 6ab74d0739602..e3cb77181a9ea 100644 --- a/types/json/binary_test.go +++ b/types/json/binary_test.go @@ -114,6 +114,7 @@ func (s *testJSONSuite) TestBinaryJSONUnquote(c *C) { }{ {j: `3`, unquoted: "3"}, {j: `"3"`, unquoted: "3"}, + {j: `"[{\"x\":\"{\\\"y\\\":12}\"}]"`, unquoted: `[{"x":"{\"y\":12}"}]`}, {j: `"hello, \"escaped quotes\" world"`, unquoted: "hello, \"escaped quotes\" world"}, {j: "\"\\u4f60\"", unquoted: "你"}, {j: `true`, unquoted: "true"}, diff --git a/types/mydecimal.go b/types/mydecimal.go index 06ca50c4204fb..996323fa8706f 100644 --- a/types/mydecimal.go +++ b/types/mydecimal.go @@ -107,6 +107,14 @@ var ( zeroMyDecimal = MyDecimal{} ) +// get the zero of MyDecimal with the specified result fraction digits +func zeroMyDecimalWithFrac(frac int8) MyDecimal { + zero := MyDecimal{} + zero.digitsFrac = frac + zero.resultFrac = frac + return zero +} + // add adds a and b and carry, returns the sum and new carry. func add(a, b, carry int32) (int32, int32) { sum := a + b + carry @@ -1556,7 +1564,7 @@ func doSub(from1, from2, to *MyDecimal) (cmp int, err error) { if to == nil { return 0, nil } - *to = zeroMyDecimal + *to = zeroMyDecimalWithFrac(to.resultFrac) return 0, nil } } @@ -1911,7 +1919,7 @@ func DecimalMul(from1, from2, to *MyDecimal) error { idx++ /* We got decimal zero */ if idx == end { - *to = zeroMyDecimal + *to = zeroMyDecimalWithFrac(to.resultFrac) break } } @@ -2010,7 +2018,7 @@ func doDivMod(from1, from2, to, mod *MyDecimal, fracIncr int) error { } if prec1 <= 0 { /* short-circuit everything: from1 == 0 */ - *to = zeroMyDecimal + *to = zeroMyDecimalWithFrac(to.resultFrac) return nil } prec1 -= countLeadingZeroes((prec1-1)%digitsPerWord, from1.wordBuf[idx1]) diff --git a/types/mydecimal_test.go b/types/mydecimal_test.go index e799692231c6a..551105987b3d0 100644 --- a/types/mydecimal_test.go +++ b/types/mydecimal_test.go @@ -694,6 +694,7 @@ func (s *testMyDecimalSuite) TestAdd(c *C) { {"-123.45", "12345", "12221.55", nil}, {"5", "-6.0", "-1.0", nil}, {"2" + strings.Repeat("1", 71), strings.Repeat("8", 81), "8888888890" + strings.Repeat("9", 71), nil}, + {"-1234.1234", "1234.1234", "0.0000", nil}, } for _, tt := range tests { a := NewDecFromStringForTest(tt.a) @@ -718,7 +719,7 @@ func (s *testMyDecimalSuite) TestSub(c *C) { {"1234500009876.5", ".00012345000098765", "1234500009876.49987654999901235", nil}, {"9999900000000.5", ".555", "9999899999999.945", nil}, {"1111.5551", "1111.555", "0.0001", nil}, - {".555", ".555", "0", nil}, + {".555", ".555", "0.000", nil}, {"10000000", "1", "9999999", nil}, {"1000001000", ".1", "1000000999.9", nil}, {"1000000000", ".1", "999999999.9", nil}, @@ -728,6 +729,7 @@ func (s *testMyDecimalSuite) TestSub(c *C) { {"-123.45", "-12345", "12221.55", nil}, {"-12345", "123.45", "-12468.45", nil}, {"12345", "-123.45", "12468.45", nil}, + {"12.12", "12.12", "0.00", nil}, } for _, tt := range tests { var a, b, sum MyDecimal @@ -759,6 +761,7 @@ func (s *testMyDecimalSuite) TestMul(c *C) { {"1" + strings.Repeat("0", 60), "1" + strings.Repeat("0", 60), "0", ErrOverflow}, {"0.5999991229316", "0.918755041726043", "0.5512522192246113614062276588", nil}, {"0.5999991229317", "0.918755041726042", "0.5512522192247026369112773314", nil}, + {"0.000", "-1", "0.000", nil}, } for _, tt := range tests { var a, b, product MyDecimal @@ -786,7 +789,7 @@ func (s *testMyDecimalSuite) TestDivMod(c *C) { {"0", "0", "", ErrDivByZero}, {"-12193185.1853376", "98765.4321", "-123.456000000000000000", nil}, {"121931851853376", "987654321", "123456.000000000", nil}, - {"0", "987", "0", nil}, + {"0", "987", "0.00000", nil}, {"1", "3", "0.333333333", nil}, {"1.000000000000", "3", "0.333333333333333333", nil}, {"1", "1", "1.000000000", nil}, @@ -799,7 +802,7 @@ func (s *testMyDecimalSuite) TestDivMod(c *C) { var a, b, to MyDecimal a.FromString([]byte(tt.a)) b.FromString([]byte(tt.b)) - err := doDivMod(&a, &b, &to, nil, 5) + err := DecimalDiv(&a, &b, &to, 5) c.Check(err, Equals, tt.err) if tt.err == ErrDivByZero { continue @@ -816,12 +819,13 @@ func (s *testMyDecimalSuite) TestDivMod(c *C) { {"99999999999999999999999999999999999999", "3", "0", nil}, {"51", "0.003430", "0.002760", nil}, {"0.0000000001", "1.0", "0.0000000001", nil}, + {"0.000", "0.1", "0.000", nil}, } for _, tt := range tests { var a, b, to MyDecimal a.FromString([]byte(tt.a)) b.FromString([]byte(tt.b)) - ec := doDivMod(&a, &b, nil, &to, 0) + ec := DecimalMod(&a, &b, &to) c.Check(ec, Equals, tt.err) if tt.err == ErrDivByZero { continue @@ -836,6 +840,7 @@ func (s *testMyDecimalSuite) TestDivMod(c *C) { {"1", "1.000", "1.0000", nil}, {"2", "3", "0.6667", nil}, {"51", "0.003430", "14868.8047", nil}, + {"0.000", "0.1", "0.0000000", nil}, } for _, tt := range tests { var a, b, to MyDecimal diff --git a/types/time.go b/types/time.go index 1915b01873629..eddbb34bd30a2 100644 --- a/types/time.go +++ b/types/time.go @@ -17,6 +17,7 @@ import ( "bytes" "context" "fmt" + "io" "math" "regexp" "strconv" @@ -194,10 +195,12 @@ const ( // FromGoTime translates time.Time to mysql time internal representation. func FromGoTime(t gotime.Time) MysqlTime { + // Plus 500 nanosecond for rounding of the millisecond part. + t = t.Add(500 * gotime.Nanosecond) + year, month, day := t.Date() hour, minute, second := t.Clock() - // Nanosecond plus 500 then divided 1000 means rounding to microseconds. - microsecond := (t.Nanosecond() + 500) / 1000 + microsecond := t.Nanosecond() / 1000 return FromDate(year, int(month), day, hour, minute, second, microsecond) } @@ -1169,6 +1172,9 @@ func ParseDuration(sc *stmtctx.StatementContext, str string, fsp int) (Duration, return ZeroDuration, ErrTruncatedWrongVal.GenWithStackByArgs("time", origStr) } + if terror.ErrorEqual(err, io.EOF) { + err = ErrTruncatedWrongVal.GenWithStackByArgs("time", origStr) + } if err != nil { return ZeroDuration, errors.Trace(err) } @@ -1470,7 +1476,7 @@ func checkDateRange(t MysqlTime) error { func checkMonthDay(year, month, day int, allowInvalidDate bool) error { if month < 0 || month > 12 { - return errors.Trace(ErrIncorrectDatetimeValue.GenWithStackByArgs(month)) + return errors.Trace(ErrIncorrectDatetimeValue.GenWithStackByArgs(fmt.Sprintf("%d-%d-%d", year, month, day))) } maxDay := 31 @@ -1484,7 +1490,7 @@ func checkMonthDay(year, month, day int, allowInvalidDate bool) error { } if day < 0 || day > maxDay { - return errors.Trace(ErrIncorrectDatetimeValue.GenWithStackByArgs(day)) + return errors.Trace(ErrIncorrectDatetimeValue.GenWithStackByArgs(fmt.Sprintf("%d-%d-%d", year, month, day))) } return nil } @@ -1650,6 +1656,7 @@ func parseSingleTimeValue(unit string, format string, strictCheck bool) (int64, if unit != "SECOND" { err = ErrTruncatedWrongValue.GenWithStackByArgs(format) } + dv *= sign } switch strings.ToUpper(unit) { case "MICROSECOND": @@ -2172,6 +2179,11 @@ func strToDate(t *MysqlTime, date string, format string, ctx map[string]int) boo return true } + if len(date) == 0 { + ctx[token] = 0 + return true + } + dateRemain, succ := matchDateWithToken(t, date, token, ctx) if !succ { return false @@ -2289,21 +2301,14 @@ func GetFormatType(format string) (isDuration, isDate bool) { isDuration, isDate = false, false break } - var durationTokens bool - var dateTokens bool if len(token) >= 2 && token[0] == '%' { switch token[1] { - case 'h', 'H', 'i', 'I', 's', 'S', 'k', 'l': - durationTokens = true + case 'h', 'H', 'i', 'I', 's', 'S', 'k', 'l', 'f': + isDuration = true case 'y', 'Y', 'm', 'M', 'c', 'b', 'D', 'd', 'e': - dateTokens = true + isDate = true } } - if durationTokens { - isDuration = true - } else if dateTokens { - isDate = true - } if isDuration && isDate { break } diff --git a/types/time_test.go b/types/time_test.go index 57c6cca351d6f..7a1995f9eee8a 100644 --- a/types/time_test.go +++ b/types/time_test.go @@ -1640,3 +1640,32 @@ func (s *testTimeSuite) TestCheckMonthDay(c *C) { } } } + +func (s *testTimeSuite) TestFromGoTime(c *C) { + // Test rounding of nanosecond to millisecond. + cases := []struct { + input string + yy int + mm int + dd int + hh int + min int + sec int + micro int + }{ + {"2006-01-02T15:04:05.999999999Z", 2006, 1, 2, 15, 4, 6, 0}, + {"2006-01-02T15:04:05.999999000Z", 2006, 1, 2, 15, 4, 5, 999999}, + {"2006-01-02T15:04:05.999999499Z", 2006, 1, 2, 15, 4, 5, 999999}, + {"2006-01-02T15:04:05.999999500Z", 2006, 1, 2, 15, 4, 6, 0}, + {"2006-01-02T15:04:05.000000501Z", 2006, 1, 2, 15, 4, 5, 1}, + } + + for ith, ca := range cases { + t, err := time.Parse(time.RFC3339Nano, ca.input) + c.Assert(err, IsNil) + + t1 := types.FromGoTime(t) + c.Assert(t1, Equals, types.FromDate(ca.yy, ca.mm, ca.dd, ca.hh, ca.min, ca.sec, ca.micro), Commentf("idx %d", ith)) + } + +} diff --git a/util/admin/admin.go b/util/admin/admin.go index 1d29b768da353..62a3730d2eef9 100644 --- a/util/admin/admin.go +++ b/util/admin/admin.go @@ -17,6 +17,7 @@ import ( "context" "fmt" "io" + "math" "sort" "time" @@ -118,13 +119,18 @@ func CancelJobs(txn kv.Transaction, ids []int64) ([]error, error) { return nil, nil } - jobs, err := GetDDLJobs(txn) + errs := make([]error, len(ids)) + t := meta.NewMeta(txn) + generalJobs, err := getDDLJobsInQueue(t, meta.DefaultJobListKey) + if err != nil { + return nil, errors.Trace(err) + } + addIdxJobs, err := getDDLJobsInQueue(t, meta.AddIndexJobListKey) if err != nil { return nil, errors.Trace(err) } + jobs := append(generalJobs, addIdxJobs...) - errs := make([]error, len(ids)) - t := meta.NewMeta(txn) for i, id := range ids { found := false for j, job := range jobs { @@ -157,7 +163,8 @@ func CancelJobs(txn kv.Transaction, ids []int64) ([]error, error) { continue } if job.Type == model.ActionAddIndex { - err = t.UpdateDDLJob(int64(j), job, true, meta.AddIndexJobListKey) + offset := int64(j - len(generalJobs)) + err = t.UpdateDDLJob(offset, job, true, meta.AddIndexJobListKey) } else { err = t.UpdateDDLJob(int64(j), job, true) } @@ -227,7 +234,7 @@ const DefNumHistoryJobs = 10 // The maximum count of history jobs is num. func GetHistoryDDLJobs(txn kv.Transaction, maxNumJobs int) ([]*model.Job, error) { t := meta.NewMeta(txn) - jobs, err := t.GetAllHistoryDDLJobs() + jobs, err := t.GetLastNHistoryDDLJobs(maxNumJobs) if err != nil { return nil, errors.Trace(err) } @@ -267,28 +274,46 @@ func getCount(ctx sessionctx.Context, sql string) (int64, error) { return rows[0].GetInt64(0), nil } +// Count greater Types +const ( + // TblCntGreater means that the number of table rows is more than the number of index rows. + TblCntGreater byte = 1 + // IdxCntGreater means that the number of index rows is more than the number of table rows. + IdxCntGreater byte = 2 +) + // CheckIndicesCount compares indices count with table count. +// It returns the count greater type, the index offset and an error. // It returns nil if the count from the index is equal to the count from the table columns, -// otherwise it returns an error with a different information. -func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices []string) error { +// otherwise it returns an error and the corresponding index's offset. +func CheckIndicesCount(ctx sessionctx.Context, dbName, tableName string, indices []string) (byte, int, error) { // Add `` for some names like `table name`. - sql := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s`", dbName, tableName) + sql := fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX()", dbName, tableName) tblCnt, err := getCount(ctx, sql) if err != nil { - return errors.Trace(err) + return 0, 0, errors.Trace(err) } - for _, idx := range indices { + for i, idx := range indices { sql = fmt.Sprintf("SELECT COUNT(*) FROM `%s`.`%s` USE INDEX(`%s`)", dbName, tableName, idx) idxCnt, err := getCount(ctx, sql) if err != nil { - return errors.Trace(err) + return 0, i, errors.Trace(err) } - if tblCnt != idxCnt { - return errors.Errorf("table count %d != index(%s) count %d", tblCnt, idx, idxCnt) + logutil.Logger(context.Background()).Info("check indices count", + zap.String("table", tableName), zap.Int64("cnt", tblCnt), zap.Reflect("index", idx), zap.Int64("cnt", idxCnt)) + if tblCnt == idxCnt { + continue } - } - return nil + var ret byte + if tblCnt > idxCnt { + ret = TblCntGreater + } else if idxCnt > tblCnt { + ret = IdxCntGreater + } + return ret, i, errors.Errorf("table count %d != index(%s) count %d", tblCnt, idx, idxCnt) + } + return 0, 0, nil } // ScanIndexData scans the index handles and values in a limited number, according to the index information. @@ -444,7 +469,7 @@ func CheckRecordAndIndex(sessCtx sessionctx.Context, txn kv.Transaction, t table cols[i] = t.Cols()[col.Offset] } - startKey := t.RecordKey(0) + startKey := t.RecordKey(math.MinInt64) filterFunc := func(h1 int64, vals1 []types.Datum, cols []*table.Column) (bool, error) { for i, val := range vals1 { col := cols[i] @@ -477,7 +502,6 @@ func CheckRecordAndIndex(sessCtx sessionctx.Context, txn kv.Transaction, t table return true, nil } err := iterRecords(sessCtx, txn, t, startKey, cols, filterFunc, genExprs) - if err != nil { return errors.Trace(err) } @@ -590,25 +614,16 @@ func CompareTableRecord(sessCtx sessionctx.Context, txn kv.Transaction, t table. } func makeRowDecoder(t table.Table, decodeCol []*table.Column, genExpr map[model.TableColumnID]expression.Expression) *decoder.RowDecoder { - cols := t.Cols() - tblInfo := t.Meta() - decodeColsMap := make(map[int64]decoder.Column, len(decodeCol)) - for _, v := range decodeCol { - col := cols[v.Offset] - tpExpr := decoder.Column{ - Col: col, - } - if col.IsGenerated() && !col.GeneratedStored { - for _, c := range cols { - if _, ok := col.Dependences[c.Name.L]; ok { - decodeColsMap[c.ID] = decoder.Column{ - Col: c, - } - } - } - tpExpr.GenExpr = genExpr[model.TableColumnID{TableID: tblInfo.ID, ColumnID: col.ID}] - } - decodeColsMap[col.ID] = tpExpr + var containsVirtualCol bool + decodeColsMap, ignored := decoder.BuildFullDecodeColMap(decodeCol, t, func(genCol *table.Column) (expression.Expression, error) { + containsVirtualCol = true + return genExpr[model.TableColumnID{TableID: t.Meta().ID, ColumnID: genCol.ID}], nil + }) + _ = ignored + + if containsVirtualCol { + decoder.SubstituteGenColsInDecodeColMap(decodeColsMap) + decoder.RemoveUnusedVirtualCols(decodeColsMap, decodeCol) } return decoder.NewRowDecoder(t, decodeColsMap) } diff --git a/util/admin/admin_integration_test.go b/util/admin/admin_integration_test.go index e2ec3ba6e659c..c6a226e01f4da 100644 --- a/util/admin/admin_integration_test.go +++ b/util/admin/admin_integration_test.go @@ -43,7 +43,7 @@ func (s *testAdminSuite) SetUpSuite(c *C) { c.Assert(err, IsNil) s.store = store session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() d, err := session.BootstrapSession(s.store) c.Assert(err, IsNil) d.SetStatsUpdating(true) diff --git a/util/admin/admin_test.go b/util/admin/admin_test.go index 1c30ffa90155b..4dc0730ea86f9 100644 --- a/util/admin/admin_test.go +++ b/util/admin/admin_test.go @@ -254,6 +254,38 @@ func (s *testSuite) TestCancelJobs(c *C) { c.Assert(errs[0], NotNil) c.Assert(errs[0].Error(), Equals, "[admin:6]This job:101 is almost finished, can't be cancelled now") + // When both types of jobs exist in the DDL queue, + // we first cancel the job with a larger ID. + job = &model.Job{ + ID: 1000, + SchemaID: 1, + TableID: 2, + Type: model.ActionAddIndex, + } + job1 := &model.Job{ + ID: 1001, + SchemaID: 1, + TableID: 2, + Type: model.ActionAddColumn, + } + job2 := &model.Job{ + ID: 1002, + SchemaID: 1, + TableID: 2, + Type: model.ActionAddIndex, + } + err = t.EnQueueDDLJob(job, meta.AddIndexJobListKey) + c.Assert(err, IsNil) + err = t.EnQueueDDLJob(job1) + c.Assert(err, IsNil) + err = t.EnQueueDDLJob(job2, meta.AddIndexJobListKey) + c.Assert(err, IsNil) + errs, err = CancelJobs(txn, []int64{job1.ID, job.ID, job2.ID}) + c.Assert(err, IsNil) + for _, err := range errs { + c.Assert(err, IsNil) + } + err = txn.Rollback() c.Assert(err, IsNil) } @@ -458,7 +490,7 @@ func (s *testSuite) testIndex(c *C, ctx sessionctx.Context, dbName string, tb ta c.Assert(err, IsNil) idxNames := []string{idx.Meta().Name.L} - err = CheckIndicesCount(ctx, dbName, tb.Meta().Name.L, idxNames) + _, _, err = CheckIndicesCount(ctx, dbName, tb.Meta().Name.L, idxNames) c.Assert(err, IsNil) mockCtx := mock.NewContext() @@ -480,7 +512,7 @@ func (s *testSuite) testIndex(c *C, ctx sessionctx.Context, dbName string, tb ta diffMsg := newDiffRetError("index", record1, nil) c.Assert(err.Error(), DeepEquals, diffMsg) - err = CheckIndicesCount(ctx, dbName, tb.Meta().Name.L, idxNames) + _, _, err = CheckIndicesCount(ctx, dbName, tb.Meta().Name.L, idxNames) c.Assert(err, IsNil) // set data to: @@ -539,7 +571,7 @@ func (s *testSuite) testIndex(c *C, ctx sessionctx.Context, dbName string, tb ta diffMsg = newDiffRetError("index", record1, nil) c.Assert(err.Error(), DeepEquals, diffMsg) - err = CheckIndicesCount(ctx, dbName, tb.Meta().Name.L, idxNames) + _, _, err = CheckIndicesCount(ctx, dbName, tb.Meta().Name.L, idxNames) c.Assert(err.Error(), Equals, "table count 3 != index(c) count 4") // set data to: @@ -559,7 +591,7 @@ func (s *testSuite) testIndex(c *C, ctx sessionctx.Context, dbName string, tb ta diffMsg = newDiffRetError("index", nil, record1) c.Assert(err.Error(), DeepEquals, diffMsg) - err = CheckIndicesCount(ctx, dbName, tb.Meta().Name.L, idxNames) + _, _, err = CheckIndicesCount(ctx, dbName, tb.Meta().Name.L, idxNames) c.Assert(err.Error(), Equals, "table count 4 != index(c) count 3") } diff --git a/util/chunk/chunk.go b/util/chunk/chunk.go index 5a375814b7a18..86c6613cc2071 100644 --- a/util/chunk/chunk.go +++ b/util/chunk/chunk.go @@ -77,24 +77,30 @@ func New(fields []*types.FieldType, cap, maxChunkSize int) *Chunk { return chk } -// Renew creates a new Chunk based on an existing Chunk. The newly created Chunk -// has the same data schema with the old Chunk. The capacity of the new Chunk -// might be doubled based on the capacity of the old Chunk and the maxChunkSize. -// chk: old chunk(often used in previous call). -// maxChunkSize: the limit for the max number of rows. -func Renew(chk *Chunk, maxChunkSize int) *Chunk { +// renewWithCapacity creates a new Chunk based on an existing Chunk with capacity. The newly +// created Chunk has the same data schema with the old Chunk. +func renewWithCapacity(chk *Chunk, cap, maxChunkSize int) *Chunk { newChk := new(Chunk) if chk.columns == nil { return newChk } - newCap := reCalcCapacity(chk, maxChunkSize) - newChk.columns = renewColumns(chk.columns, newCap) + newChk.columns = renewColumns(chk.columns, cap) newChk.numVirtualRows = 0 - newChk.capacity = newCap + newChk.capacity = cap newChk.requiredRows = maxChunkSize return newChk } +// Renew creates a new Chunk based on an existing Chunk. The newly created Chunk +// has the same data schema with the old Chunk. The capacity of the new Chunk +// might be doubled based on the capacity of the old Chunk and the maxChunkSize. +// chk: old chunk(often used in previous call). +// maxChunkSize: the limit for the max number of rows. +func Renew(chk *Chunk, maxChunkSize int) *Chunk { + newCap := reCalcCapacity(chk, maxChunkSize) + return renewWithCapacity(chk, newCap, maxChunkSize) +} + // renewColumns creates the columns of a Chunk. The capacity of the newly // created columns is equal to cap. func renewColumns(oldCol []*column, cap int) []*column { diff --git a/util/chunk/row.go b/util/chunk/row.go index 0d282558fc0e1..ee0bbdd2b4dae 100644 --- a/util/chunk/row.go +++ b/util/chunk/row.go @@ -222,3 +222,10 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum { func (r Row) IsNull(colIdx int) bool { return r.c.columns[colIdx].isNull(r.idx) } + +// CopyConstruct creates a new row and copies this row's data into it. +func (r Row) CopyConstruct() Row { + newChk := renewWithCapacity(r.c, 1, 1) + newChk.AppendRow(r) + return newChk.GetRow(0) +} diff --git a/util/codec/codec.go b/util/codec/codec.go index 19cbfff6c7977..f2a9b323b3f5b 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -312,9 +312,9 @@ func Decode(b []byte, size int) ([]types.Datum, error) { // DecodeRange decodes the range values from a byte slice that generated by EncodeKey. // It handles some special values like `MinNotNull` and `MaxValueDatum`. -func DecodeRange(b []byte, size int) ([]types.Datum, error) { +func DecodeRange(b []byte, size int) ([]types.Datum, []byte, error) { if len(b) < 1 { - return nil, errors.New("invalid encoded key: length of key is zero") + return nil, b, errors.New("invalid encoded key: length of key is zero") } var ( @@ -326,7 +326,7 @@ func DecodeRange(b []byte, size int) ([]types.Datum, error) { var d types.Datum b, d, err = DecodeOne(b) if err != nil { - return nil, errors.Trace(err) + return values, b, errors.Trace(err) } values = append(values, d) } @@ -341,10 +341,10 @@ func DecodeRange(b []byte, size int) ([]types.Datum, error) { case maxFlag, maxFlag + 1: values = append(values, types.MaxValueDatum()) default: - return nil, errors.Errorf("invalid encoded key flag %v", b[0]) + return values, b, errors.Errorf("invalid encoded key flag %v", b[0]) } } - return values, nil + return values, nil, nil } // DecodeOne decodes on datum from a byte slice generated with EncodeKey or EncodeValue. diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index b54a6ae8adeb4..90c84869ad18e 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -985,14 +985,14 @@ func chunkForTest(c *C, sc *stmtctx.StatementContext, datums []types.Datum, tps } func (s *testCodecSuite) TestDecodeRange(c *C) { - _, err := DecodeRange(nil, 0) + _, _, err := DecodeRange(nil, 0) c.Assert(err, NotNil) datums := types.MakeDatums(1, "abc", 1.1, []byte("def")) rowData, err := EncodeValue(nil, nil, datums...) c.Assert(err, IsNil) - datums1, err := DecodeRange(rowData, len(datums)) + datums1, _, err := DecodeRange(rowData, len(datums)) c.Assert(err, IsNil) for i, datum := range datums1 { cmp, err := datum.CompareDatum(nil, &datums[i]) @@ -1002,7 +1002,7 @@ func (s *testCodecSuite) TestDecodeRange(c *C) { for _, b := range []byte{NilFlag, bytesFlag, maxFlag, maxFlag + 1} { newData := append(rowData, b) - _, err := DecodeRange(newData, len(datums)+1) + _, _, err := DecodeRange(newData, len(datums)+1) c.Assert(err, IsNil) } } diff --git a/util/deadlock/deadlock.go b/util/deadlock/deadlock.go new file mode 100644 index 0000000000000..5f0d781427f4f --- /dev/null +++ b/util/deadlock/deadlock.go @@ -0,0 +1,130 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package deadlock + +import ( + "fmt" + "sync" +) + +// Detector detects deadlock. +type Detector struct { + waitForMap map[uint64]*txnList + lock sync.Mutex +} + +type txnList struct { + txns []txnKeyHashPair +} + +type txnKeyHashPair struct { + txn uint64 + keyHash uint64 +} + +// NewDetector creates a new Detector. +func NewDetector() *Detector { + return &Detector{ + waitForMap: map[uint64]*txnList{}, + } +} + +// ErrDeadlock is returned when deadlock is detected. +type ErrDeadlock struct { + KeyHash uint64 +} + +func (e *ErrDeadlock) Error() string { + return fmt.Sprintf("deadlock(%d)", e.KeyHash) +} + +// Detect detects deadlock for the sourceTxn on a locked key. +func (d *Detector) Detect(sourceTxn, waitForTxn, keyHash uint64) *ErrDeadlock { + d.lock.Lock() + err := d.doDetect(sourceTxn, waitForTxn) + if err == nil { + d.register(sourceTxn, waitForTxn, keyHash) + } + d.lock.Unlock() + return err +} + +func (d *Detector) doDetect(sourceTxn, waitForTxn uint64) *ErrDeadlock { + list := d.waitForMap[waitForTxn] + if list == nil { + return nil + } + for _, nextTarget := range list.txns { + if nextTarget.txn == sourceTxn { + return &ErrDeadlock{KeyHash: nextTarget.keyHash} + } + if err := d.doDetect(sourceTxn, nextTarget.txn); err != nil { + return err + } + } + return nil +} + +func (d *Detector) register(sourceTxn, waitForTxn, keyHash uint64) { + list := d.waitForMap[sourceTxn] + pair := txnKeyHashPair{txn: waitForTxn, keyHash: keyHash} + if list == nil { + d.waitForMap[sourceTxn] = &txnList{txns: []txnKeyHashPair{pair}} + return + } + for _, tar := range list.txns { + if tar.txn == waitForTxn && tar.keyHash == keyHash { + return + } + } + list.txns = append(list.txns, pair) +} + +// CleanUp removes the wait for entry for the transaction. +func (d *Detector) CleanUp(txn uint64) { + d.lock.Lock() + delete(d.waitForMap, txn) + d.lock.Unlock() +} + +// CleanUpWaitFor removes a key in the wait for entry for the transaction. +func (d *Detector) CleanUpWaitFor(txn, waitForTxn, keyHash uint64) { + pair := txnKeyHashPair{txn: waitForTxn, keyHash: keyHash} + d.lock.Lock() + l := d.waitForMap[txn] + if l != nil { + for i, tar := range l.txns { + if tar == pair { + l.txns = append(l.txns[:i], l.txns[i+1:]...) + break + } + } + if len(l.txns) == 0 { + delete(d.waitForMap, txn) + } + } + d.lock.Unlock() + +} + +// Expire removes entries with TS smaller than minTS. +func (d *Detector) Expire(minTS uint64) { + d.lock.Lock() + for ts := range d.waitForMap { + if ts < minTS { + delete(d.waitForMap, ts) + } + } + d.lock.Unlock() +} diff --git a/util/deadlock/deadlock_test.go b/util/deadlock/deadlock_test.go new file mode 100644 index 0000000000000..0481c6014053c --- /dev/null +++ b/util/deadlock/deadlock_test.go @@ -0,0 +1,69 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package deadlock + +import ( + "fmt" + "testing" + + . "github.com/pingcap/check" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testDeadlockSuite{}) + +type testDeadlockSuite struct{} + +func (s *testDeadlockSuite) TestDeadlock(c *C) { + detector := NewDetector() + err := detector.Detect(1, 2, 100) + c.Assert(err, IsNil) + err = detector.Detect(2, 3, 200) + c.Assert(err, IsNil) + err = detector.Detect(3, 1, 300) + c.Assert(err, NotNil) + c.Assert(err.Error(), Equals, fmt.Sprintf("deadlock(200)")) + detector.CleanUp(2) + list2 := detector.waitForMap[2] + c.Assert(list2, IsNil) + + // After cycle is broken, no deadlock now. + err = detector.Detect(3, 1, 300) + c.Assert(err, IsNil) + list3 := detector.waitForMap[3] + c.Assert(list3.txns, HasLen, 1) + + // Different keyHash grows the list. + err = detector.Detect(3, 1, 400) + c.Assert(err, IsNil) + c.Assert(list3.txns, HasLen, 2) + + // Same waitFor and key hash doesn't grow the list. + err = detector.Detect(3, 1, 400) + c.Assert(err, IsNil) + c.Assert(list3.txns, HasLen, 2) + + detector.CleanUpWaitFor(3, 1, 300) + c.Assert(list3.txns, HasLen, 1) + detector.CleanUpWaitFor(3, 1, 400) + list3 = detector.waitForMap[3] + c.Assert(list3, IsNil) + detector.Expire(1) + c.Assert(detector.waitForMap, HasLen, 1) + detector.Expire(2) + c.Assert(detector.waitForMap, HasLen, 0) +} diff --git a/util/execdetails/execdetails.go b/util/execdetails/execdetails.go index 07a6e8a9090f5..d69de4a600b33 100644 --- a/util/execdetails/execdetails.go +++ b/util/execdetails/execdetails.go @@ -23,6 +23,7 @@ import ( "time" "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" ) // CommitDetailCtxKey presents CommitDetail info key in context. @@ -46,7 +47,11 @@ type CommitDetails struct { PrewriteTime time.Duration CommitTime time.Duration LocalLatchTime time.Duration - TotalBackoffTime time.Duration + CommitBackoffTime int64 + Mu struct { + sync.Mutex + BackoffTypes []fmt.Stringer + } ResolveLockTime int64 WriteKeys int WriteSize int @@ -67,6 +72,28 @@ const ( TotalKeysStr = "Total_keys" // ProcessKeysStr means the total processed keys. ProcessKeysStr = "Process_keys" + // PreWriteTimeStr means the time of pre-write. + PreWriteTimeStr = "Prewrite_time" + // CommitTimeStr means the time of commit. + CommitTimeStr = "Commit_time" + // GetCommitTSTimeStr means the time of getting commit ts. + GetCommitTSTimeStr = "Get_commit_ts_time" + // CommitBackoffTimeStr means the time of commit backoff. + CommitBackoffTimeStr = "Commit_backoff_time" + // BackoffTypesStr means the backoff type. + BackoffTypesStr = "Backoff_types" + // ResolveLockTimeStr means the time of resolving lock. + ResolveLockTimeStr = "Resolve_lock_time" + // LocalLatchWaitTimeStr means the time of waiting in local latch. + LocalLatchWaitTimeStr = "Local_latch_wait_time" + // WriteKeysStr means the count of keys in the transaction. + WriteKeysStr = "Write_keys" + // WriteSizeStr means the key/value size in the transaction. + WriteSizeStr = "Write_size" + // PrewriteRegionStr means the count of region when pre-write. + PrewriteRegionStr = "Prewrite_region" + // TxnRetryStr means the count of transaction retry. + TxnRetryStr = "Txn_retry" ) // String implements the fmt.Stringer interface. @@ -93,41 +120,112 @@ func (d ExecDetails) String() string { commitDetails := d.CommitDetail if commitDetails != nil { if commitDetails.PrewriteTime > 0 { - parts = append(parts, fmt.Sprintf("Prewrite_time: %v", commitDetails.PrewriteTime.Seconds())) + parts = append(parts, PreWriteTimeStr+": "+strconv.FormatFloat(commitDetails.PrewriteTime.Seconds(), 'f', -1, 64)) } if commitDetails.CommitTime > 0 { - parts = append(parts, fmt.Sprintf("Commit_time: %v", commitDetails.CommitTime.Seconds())) + parts = append(parts, CommitTimeStr+": "+strconv.FormatFloat(commitDetails.CommitTime.Seconds(), 'f', -1, 64)) } if commitDetails.GetCommitTsTime > 0 { - parts = append(parts, fmt.Sprintf("Get_commit_ts_time: %v", commitDetails.GetCommitTsTime.Seconds())) + parts = append(parts, GetCommitTSTimeStr+": "+strconv.FormatFloat(commitDetails.GetCommitTsTime.Seconds(), 'f', -1, 64)) + } + commitBackoffTime := atomic.LoadInt64(&commitDetails.CommitBackoffTime) + if commitBackoffTime > 0 { + parts = append(parts, CommitBackoffTimeStr+": "+strconv.FormatFloat(time.Duration(commitBackoffTime).Seconds(), 'f', -1, 64)) } - if commitDetails.TotalBackoffTime > 0 { - parts = append(parts, fmt.Sprintf("Total_backoff_time: %v", commitDetails.TotalBackoffTime.Seconds())) + commitDetails.Mu.Lock() + if len(commitDetails.Mu.BackoffTypes) > 0 { + parts = append(parts, BackoffTypesStr+": "+fmt.Sprintf("%v", commitDetails.Mu.BackoffTypes)) } + commitDetails.Mu.Unlock() resolveLockTime := atomic.LoadInt64(&commitDetails.ResolveLockTime) if resolveLockTime > 0 { - parts = append(parts, fmt.Sprintf("Resolve_lock_time: %v", time.Duration(resolveLockTime).Seconds())) + parts = append(parts, ResolveLockTimeStr+": "+strconv.FormatFloat(time.Duration(resolveLockTime).Seconds(), 'f', -1, 64)) } if commitDetails.LocalLatchTime > 0 { - parts = append(parts, fmt.Sprintf("Local_latch_wait_time: %v", commitDetails.LocalLatchTime.Seconds())) + parts = append(parts, LocalLatchWaitTimeStr+": "+strconv.FormatFloat(commitDetails.LocalLatchTime.Seconds(), 'f', -1, 64)) } if commitDetails.WriteKeys > 0 { - parts = append(parts, fmt.Sprintf("Write_keys: %d", commitDetails.WriteKeys)) + parts = append(parts, WriteKeysStr+": "+strconv.FormatInt(int64(commitDetails.WriteKeys), 10)) } if commitDetails.WriteSize > 0 { - parts = append(parts, fmt.Sprintf("Write_size: %d", commitDetails.WriteSize)) + parts = append(parts, WriteSizeStr+": "+strconv.FormatInt(int64(commitDetails.WriteSize), 10)) } prewriteRegionNum := atomic.LoadInt32(&commitDetails.PrewriteRegionNum) if prewriteRegionNum > 0 { - parts = append(parts, fmt.Sprintf("Prewrite_region: %d", prewriteRegionNum)) + parts = append(parts, PrewriteRegionStr+": "+strconv.FormatInt(int64(prewriteRegionNum), 10)) } if commitDetails.TxnRetry > 0 { - parts = append(parts, fmt.Sprintf("Txn_retry: %d", commitDetails.TxnRetry)) + parts = append(parts, TxnRetryStr+": "+strconv.FormatInt(int64(commitDetails.TxnRetry), 10)) } } return strings.Join(parts, " ") } +// ToZapFields wraps the ExecDetails as zap.Fields. +func (d ExecDetails) ToZapFields() (fields []zap.Field) { + fields = make([]zap.Field, 0, 16) + if d.ProcessTime > 0 { + fields = append(fields, zap.String(strings.ToLower(ProcessTimeStr), strconv.FormatFloat(d.ProcessTime.Seconds(), 'f', -1, 64)+"s")) + } + if d.WaitTime > 0 { + fields = append(fields, zap.String(strings.ToLower(WaitTimeStr), strconv.FormatFloat(d.ProcessTime.Seconds(), 'f', -1, 64)+"s")) + } + if d.BackoffTime > 0 { + fields = append(fields, zap.String(strings.ToLower(BackoffTimeStr), strconv.FormatFloat(d.BackoffTime.Seconds(), 'f', -1, 64)+"s")) + } + if d.RequestCount > 0 { + fields = append(fields, zap.String(strings.ToLower(RequestCountStr), strconv.FormatInt(int64(d.RequestCount), 10))) + } + if d.TotalKeys > 0 { + fields = append(fields, zap.String(strings.ToLower(TotalKeysStr), strconv.FormatInt(d.TotalKeys, 10))) + } + if d.ProcessedKeys > 0 { + fields = append(fields, zap.String(strings.ToLower(ProcessKeysStr), strconv.FormatInt(d.ProcessedKeys, 10))) + } + commitDetails := d.CommitDetail + if commitDetails != nil { + if commitDetails.PrewriteTime > 0 { + fields = append(fields, zap.String("prewrite_time", fmt.Sprintf("%v", strconv.FormatFloat(commitDetails.PrewriteTime.Seconds(), 'f', -1, 64)+"s"))) + } + if commitDetails.CommitTime > 0 { + fields = append(fields, zap.String("commit_time", fmt.Sprintf("%v", strconv.FormatFloat(commitDetails.CommitTime.Seconds(), 'f', -1, 64)+"s"))) + } + if commitDetails.GetCommitTsTime > 0 { + fields = append(fields, zap.String("get_commit_ts_time", fmt.Sprintf("%v", strconv.FormatFloat(commitDetails.GetCommitTsTime.Seconds(), 'f', -1, 64)+"s"))) + } + commitBackoffTime := atomic.LoadInt64(&commitDetails.CommitBackoffTime) + if commitBackoffTime > 0 { + fields = append(fields, zap.String("commit_backoff_time", fmt.Sprintf("%v", strconv.FormatFloat(time.Duration(commitBackoffTime).Seconds(), 'f', -1, 64)+"s"))) + } + commitDetails.Mu.Lock() + if len(commitDetails.Mu.BackoffTypes) > 0 { + fields = append(fields, zap.String("backoff_types", fmt.Sprintf("%v", commitDetails.Mu.BackoffTypes))) + } + commitDetails.Mu.Unlock() + resolveLockTime := atomic.LoadInt64(&commitDetails.ResolveLockTime) + if resolveLockTime > 0 { + fields = append(fields, zap.String("resolve_lock_time", fmt.Sprintf("%v", strconv.FormatFloat(time.Duration(resolveLockTime).Seconds(), 'f', -1, 64)+"s"))) + } + if commitDetails.LocalLatchTime > 0 { + fields = append(fields, zap.String("local_latch_wait_time", fmt.Sprintf("%v", strconv.FormatFloat(commitDetails.LocalLatchTime.Seconds(), 'f', -1, 64)+"s"))) + } + if commitDetails.WriteKeys > 0 { + fields = append(fields, zap.Int("write_keys", commitDetails.WriteKeys)) + } + if commitDetails.WriteSize > 0 { + fields = append(fields, zap.Int("write_size", commitDetails.WriteSize)) + } + prewriteRegionNum := atomic.LoadInt32(&commitDetails.PrewriteRegionNum) + if prewriteRegionNum > 0 { + fields = append(fields, zap.Int32("prewrite_region", prewriteRegionNum)) + } + if commitDetails.TxnRetry > 0 { + fields = append(fields, zap.Int("txn_retry", commitDetails.TxnRetry)) + } + } + return fields +} + // CopRuntimeStats collects cop tasks' execution info. type CopRuntimeStats struct { sync.Mutex diff --git a/util/execdetails/execdetails_test.go b/util/execdetails/execdetails_test.go index de39f456ce09b..7aee493aee457 100644 --- a/util/execdetails/execdetails_test.go +++ b/util/execdetails/execdetails_test.go @@ -14,9 +14,12 @@ package execdetails import ( + "fmt" + "sync" "testing" "time" + "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tipb/go-tipb" ) @@ -33,7 +36,18 @@ func TestString(t *testing.T) { PrewriteTime: time.Second, CommitTime: time.Second, LocalLatchTime: time.Second, - TotalBackoffTime: time.Second, + CommitBackoffTime: int64(time.Second), + Mu: struct { + sync.Mutex + BackoffTypes []fmt.Stringer + }{BackoffTypes: []fmt.Stringer{ + stringutil.MemoizeStr(func() string { + return "backoff1" + }), + stringutil.MemoizeStr(func() string { + return "backoff2" + }), + }}, ResolveLockTime: 1000000000, // 10^9 ns = 1s WriteKeys: 1, WriteSize: 1, @@ -42,7 +56,7 @@ func TestString(t *testing.T) { }, } expected := "Process_time: 2.005 Wait_time: 1 Backoff_time: 1 Request_count: 1 Total_keys: 100 Process_keys: 10 Prewrite_time: 1 Commit_time: 1 " + - "Get_commit_ts_time: 1 Total_backoff_time: 1 Resolve_lock_time: 1 Local_latch_wait_time: 1 Write_keys: 1 Write_size: 1 Prewrite_region: 1 Txn_retry: 1" + "Get_commit_ts_time: 1 Commit_backoff_time: 1 Backoff_types: [backoff1 backoff2] Resolve_lock_time: 1 Local_latch_wait_time: 1 Write_keys: 1 Write_size: 1 Prewrite_region: 1 Txn_retry: 1" if str := detail.String(); str != expected { t.Errorf("got:\n%s\nexpected:\n%s", str, expected) } diff --git a/util/expensivequery/expensivequery.go b/util/expensivequery/expensivequery.go new file mode 100644 index 0000000000000..43dcb9b76987d --- /dev/null +++ b/util/expensivequery/expensivequery.go @@ -0,0 +1,161 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expensivequery + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/logutil" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// Handle is the handler for expensive query. +type Handle struct { + mu sync.RWMutex + exitCh chan struct{} + sm util.SessionManager +} + +// NewExpensiveQueryHandle builds a new expensive query handler. +func NewExpensiveQueryHandle(exitCh chan struct{}) *Handle { + return &Handle{exitCh: exitCh} +} + +// SetSessionManager sets the SessionManager which is used to fetching the info +// of all active sessions. +func (eqh *Handle) SetSessionManager(sm util.SessionManager) *Handle { + eqh.sm = sm + return eqh +} + +// Run starts a expensive query checker goroutine at the start time of the server. +func (eqh *Handle) Run() { + threshold := atomic.LoadUint64(&variable.ExpensiveQueryTimeThreshold) + // use 100ms as tickInterval temply, may use given interval or use defined variable later + tickInterval := time.Millisecond * time.Duration(100) + ticker := time.NewTicker(tickInterval) + for { + select { + case <-ticker.C: + processInfo := eqh.sm.ShowProcessList() + for _, info := range processInfo { + if info.Info == nil || info.ExceedExpensiveTimeThresh { + continue + } + costTime := time.Since(info.Time) + if costTime >= time.Second*time.Duration(threshold) && log.GetLevel() <= zapcore.WarnLevel { + logExpensiveQuery(costTime, info) + info.ExceedExpensiveTimeThresh = true + + } else if info.MaxExecutionTime > 0 && costTime > time.Duration(info.MaxExecutionTime)*time.Millisecond { + eqh.sm.Kill(info.ID, true) + } + } + threshold = atomic.LoadUint64(&variable.ExpensiveQueryTimeThreshold) + case <-eqh.exitCh: + return + } + } +} + +// Close closes the handle and release the background goroutine. +func (eqh *Handle) Close() { + close(eqh.exitCh) +} + +// LogOnQueryExceedMemQuota prints a log when memory usage of connID is out of memory quota. +func (eqh *Handle) LogOnQueryExceedMemQuota(connID uint64) { + if log.GetLevel() > zapcore.WarnLevel { + return + } + info, ok := eqh.sm.GetProcessInfo(connID) + if !ok { + return + } + logExpensiveQuery(time.Since(info.Time), info) +} + +// logExpensiveQuery logs the queries which exceed the time threshold or memory threshold. +func logExpensiveQuery(costTime time.Duration, info *util.ProcessInfo) { + logFields := make([]zap.Field, 0, 20) + logFields = append(logFields, zap.String("cost_time", strconv.FormatFloat(costTime.Seconds(), 'f', -1, 64)+"s")) + execDetail := info.StmtCtx.GetExecDetails() + logFields = append(logFields, execDetail.ToZapFields()...) + if copTaskInfo := info.StmtCtx.CopTasksDetails(); copTaskInfo != nil { + logFields = append(logFields, copTaskInfo.ToZapFields()...) + } + if statsInfo := info.StatsInfo(info.Plan); len(statsInfo) > 0 { + var buf strings.Builder + firstComma := false + vStr := "" + for k, v := range statsInfo { + if v == 0 { + vStr = "pseudo" + } else { + vStr = strconv.FormatUint(v, 10) + } + if firstComma { + buf.WriteString("," + k + ":" + vStr) + } else { + buf.WriteString(k + ":" + vStr) + firstComma = true + } + } + logFields = append(logFields, zap.String("stats", buf.String())) + } + if info.ID != 0 { + logFields = append(logFields, zap.Uint64("conn_id", info.ID)) + } + if len(info.User) > 0 { + logFields = append(logFields, zap.String("user", info.User)) + } + if info.DB != nil && len(info.DB.(string)) > 0 { + logFields = append(logFields, zap.String("database", info.DB.(string))) + } + var tableIDs, indexNames string + if len(info.StmtCtx.TableIDs) > 0 { + tableIDs = strings.Replace(fmt.Sprintf("%v", info.StmtCtx.TableIDs), " ", ",", -1) + logFields = append(logFields, zap.String("table_ids", tableIDs)) + } + if len(info.StmtCtx.IndexNames) > 0 { + indexNames = strings.Replace(fmt.Sprintf("%v", info.StmtCtx.IndexNames), " ", ",", -1) + logFields = append(logFields, zap.String("index_ids", indexNames)) + } + logFields = append(logFields, zap.Uint64("txn_start_ts", info.CurTxnStartTS)) + if memTracker := info.StmtCtx.MemTracker; memTracker != nil { + logFields = append(logFields, zap.String("mem_max", memTracker.BytesToString(memTracker.MaxConsumed()))) + } + + const logSQLLen = 1024 * 8 + var sql string + if info.Info != nil { + sql = info.Info.(string) + } + if len(sql) > logSQLLen { + sql = fmt.Sprintf("%s len(%d)", sql[:logSQLLen], len(sql)) + } + logFields = append(logFields, zap.String("sql", sql)) + + logutil.Logger(context.Background()).Warn("expensive_query", logFields...) +} diff --git a/util/kvcache/simple_lru.go b/util/kvcache/simple_lru.go index 7120e3a5abb7c..b3f18f2871d42 100644 --- a/util/kvcache/simple_lru.go +++ b/util/kvcache/simple_lru.go @@ -38,6 +38,7 @@ type cacheEntry struct { type SimpleLRUCache struct { capacity uint size uint + // 0 indicates no quota quota uint64 guard float64 elements map[string]*list.Element @@ -88,6 +89,17 @@ func (l *SimpleLRUCache) Put(key Key, value Value) { l.elements[hash] = element l.size++ + // Getting used memory is expensive and can be avoided by setting quota to 0. + if l.quota <= 0 { + if l.size > l.capacity { + lru := l.cache.Back() + l.cache.Remove(lru) + delete(l.elements, string(lru.Value.(*cacheEntry).key.Hash())) + l.size-- + } + return + } + memUsed, err := memory.MemUsed() if err != nil { l.DeleteAll() @@ -137,3 +149,13 @@ func (l *SimpleLRUCache) DeleteAll() { func (l *SimpleLRUCache) Size() int { return int(l.size) } + +// Values return all values in cache. +func (l *SimpleLRUCache) Values() []Value { + values := make([]Value, 0, l.cache.Len()) + for ele := l.cache.Front(); ele != nil; ele = ele.Next() { + value := ele.Value.(*cacheEntry).value + values = append(values, value) + } + return values +} diff --git a/util/kvcache/simple_lru_test.go b/util/kvcache/simple_lru_test.go index 0cbf21706df92..14e3d629bec61 100644 --- a/util/kvcache/simple_lru_test.go +++ b/util/kvcache/simple_lru_test.go @@ -56,7 +56,7 @@ func (s *testLRUCacheSuite) TestPut(c *C) { maxMem, err := memory.MemTotal() c.Assert(err, IsNil) - lru := NewSimpleLRUCache(3, 0.1, maxMem) + lru := NewSimpleLRUCache(3, 0, maxMem) c.Assert(lru.capacity, Equals, uint(3)) keys := make([]*mockCacheKey, 5) @@ -106,6 +106,22 @@ func (s *testLRUCacheSuite) TestPut(c *C) { c.Assert(root, IsNil) } +func (s *testLRUCacheSuite) TestZeroQuota(c *C) { + lru := NewSimpleLRUCache(100, 0, 0) + c.Assert(lru.capacity, Equals, uint(100)) + + keys := make([]*mockCacheKey, 100) + vals := make([]int64, 100) + + for i := 0; i < 100; i++ { + keys[i] = newMockHashKey(int64(i)) + vals[i] = int64(i) + lru.Put(keys[i], vals[i]) + } + c.Assert(lru.size, Equals, lru.capacity) + c.Assert(lru.size, Equals, uint(100)) +} + func (s *testLRUCacheSuite) TestOOMGuard(c *C) { maxMem, err := memory.MemTotal() c.Assert(err, IsNil) @@ -135,7 +151,7 @@ func (s *testLRUCacheSuite) TestGet(c *C) { maxMem, err := memory.MemTotal() c.Assert(err, IsNil) - lru := NewSimpleLRUCache(3, 0.1, maxMem) + lru := NewSimpleLRUCache(3, 0, maxMem) keys := make([]*mockCacheKey, 5) vals := make([]int64, 5) @@ -178,7 +194,7 @@ func (s *testLRUCacheSuite) TestDelete(c *C) { maxMem, err := memory.MemTotal() c.Assert(err, IsNil) - lru := NewSimpleLRUCache(3, 0.1, maxMem) + lru := NewSimpleLRUCache(3, 0, maxMem) keys := make([]*mockCacheKey, 3) vals := make([]int64, 3) @@ -207,7 +223,7 @@ func (s *testLRUCacheSuite) TestDeleteAll(c *C) { maxMem, err := memory.MemTotal() c.Assert(err, IsNil) - lru := NewSimpleLRUCache(3, 0.1, maxMem) + lru := NewSimpleLRUCache(3, 0, maxMem) keys := make([]*mockCacheKey, 3) vals := make([]int64, 3) @@ -228,3 +244,25 @@ func (s *testLRUCacheSuite) TestDeleteAll(c *C) { c.Assert(int(lru.size), Equals, 0) } } + +func (s *testLRUCacheSuite) TestValues(c *C) { + maxMem, err := memory.MemTotal() + c.Assert(err, IsNil) + + lru := NewSimpleLRUCache(5, 0, maxMem) + + keys := make([]*mockCacheKey, 5) + vals := make([]int64, 5) + + for i := 0; i < 5; i++ { + keys[i] = newMockHashKey(int64(i)) + vals[i] = int64(i) + lru.Put(keys[i], vals[i]) + } + + values := lru.Values() + c.Assert(len(values), Equals, 5) + for i := 0; i < 5; i++ { + c.Assert(values[i], Equals, int64(4-i)) + } +} diff --git a/util/kvencoder/allocator.go b/util/kvencoder/allocator.go index 36a80382819f2..2ed59d3dac282 100644 --- a/util/kvencoder/allocator.go +++ b/util/kvencoder/allocator.go @@ -36,8 +36,9 @@ type Allocator struct { } // Alloc allocs a next autoID for table with tableID. -func (alloc *Allocator) Alloc(tableID int64) (int64, error) { - return atomic.AddInt64(&alloc.base, 1), nil +func (alloc *Allocator) Alloc(tableID int64, n uint64) (int64, int64, error) { + min := alloc.base + return min, atomic.AddInt64(&alloc.base, int64(n)), nil } // Reset allow newBase smaller than alloc.base, and will set the alloc.base to newBase. diff --git a/util/kvencoder/kv_encoder_test.go b/util/kvencoder/kv_encoder_test.go index aae6797c90952..1aa044dc28eea 100644 --- a/util/kvencoder/kv_encoder_test.go +++ b/util/kvencoder/kv_encoder_test.go @@ -49,7 +49,7 @@ func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) { return nil, nil, errors.Trace(err) } session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() dom, err := session.BootstrapSession(store) return store, dom, errors.Trace(err) } diff --git a/util/logutil/log.go b/util/logutil/log.go index 40061f20be11a..7b54fbc093a4b 100644 --- a/util/logutil/log.go +++ b/util/logutil/log.go @@ -42,6 +42,8 @@ const ( DefaultSlowThreshold = 300 // DefaultQueryLogMaxLen is the default max length of the query in the log. DefaultQueryLogMaxLen = 2048 + // DefaultRecordPlanInSlowLog is the default value for whether enable log query plan in the slow log. + DefaultRecordPlanInSlowLog = 1 ) // EmptyFileLogConfig is an empty FileLogConfig. diff --git a/util/memory/action.go b/util/memory/action.go index bcafa0e274c0d..e8f79ef7a5b12 100644 --- a/util/memory/action.go +++ b/util/memory/action.go @@ -15,6 +15,7 @@ package memory import ( "context" + "fmt" "sync" "github.com/pingcap/parser/mysql" @@ -29,12 +30,22 @@ type ActionOnExceed interface { // Action will be called when memory usage exceeds memory quota by the // corresponding Tracker. Action(t *Tracker) + // SetLogHook binds a log hook which will be triggered and log an detailed + // message for the out-of-memory sql. + SetLogHook(hook func(uint64)) } // LogOnExceed logs a warning only once when memory usage exceeds memory quota. type LogOnExceed struct { - mutex sync.Mutex // For synchronization. - acted bool + mutex sync.Mutex // For synchronization. + acted bool + ConnID uint64 + logHook func(uint64) +} + +// SetLogHook sets a hook for LogOnExceed. +func (a *LogOnExceed) SetLogHook(hook func(uint64)) { + a.logHook = hook } // Action logs a warning only once when memory usage exceeds memory quota. @@ -43,15 +54,26 @@ func (a *LogOnExceed) Action(t *Tracker) { defer a.mutex.Unlock() if !a.acted { a.acted = true - logutil.Logger(context.Background()).Warn("memory exceeds quota", - zap.Error(errMemExceedThreshold.GenWithStackByArgs(t.label, t.BytesConsumed(), t.bytesLimit, t.String()))) + if a.logHook == nil { + logutil.Logger(context.Background()).Warn("memory exceeds quota", + zap.Error(errMemExceedThreshold.GenWithStackByArgs(t.label, t.BytesConsumed(), t.bytesLimit, t.String()))) + return + } + a.logHook(a.ConnID) } } // PanicOnExceed panics when memory usage exceeds memory quota. type PanicOnExceed struct { - mutex sync.Mutex // For synchronization. - acted bool + mutex sync.Mutex // For synchronization. + acted bool + ConnID uint64 + logHook func(uint64) +} + +// SetLogHook sets a hook for PanicOnExceed. +func (a *PanicOnExceed) SetLogHook(hook func(uint64)) { + a.logHook = hook } // Action panics when memory usage exceeds memory quota. @@ -63,7 +85,10 @@ func (a *PanicOnExceed) Action(t *Tracker) { } a.acted = true a.mutex.Unlock() - panic(PanicMemoryExceed + t.String()) + if a.logHook != nil { + a.logHook(a.ConnID) + } + panic(PanicMemoryExceed + fmt.Sprintf("[conn_id=%d]", a.ConnID)) } var ( diff --git a/util/memory/tracker.go b/util/memory/tracker.go index 10be4a9257505..3b935360f0fce 100644 --- a/util/memory/tracker.go +++ b/util/memory/tracker.go @@ -87,11 +87,6 @@ func (t *Tracker) AttachTo(parent *Tracker) { t.parent.Consume(t.BytesConsumed()) } -// Detach detaches this Tracker from its parent. -func (t *Tracker) Detach() { - t.parent.remove(t) -} - func (t *Tracker) remove(oldChild *Tracker) { t.mu.Lock() defer t.mu.Unlock() @@ -144,17 +139,13 @@ func (t *Tracker) Consume(bytes int64) { rootExceed = tracker } - if tracker.parent == nil { - // since we only need a total memory usage during execution, - // we only record max consumed bytes in root(statement-level) for performance. - for { - maxNow := atomic.LoadInt64(&tracker.maxConsumed) - consumed := atomic.LoadInt64(&tracker.bytesConsumed) - if consumed > maxNow && !atomic.CompareAndSwapInt64(&tracker.maxConsumed, maxNow, consumed) { - continue - } - break + for { + maxNow := atomic.LoadInt64(&tracker.maxConsumed) + consumed := atomic.LoadInt64(&tracker.bytesConsumed) + if consumed > maxNow && !atomic.CompareAndSwapInt64(&tracker.maxConsumed, maxNow, consumed) { + continue } + break } } if rootExceed != nil { @@ -172,6 +163,21 @@ func (t *Tracker) MaxConsumed() int64 { return atomic.LoadInt64(&t.maxConsumed) } +// SearchTracker searches the specific tracker under this tracker. +func (t *Tracker) SearchTracker(label string) *Tracker { + if t.label.String() == label { + return t + } + t.mu.Lock() + defer t.mu.Unlock() + for _, child := range t.mu.children { + if result := child.SearchTracker(label); result != nil { + return result + } + } + return nil +} + // String returns the string representation of this Tracker tree. func (t *Tracker) String() string { buffer := bytes.NewBufferString("\n") @@ -182,9 +188,9 @@ func (t *Tracker) String() string { func (t *Tracker) toString(indent string, buffer *bytes.Buffer) { fmt.Fprintf(buffer, "%s\"%s\"{\n", indent, t.label) if t.bytesLimit > 0 { - fmt.Fprintf(buffer, "%s \"quota\": %s\n", indent, t.bytesToString(t.bytesLimit)) + fmt.Fprintf(buffer, "%s \"quota\": %s\n", indent, t.BytesToString(t.bytesLimit)) } - fmt.Fprintf(buffer, "%s \"consumed\": %s\n", indent, t.bytesToString(t.BytesConsumed())) + fmt.Fprintf(buffer, "%s \"consumed\": %s\n", indent, t.BytesToString(t.BytesConsumed())) t.mu.Lock() for i := range t.mu.children { @@ -196,7 +202,8 @@ func (t *Tracker) toString(indent string, buffer *bytes.Buffer) { buffer.WriteString(indent + "}\n") } -func (t *Tracker) bytesToString(numBytes int64) string { +// BytesToString converts the memory consumption to a readable string. +func (t *Tracker) BytesToString(numBytes int64) string { GB := float64(numBytes) / float64(1<<30) if GB > 1 { return fmt.Sprintf("%v GB", GB) diff --git a/util/memory/tracker_test.go b/util/memory/tracker_test.go index 11c6be4848c75..bf7ac98ecc506 100644 --- a/util/memory/tracker_test.go +++ b/util/memory/tracker_test.go @@ -98,6 +98,9 @@ type mockAction struct { called bool } +func (a *mockAction) SetLogHook(hook func(uint64)) { +} + func (a *mockAction) Action(t *Tracker) { a.called = true } diff --git a/util/misc_test.go b/util/misc_test.go index 1ae47e6e2f6a7..7c365a98fdb5a 100644 --- a/util/misc_test.go +++ b/util/misc_test.go @@ -14,10 +14,17 @@ package util import ( + "bytes" "time" . "github.com/pingcap/check" "github.com/pingcap/errors" + "github.com/pingcap/parser" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" + "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/util/memory" + "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tidb/util/testleak" ) @@ -113,3 +120,65 @@ func (s *testMiscSuite) TestCompatibleParseGCTime(c *C) { c.Assert(err, NotNil) } } + +func (s *testMiscSuite) TestBasicFunc(c *C) { + // Test for GetStack. + b := GetStack() + c.Assert(len(b) < 4096, IsTrue) + + // Test for WithRecovery. + var recover interface{} + WithRecovery(func() { + panic("test") + }, func(r interface{}) { + recover = r + }) + c.Assert(recover, Equals, "test") + + // Test for SyntaxError. + c.Assert(SyntaxError(nil), IsNil) + c.Assert(terror.ErrorEqual(SyntaxError(errors.New("test")), parser.ErrParse), IsTrue) + c.Assert(terror.ErrorEqual(SyntaxError(parser.ErrSyntax.GenWithStackByArgs()), parser.ErrSyntax), IsTrue) + + // Test for SyntaxWarn. + c.Assert(SyntaxWarn(nil), IsNil) + c.Assert(terror.ErrorEqual(SyntaxWarn(errors.New("test")), parser.ErrParse), IsTrue) + + // Test for ProcessInfo. + pi := ProcessInfo{ + ID: 1, + User: "test", + Host: "www", + DB: "db", + Command: mysql.ComSleep, + Plan: nil, + Time: time.Now(), + State: 1, + Info: "test", + StmtCtx: &stmtctx.StatementContext{ + MemTracker: memory.NewTracker(stringutil.StringerStr(""), -1), + }, + } + row := pi.ToRowForShow(false) + row2 := pi.ToRowForShow(true) + c.Assert(row, DeepEquals, row2) + c.Assert(len(row), Equals, 8) + c.Assert(row[0], Equals, pi.ID) + c.Assert(row[1], Equals, pi.User) + c.Assert(row[2], Equals, pi.Host) + c.Assert(row[3], Equals, pi.DB) + c.Assert(row[4], Equals, "Sleep") + c.Assert(row[5], Equals, uint64(0)) + c.Assert(row[6], Equals, "1") + c.Assert(row[7], Equals, "test") + + row3 := pi.ToRow() + c.Assert(row3[:8], DeepEquals, row) + c.Assert(row3[8], Equals, int64(0)) + + // Test for RandomBuf. + buf := RandomBuf(5) + c.Assert(len(buf), Equals, 5) + c.Assert(bytes.Contains(buf, []byte("$")), IsFalse) + c.Assert(bytes.Contains(buf, []byte{0}), IsFalse) +} diff --git a/util/mock/context.go b/util/mock/context.go index c3419792ac857..a7b76b282baf3 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -228,6 +228,9 @@ func NewContext() *Context { sctx.sessionVars.MaxChunkSize = 32 sctx.sessionVars.StmtCtx.TimeZone = time.UTC sctx.sessionVars.GlobalVarsAccessor = variable.NewMockGlobalAccessor() + if err := sctx.GetSessionVars().SetSystemVar(variable.MaxAllowedPacket, "67108864"); err != nil { + panic(err) + } return sctx } diff --git a/util/pdapi/const.go b/util/pdapi/const.go index 436784f627245..29d3746e50122 100644 --- a/util/pdapi/const.go +++ b/util/pdapi/const.go @@ -15,8 +15,9 @@ package pdapi // The following constants are the APIs of PD server. const ( - HotRead = "/pd/api/v1/hotspot/regions/read" - HotWrite = "/pd/api/v1/hotspot/regions/write" - Regions = "/pd/api/v1/regions" - Stores = "/pd/api/v1/stores" + HotRead = "/pd/api/v1/hotspot/regions/read" + HotWrite = "/pd/api/v1/hotspot/regions/write" + Regions = "/pd/api/v1/regions" + RegionByID = "/pd/api/v1//region/id/" + Stores = "/pd/api/v1/stores" ) diff --git a/util/plancodec/codec.go b/util/plancodec/codec.go new file mode 100644 index 0000000000000..cd273fd973e13 --- /dev/null +++ b/util/plancodec/codec.go @@ -0,0 +1,281 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plancodec + +import ( + "bytes" + "encoding/base64" + "strconv" + "strings" + "sync" + + "github.com/golang/snappy" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/util/texttree" +) + +const ( + rootTaskType = "0" + copTaskType = "1" +) + +const ( + idSeparator = "_" + lineBreaker = '\n' + lineBreakerStr = "\n" + separator = '\t' + separatorStr = "\t" +) + +var decoderPool = sync.Pool{ + New: func() interface{} { + return &planDecoder{} + }, +} + +// DecodePlan use to decode the string to plan tree. +func DecodePlan(planString string) (string, error) { + if len(planString) == 0 { + return "", nil + } + pd := decoderPool.Get().(*planDecoder) + defer decoderPool.Put(pd) + pd.buf.Reset() + return pd.decode(planString) +} + +type planDecoder struct { + buf bytes.Buffer + depths []int + indents [][]rune + planInfos []*planInfo +} + +type planInfo struct { + depth int + fields []string +} + +func (pd *planDecoder) decode(planString string) (string, error) { + str, err := decompress(planString) + if err != nil { + return "", err + } + + nodes := strings.Split(str, lineBreakerStr) + if len(pd.depths) < len(nodes) { + pd.depths = make([]int, 0, len(nodes)) + pd.planInfos = make([]*planInfo, 0, len(nodes)) + pd.indents = make([][]rune, 0, len(nodes)) + } + pd.depths = pd.depths[:0] + pd.planInfos = pd.planInfos[:0] + planInfos := pd.planInfos + for _, node := range nodes { + p, err := decodePlanInfo(node) + if err != nil { + return "", err + } + if p == nil { + continue + } + planInfos = append(planInfos, p) + pd.depths = append(pd.depths, p.depth) + } + + // Calculated indentation of plans. + pd.initPlanTreeIndents() + for i := 1; i < len(pd.depths); i++ { + parentIndex := pd.findParentIndex(i) + pd.fillIndent(parentIndex, i) + } + // Align the value of plan fields. + pd.alignFields(planInfos) + + for i, p := range planInfos { + if i > 0 { + pd.buf.WriteByte(lineBreaker) + } + // This is for alignment. + pd.buf.WriteByte(separator) + pd.buf.WriteString(string(pd.indents[i])) + for j := 0; j < len(p.fields); j++ { + if j > 0 { + pd.buf.WriteByte(separator) + } + pd.buf.WriteString(p.fields[j]) + } + } + return pd.buf.String(), nil +} + +func (pd *planDecoder) initPlanTreeIndents() { + pd.indents = pd.indents[:0] + for i := 0; i < len(pd.depths); i++ { + indent := make([]rune, 2*pd.depths[i]) + pd.indents = append(pd.indents, indent) + if len(indent) == 0 { + continue + } + for i := 0; i < len(indent)-2; i++ { + indent[i] = ' ' + } + indent[len(indent)-2] = texttree.TreeLastNode + indent[len(indent)-1] = texttree.TreeNodeIdentifier + } +} + +func (pd *planDecoder) findParentIndex(childIndex int) int { + for i := childIndex - 1; i > 0; i-- { + if pd.depths[i]+1 == pd.depths[childIndex] { + return i + } + } + return 0 +} +func (pd *planDecoder) fillIndent(parentIndex, childIndex int) { + depth := pd.depths[childIndex] + if depth == 0 { + return + } + idx := depth*2 - 2 + for i := childIndex - 1; i > parentIndex; i-- { + if pd.indents[i][idx] == texttree.TreeLastNode { + pd.indents[i][idx] = texttree.TreeMiddleNode + break + } + pd.indents[i][idx] = texttree.TreeBody + } +} + +func (pd *planDecoder) alignFields(planInfos []*planInfo) { + if len(planInfos) == 0 { + return + } + fieldsLen := len(planInfos[0].fields) + // Last field no need to align. + fieldsLen-- + for colIdx := 0; colIdx < fieldsLen; colIdx++ { + maxFieldLen := pd.getMaxFieldLength(colIdx, planInfos) + for rowIdx, p := range planInfos { + fillLen := maxFieldLen - pd.getPlanFieldLen(rowIdx, colIdx, p) + for i := 0; i < fillLen; i++ { + p.fields[colIdx] += " " + } + } + } +} + +func (pd *planDecoder) getMaxFieldLength(idx int, planInfos []*planInfo) int { + maxLength := -1 + for rowIdx, p := range planInfos { + l := pd.getPlanFieldLen(rowIdx, idx, p) + if l > maxLength { + maxLength = l + } + } + return maxLength +} + +func (pd *planDecoder) getPlanFieldLen(rowIdx, colIdx int, p *planInfo) int { + if colIdx == 0 { + return len(p.fields[0]) + len(pd.indents[rowIdx]) + } + return len(p.fields[colIdx]) +} + +func decodePlanInfo(str string) (*planInfo, error) { + values := strings.Split(str, separatorStr) + if len(values) < 2 { + return nil, nil + } + + p := &planInfo{ + fields: make([]string, 0, len(values)-1), + } + for i, v := range values { + switch i { + // depth + case 0: + depth, err := strconv.Atoi(v) + if err != nil { + return nil, errors.Errorf("decode plan: %v, depth: %v, error: %v", str, v, err) + } + p.depth = depth + // plan ID + case 1: + ids := strings.Split(v, idSeparator) + if len(ids) != 2 { + return nil, errors.Errorf("decode plan: %v error, invalid plan id: %v", str, v) + } + planID, err := strconv.Atoi(ids[0]) + if err != err { + return nil, errors.Errorf("decode plan: %v, plan id: %v, error: %v", str, v, err) + } + p.fields = append(p.fields, PhysicalIDToTypeString(planID)+idSeparator+ids[1]) + // task type + case 2: + if v == rootTaskType { + p.fields = append(p.fields, "root") + } else { + p.fields = append(p.fields, "cop") + } + default: + p.fields = append(p.fields, v) + } + } + return p, nil +} + +// EncodePlanNode is used to encode the plan to a string. +func EncodePlanNode(depth, pid int, planType string, isRoot bool, rowCount float64, explainInfo string, buf *bytes.Buffer) { + buf.WriteString(strconv.Itoa(depth)) + buf.WriteByte(separator) + buf.WriteString(encodeID(planType, pid)) + buf.WriteByte(separator) + if isRoot { + buf.WriteString(rootTaskType) + } else { + buf.WriteString(copTaskType) + } + buf.WriteByte(separator) + buf.WriteString(strconv.FormatFloat(rowCount, 'f', -1, 64)) + buf.WriteByte(separator) + buf.WriteString(explainInfo) + buf.WriteByte(lineBreaker) +} + +func encodeID(planType string, id int) string { + planID := TypeStringToPhysicalID(planType) + return strconv.Itoa(planID) + idSeparator + strconv.Itoa(id) +} + +// Compress is used to compress the input with zlib. +func Compress(input []byte) string { + compressBytes := snappy.Encode(nil, input) + return base64.StdEncoding.EncodeToString(compressBytes) +} + +func decompress(str string) (string, error) { + decodeBytes, err := base64.StdEncoding.DecodeString(str) + if err != nil { + return "", err + } + + bs, err := snappy.Decode(nil, decodeBytes) + if err != nil { + return "", err + } + return string(bs), nil +} diff --git a/util/plancodec/id.go b/util/plancodec/id.go new file mode 100644 index 0000000000000..8369f7bfd0e59 --- /dev/null +++ b/util/plancodec/id.go @@ -0,0 +1,313 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package plancodec + +import "strconv" + +const ( + // TypeSel is the type of Selection. + TypeSel = "Selection" + // TypeSet is the type of Set. + TypeSet = "Set" + // TypeProj is the type of Projection. + TypeProj = "Projection" + // TypeAgg is the type of Aggregation. + TypeAgg = "Aggregation" + // TypeStreamAgg is the type of StreamAgg. + TypeStreamAgg = "StreamAgg" + // TypeHashAgg is the type of HashAgg. + TypeHashAgg = "HashAgg" + // TypeShow is the type of show. + TypeShow = "Show" + // TypeJoin is the type of Join. + TypeJoin = "Join" + // TypeUnion is the type of Union. + TypeUnion = "Union" + // TypeTableScan is the type of TableScan. + TypeTableScan = "TableScan" + // TypeMemTableScan is the type of TableScan. + TypeMemTableScan = "MemTableScan" + // TypeUnionScan is the type of UnionScan. + TypeUnionScan = "UnionScan" + // TypeIdxScan is the type of IndexScan. + TypeIdxScan = "IndexScan" + // TypeSort is the type of Sort. + TypeSort = "Sort" + // TypeTopN is the type of TopN. + TypeTopN = "TopN" + // TypeLimit is the type of Limit. + TypeLimit = "Limit" + // TypeHashLeftJoin is the type of left hash join. + TypeHashLeftJoin = "HashLeftJoin" + // TypeHashRightJoin is the type of right hash join. + TypeHashRightJoin = "HashRightJoin" + // TypeMergeJoin is the type of merge join. + TypeMergeJoin = "MergeJoin" + // TypeIndexJoin is the type of index look up join. + TypeIndexJoin = "IndexJoin" + // TypeIndexMergeJoin is the type of index look up merge join. + TypeIndexMergeJoin = "IndexMergeJoin" + // TypeIndexHashJoin is the type of index nested loop hash join. + TypeIndexHashJoin = "IndexHashJoin" + // TypeApply is the type of Apply. + TypeApply = "Apply" + // TypeMaxOneRow is the type of MaxOneRow. + TypeMaxOneRow = "MaxOneRow" + // TypeExists is the type of Exists. + TypeExists = "Exists" + // TypeDual is the type of TableDual. + TypeDual = "TableDual" + // TypeLock is the type of SelectLock. + TypeLock = "SelectLock" + // TypeInsert is the type of Insert + TypeInsert = "Insert" + // TypeUpdate is the type of Update. + TypeUpdate = "Update" + // TypeDelete is the type of Delete. + TypeDelete = "Delete" + // TypeIndexLookUp is the type of IndexLookUp. + TypeIndexLookUp = "IndexLookUp" + // TypeTableReader is the type of TableReader. + TypeTableReader = "TableReader" + // TypeIndexReader is the type of IndexReader. + TypeIndexReader = "IndexReader" + // TypeWindow is the type of Window. + TypeWindow = "Window" + // TypeTableGather is the type of TableGather. + TypeTableGather = "TableGather" + // TypeIndexMerge is the type of IndexMergeReader + TypeIndexMerge = "IndexMerge" + // TypePointGet is the type of PointGetPlan. + TypePointGet = "Point_Get" + // TypeShowDDLJobs is the type of show ddl jobs. + TypeShowDDLJobs = "ShowDDLJobs" + // TypeBatchPointGet is the type of BatchPointGetPlan. + TypeBatchPointGet = "Batch_Point_Get" +) + +// plan id. +const ( + typeSelID int = iota + 1 + typeSetID + typeProjID + typeAggID + typeStreamAggID + typeHashAggID + typeShowID + typeJoinID + typeUnionID + typeTableScanID + typeMemTableScanID + typeUnionScanID + typeIdxScanID + typeSortID + typeTopNID + typeLimitID + typeHashLeftJoinID + typeHashRightJoinID + typeMergeJoinID + typeIndexJoinID + typeIndexMergeJoinID + typeIndexHashJoinID + typeApplyID + typeMaxOneRowID + typeExistsID + typeDualID + typeLockID + typeInsertID + typeUpdateID + typeDeleteID + typeIndexLookUpID + typeTableReaderID + typeIndexReaderID + typeWindowID + typeTableGatherID + typeIndexMergeID + typePointGet + typeShowDDLJobs + typeBatchPointGet +) + +// TypeStringToPhysicalID converts the plan type string to plan id. +func TypeStringToPhysicalID(tp string) int { + switch tp { + case TypeSel: + return typeSelID + case TypeSet: + return typeSetID + case TypeProj: + return typeProjID + case TypeAgg: + return typeAggID + case TypeStreamAgg: + return typeStreamAggID + case TypeHashAgg: + return typeHashAggID + case TypeShow: + return typeShowID + case TypeJoin: + return typeJoinID + case TypeUnion: + return typeUnionID + case TypeTableScan: + return typeTableScanID + case TypeMemTableScan: + return typeMemTableScanID + case TypeUnionScan: + return typeUnionScanID + case TypeIdxScan: + return typeIdxScanID + case TypeSort: + return typeSortID + case TypeTopN: + return typeTopNID + case TypeLimit: + return typeLimitID + case TypeHashLeftJoin: + return typeHashLeftJoinID + case TypeHashRightJoin: + return typeHashRightJoinID + case TypeMergeJoin: + return typeMergeJoinID + case TypeIndexJoin: + return typeIndexJoinID + case TypeIndexMergeJoin: + return typeIndexMergeJoinID + case TypeIndexHashJoin: + return typeIndexHashJoinID + case TypeApply: + return typeApplyID + case TypeMaxOneRow: + return typeMaxOneRowID + case TypeExists: + return typeExistsID + case TypeDual: + return typeDualID + case TypeLock: + return typeLockID + case TypeInsert: + return typeInsertID + case TypeUpdate: + return typeUpdateID + case TypeDelete: + return typeDeleteID + case TypeIndexLookUp: + return typeIndexLookUpID + case TypeTableReader: + return typeTableReaderID + case TypeIndexReader: + return typeIndexReaderID + case TypeWindow: + return typeWindowID + case TypeTableGather: + return typeTableGatherID + case TypeIndexMerge: + return typeIndexMergeID + case TypePointGet: + return typePointGet + case TypeShowDDLJobs: + return typeShowDDLJobs + case TypeBatchPointGet: + return typeBatchPointGet + } + // Should never reach here. + return 0 +} + +// PhysicalIDToTypeString converts the plan id to plan type string. +func PhysicalIDToTypeString(id int) string { + switch id { + case typeSelID: + return TypeSel + case typeSetID: + return TypeSet + case typeProjID: + return TypeProj + case typeAggID: + return TypeAgg + case typeStreamAggID: + return TypeStreamAgg + case typeHashAggID: + return TypeHashAgg + case typeShowID: + return TypeShow + case typeJoinID: + return TypeJoin + case typeUnionID: + return TypeUnion + case typeTableScanID: + return TypeTableScan + case typeMemTableScanID: + return TypeMemTableScan + case typeUnionScanID: + return TypeUnionScan + case typeIdxScanID: + return TypeIdxScan + case typeSortID: + return TypeSort + case typeTopNID: + return TypeTopN + case typeLimitID: + return TypeLimit + case typeHashLeftJoinID: + return TypeHashLeftJoin + case typeHashRightJoinID: + return TypeHashRightJoin + case typeMergeJoinID: + return TypeMergeJoin + case typeIndexJoinID: + return TypeIndexJoin + case typeIndexMergeJoinID: + return TypeIndexMergeJoin + case typeIndexHashJoinID: + return TypeIndexHashJoin + case typeApplyID: + return TypeApply + case typeMaxOneRowID: + return TypeMaxOneRow + case typeExistsID: + return TypeExists + case typeDualID: + return TypeDual + case typeLockID: + return TypeLock + case typeInsertID: + return TypeInsert + case typeUpdateID: + return TypeUpdate + case typeDeleteID: + return TypeDelete + case typeIndexLookUpID: + return TypeIndexLookUp + case typeTableReaderID: + return TypeTableReader + case typeIndexReaderID: + return TypeIndexReader + case typeWindowID: + return TypeWindow + case typeTableGatherID: + return TypeTableGather + case typeIndexMergeID: + return TypeIndexMerge + case typePointGet: + return TypePointGet + case typeShowDDLJobs: + return TypeShowDDLJobs + case typeBatchPointGet: + return TypeBatchPointGet + } + + // Should never reach here. + return "UnknownPlanID" + strconv.Itoa(id) +} diff --git a/util/printer/printer.go b/util/printer/printer.go index f322c15e5c127..cd6bbd9a28aff 100644 --- a/util/printer/printer.go +++ b/util/printer/printer.go @@ -33,7 +33,7 @@ var ( TiDBGitBranch = "None" GoVersion = "None" // TiKVMinVersion is the minimum version of TiKV that can be compatible with the current TiDB. - TiKVMinVersion = "2.1.0-alpha.1-ff3dd160846b7d1aed9079c389fc188f7f5ea13e" + TiKVMinVersion = "v3.0.0-60965b006877ca7234adaced7890d7b029ed1306" ) // PrintTiDBInfo prints the TiDB version information. diff --git a/util/processinfo.go b/util/processinfo.go index d6a992da0e318..b09edf810b184 100644 --- a/util/processinfo.go +++ b/util/processinfo.go @@ -18,28 +18,38 @@ import ( "time" "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/sessionctx/stmtctx" ) // ProcessInfo is a struct used for show processlist statement. type ProcessInfo struct { - ID uint64 - User string - Host string - DB string - Command byte - Plan interface{} - Time time.Time - State uint16 - Info string + ID uint64 + User string + Host string + DB interface{} + Command byte + Plan interface{} + Time time.Time + State uint16 + Info interface{} + CurTxnStartTS uint64 + StmtCtx *stmtctx.StatementContext + StatsInfo func(interface{}) map[string]uint64 + ExceedExpensiveTimeThresh bool + // MaxExecutionTime is the timeout for select statement, in milliseconds. + // If the query takes too long, kill it. + MaxExecutionTime uint64 } -// ToRow returns []interface{} for the row data of "show processlist" and "select * from infoschema.processlist". -func (pi *ProcessInfo) ToRow(full bool) []interface{} { - var info string - if full { - info = pi.Info - } else { - info = fmt.Sprintf("%.100v", pi.Info) +// ToRowForShow returns []interface{} for the row data of "SHOW [FULL] PROCESSLIST". +func (pi *ProcessInfo) ToRowForShow(full bool) []interface{} { + var info interface{} + if pi.Info != nil { + if full { + info = pi.Info.(string) + } else { + info = fmt.Sprintf("%.100v", pi.Info.(string)) + } } t := uint64(time.Since(pi.Time) / time.Second) return []interface{}{ @@ -54,11 +64,16 @@ func (pi *ProcessInfo) ToRow(full bool) []interface{} { } } +// ToRow returns []interface{} for the row data of +// "SELECT * FROM INFORMATION_SCHEMA.PROCESSLIST". +func (pi *ProcessInfo) ToRow() []interface{} { + return append(pi.ToRowForShow(true), pi.StmtCtx.MemTracker.BytesConsumed()) +} + // SessionManager is an interface for session manage. Show processlist and // kill statement rely on this interface. type SessionManager interface { - // ShowProcessList returns map[connectionID]ProcessInfo - ShowProcessList() map[uint64]ProcessInfo - GetProcessInfo(id uint64) (ProcessInfo, bool) + ShowProcessList() map[uint64]*ProcessInfo + GetProcessInfo(id uint64) (*ProcessInfo, bool) Kill(connectionID uint64, query bool) } diff --git a/util/profile/profile.go b/util/profile/profile.go new file mode 100644 index 0000000000000..7286e1f070cad --- /dev/null +++ b/util/profile/profile.go @@ -0,0 +1,215 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package profile + +import ( + "bytes" + "fmt" + "io" + "runtime/pprof" + "strconv" + "strings" + "time" + + "github.com/google/pprof/graph" + "github.com/google/pprof/measurement" + "github.com/google/pprof/profile" + "github.com/google/pprof/report" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/texttree" +) + +// CPUProfileInterval represents the duration of sampling CPU +var CPUProfileInterval = 30 * time.Second + +// Collector is used to collect the profile results +type Collector struct { + Rows [][]types.Datum +} + +type perfNode struct { + Name string + Location string + Cum int64 + CumFormat string + Percent string + Children []*perfNode +} + +func (c *Collector) collect(node *perfNode, depth int64, indent string, rootChild int, parentCur int64, isLastChild bool) { + row := types.MakeDatums( + texttree.PrettyIdentifier(node.Name, indent, isLastChild), + node.Percent, + strings.TrimSpace(measurement.Percentage(node.Cum, parentCur)), + rootChild, + depth, + node.Location, + ) + c.Rows = append(c.Rows, row) + + indent4Child := texttree.Indent4Child(indent, isLastChild) + for i, child := range node.Children { + rc := rootChild + if depth == 0 { + rc = i + 1 + } + c.collect(child, depth+1, indent4Child, rc, node.Cum, i == len(node.Children)-1) + } +} + +func (c *Collector) profileReaderToDatums(f io.Reader) ([][]types.Datum, error) { + p, err := profile.Parse(f) + if err != nil { + return nil, err + } + return c.profileToDatums(p) +} + +func (c *Collector) profileToDatums(p *profile.Profile) ([][]types.Datum, error) { + err := p.Aggregate(true, true, true, true, true) + if err != nil { + return nil, err + } + rpt := report.NewDefault(p, report.Options{ + OutputFormat: report.Dot, + CallTree: true, + }) + g, config := report.GetDOT(rpt) + var nodes []*perfNode + nroots := 0 + rootValue := int64(0) + nodeArr := []string{} + nodeMap := map[*graph.Node]*perfNode{} + // Make all nodes and the map, collect the roots. + for _, n := range g.Nodes { + v := n.CumValue() + node := &perfNode{ + Name: n.Info.Name, + Location: fmt.Sprintf("%s:%d", n.Info.File, n.Info.Lineno), + Cum: v, + CumFormat: config.FormatValue(v), + Percent: strings.TrimSpace(measurement.Percentage(v, config.Total)), + } + nodes = append(nodes, node) + if len(n.In) == 0 { + nodes[nroots], nodes[len(nodes)-1] = nodes[len(nodes)-1], nodes[nroots] + nroots++ + rootValue += v + } + nodeMap[n] = node + // Get all node names into an array. + nodeArr = append(nodeArr, n.Info.Name) + } + // Populate the child links. + for _, n := range g.Nodes { + node := nodeMap[n] + for child := range n.Out { + node.Children = append(node.Children, nodeMap[child]) + } + } + + rootNode := &perfNode{ + Name: "root", + Location: "root", + Cum: rootValue, + CumFormat: config.FormatValue(rootValue), + Percent: strings.TrimSpace(measurement.Percentage(rootValue, config.Total)), + Children: nodes[0:nroots], + } + + c.collect(rootNode, 0, "", 0, config.Total, len(rootNode.Children) == 0) + return c.Rows, nil +} + +// cpuProfileGraph returns the CPU profile flamegraph which is organized by tree form +func (c *Collector) cpuProfileGraph() ([][]types.Datum, error) { + buffer := &bytes.Buffer{} + if err := pprof.StartCPUProfile(buffer); err != nil { + panic(err) + } + time.Sleep(CPUProfileInterval) + pprof.StopCPUProfile() + return c.profileReaderToDatums(buffer) +} + +// ProfileGraph returns the CPU/memory/mutex/allocs/block profile flamegraph which is organized by tree form +func (c *Collector) ProfileGraph(name string) ([][]types.Datum, error) { + if strings.ToLower(strings.TrimSpace(name)) == "cpu" { + return c.cpuProfileGraph() + } + + p := pprof.Lookup(name) + if p == nil { + return nil, errors.Errorf("cannot retrieve %s profile", name) + } + buffer := &bytes.Buffer{} + if err := p.WriteTo(buffer, 0); err != nil { + return nil, err + } + return c.profileReaderToDatums(buffer) +} + +// Goroutines returns the groutine list which alive in runtime +func (c *Collector) Goroutines() ([][]types.Datum, error) { + p := pprof.Lookup("goroutine") + if p == nil { + return nil, errors.Errorf("cannot retrieve goroutine profile") + } + + buffer := bytes.Buffer{} + err := p.WriteTo(&buffer, 2) + if err != nil { + return nil, err + } + + goroutines := strings.Split(buffer.String(), "\n\n") + var rows [][]types.Datum + for _, goroutine := range goroutines { + colIndex := strings.Index(goroutine, ":") + if colIndex < 0 { + return nil, errors.New("goroutine incompatible with current go version") + } + + headers := strings.SplitN(strings.TrimSpace(goroutine[len("goroutine")+1:colIndex]), " ", 2) + if len(headers) != 2 { + return nil, errors.Errorf("incompatible goroutine headers: %s", goroutine[len("goroutine")+1:colIndex]) + } + id, err := strconv.Atoi(strings.TrimSpace(headers[0])) + if err != nil { + return nil, errors.Annotatef(err, "invalid goroutine id: %s", headers[0]) + } + state := strings.Trim(headers[1], "[]") + stack := strings.Split(strings.TrimSpace(goroutine[colIndex+1:]), "\n") + for i := 0; i < len(stack)/2; i++ { + fn := stack[i*2] + loc := stack[i*2+1] + var identifier string + if i == 0 { + identifier = fn + } else if i == len(stack)/2-1 { + identifier = string(texttree.TreeLastNode) + string(texttree.TreeNodeIdentifier) + fn + } else { + identifier = string(texttree.TreeMiddleNode) + string(texttree.TreeNodeIdentifier) + fn + } + rows = append(rows, types.MakeDatums( + identifier, + id, + state, + strings.TrimSpace(loc), + )) + } + } + return rows, nil +} diff --git a/util/profile/profile_test.go b/util/profile/profile_test.go new file mode 100644 index 0000000000000..27b245ebf17a7 --- /dev/null +++ b/util/profile/profile_test.go @@ -0,0 +1,67 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package profile_test + +import ( + "testing" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/domain" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/session" + "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/util/profile" + "github.com/pingcap/tidb/util/testkit" +) + +type profileSuite struct { + store kv.Storage + dom *domain.Domain +} + +var _ = Suite(&profileSuite{}) + +func TestT(t *testing.T) { + TestingT(t) +} + +func (s *profileSuite) SetUpSuite(c *C) { + var err error + s.store, err = mockstore.NewMockTikvStore() + c.Assert(err, IsNil) + session.DisableStats4Test() + s.dom, err = session.BootstrapSession(s.store) + c.Assert(err, IsNil) +} + +func (s *profileSuite) TearDownSuite(c *C) { + s.dom.Close() + s.store.Close() +} + +func (s *profileSuite) TestProfiles(c *C) { + oldValue := profile.CPUProfileInterval + profile.CPUProfileInterval = 2 * time.Second + defer func() { + profile.CPUProfileInterval = oldValue + }() + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("select * from performance_schema.tidb_profile_cpu") + tk.MustExec("select * from performance_schema.tidb_profile_memory") + tk.MustExec("select * from performance_schema.tidb_profile_allocs") + tk.MustExec("select * from performance_schema.tidb_profile_mutex") + tk.MustExec("select * from performance_schema.tidb_profile_block") + tk.MustExec("select * from performance_schema.tidb_profile_goroutines") +} diff --git a/util/ranger/points.go b/util/ranger/points.go index 75ae0a793ed3b..efb2efcda40d2 100644 --- a/util/ranger/points.go +++ b/util/ranger/points.go @@ -250,6 +250,11 @@ func (r *builder) buildFormBinOp(expr *expression.ScalarFunction) []point { return nil } + value, op, isValidRange := handleUnsignedIntCol(ft, value, op) + if !isValidRange { + return nil + } + switch op { case ast.EQ: startPoint := point{value: value, start: true} @@ -339,6 +344,29 @@ func HandlePadCharToFullLength(sc *stmtctx.StatementContext, ft *types.FieldType } } +// handleUnsignedIntCol handles the case when unsigned column meets negative integer value. +// The three returned values are: fixed constant value, fixed operator, and a boolean +// which indicates whether the range is valid or not. +func handleUnsignedIntCol(ft *types.FieldType, val types.Datum, op string) (types.Datum, string, bool) { + isUnsigned := mysql.HasUnsignedFlag(ft.Flag) + isIntegerType := mysql.IsIntegerType(ft.Tp) + isNegativeInteger := (val.Kind() == types.KindInt64 && val.GetInt64() < 0) + + if !isUnsigned || !isIntegerType || !isNegativeInteger { + return val, op, true + } + + // If the operator is GT, GE or NE, the range should be [0, +inf]. + // Otherwise the value is out of valid range. + if op == ast.GT || op == ast.GE || op == ast.NE { + op = ast.GE + val.SetUint64(0) + return val, op, true + } + + return val, op, false +} + func (r *builder) buildFromIsTrue(expr *expression.ScalarFunction, isNot int) []point { if isNot == 1 { // NOT TRUE range is {[null null] [0, 0]} diff --git a/util/ranger/ranger.go b/util/ranger/ranger.go index 015f6ea447eb7..6f858af480a39 100644 --- a/util/ranger/ranger.go +++ b/util/ranger/ranger.go @@ -263,13 +263,17 @@ func buildColumnRange(accessConditions []expression.Expression, sc *stmtctx.Stat } if colLen != types.UnspecifiedLength { for _, ran := range ranges { - if fixRangeDatum(&ran.LowVal[0], colLen, tp) { + if CutDatumByPrefixLen(&ran.LowVal[0], colLen, tp) { ran.LowExclude = false } - if fixRangeDatum(&ran.HighVal[0], colLen, tp) { + if CutDatumByPrefixLen(&ran.HighVal[0], colLen, tp) { ran.HighExclude = false } } + ranges, err = unionRanges(sc, ranges) + if err != nil { + return nil, err + } } return ranges, nil } @@ -425,17 +429,17 @@ func fixPrefixColRange(ranges []*Range, lengths []int, tp []*types.FieldType) bo for _, ran := range ranges { lowTail := len(ran.LowVal) - 1 for i := 0; i < lowTail; i++ { - fixRangeDatum(&ran.LowVal[i], lengths[i], tp[i]) + CutDatumByPrefixLen(&ran.LowVal[i], lengths[i], tp[i]) } - lowCut := fixRangeDatum(&ran.LowVal[lowTail], lengths[lowTail], tp[lowTail]) + lowCut := CutDatumByPrefixLen(&ran.LowVal[lowTail], lengths[lowTail], tp[lowTail]) if lowCut { ran.LowExclude = false } highTail := len(ran.HighVal) - 1 for i := 0; i < highTail; i++ { - fixRangeDatum(&ran.HighVal[i], lengths[i], tp[i]) + CutDatumByPrefixLen(&ran.HighVal[i], lengths[i], tp[i]) } - highCut := fixRangeDatum(&ran.HighVal[highTail], lengths[highTail], tp[highTail]) + highCut := CutDatumByPrefixLen(&ran.HighVal[highTail], lengths[highTail], tp[highTail]) if highCut { ran.HighExclude = false } @@ -444,9 +448,9 @@ func fixPrefixColRange(ranges []*Range, lengths []int, tp []*types.FieldType) bo return hasCut } -func fixRangeDatum(v *types.Datum, length int, tp *types.FieldType) bool { - // If this column is prefix and the prefix length is smaller than the range, cut it. - // In case of UTF8, prefix should be cut by characters rather than bytes +// CutDatumByPrefixLen cuts the datum according to the prefix length. +// If it's UTF8 encoded, we will cut it by characters rather than bytes. +func CutDatumByPrefixLen(v *types.Datum, length int, tp *types.FieldType) bool { if v.Kind() == types.KindString || v.Kind() == types.KindBytes { colCharset := tp.Charset colValue := v.GetBytes() diff --git a/util/ranger/ranger_test.go b/util/ranger/ranger_test.go index fa01ae36281ce..62980b737eae9 100644 --- a/util/ranger/ranger_test.go +++ b/util/ranger/ranger_test.go @@ -14,6 +14,7 @@ package ranger_test import ( + "context" "fmt" "testing" @@ -33,6 +34,7 @@ import ( "github.com/pingcap/tidb/util/ranger" "github.com/pingcap/tidb/util/testkit" "github.com/pingcap/tidb/util/testleak" + "github.com/pingcap/tidb/util/testutil" ) func TestT(t *testing.T) { @@ -43,10 +45,18 @@ var _ = Suite(&testRangerSuite{}) type testRangerSuite struct { *parser.Parser + testData testutil.TestData } func (s *testRangerSuite) SetUpSuite(c *C) { s.Parser = parser.New() + var err error + s.testData, err = testutil.LoadTestSuiteData("testdata", "ranger_suite") + c.Assert(err, IsNil) +} + +func (s *testRangerSuite) TearDownSuite(c *C) { + c.Assert(s.testData.GenerateOutputIfNeeded(), IsNil) } func newDomainStoreWithBootstrap(c *C) (*domain.Domain, kv.Storage, error) { @@ -59,7 +69,7 @@ func newDomainStoreWithBootstrap(c *C) (*domain.Domain, kv.Storage, error) { ) c.Assert(err, IsNil) session.SetSchemaLease(0) - session.SetStatsLease(0) + session.DisableStats4Test() if err != nil { return nil, nil, errors.Trace(err) } @@ -292,27 +302,28 @@ func (s *testRangerSuite) TestTableRange(c *C) { }, } + ctx := context.Background() for _, tt := range tests { sql := "select * from t where " + tt.exprStr - ctx := testKit.Se.(sessionctx.Context) - stmts, err := session.Parse(ctx, sql) + sctx := testKit.Se.(sessionctx.Context) + stmts, err := session.Parse(sctx, sql) c.Assert(err, IsNil, Commentf("error %v, for expr %s", err, tt.exprStr)) c.Assert(stmts, HasLen, 1) - is := domain.GetDomain(ctx).InfoSchema() - err = plannercore.Preprocess(ctx, stmts[0], is) + is := domain.GetDomain(sctx).InfoSchema() + err = plannercore.Preprocess(sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for resolve name, expr %s", err, tt.exprStr)) - p, err := plannercore.BuildLogicalPlan(ctx, stmts[0], is) + p, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for build plan, expr %s", err, tt.exprStr)) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) conds := make([]expression.Expression, 0, len(selection.Conditions)) for _, cond := range selection.Conditions { - conds = append(conds, expression.PushDownNot(ctx, cond, false)) + conds = append(conds, expression.PushDownNot(sctx, cond, false)) } tbl := selection.Children()[0].(*plannercore.DataSource).TableInfo() col := expression.ColInfo2Col(selection.Schema().Columns, tbl.Columns[0]) c.Assert(col, NotNil) var filter []expression.Expression - conds, filter = ranger.DetachCondsForColumn(ctx, conds, col) + conds, filter = ranger.DetachCondsForColumn(sctx, conds, col) c.Assert(fmt.Sprintf("%s", conds), Equals, tt.accessConds, Commentf("wrong access conditions for expr: %s", tt.exprStr)) c.Assert(fmt.Sprintf("%s", filter), Equals, tt.filterConds, Commentf("wrong filter conditions for expr: %s", tt.exprStr)) result, err := ranger.BuildTableRange(conds, new(stmtctx.StatementContext), col.RetType) @@ -575,27 +586,28 @@ func (s *testRangerSuite) TestIndexRange(c *C) { }, } + ctx := context.Background() for _, tt := range tests { sql := "select * from t where " + tt.exprStr - ctx := testKit.Se.(sessionctx.Context) - stmts, err := session.Parse(ctx, sql) + sctx := testKit.Se.(sessionctx.Context) + stmts, err := session.Parse(sctx, sql) c.Assert(err, IsNil, Commentf("error %v, for expr %s", err, tt.exprStr)) c.Assert(stmts, HasLen, 1) - is := domain.GetDomain(ctx).InfoSchema() - err = plannercore.Preprocess(ctx, stmts[0], is) + is := domain.GetDomain(sctx).InfoSchema() + err = plannercore.Preprocess(sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for resolve name, expr %s", err, tt.exprStr)) - p, err := plannercore.BuildLogicalPlan(ctx, stmts[0], is) + p, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for build plan, expr %s", err, tt.exprStr)) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) tbl := selection.Children()[0].(*plannercore.DataSource).TableInfo() c.Assert(selection, NotNil, Commentf("expr:%v", tt.exprStr)) conds := make([]expression.Expression, 0, len(selection.Conditions)) for _, cond := range selection.Conditions { - conds = append(conds, expression.PushDownNot(ctx, cond, false)) + conds = append(conds, expression.PushDownNot(sctx, cond, false)) } cols, lengths := expression.IndexInfo2Cols(selection.Schema().Columns, tbl.Indices[tt.indexPos]) c.Assert(cols, NotNil) - res, err := ranger.DetachCondAndBuildRangeForIndex(ctx, conds, cols, lengths) + res, err := ranger.DetachCondAndBuildRangeForIndex(sctx, conds, cols, lengths) c.Assert(err, IsNil) c.Assert(fmt.Sprintf("%s", res.AccessConds), Equals, tt.accessConds, Commentf("wrong access conditions for expr: %s", tt.exprStr)) c.Assert(fmt.Sprintf("%s", res.RemainedConds), Equals, tt.filterConds, Commentf("wrong filter conditions for expr: %s", tt.exprStr)) @@ -660,29 +672,63 @@ func (s *testRangerSuite) TestIndexRangeForUnsignedInt(c *C) { filterConds: "[]", resultStr: `[(NULL,1) (2,9223372036854775810) (9223372036854775810,+inf]]`, }, + { + indexPos: 0, + exprStr: `a >= -2147483648`, + accessConds: "[ge(test.t.a, -2147483648)]", + filterConds: "[]", + resultStr: `[[0,+inf]]`, + }, + { + indexPos: 0, + exprStr: `a > -2147483648`, + accessConds: "[gt(test.t.a, -2147483648)]", + filterConds: "[]", + resultStr: `[[0,+inf]]`, + }, + { + indexPos: 0, + exprStr: `a != -2147483648`, + accessConds: "[ne(test.t.a, -2147483648)]", + filterConds: "[]", + resultStr: `[[0,+inf]]`, + }, + { + exprStr: "a < -1 or a < 1", + accessConds: "[or(lt(test.t.a, -1), lt(test.t.a, 1))]", + filterConds: "[]", + resultStr: "[[-inf,1)]", + }, + { + exprStr: "a < -1 and a < 1", + accessConds: "[lt(test.t.a, -1) lt(test.t.a, 1)]", + filterConds: "[]", + resultStr: "[]", + }, } + ctx := context.Background() for _, tt := range tests { sql := "select * from t where " + tt.exprStr - ctx := testKit.Se.(sessionctx.Context) - stmts, err := session.Parse(ctx, sql) + sctx := testKit.Se.(sessionctx.Context) + stmts, err := session.Parse(sctx, sql) c.Assert(err, IsNil, Commentf("error %v, for expr %s", err, tt.exprStr)) c.Assert(stmts, HasLen, 1) - is := domain.GetDomain(ctx).InfoSchema() - err = plannercore.Preprocess(ctx, stmts[0], is) + is := domain.GetDomain(sctx).InfoSchema() + err = plannercore.Preprocess(sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for resolve name, expr %s", err, tt.exprStr)) - p, err := plannercore.BuildLogicalPlan(ctx, stmts[0], is) + p, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for build plan, expr %s", err, tt.exprStr)) selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) tbl := selection.Children()[0].(*plannercore.DataSource).TableInfo() c.Assert(selection, NotNil, Commentf("expr:%v", tt.exprStr)) conds := make([]expression.Expression, 0, len(selection.Conditions)) for _, cond := range selection.Conditions { - conds = append(conds, expression.PushDownNot(ctx, cond, false)) + conds = append(conds, expression.PushDownNot(sctx, cond, false)) } cols, lengths := expression.IndexInfo2Cols(selection.Schema().Columns, tbl.Indices[tt.indexPos]) c.Assert(cols, NotNil) - res, err := ranger.DetachCondAndBuildRangeForIndex(ctx, conds, cols, lengths) + res, err := ranger.DetachCondAndBuildRangeForIndex(sctx, conds, cols, lengths) c.Assert(err, IsNil) c.Assert(fmt.Sprintf("%s", res.AccessConds), Equals, tt.accessConds, Commentf("wrong access conditions for expr: %s", tt.exprStr)) c.Assert(fmt.Sprintf("%s", res.RemainedConds), Equals, tt.filterConds, Commentf("wrong filter conditions for expr: %s", tt.exprStr)) @@ -710,6 +756,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds string filterConds string resultStr string + length int }{ { colPos: 0, @@ -717,6 +764,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[eq(test.t.a, 1)]", filterConds: "[gt(test.t.b, 1)]", resultStr: "[[1,1]]", + length: types.UnspecifiedLength, }, { colPos: 1, @@ -724,6 +772,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[gt(test.t.b, 1)]", filterConds: "[]", resultStr: "[(1,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -731,6 +780,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[eq(1, test.t.a)]", filterConds: "[]", resultStr: "[[1,1]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -738,6 +788,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[ne(test.t.a, 1)]", filterConds: "[]", resultStr: "[[-inf,1) (1,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -745,6 +796,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[ne(1, test.t.a)]", filterConds: "[]", resultStr: "[[-inf,1) (1,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -752,6 +804,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[gt(test.t.a, 1)]", filterConds: "[]", resultStr: "[(1,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -759,6 +812,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[lt(1, test.t.a)]", filterConds: "[]", resultStr: "[(1,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -766,6 +820,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[ge(test.t.a, 1)]", filterConds: "[]", resultStr: "[[1,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -773,6 +828,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[le(1, test.t.a)]", filterConds: "[]", resultStr: "[[1,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -780,6 +836,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[lt(test.t.a, 1)]", filterConds: "[]", resultStr: "[[-inf,1)]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -787,6 +844,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[gt(1, test.t.a)]", filterConds: "[]", resultStr: "[[-inf,1)]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -794,6 +852,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[le(test.t.a, 1)]", filterConds: "[]", resultStr: "[[-inf,1]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -801,6 +860,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[ge(1, test.t.a)]", filterConds: "[]", resultStr: "[[-inf,1]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -808,6 +868,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[test.t.a]", filterConds: "[]", resultStr: "[[-inf,0) (0,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -815,6 +876,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[in(test.t.a, 1, 3, , 2)]", filterConds: "[]", resultStr: "[[1,1] [2,2] [3,3]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -822,6 +884,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[in(test.t.a, 8, 8, 81, 45)]", filterConds: "[]", resultStr: `[[8,8] [45,45] [81,81]]`, + length: types.UnspecifiedLength, }, { colPos: 0, @@ -829,6 +892,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[ge(test.t.a, 1) le(test.t.a, 2)]", filterConds: "[]", resultStr: "[[1,2]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -836,6 +900,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[or(lt(test.t.a, 1), gt(test.t.a, 2))]", filterConds: "[]", resultStr: "[[-inf,1) (2,+inf]]", + length: types.UnspecifiedLength, }, //{ // `a > null` will be converted to `castAsString(a) > null` which can not be extracted as access condition. @@ -848,6 +913,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[ge(test.t.a, 2) le(test.t.a, 1)]", filterConds: "[]", resultStr: "[]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -855,6 +921,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[or(lt(test.t.a, 2), gt(test.t.a, 1))]", filterConds: "[]", resultStr: "[[-inf,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -862,6 +929,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[isnull(test.t.a)]", filterConds: "[]", resultStr: "[[NULL,NULL]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -869,6 +937,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[not(isnull(test.t.a))]", filterConds: "[]", resultStr: "[[-inf,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -876,6 +945,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[istrue(test.t.a)]", filterConds: "[]", resultStr: "[[-inf,0) (0,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -883,6 +953,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[not(istrue(test.t.a))]", filterConds: "[]", resultStr: "[[NULL,NULL] [0,0]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -890,6 +961,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[isfalse(test.t.a)]", filterConds: "[]", resultStr: "[[0,0]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -897,6 +969,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[not(isfalse(test.t.a))]", filterConds: "[]", resultStr: "[[NULL,0) (0,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 1, @@ -904,6 +977,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[in(test.t.b, 1, 2.1)]", filterConds: "[]", resultStr: "[[1,1] [2.1,2.1]]", + length: types.UnspecifiedLength, }, { colPos: 0, @@ -911,6 +985,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[gt(test.t.a, 9223372036854775807)]", filterConds: "[]", resultStr: "[(9223372036854775807,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 2, @@ -918,6 +993,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[gt(test.t.c, 111.11111111)]", filterConds: "[]", resultStr: "[[111.111115,+inf]]", + length: types.UnspecifiedLength, }, { colPos: 3, @@ -925,6 +1001,7 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[gt(test.t.d, aaaaaaaaaaaaaa)]", filterConds: "[]", resultStr: "[(\"aaaaaaaaaaaaaa\",+inf]]", + length: types.UnspecifiedLength, }, { colPos: 4, @@ -932,32 +1009,59 @@ func (s *testRangerSuite) TestColumnRange(c *C) { accessConds: "[gt(test.t.e, 18446744073709500000)]", filterConds: "[]", resultStr: "[(18446744073709500000,+inf]]", + length: types.UnspecifiedLength, + }, + { + colPos: 4, + exprStr: `e > -2147483648`, + accessConds: "[gt(test.t.e, -2147483648)]", + filterConds: "[]", + resultStr: "[[0,+inf]]", + length: types.UnspecifiedLength, + }, + { + colPos: 3, + exprStr: "d = 'aab' or d = 'aac'", + accessConds: "[or(eq(test.t.d, aab), eq(test.t.d, aac))]", + filterConds: "[]", + resultStr: "[[\"a\",\"a\"]]", + length: 1, + }, + // This test case cannot be simplified to [1, 3] otherwise the index join will executes wrongly. + { + colPos: 0, + exprStr: "a in (1, 2, 3)", + accessConds: "[in(test.t.a, 1, 2, 3)]", + filterConds: "", + resultStr: "[[1,1] [2,2] [3,3]]", + length: types.UnspecifiedLength, }, } + ctx := context.Background() for _, tt := range tests { sql := "select * from t where " + tt.exprStr - ctx := testKit.Se.(sessionctx.Context) - stmts, err := session.Parse(ctx, sql) + sctx := testKit.Se.(sessionctx.Context) + stmts, err := session.Parse(sctx, sql) c.Assert(err, IsNil, Commentf("error %v, for expr %s", err, tt.exprStr)) c.Assert(stmts, HasLen, 1) - is := domain.GetDomain(ctx).InfoSchema() - err = plannercore.Preprocess(ctx, stmts[0], is) + is := domain.GetDomain(sctx).InfoSchema() + err = plannercore.Preprocess(sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for resolve name, expr %s", err, tt.exprStr)) - p, err := plannercore.BuildLogicalPlan(ctx, stmts[0], is) + p, err := plannercore.BuildLogicalPlan(ctx, sctx, stmts[0], is) c.Assert(err, IsNil, Commentf("error %v, for build plan, expr %s", err, tt.exprStr)) sel := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection) ds, ok := sel.Children()[0].(*plannercore.DataSource) c.Assert(ok, IsTrue, Commentf("expr:%v", tt.exprStr)) conds := make([]expression.Expression, 0, len(sel.Conditions)) for _, cond := range sel.Conditions { - conds = append(conds, expression.PushDownNot(ctx, cond, false)) + conds = append(conds, expression.PushDownNot(sctx, cond, false)) } col := expression.ColInfo2Col(sel.Schema().Columns, ds.TableInfo().Columns[tt.colPos]) c.Assert(col, NotNil) conds = ranger.ExtractAccessConditionsForColumn(conds, col.UniqueID) c.Assert(fmt.Sprintf("%s", conds), Equals, tt.accessConds, Commentf("wrong access conditions for expr: %s", tt.exprStr)) - result, err := ranger.BuildColumnRange(conds, new(stmtctx.StatementContext), col.RetType, types.UnspecifiedLength) + result, err := ranger.BuildColumnRange(conds, new(stmtctx.StatementContext), col.RetType, tt.length) c.Assert(err, IsNil) got := fmt.Sprintf("%v", result) c.Assert(got, Equals, tt.resultStr, Commentf("different for expr %s, col: %v", tt.exprStr, col)) @@ -1005,20 +1109,18 @@ func (s *testRangerSuite) TestCompIndexInExprCorrCol(c *C) { testKit.MustExec("create table t(a int primary key, b int, c int, d int, e int, index idx(b,c,d))") testKit.MustExec("insert into t values(1,1,1,1,2),(2,1,2,1,0)") testKit.MustExec("analyze table t") - testKit.MustQuery("explain select t.e in (select count(*) from t s use index(idx), t t1 where s.b = 1 and s.c in (1, 2) and s.d = t.a and s.a = t1.a) from t").Check(testkit.Rows( - "Projection_11 2.00 root 9_aux_0", - "└─Apply_13 2.00 root left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.e, 7_col_0)", - " ├─TableReader_15 2.00 root data:TableScan_14", - " │ └─TableScan_14 2.00 cop table:t, range:[-inf,+inf], keep order:false", - " └─StreamAgg_20 1.00 root funcs:count(1)", - " └─IndexJoin_32 2.00 root inner join, inner:TableReader_31, outer key:test.s.a, inner key:test.t1.a", - " ├─IndexReader_27 2.00 root index:IndexScan_26", - " │ └─IndexScan_26 2.00 cop table:s, index:b, c, d, range: decided by [eq(test.s.b, 1) in(test.s.c, 1, 2) eq(test.s.d, test.t.a)], keep order:false", - " └─TableReader_31 1.00 root data:TableScan_30", - " └─TableScan_30 1.00 cop table:t1, range: decided by [test.s.a], keep order:false", - )) - testKit.MustQuery("select t.e in (select count(*) from t s use index(idx), t t1 where s.b = 1 and s.c in (1, 2) and s.d = t.a and s.a = t1.a) from t").Check(testkit.Rows( - "1", - "1", - )) + + var input []string + var output []struct { + SQL string + Result []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Result = s.testData.ConvertRowsToStrings(testKit.MustQuery(tt).Rows()) + }) + testKit.MustQuery(tt).Check(testkit.Rows(output[i].Result...)) + } } diff --git a/util/ranger/testdata/ranger_suite_in.json b/util/ranger/testdata/ranger_suite_in.json new file mode 100644 index 0000000000000..a83b5d41d7f0b --- /dev/null +++ b/util/ranger/testdata/ranger_suite_in.json @@ -0,0 +1,9 @@ +[ + { + "name": "TestCompIndexInExprCorrCol", + "cases": [ + "explain select t.e in (select count(*) from t s use index(idx), t t1 where s.b = 1 and s.c in (1, 2) and s.d = t.a and s.a = t1.a) from t", + "select t.e in (select count(*) from t s use index(idx), t t1 where s.b = 1 and s.c in (1, 2) and s.d = t.a and s.a = t1.a) from t" + ] + } +] diff --git a/util/ranger/testdata/ranger_suite_out.json b/util/ranger/testdata/ranger_suite_out.json new file mode 100644 index 0000000000000..ba316e33c7d9c --- /dev/null +++ b/util/ranger/testdata/ranger_suite_out.json @@ -0,0 +1,29 @@ +[ + { + "Name": "TestCompIndexInExprCorrCol", + "Cases": [ + { + "SQL": "explain select t.e in (select count(*) from t s use index(idx), t t1 where s.b = 1 and s.c in (1, 2) and s.d = t.a and s.a = t1.a) from t", + "Result": [ + "Projection_11 2.00 root 9_aux_0", + "└─Apply_13 2.00 root CARTESIAN left outer semi join, inner:StreamAgg_20, other cond:eq(test.t.e, 7_col_0)", + " ├─TableReader_15 2.00 root data:TableScan_14", + " │ └─TableScan_14 2.00 cop table:t, range:[-inf,+inf], keep order:false", + " └─StreamAgg_20 1.00 root funcs:count(1)", + " └─IndexJoin_32 2.00 root inner join, inner:TableReader_31, outer key:test.s.a, inner key:test.t1.a", + " ├─IndexReader_27 2.00 root index:IndexScan_26", + " │ └─IndexScan_26 2.00 cop table:s, index:b, c, d, range: decided by [eq(test.s.b, 1) in(test.s.c, 1, 2) eq(test.s.d, test.t.a)], keep order:false", + " └─TableReader_31 1.00 root data:TableScan_30", + " └─TableScan_30 1.00 cop table:t1, range: decided by [test.s.a], keep order:false" + ] + }, + { + "SQL": "select t.e in (select count(*) from t s use index(idx), t t1 where s.b = 1 and s.c in (1, 2) and s.d = t.a and s.a = t1.a) from t", + "Result": [ + "1", + "1" + ] + } + ] + } +] diff --git a/util/rowDecoder/decoder.go b/util/rowDecoder/decoder.go index e6c91425a2778..68ef00c9961ae 100644 --- a/util/rowDecoder/decoder.go +++ b/util/rowDecoder/decoder.go @@ -14,8 +14,10 @@ package decoder import ( + "sort" "time" + "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" @@ -134,3 +136,114 @@ func (rd *RowDecoder) DecodeAndEvalRowWithMap(ctx sessionctx.Context, handle int } return row, nil } + +// BuildFullDecodeColMap build a map that contains [columnID -> struct{*table.Column, expression.Expression}] from +// indexed columns and all of its depending columns. `genExprProducer` is used to produce a generated expression based on a table.Column. +func BuildFullDecodeColMap(indexedCols []*table.Column, t table.Table, genExprProducer func(*table.Column) (expression.Expression, error)) (map[int64]Column, error) { + pendingCols := make([]*table.Column, len(indexedCols)) + copy(pendingCols, indexedCols) + decodeColMap := make(map[int64]Column, len(pendingCols)) + + for i := 0; i < len(pendingCols); i++ { + col := pendingCols[i] + if _, ok := decodeColMap[col.ID]; ok { + continue // already discovered + } + + if col.IsGenerated() && !col.GeneratedStored { + // Find depended columns and put them into pendingCols. For example, idx(c) with column definition `c int as (a + b)`, + // depended columns of `c` is `a` and `b`, and both of them will be put into the pendingCols, waiting for next traversal. + for _, c := range t.Cols() { + if _, ok := col.Dependences[c.Name.L]; ok { + pendingCols = append(pendingCols, c) + } + } + + e, err := genExprProducer(col) + if err != nil { + return nil, errors.Trace(err) + } + decodeColMap[col.ID] = Column{ + Col: col, + GenExpr: e, + } + } else { + decodeColMap[col.ID] = Column{ + Col: col, + } + } + } + return decodeColMap, nil +} + +// SubstituteGenColsInDecodeColMap substitutes generated columns in every expression +// with non-generated one by looking up decodeColMap. +func SubstituteGenColsInDecodeColMap(decodeColMap map[int64]Column) { + // Sort columns by table.Column.Offset in ascending order. + type Pair struct { + colID int64 + colOffset int + } + var orderedCols []Pair + for colID, col := range decodeColMap { + orderedCols = append(orderedCols, Pair{colID, col.Col.Offset}) + } + sort.Slice(orderedCols, func(i, j int) bool { return orderedCols[i].colOffset < orderedCols[j].colOffset }) + + // Iterate over decodeColMap, the substitution only happens once for each virtual column because + // columns with smaller offset can not refer to those with larger ones. https://dev.mysql.com/doc/refman/5.7/en/create-table-generated-columns.html. + for _, pair := range orderedCols { + colID := pair.colID + decCol := decodeColMap[colID] + if decCol.GenExpr != nil { + decodeColMap[colID] = Column{ + Col: decCol.Col, + GenExpr: substituteGeneratedColumn(decCol.GenExpr, decodeColMap), + } + } else { + decodeColMap[colID] = Column{ + Col: decCol.Col, + } + } + } +} + +// substituteGeneratedColumn substitutes generated columns in an expression with non-generated one by looking up decodeColMap. +func substituteGeneratedColumn(expr expression.Expression, decodeColMap map[int64]Column) expression.Expression { + switch v := expr.(type) { + case *expression.Column: + if c, ok := decodeColMap[v.ID]; c.GenExpr != nil && ok { + return c.GenExpr + } + return v + case *expression.ScalarFunction: + newArgs := make([]expression.Expression, 0, len(v.GetArgs())) + for _, arg := range v.GetArgs() { + newArgs = append(newArgs, substituteGeneratedColumn(arg, decodeColMap)) + } + return expression.NewFunctionInternal(v.GetCtx(), v.FuncName.L, v.RetType, newArgs...) + } + return expr +} + +// RemoveUnusedVirtualCols removes all virtual columns in decodeColMap that cannot found in indexedCols. +func RemoveUnusedVirtualCols(decodeColMap map[int64]Column, indexedCols []*table.Column) { + for colID, decCol := range decodeColMap { + col := decCol.Col + if !col.IsGenerated() || col.GeneratedStored { + continue + } + + found := false + for _, v := range indexedCols { + if v.Offset == col.Offset { + found = true + break + } + } + + if !found { + delete(decodeColMap, colID) + } + } +} diff --git a/util/sqlexec/restricted_sql_executor.go b/util/sqlexec/restricted_sql_executor.go index 70923acc09288..cfbe37e4603cb 100644 --- a/util/sqlexec/restricted_sql_executor.go +++ b/util/sqlexec/restricted_sql_executor.go @@ -78,7 +78,7 @@ type Statement interface { IsReadOnly(vars *variable.SessionVars) bool // RebuildPlan rebuilds the plan of the statement. - RebuildPlan() (schemaVersion int64, err error) + RebuildPlan(ctx context.Context) (schemaVersion int64, err error) } // RecordSet is an abstract result set interface to help get data from Plan. @@ -87,12 +87,26 @@ type RecordSet interface { Fields() []*ast.ResultField // Next reads records into chunk. - Next(ctx context.Context, req *chunk.RecordBatch) error + Next(ctx context.Context, req *chunk.Chunk) error - //NewRecordBatch create a recordBatch. - NewRecordBatch() *chunk.RecordBatch + //NewChunk create a chunk. + NewChunk() *chunk.Chunk // Close closes the underlying iterator, call Next after Close will // restart the iteration. Close() error } + +// MultiQueryNoDelayResult is an interface for one no-delay result for one statement in multi-queries. +type MultiQueryNoDelayResult interface { + // AffectedRows return affected row for one statement in multi-queries. + AffectedRows() uint64 + // LastMessage return last message for one statement in multi-queries. + LastMessage() string + // WarnCount return warn count for one statement in multi-queries. + WarnCount() uint16 + // Status return status when executing one statement in multi-queries. + Status() uint16 + // LastInsertID return last insert id for one statement in multi-queries. + LastInsertID() uint64 +} diff --git a/util/stmtsummary/statement_summary.go b/util/stmtsummary/statement_summary.go new file mode 100644 index 0000000000000..c00b42fe9be05 --- /dev/null +++ b/util/stmtsummary/statement_summary.go @@ -0,0 +1,293 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package stmtsummary + +import ( + "strings" + "sync" + "time" + + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/hack" + "github.com/pingcap/tidb/util/kvcache" +) + +// There're many types of statement summary tables in MySQL, but we have +// only implemented events_statement_summary_by_digest for now. + +// stmtSummaryByDigestKey defines key for stmtSummaryByDigestMap.summaryMap +type stmtSummaryByDigestKey struct { + // Same statements may appear in different schema, but they refer to different tables. + schemaName string + digest string + // TODO: add plan digest + // `hash` is the hash value of this object + hash []byte +} + +// Hash implements SimpleLRUCache.Key +func (key *stmtSummaryByDigestKey) Hash() []byte { + if len(key.hash) == 0 { + key.hash = make([]byte, 0, len(key.schemaName)+len(key.digest)) + key.hash = append(key.hash, hack.Slice(key.digest)...) + key.hash = append(key.hash, hack.Slice(strings.ToLower(key.schemaName))...) + } + return key.hash +} + +// stmtSummaryByDigestMap is a LRU cache that stores statement summaries. +type stmtSummaryByDigestMap struct { + // It's rare to read concurrently, so RWMutex is not needed. + sync.Mutex + summaryMap *kvcache.SimpleLRUCache + + // enabledWrapper encapsulates variables needed to judge whether statement summary is enabled. + enabledWrapper struct { + sync.RWMutex + // enabled indicates whether statement summary is enabled in current server. + sessionEnabled string + // setInSession indicates whether statement summary has been set in any session. + globalEnabled string + } +} + +// StmtSummaryByDigestMap is a global map containing all statement summaries. +var StmtSummaryByDigestMap = newStmtSummaryByDigestMap() + +// stmtSummaryByDigest is the summary for each type of statements. +type stmtSummaryByDigest struct { + // It's rare to read concurrently, so RWMutex is not needed. + sync.Mutex + schemaName string + digest string + normalizedSQL string + sampleSQL string + execCount uint64 + sumLatency uint64 + maxLatency uint64 + minLatency uint64 + sumAffectedRows uint64 + // Number of rows sent to client. + sumSentRows uint64 + // The first time this type of SQL executes. + firstSeen time.Time + // The last time this type of SQL executes. + lastSeen time.Time +} + +// StmtExecInfo records execution information of each statement. +type StmtExecInfo struct { + SchemaName string + OriginalSQL string + NormalizedSQL string + Digest string + TotalLatency uint64 + AffectedRows uint64 + // Number of rows sent to client. + SentRows uint64 + StartTime time.Time +} + +// newStmtSummaryByDigestMap creates an empty stmtSummaryByDigestMap. +func newStmtSummaryByDigestMap() *stmtSummaryByDigestMap { + maxStmtCount := config.GetGlobalConfig().StmtSummary.MaxStmtCount + ssMap := &stmtSummaryByDigestMap{ + summaryMap: kvcache.NewSimpleLRUCache(maxStmtCount, 0, 0), + } + // enabledWrapper.defaultEnabled will be initialized in package variable. + ssMap.enabledWrapper.sessionEnabled = "" + ssMap.enabledWrapper.globalEnabled = "" + return ssMap +} + +// newStmtSummaryByDigest creates a stmtSummaryByDigest from StmtExecInfo +func newStmtSummaryByDigest(sei *StmtExecInfo) *stmtSummaryByDigest { + // Trim SQL to size MaxSQLLength + maxSQLLength := config.GetGlobalConfig().StmtSummary.MaxSQLLength + normalizedSQL := sei.NormalizedSQL + if len(normalizedSQL) > int(maxSQLLength) { + normalizedSQL = normalizedSQL[:maxSQLLength] + } + sampleSQL := sei.OriginalSQL + if len(sampleSQL) > int(maxSQLLength) { + sampleSQL = sampleSQL[:maxSQLLength] + } + + return &stmtSummaryByDigest{ + schemaName: sei.SchemaName, + digest: sei.Digest, + normalizedSQL: normalizedSQL, + sampleSQL: sampleSQL, + execCount: 1, + sumLatency: sei.TotalLatency, + maxLatency: sei.TotalLatency, + minLatency: sei.TotalLatency, + sumAffectedRows: sei.AffectedRows, + sumSentRows: sei.SentRows, + firstSeen: sei.StartTime, + lastSeen: sei.StartTime, + } +} + +// Add a StmtExecInfo to stmtSummary +func (ssbd *stmtSummaryByDigest) add(sei *StmtExecInfo) { + ssbd.Lock() + + ssbd.sumLatency += sei.TotalLatency + ssbd.execCount++ + if sei.TotalLatency > ssbd.maxLatency { + ssbd.maxLatency = sei.TotalLatency + } + if sei.TotalLatency < ssbd.minLatency { + ssbd.minLatency = sei.TotalLatency + } + ssbd.sumAffectedRows += sei.AffectedRows + ssbd.sumSentRows += sei.SentRows + if sei.StartTime.Before(ssbd.firstSeen) { + ssbd.firstSeen = sei.StartTime + } + if ssbd.lastSeen.Before(sei.StartTime) { + ssbd.lastSeen = sei.StartTime + } + + ssbd.Unlock() +} + +// AddStatement adds a statement to StmtSummaryByDigestMap. +func (ssMap *stmtSummaryByDigestMap) AddStatement(sei *StmtExecInfo) { + key := &stmtSummaryByDigestKey{ + schemaName: sei.SchemaName, + digest: sei.Digest, + } + + ssMap.Lock() + // Check again. Statements could be added before disabling the flag and after Clear() + if !ssMap.Enabled() { + ssMap.Unlock() + return + } + value, ok := ssMap.summaryMap.Get(key) + if !ok { + newSummary := newStmtSummaryByDigest(sei) + ssMap.summaryMap.Put(key, newSummary) + } + ssMap.Unlock() + + // Lock a single entry, not the whole cache. + if ok { + value.(*stmtSummaryByDigest).add(sei) + } +} + +// Clear removes all statement summaries. +func (ssMap *stmtSummaryByDigestMap) Clear() { + ssMap.Lock() + ssMap.summaryMap.DeleteAll() + ssMap.Unlock() +} + +// ToDatum converts statement summary to Datum +func (ssMap *stmtSummaryByDigestMap) ToDatum() [][]types.Datum { + ssMap.Lock() + values := ssMap.summaryMap.Values() + ssMap.Unlock() + + rows := make([][]types.Datum, 0, len(values)) + for _, value := range values { + summary := value.(*stmtSummaryByDigest) + summary.Lock() + record := types.MakeDatums( + summary.schemaName, + summary.digest, + summary.normalizedSQL, + summary.execCount, + summary.sumLatency, + summary.maxLatency, + summary.minLatency, + summary.sumLatency/summary.execCount, // AVG_LATENCY + summary.sumAffectedRows, + types.Time{Time: types.FromGoTime(summary.firstSeen), Type: mysql.TypeTimestamp}, + types.Time{Time: types.FromGoTime(summary.lastSeen), Type: mysql.TypeTimestamp}, + summary.sampleSQL, + ) + summary.Unlock() + rows = append(rows, record) + } + + return rows +} + +// SetEnabled enables or disables statement summary in global(cluster) or session(server) scope. +func (ssMap *stmtSummaryByDigestMap) SetEnabled(value string, inSession bool) { + value = ssMap.normalizeEnableValue(value) + + ssMap.enabledWrapper.Lock() + if inSession { + ssMap.enabledWrapper.sessionEnabled = value + } else { + ssMap.enabledWrapper.globalEnabled = value + } + sessionEnabled := ssMap.enabledWrapper.sessionEnabled + globalEnabled := ssMap.enabledWrapper.globalEnabled + ssMap.enabledWrapper.Unlock() + + // Clear all summaries once statement summary is disabled. + var needClear bool + if ssMap.isSet(sessionEnabled) { + needClear = !ssMap.isEnabled(sessionEnabled) + } else { + needClear = !ssMap.isEnabled(globalEnabled) + } + if needClear { + ssMap.Clear() + } +} + +// Enabled returns whether statement summary is enabled. +func (ssMap *stmtSummaryByDigestMap) Enabled() bool { + ssMap.enabledWrapper.RLock() + var enabled bool + if ssMap.isSet(ssMap.enabledWrapper.sessionEnabled) { + enabled = ssMap.isEnabled(ssMap.enabledWrapper.sessionEnabled) + } else { + enabled = ssMap.isEnabled(ssMap.enabledWrapper.globalEnabled) + } + ssMap.enabledWrapper.RUnlock() + return enabled +} + +// normalizeEnableValue converts 'ON' to '1' and 'OFF' to '0' +func (ssMap *stmtSummaryByDigestMap) normalizeEnableValue(value string) string { + switch { + case strings.EqualFold(value, "ON"): + return "1" + case strings.EqualFold(value, "OFF"): + return "0" + default: + return value + } +} + +// isEnabled converts a string value to bool. +// 1 indicates true, 0 or '' indicates false. +func (ssMap *stmtSummaryByDigestMap) isEnabled(value string) bool { + return value == "1" +} + +// isSet judges whether the variable is set. +func (ssMap *stmtSummaryByDigestMap) isSet(value string) bool { + return value != "" +} diff --git a/util/stmtsummary/statement_summary_test.go b/util/stmtsummary/statement_summary_test.go new file mode 100644 index 0000000000000..5d742b1d08396 --- /dev/null +++ b/util/stmtsummary/statement_summary_test.go @@ -0,0 +1,378 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package stmtsummary + +import ( + "fmt" + "strings" + "sync" + "testing" + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/config" + "github.com/pingcap/tidb/types" +) + +var _ = Suite(&testStmtSummarySuite{}) + +type testStmtSummarySuite struct { + ssMap *stmtSummaryByDigestMap +} + +func (s *testStmtSummarySuite) SetUpSuite(c *C) { + s.ssMap = newStmtSummaryByDigestMap() + s.ssMap.SetEnabled("1", false) +} + +func TestT(t *testing.T) { + CustomVerboseFlag = true + TestingT(t) +} + +// Test stmtSummaryByDigest.AddStatement +func (s *testStmtSummarySuite) TestAddStatement(c *C) { + s.ssMap.Clear() + + // First statement + stmtExecInfo1 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: "original_sql1", + NormalizedSQL: "normalized_sql", + Digest: "digest", + TotalLatency: 10000, + AffectedRows: 100, + SentRows: 100, + StartTime: time.Date(2019, 1, 1, 10, 10, 10, 10, time.UTC), + } + key := &stmtSummaryByDigestKey{ + schemaName: stmtExecInfo1.SchemaName, + digest: stmtExecInfo1.Digest, + } + expectedSummary := stmtSummaryByDigest{ + schemaName: stmtExecInfo1.SchemaName, + digest: stmtExecInfo1.Digest, + normalizedSQL: stmtExecInfo1.NormalizedSQL, + sampleSQL: stmtExecInfo1.OriginalSQL, + execCount: 1, + sumLatency: stmtExecInfo1.TotalLatency, + maxLatency: stmtExecInfo1.TotalLatency, + minLatency: stmtExecInfo1.TotalLatency, + sumAffectedRows: stmtExecInfo1.AffectedRows, + sumSentRows: stmtExecInfo1.SentRows, + firstSeen: stmtExecInfo1.StartTime, + lastSeen: stmtExecInfo1.StartTime, + } + + s.ssMap.AddStatement(stmtExecInfo1) + summary, ok := s.ssMap.summaryMap.Get(key) + c.Assert(ok, IsTrue) + c.Assert(*summary.(*stmtSummaryByDigest) == expectedSummary, IsTrue) + + // Second statement + stmtExecInfo2 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: "original_sql2", + NormalizedSQL: "normalized_sql", + Digest: "digest", + TotalLatency: 50000, + AffectedRows: 500, + SentRows: 500, + StartTime: time.Date(2019, 1, 1, 10, 10, 20, 10, time.UTC), + } + expectedSummary.execCount++ + expectedSummary.sumLatency += stmtExecInfo2.TotalLatency + expectedSummary.maxLatency = stmtExecInfo2.TotalLatency + expectedSummary.sumAffectedRows += stmtExecInfo2.AffectedRows + expectedSummary.sumSentRows += stmtExecInfo2.SentRows + expectedSummary.lastSeen = stmtExecInfo2.StartTime + + s.ssMap.AddStatement(stmtExecInfo2) + summary, ok = s.ssMap.summaryMap.Get(key) + c.Assert(ok, IsTrue) + c.Assert(*summary.(*stmtSummaryByDigest) == expectedSummary, IsTrue) + + // Third statement + stmtExecInfo3 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: "original_sql3", + NormalizedSQL: "normalized_sql", + Digest: "digest", + TotalLatency: 1000, + AffectedRows: 10, + SentRows: 10, + StartTime: time.Date(2019, 1, 1, 10, 10, 0, 10, time.UTC), + } + expectedSummary.execCount++ + expectedSummary.sumLatency += stmtExecInfo3.TotalLatency + expectedSummary.minLatency = stmtExecInfo3.TotalLatency + expectedSummary.sumAffectedRows += stmtExecInfo3.AffectedRows + expectedSummary.sumSentRows += stmtExecInfo3.SentRows + expectedSummary.firstSeen = stmtExecInfo3.StartTime + + s.ssMap.AddStatement(stmtExecInfo3) + summary, ok = s.ssMap.summaryMap.Get(key) + c.Assert(ok, IsTrue) + c.Assert(*summary.(*stmtSummaryByDigest) == expectedSummary, IsTrue) + + // Fourth statement that in a different schema + stmtExecInfo4 := &StmtExecInfo{ + SchemaName: "schema_name2", + OriginalSQL: "original_sql1", + NormalizedSQL: "normalized_sql", + Digest: "digest", + TotalLatency: 1000, + AffectedRows: 10, + SentRows: 10, + StartTime: time.Date(2019, 1, 1, 10, 10, 0, 10, time.UTC), + } + key = &stmtSummaryByDigestKey{ + schemaName: stmtExecInfo4.SchemaName, + digest: stmtExecInfo4.Digest, + } + + s.ssMap.AddStatement(stmtExecInfo4) + c.Assert(s.ssMap.summaryMap.Size(), Equals, 2) + _, ok = s.ssMap.summaryMap.Get(key) + c.Assert(ok, IsTrue) + + // Fifth statement that has a different digest + stmtExecInfo5 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: "original_sql1", + NormalizedSQL: "normalized_sql2", + Digest: "digest2", + TotalLatency: 1000, + AffectedRows: 10, + SentRows: 10, + StartTime: time.Date(2019, 1, 1, 10, 10, 0, 10, time.UTC), + } + key = &stmtSummaryByDigestKey{ + schemaName: stmtExecInfo5.SchemaName, + digest: stmtExecInfo5.Digest, + } + + s.ssMap.AddStatement(stmtExecInfo5) + c.Assert(s.ssMap.summaryMap.Size(), Equals, 3) + _, ok = s.ssMap.summaryMap.Get(key) + c.Assert(ok, IsTrue) +} + +func match(c *C, row []types.Datum, expected ...interface{}) { + c.Assert(len(row), Equals, len(expected)) + for i := range row { + got := fmt.Sprintf("%v", row[i].GetValue()) + need := fmt.Sprintf("%v", expected[i]) + c.Assert(got, Equals, need) + } +} + +// Test stmtSummaryByDigest.ToDatum +func (s *testStmtSummarySuite) TestToDatum(c *C) { + s.ssMap.Clear() + + stmtExecInfo1 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: "original_sql1", + NormalizedSQL: "normalized_sql", + Digest: "digest", + TotalLatency: 10000, + AffectedRows: 100, + SentRows: 100, + StartTime: time.Date(2019, 1, 1, 10, 10, 10, 10, time.UTC), + } + s.ssMap.AddStatement(stmtExecInfo1) + datums := s.ssMap.ToDatum() + c.Assert(len(datums), Equals, 1) + t := types.Time{Time: types.FromGoTime(stmtExecInfo1.StartTime), Type: mysql.TypeTimestamp} + match(c, datums[0], stmtExecInfo1.SchemaName, stmtExecInfo1.Digest, stmtExecInfo1.NormalizedSQL, + 1, stmtExecInfo1.TotalLatency, stmtExecInfo1.TotalLatency, stmtExecInfo1.TotalLatency, stmtExecInfo1.TotalLatency, + stmtExecInfo1.AffectedRows, t, t, stmtExecInfo1.OriginalSQL) +} + +// Test AddStatement and ToDatum parallel +func (s *testStmtSummarySuite) TestAddStatementParallel(c *C) { + s.ssMap.Clear() + + threads := 8 + loops := 32 + wg := sync.WaitGroup{} + wg.Add(threads) + + addStmtFunc := func() { + defer wg.Done() + stmtExecInfo1 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: "original_sql1", + NormalizedSQL: "normalized_sql", + Digest: "digest", + TotalLatency: 10000, + AffectedRows: 100, + SentRows: 100, + StartTime: time.Date(2019, 1, 1, 10, 10, 10, 10, time.UTC), + } + + // Add 32 times with different digest + for i := 0; i < loops; i++ { + stmtExecInfo1.Digest = fmt.Sprintf("digest%d", i) + s.ssMap.AddStatement(stmtExecInfo1) + } + + // There would be 32 summaries + datums := s.ssMap.ToDatum() + c.Assert(len(datums), Equals, loops) + } + + for i := 0; i < threads; i++ { + go addStmtFunc() + } + wg.Wait() + + datums := s.ssMap.ToDatum() + c.Assert(len(datums), Equals, loops) +} + +// Test max number of statement count. +func (s *testStmtSummarySuite) TestMaxStmtCount(c *C) { + s.ssMap.Clear() + + stmtExecInfo1 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: "original_sql1", + NormalizedSQL: "normalized_sql", + Digest: "digest", + TotalLatency: 10000, + AffectedRows: 100, + SentRows: 100, + StartTime: time.Date(2019, 1, 1, 10, 10, 10, 10, time.UTC), + } + + maxStmtCount := config.GetGlobalConfig().StmtSummary.MaxStmtCount + + // 1000 digests + loops := int(maxStmtCount) * 10 + for i := 0; i < loops; i++ { + stmtExecInfo1.Digest = fmt.Sprintf("digest%d", i) + s.ssMap.AddStatement(stmtExecInfo1) + } + + // Summary count should be MaxStmtCount + sm := s.ssMap.summaryMap + c.Assert(sm.Size(), Equals, int(maxStmtCount)) + + // LRU cache should work + for i := loops - int(maxStmtCount); i < loops; i++ { + key := &stmtSummaryByDigestKey{ + schemaName: stmtExecInfo1.SchemaName, + digest: fmt.Sprintf("digest%d", i), + } + _, ok := sm.Get(key) + c.Assert(ok, IsTrue) + } +} + +// Test max length of normalized and sample SQL. +func (s *testStmtSummarySuite) TestMaxSQLLength(c *C) { + s.ssMap.Clear() + + // Create a long SQL + maxSQLLength := config.GetGlobalConfig().StmtSummary.MaxSQLLength + length := int(maxSQLLength) * 10 + str := strings.Repeat("a", length) + + stmtExecInfo1 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: str, + NormalizedSQL: str, + Digest: "digest", + TotalLatency: 10000, + AffectedRows: 100, + SentRows: 100, + StartTime: time.Date(2019, 1, 1, 10, 10, 10, 10, time.UTC), + } + + s.ssMap.AddStatement(stmtExecInfo1) + key := &stmtSummaryByDigestKey{ + schemaName: stmtExecInfo1.SchemaName, + digest: stmtExecInfo1.Digest, + } + value, ok := s.ssMap.summaryMap.Get(key) + c.Assert(ok, IsTrue) + // Length of normalizedSQL and sampleSQL should be maxSQLLength + summary := value.(*stmtSummaryByDigest) + c.Assert(len(summary.normalizedSQL), Equals, int(maxSQLLength)) + c.Assert(len(summary.sampleSQL), Equals, int(maxSQLLength)) +} + +// Test setting EnableStmtSummary to 0 +func (s *testStmtSummarySuite) TestDisableStmtSummary(c *C) { + s.ssMap.Clear() + // Set false in global scope, it should work. + s.ssMap.SetEnabled("0", false) + + stmtExecInfo1 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: "original_sql1", + NormalizedSQL: "normalized_sql", + Digest: "digest", + TotalLatency: 10000, + AffectedRows: 100, + SentRows: 100, + StartTime: time.Date(2019, 1, 1, 10, 10, 10, 10, time.UTC), + } + + s.ssMap.AddStatement(stmtExecInfo1) + datums := s.ssMap.ToDatum() + c.Assert(len(datums), Equals, 0) + + // Set true in session scope, it will overwrite global scope. + s.ssMap.SetEnabled("1", true) + + s.ssMap.AddStatement(stmtExecInfo1) + datums = s.ssMap.ToDatum() + c.Assert(len(datums), Equals, 1) + + // Set false in global scope, it shouldn't work. + s.ssMap.SetEnabled("0", false) + + stmtExecInfo2 := &StmtExecInfo{ + SchemaName: "schema_name", + OriginalSQL: "original_sql2", + NormalizedSQL: "normalized_sql2", + Digest: "digest2", + TotalLatency: 50000, + AffectedRows: 500, + SentRows: 500, + StartTime: time.Date(2019, 1, 1, 10, 10, 20, 10, time.UTC), + } + s.ssMap.AddStatement(stmtExecInfo2) + datums = s.ssMap.ToDatum() + c.Assert(len(datums), Equals, 2) + + // Unset in session scope + s.ssMap.SetEnabled("", true) + s.ssMap.AddStatement(stmtExecInfo2) + datums = s.ssMap.ToDatum() + c.Assert(len(datums), Equals, 0) + + // Unset in global scope + s.ssMap.SetEnabled("", false) + s.ssMap.AddStatement(stmtExecInfo1) + datums = s.ssMap.ToDatum() + c.Assert(len(datums), Equals, 0) + + // Set back + s.ssMap.SetEnabled("1", false) +} diff --git a/util/stringutil/string_util.go b/util/stringutil/string_util.go index 6e69f61d2bd52..b6e472b5a40e4 100644 --- a/util/stringutil/string_util.go +++ b/util/stringutil/string_util.go @@ -252,23 +252,18 @@ func Copy(src string) string { return string(hack.Slice(src)) } -// stringerFunc defines string func implement fmt.Stringer. -type stringerFunc func() string +// StringerFunc defines string func implement fmt.Stringer. +type StringerFunc func() string // String implements fmt.Stringer -func (l stringerFunc) String() string { +func (l StringerFunc) String() string { return l() } // MemoizeStr returns memoized version of stringFunc. func MemoizeStr(l func() string) fmt.Stringer { - var result string - return stringerFunc(func() string { - if result != "" { - return result - } - result = l() - return result + return StringerFunc(func() string { + return l() }) } diff --git a/util/testkit/testkit.go b/util/testkit/testkit.go index c9e58d0cf9f58..cf9aa5b0dd0fe 100644 --- a/util/testkit/testkit.go +++ b/util/testkit/testkit.go @@ -23,6 +23,8 @@ import ( "github.com/pingcap/check" "github.com/pingcap/errors" + "github.com/pingcap/parser/model" + "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/util/sqlexec" @@ -190,6 +192,20 @@ func (tk *TestKit) MustIndexLookup(sql string, args ...interface{}) *Result { return tk.MustQuery(sql, args...) } +// MustTableDual checks whether the plan for the sql is TableDual. +func (tk *TestKit) MustTableDual(sql string, args ...interface{}) *Result { + rs := tk.MustQuery("explain "+sql, args...) + hasTableDual := false + for i := range rs.rows { + if strings.Contains(rs.rows[i][0], "TableDual") { + hasTableDual = true + break + } + } + tk.c.Assert(hasTableDual, check.IsTrue) + return tk.MustQuery(sql, args...) +} + // MustPointGet checks whether the plan for the sql is Point_Get. func (tk *TestKit) MustPointGet(sql string, args ...interface{}) *Result { rs := tk.MustQuery("explain "+sql, args...) @@ -234,12 +250,16 @@ func (tk *TestKit) ResultSetToResult(rs sqlexec.RecordSet, comment check.Comment return tk.ResultSetToResultWithCtx(context.Background(), rs, comment) } -// ResultSetToResultWithCtx converts sqlexec.RecordSet to testkit.Result. -func (tk *TestKit) ResultSetToResultWithCtx(ctx context.Context, rs sqlexec.RecordSet, comment check.CommentInterface) *Result { - rows, err := session.GetRows4Test(ctx, tk.Se, rs) - tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment) +// ResultSetToStringSlice changes the RecordSet to [][]string. +func ResultSetToStringSlice(ctx context.Context, s session.Session, rs sqlexec.RecordSet) ([][]string, error) { + rows, err := session.GetRows4Test(ctx, s, rs) + if err != nil { + return nil, err + } err = rs.Close() - tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment) + if err != nil { + return nil, err + } sRows := make([][]string, len(rows)) for i := range rows { row := rows[i] @@ -250,11 +270,20 @@ func (tk *TestKit) ResultSetToResultWithCtx(ctx context.Context, rs sqlexec.Reco } else { d := row.GetDatum(j, &rs.Fields()[j].Column.FieldType) iRow[j], err = d.ToString() - tk.c.Assert(err, check.IsNil) + if err != nil { + return nil, err + } } } sRows[i] = iRow } + return sRows, nil +} + +// ResultSetToResultWithCtx converts sqlexec.RecordSet to testkit.Result. +func (tk *TestKit) ResultSetToResultWithCtx(ctx context.Context, rs sqlexec.RecordSet, comment check.CommentInterface) *Result { + sRows, err := ResultSetToStringSlice(ctx, tk.Se, rs) + tk.c.Check(err, check.IsNil, comment) return &Result{rows: sRows, c: tk.c, comment: comment} } @@ -262,3 +291,12 @@ func (tk *TestKit) ResultSetToResultWithCtx(ctx context.Context, rs sqlexec.Reco func Rows(args ...string) [][]interface{} { return testutil.RowsWithSep(" ", args...) } + +// GetTableID gets table ID by name. +func (tk *TestKit) GetTableID(tableName string) int64 { + dom := domain.GetDomain(tk.Se) + is := dom.InfoSchema() + tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr(tableName)) + tk.c.Assert(err, check.IsNil) + return tbl.Meta().ID +} diff --git a/util/testutil/testutil.go b/util/testutil/testutil.go index 003756b0fd405..9e4928c6df696 100644 --- a/util/testutil/testutil.go +++ b/util/testutil/testutil.go @@ -14,10 +14,20 @@ package testutil import ( + "bytes" + "encoding/json" + "flag" "fmt" + "io/ioutil" + "os" + "path/filepath" + "reflect" + "regexp" + "runtime" "strings" "github.com/pingcap/check" + "github.com/pingcap/errors" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" ) @@ -108,3 +118,181 @@ func RowsWithSep(sep string, args ...string) [][]interface{} { } return rows } + +// record is a flag used for generate test result. +var record bool + +func init() { + flag.BoolVar(&record, "record", false, "to generate test result") +} + +type testCases struct { + Name string + Cases *json.RawMessage // For delayed parse. + decodedOut interface{} // For generate output. +} + +// TestData stores all the data of a test suite. +type TestData struct { + input []testCases + output []testCases + filePathPrefix string + funcMap map[string]int +} + +// LoadTestSuiteData loads test suite data from file. +func LoadTestSuiteData(dir, suiteName string) (res TestData, err error) { + res.filePathPrefix = filepath.Join(dir, suiteName) + res.input, err = loadTestSuiteCases(fmt.Sprintf("%s_in.json", res.filePathPrefix)) + if err != nil { + return res, err + } + if record { + res.output = make([]testCases, len(res.input), len(res.input)) + for i := range res.input { + res.output[i].Name = res.input[i].Name + } + } else { + res.output, err = loadTestSuiteCases(fmt.Sprintf("%s_out.json", res.filePathPrefix)) + if err != nil { + return res, err + } + if len(res.input) != len(res.output) { + return res, errors.New(fmt.Sprintf("Number of test input cases %d does not match test output cases %d", len(res.input), len(res.output))) + } + } + res.funcMap = make(map[string]int, len(res.input)) + for i, test := range res.input { + res.funcMap[test.Name] = i + if test.Name != res.output[i].Name { + return res, errors.New(fmt.Sprintf("Input name of the %d-case %s does not match output %s", i, test.Name, res.output[i].Name)) + } + } + return res, nil +} + +func loadTestSuiteCases(filePath string) (res []testCases, err error) { + var jsonFile *os.File + jsonFile, err = os.Open(filePath) + if err != nil { + return res, err + } + defer func() { + err1 := jsonFile.Close() + if err == nil { + err = err1 + } + }() + byteValue, err := ioutil.ReadAll(jsonFile) + if err != nil { + return res, err + } + // Remove comments, since they are not allowed in json. + re := regexp.MustCompile("(?s)//.*?\n") + err = json.Unmarshal(re.ReplaceAll(byteValue, nil), &res) + return res, err +} + +// GetTestCasesByName gets the test cases for a test function by its name. +func (t *TestData) GetTestCasesByName(caseName string, c *check.C, in interface{}, out interface{}) { + casesIdx, ok := t.funcMap[caseName] + c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", caseName)) + err := json.Unmarshal(*t.input[casesIdx].Cases, in) + c.Assert(err, check.IsNil) + if !record { + err = json.Unmarshal(*t.output[casesIdx].Cases, out) + c.Assert(err, check.IsNil) + } else { + // Init for generate output file. + inputLen := reflect.ValueOf(in).Elem().Len() + v := reflect.ValueOf(out).Elem() + if v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen)) + } + } + t.output[casesIdx].decodedOut = out +} + +// GetTestCases gets the test cases for a test function. +func (t *TestData) GetTestCases(c *check.C, in interface{}, out interface{}) { + // Extract caller's name. + pc, _, _, ok := runtime.Caller(1) + c.Assert(ok, check.IsTrue) + details := runtime.FuncForPC(pc) + funcNameIdx := strings.LastIndex(details.Name(), ".") + funcName := details.Name()[funcNameIdx+1:] + + casesIdx, ok := t.funcMap[funcName] + c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", funcName)) + err := json.Unmarshal(*t.input[casesIdx].Cases, in) + c.Assert(err, check.IsNil) + if !record { + err = json.Unmarshal(*t.output[casesIdx].Cases, out) + c.Assert(err, check.IsNil) + } else { + // Init for generate output file. + inputLen := reflect.ValueOf(in).Elem().Len() + v := reflect.ValueOf(out).Elem() + if v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen)) + } + } + t.output[casesIdx].decodedOut = out +} + +// OnRecord execute the function to update result. +func (t *TestData) OnRecord(updateFunc func()) { + if record { + updateFunc() + } +} + +// ConvertRowsToStrings converts [][]interface{} to []string. +func (t *TestData) ConvertRowsToStrings(rows [][]interface{}) (rs []string) { + for _, row := range rows { + s := fmt.Sprintf("%v", row) + // Trim the leftmost `[` and rightmost `]`. + s = s[1 : len(s)-1] + rs = append(rs, s) + } + return rs +} + +// GenerateOutputIfNeeded generate the output file. +func (t *TestData) GenerateOutputIfNeeded() (err error) { + if !record { + return nil + } + + buf := new(bytes.Buffer) + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) + enc.SetIndent("", " ") + for i, test := range t.output { + err := enc.Encode(test.decodedOut) + if err != nil { + return err + } + res := make([]byte, len(buf.Bytes()), len(buf.Bytes())) + copy(res, buf.Bytes()) + buf.Reset() + rm := json.RawMessage(res) + t.output[i].Cases = &rm + } + err = enc.Encode(t.output) + if err != nil { + return err + } + file, err := os.Create(fmt.Sprintf("%s_out.json", t.filePathPrefix)) + if err != nil { + return err + } + defer func() { + err1 := file.Close() + if err == nil { + err = err1 + } + }() + _, err = file.Write(buf.Bytes()) + return err +} diff --git a/util/texttree/texttree.go b/util/texttree/texttree.go new file mode 100644 index 0000000000000..0d910fb6f2e06 --- /dev/null +++ b/util/texttree/texttree.go @@ -0,0 +1,80 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package texttree + +const ( + // TreeBody indicates the current operator sub-tree is not finished, still + // has child operators to be attached on. + TreeBody = '│' + // TreeMiddleNode indicates this operator is not the last child of the + // current sub-tree rooted by its parent. + TreeMiddleNode = '├' + // TreeLastNode indicates this operator is the last child of the current + // sub-tree rooted by its parent. + TreeLastNode = '└' + // TreeGap is used to represent the gap between the branches of the tree. + TreeGap = ' ' + // TreeNodeIdentifier is used to replace the treeGap once we need to attach + // a node to a sub-tree. + TreeNodeIdentifier = '─' +) + +// Indent4Child appends more blank to the `indent` string +func Indent4Child(indent string, isLastChild bool) string { + if !isLastChild { + return string(append([]rune(indent), TreeBody, TreeGap)) + } + + // If the current node is the last node of the current operator tree, we + // need to end this sub-tree by changing the closest treeBody to a treeGap. + indentBytes := []rune(indent) + for i := len(indentBytes) - 1; i >= 0; i-- { + if indentBytes[i] == TreeBody { + indentBytes[i] = TreeGap + break + } + } + + return string(append(indentBytes, TreeBody, TreeGap)) +} + +// PrettyIdentifier returns a pretty identifier which contains indent and tree node hierarchy indicator +func PrettyIdentifier(id, indent string, isLastChild bool) string { + if len(indent) == 0 { + return id + } + + indentBytes := []rune(indent) + for i := len(indentBytes) - 1; i >= 0; i-- { + if indentBytes[i] != TreeBody { + continue + } + + // Here we attach a new node to the current sub-tree by changing + // the closest treeBody to a: + // 1. treeLastNode, if this operator is the last child. + // 2. treeMiddleNode, if this operator is not the last child.. + if isLastChild { + indentBytes[i] = TreeLastNode + } else { + indentBytes[i] = TreeMiddleNode + } + break + } + + // Replace the treeGap between the treeBody and the node to a + // treeNodeIdentifier. + indentBytes[len(indentBytes)-1] = TreeNodeIdentifier + return string(indentBytes) + id +} diff --git a/util/texttree/texttree_test.go b/util/texttree/texttree_test.go new file mode 100644 index 0000000000000..b3e591efd85b0 --- /dev/null +++ b/util/texttree/texttree_test.go @@ -0,0 +1,44 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package texttree_test + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/texttree" +) + +type texttreeSuite struct{} + +var _ = Suite(&texttreeSuite{}) + +func TestT(t *testing.T) { + CustomVerboseFlag = true + TestingT(t) +} + +func (s *texttreeSuite) TestPrettyIdentifier(c *C) { + c.Assert(texttree.PrettyIdentifier("test", "", false), Equals, "test") + c.Assert(texttree.PrettyIdentifier("test", " │ ", false), Equals, " ├ ─test") + c.Assert(texttree.PrettyIdentifier("test", "\t\t│\t\t", false), Equals, "\t\t├\t─test") + c.Assert(texttree.PrettyIdentifier("test", " │ ", true), Equals, " └ ─test") + c.Assert(texttree.PrettyIdentifier("test", "\t\t│\t\t", true), Equals, "\t\t└\t─test") +} + +func (s *texttreeSuite) TestIndent4Child(c *C) { + c.Assert(texttree.Indent4Child(" ", false), Equals, " │ ") + c.Assert(texttree.Indent4Child(" ", true), Equals, " │ ") + c.Assert(texttree.Indent4Child(" │ ", true), Equals, " │ ") +}