diff --git a/CHANGELOG.md b/CHANGELOG.md index d14767c4..dcee8393 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,8 @@ - New #348: Realize `Schema::loadResultColumn()` method (@Tigrov) - New #354: Add `FOR` clause to query (@vjik) - New #355: Use `DateTimeColumn` class for datetime column types (@Tigrov) +- New #356: Implement `DMLQueryBuilder::upsertWithReturningPks()` method (@Tigrov) +- Enh #356: Refactor `Command::insertWithReturningPks()` and `DMLQueryBuilder::upsert()` methods (@Tigrov) ## 1.2.0 March 21, 2024 diff --git a/src/Command.php b/src/Command.php index 152c5554..5efe7b16 100644 --- a/src/Command.php +++ b/src/Command.php @@ -8,10 +8,14 @@ use Yiisoft\Db\Driver\Pdo\AbstractPdoCommand; use Yiisoft\Db\Exception\Exception; use Yiisoft\Db\Exception\InvalidArgumentException; +use Yiisoft\Db\Exception\NotSupportedException; +use Yiisoft\Db\Query\QueryInterface; +use Yiisoft\Db\Schema\Column\ColumnInterface; use function array_pop; use function count; use function ltrim; +use function mb_strlen; use function preg_match_all; use function strpos; @@ -21,27 +25,45 @@ */ final class Command extends AbstractPdoCommand { - public function insertWithReturningPks(string $table, array $columns): array|false + public function insertWithReturningPks(string $table, array|QueryInterface $columns): array|false { $params = []; $sql = $this->db->getQueryBuilder()->insert($table, $columns, $params); $this->setSql($sql)->bindValues($params); - if (!$this->execute()) { + if ($this->execute() === 0) { return false; } $tableSchema = $this->db->getSchema()->getTableSchema($table); $tablePrimaryKeys = $tableSchema?->getPrimaryKey() ?? []; + if (empty($tablePrimaryKeys)) { + return []; + } + + if ($columns instanceof QueryInterface) { + throw new NotSupportedException(__METHOD__ . '() not supported for QueryInterface by SQLite.'); + } + $result = []; + + /** @var TableSchema $tableSchema */ foreach ($tablePrimaryKeys as $name) { - if ($tableSchema?->getColumn($name)?->isAutoIncrement()) { - $result[$name] = $this->db->getLastInsertId((string) $tableSchema?->getSequenceName()); - continue; + /** @var ColumnInterface $column */ + $column = $tableSchema->getColumn($name); + + if ($column->isAutoIncrement()) { + $value = $this->db->getLastInsertId(); + } else { + $value = $columns[$name] ?? $column->getDefaultValue(); + } + + if ($this->phpTypecasting) { + $value = $column->phpTypecast($value); } - $result[$name] = $columns[$name] ?? $tableSchema?->getColumn($name)?->getDefaultValue(); + $result[$name] = $value; } return $result; @@ -139,11 +161,11 @@ protected function queryInternal(int $queryMode): mixed * * @throws InvalidArgumentException * - * @return array|bool List of SQL statements or `false` if there's a single statement. + * @return array|false List of SQL statements or `false` if there's a single statement. * * @psalm-return false|list */ - private function splitStatements(string $sql, array $params): bool|array + private function splitStatements(string $sql, array $params): array|false { $semicolonIndex = strpos($sql, ';'); diff --git a/src/DMLQueryBuilder.php b/src/DMLQueryBuilder.php index 129092b3..cdeb8730 100644 --- a/src/DMLQueryBuilder.php +++ b/src/DMLQueryBuilder.php @@ -4,13 +4,13 @@ namespace Yiisoft\Db\Sqlite; -use Yiisoft\Db\Constraint\Constraint; use Yiisoft\Db\Exception\InvalidArgumentException; use Yiisoft\Db\Exception\NotSupportedException; use Yiisoft\Db\Expression\Expression; use Yiisoft\Db\Query\QueryInterface; use Yiisoft\Db\QueryBuilder\AbstractDMLQueryBuilder; +use function array_map; use function implode; /** @@ -18,7 +18,7 @@ */ final class DMLQueryBuilder extends AbstractDMLQueryBuilder { - public function insertWithReturningPks(string $table, QueryInterface|array $columns, array &$params = []): string + public function insertWithReturningPks(string $table, array|QueryInterface $columns, array &$params = []): string { throw new NotSupportedException(__METHOD__ . '() is not supported by SQLite.'); } @@ -52,65 +52,55 @@ public function resetSequence(string $table, int|string|null $value = null): str public function upsert( string $table, - QueryInterface|array $insertColumns, - bool|array $updateColumns, - array &$params + array|QueryInterface $insertColumns, + array|bool $updateColumns = true, + array &$params = [], ): string { - /** @var Constraint[] $constraints */ - $constraints = []; + $insertSql = $this->insert($table, $insertColumns, $params); - [$uniqueNames, $insertNames, $updateNames] = $this->prepareUpsertColumns( - $table, - $insertColumns, - $updateColumns, - $constraints - ); + [$uniqueNames, , $updateNames] = $this->prepareUpsertColumns($table, $insertColumns, $updateColumns); if (empty($uniqueNames)) { - return $this->insert($table, $insertColumns, $params); - } - - [, $placeholders, $values, $params] = $this->prepareInsertValues($table, $insertColumns, $params); - - $quotedTableName = $this->quoter->quoteTableName($table); - - $insertSql = 'INSERT OR IGNORE INTO ' . $quotedTableName - . (!empty($insertNames) ? ' (' . implode(', ', $insertNames) . ')' : '') - . (!empty($placeholders) ? ' VALUES (' . implode(', ', $placeholders) . ')' : ' ' . $values); - - if ($updateColumns === false) { return $insertSql; } - $updateCondition = ['or']; - - foreach ($constraints as $constraint) { - $constraintCondition = ['and']; - /** @psalm-var string[] $columnNames */ - $columnNames = $constraint->getColumnNames(); - foreach ($columnNames as $name) { - $quotedName = $this->quoter->quoteColumnName($name); - $constraintCondition[] = "$quotedTableName.$quotedName=(SELECT $quotedName FROM `EXCLUDED`)"; - } - $updateCondition[] = $constraintCondition; + if ($updateColumns === false || $updateNames === []) { + /** there are no columns to update */ + return "$insertSql ON CONFLICT DO NOTHING"; } if ($updateColumns === true) { $updateColumns = []; + /** @psalm-var string[] $updateNames */ - foreach ($updateNames as $quotedName) { - $updateColumns[$quotedName] = new Expression("(SELECT $quotedName FROM `EXCLUDED`)"); + foreach ($updateNames as $name) { + $updateColumns[$name] = new Expression( + 'EXCLUDED.' . $this->quoter->quoteColumnName($name) + ); } } - if ($updateColumns === []) { - return $insertSql; - } + [$updates, $params] = $this->prepareUpdateSets($table, $updateColumns, $params); + + return $insertSql + . ' ON CONFLICT (' . implode(', ', $uniqueNames) . ') DO UPDATE SET ' . implode(', ', $updates); + } + + public function upsertWithReturningPks( + string $table, + array|QueryInterface $insertColumns, + array|bool $updateColumns = true, + array &$params = [], + ): string { + $sql = $this->upsert($table, $insertColumns, $updateColumns, $params); + $returnColumns = $this->schema->getTableSchema($table)?->getPrimaryKey(); - $updateSql = 'WITH "EXCLUDED" (' . implode(', ', $insertNames) . ') AS (' - . (!empty($placeholders) ? 'VALUES (' . implode(', ', $placeholders) . ')' : $values) - . ') ' . $this->update($table, $updateColumns, $updateCondition, $params); + if (!empty($returnColumns)) { + $returnColumns = array_map($this->quoter->quoteColumnName(...), $returnColumns); + + $sql .= ' RETURNING ' . implode(', ', $returnColumns); + } - return "$updateSql; $insertSql;"; + return $sql; } } diff --git a/tests/Provider/QueryBuilderProvider.php b/tests/Provider/QueryBuilderProvider.php index 40bf0878..4b00f52f 100644 --- a/tests/Provider/QueryBuilderProvider.php +++ b/tests/Provider/QueryBuilderProvider.php @@ -115,72 +115,72 @@ public static function upsert(): array $concreteData = [ 'regular values' => [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ 3 => << [ @@ -199,6 +199,36 @@ public static function upsert(): array return $upsert; } + public static function upsertWithReturningPks(): array + { + $upsert = self::upsert(); + + foreach ($upsert as &$data) { + $data[3] .= ' RETURNING `id`'; + } + + $upsert['no columns to update'][3] = 'INSERT INTO `T_upsert_1` (`a`) VALUES (:qp0) ON CONFLICT DO NOTHING RETURNING `a`'; + + return [ + ...$upsert, + 'composite primary key' => [ + 'notauto_pk', + ['id_1' => 1, 'id_2' => 2.5, 'type' => 'Test'], + true, + 'INSERT INTO `notauto_pk` (`id_1`, `id_2`, `type`) VALUES (:qp0, :qp1, :qp2)' + . ' ON CONFLICT (`id_1`, `id_2`) DO UPDATE SET `type`=EXCLUDED.`type` RETURNING `id_1`, `id_2`', + [':qp0' => 1, ':qp1' => 2.5, ':qp2' => 'Test'], + ], + 'no primary key' => [ + 'type', + ['int_col' => 3, 'char_col' => 'a', 'float_col' => 1.2, 'bool_col' => true], + true, + 'INSERT INTO `type` (`int_col`, `char_col`, `float_col`, `bool_col`) VALUES (:qp0, :qp1, :qp2, :qp3)', + [':qp0' => 3, ':qp1' => 'a', ':qp2' => 1.2, ':qp3' => true], + ], + ]; + } + public static function buildColumnDefinition(): array { $values = parent::buildColumnDefinition(); diff --git a/tests/QueryBuilderTest.php b/tests/QueryBuilderTest.php index 503df6c3..f9819412 100644 --- a/tests/QueryBuilderTest.php +++ b/tests/QueryBuilderTest.php @@ -617,26 +617,21 @@ public function testUpsert( string $table, array|QueryInterface $insertColumns, array|bool $updateColumns, - string $expectedSQL, + string $expectedSql, array $expectedParams ): void { - $db = $this->getConnection(true); - - $actualParams = []; - $actualSQL = $db->getQueryBuilder()->upsert($table, $insertColumns, $updateColumns, $actualParams); - - $this->assertSame($expectedSQL, $actualSQL); - - $this->assertSame($expectedParams, $actualParams); + parent::testUpsert($table, $insertColumns, $updateColumns, $expectedSql, $expectedParams); } - #[DataProviderExternal(QueryBuilderProvider::class, 'upsert')] - public function testUpsertExecute( + #[DataProviderExternal(QueryBuilderProvider::class, 'upsertWithReturningPks')] + public function testUpsertWithReturningPks( string $table, array|QueryInterface $insertColumns, - array|bool $updateColumns + array|bool $updateColumns, + string $expectedSql, + array $expectedParams ): void { - parent::testUpsertExecute($table, $insertColumns, $updateColumns); + parent::testUpsertWithReturningPks($table, $insertColumns, $updateColumns, $expectedSql, $expectedParams); } #[DataProviderExternal(QueryBuilderProvider::class, 'selectScalar')] diff --git a/tests/Support/Fixture/sqlite.sql b/tests/Support/Fixture/sqlite.sql index de47c723..3c7cd431 100644 --- a/tests/Support/Fixture/sqlite.sql +++ b/tests/Support/Fixture/sqlite.sql @@ -165,7 +165,7 @@ CREATE TABLE "default_pk" ( CREATE TABLE "notauto_pk" ( id_1 INTEGER, - id_2 INTEGER, + id_2 DECIMAL(5,2), type VARCHAR(255) NOT NULL, PRIMARY KEY (id_1, id_2) );