diff --git a/Attribute/AsMiddleware.php b/Attribute/AsMiddleware.php new file mode 100644 index 000000000..3a0646171 --- /dev/null +++ b/Attribute/AsMiddleware.php @@ -0,0 +1,15 @@ +findTaggedServiceIds('doctrine.middleware') as $id => $tags) { + $middlewareAbstractDefs[$id] = $container->getDefinition($id); + // When a def has doctrine.middleware tags with connection attributes equal to connection names + // registration of this middleware is limited to the connections with these names + foreach ($tags as $tag) { + if (! isset($tag['connection'])) { + continue; + } + + $middlewareConnections[$id][] = $tag['connection']; + } + } + + foreach (array_keys($container->getParameter('doctrine.connections')) as $name) { + $middlewareDefs = []; + foreach ($middlewareAbstractDefs as $id => $abstractDef) { + if (isset($middlewareConnections[$id]) && ! in_array($name, $middlewareConnections[$id], true)) { + continue; + } + + $middlewareDefs[] = $childDef = $container->setDefinition( + sprintf('%s.%s', $id, $name), + new ChildDefinition($id) + ); + + if (! is_subclass_of($abstractDef->getClass(), ConnectionNameAwareInterface::class)) { + continue; + } + + $childDef->addMethodCall('setConnectionName', [$name]); + } + + $container + ->getDefinition(sprintf('doctrine.dbal.%s_connection.configuration', $name)) + ->addMethodCall('setMiddlewares', [$middlewareDefs]); + } + } +} diff --git a/DependencyInjection/DoctrineExtension.php b/DependencyInjection/DoctrineExtension.php index 6790ed0ec..a95cde667 100644 --- a/DependencyInjection/DoctrineExtension.php +++ b/DependencyInjection/DoctrineExtension.php @@ -3,6 +3,7 @@ namespace Doctrine\Bundle\DoctrineBundle\DependencyInjection; use Doctrine\Bundle\DoctrineBundle\Attribute\AsEntityListener; +use Doctrine\Bundle\DoctrineBundle\Attribute\AsMiddleware; use Doctrine\Bundle\DoctrineBundle\CacheWarmer\DoctrineMetadataCacheWarmer; use Doctrine\Bundle\DoctrineBundle\Command\Proxy\ImportDoctrineCommand; use Doctrine\Bundle\DoctrineBundle\Dbal\ManagerRegistryAwareConnectionProvider; @@ -13,8 +14,9 @@ use Doctrine\Bundle\DoctrineBundle\Repository\ServiceEntityRepositoryInterface; use Doctrine\DBAL\Connection; use Doctrine\DBAL\Connections\PrimaryReadReplicaConnection; -use Doctrine\DBAL\Driver\Middleware; +use Doctrine\DBAL\Driver\Middleware as MiddlewareInterface; use Doctrine\DBAL\Logging\LoggerChain; +use Doctrine\DBAL\Logging\Middleware; use Doctrine\DBAL\Sharding\PoolingShardConnection; use Doctrine\DBAL\Sharding\PoolingShardManager; use Doctrine\DBAL\Tools\Console\Command\ImportCommand; @@ -44,7 +46,6 @@ use Symfony\Component\DependencyInjection\Alias; use Symfony\Component\DependencyInjection\ChildDefinition; use Symfony\Component\DependencyInjection\ContainerBuilder; -use Symfony\Component\DependencyInjection\ContainerInterface; use Symfony\Component\DependencyInjection\Definition; use Symfony\Component\DependencyInjection\Exception\InvalidArgumentException; use Symfony\Component\DependencyInjection\Loader\XmlFileLoader; @@ -65,6 +66,8 @@ use function sprintf; use function str_replace; +use const PHP_VERSION_ID; + /** * DoctrineExtension is an extension for the Doctrine DBAL and ORM library. */ @@ -143,9 +146,33 @@ protected function dbalLoad(array $config, ContainerBuilder $container) $container->setParameter('doctrine.connections', $connections); $container->setParameter('doctrine.default_connection', $this->defaultConnection); + $connWithLogging = []; foreach ($config['connections'] as $name => $connection) { + if ($connection['logging']) { + $connWithLogging[] = $name; + } + $this->loadDbalConnection($name, $connection, $container); } + + /** @psalm-suppress UndefinedClass */ + $container->registerForAutoconfiguration(MiddlewareInterface::class)->addTag('doctrine.middleware'); + + if (PHP_VERSION_ID >= 80000 && method_exists(ContainerBuilder::class, 'registerAttributeForAutoconfiguration')) { + $container->registerAttributeForAutoconfiguration(AsMiddleware::class, static function (ChildDefinition $definition, AsMiddleware $attribute) { + if ($attribute->connections === []) { + $definition->addTag('doctrine.middleware'); + + return; + } + + foreach ($attribute->connections as $connName) { + $definition->addTag('doctrine.middleware', ['connection' => $connName]); + } + }); + } + + $this->useMiddlewaresIfAvailable($container, $connWithLogging); } /** @@ -160,7 +187,6 @@ protected function loadDbalConnection($name, array $connection, ContainerBuilder $configuration = $container->setDefinition(sprintf('doctrine.dbal.%s_connection.configuration', $name), new ChildDefinition('doctrine.dbal.connection.configuration')); $logger = null; if ($connection['logging']) { - $this->useMiddlewaresIfAvailable($connection, $container, $name, $configuration); $logger = new Reference('doctrine.dbal.logger'); } @@ -1073,11 +1099,11 @@ private function createArrayAdapterCachePool(ContainerBuilder $container, string return $id; } - /** @param array $connection */ - protected function useMiddlewaresIfAvailable(array $connection, ContainerBuilder $container, string $name, Definition $configuration): void + /** @param string[] $connWithLogging */ + private function useMiddlewaresIfAvailable(ContainerBuilder $container, array $connWithLogging): void { /** @psalm-suppress UndefinedClass */ - if (! interface_exists(Middleware::class)) { + if (! class_exists(Middleware::class)) { return; } @@ -1085,13 +1111,12 @@ protected function useMiddlewaresIfAvailable(array $connection, ContainerBuilder ->getDefinition('doctrine.dbal.logger') ->replaceArgument(0, null); - $loggingMiddlewareDef = $container->setDefinition( - sprintf('doctrine.dbal.%s_connection.logging_middleware', $name), - new ChildDefinition('doctrine.dbal.logging_middleware') - ); - $loggingMiddlewareDef->addArgument(new Reference('logger', ContainerInterface::NULL_ON_INVALID_REFERENCE)); - $loggingMiddlewareDef->addTag('monolog.logger', ['channel' => 'doctrine']); + $loader = new XmlFileLoader($container, new FileLocator(__DIR__ . '/../Resources/config')); + $loader->load('middlewares.xml'); - $configuration->addMethodCall('setMiddlewares', [[$loggingMiddlewareDef]]); + $loggingMiddlewareAbstractDef = $container->getDefinition('doctrine.dbal.logging_middleware'); + foreach ($connWithLogging as $connName) { + $loggingMiddlewareAbstractDef->addTag('doctrine.middleware', ['connection' => $connName]); + } } } diff --git a/DoctrineBundle.php b/DoctrineBundle.php index 01daf7d12..9ec763a9f 100644 --- a/DoctrineBundle.php +++ b/DoctrineBundle.php @@ -7,10 +7,12 @@ use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\DbalSchemaFilterPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\EntityListenerPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\IdGeneratorPass; +use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\MiddlewaresPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\RemoveProfilerControllerPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\ServiceRepositoryCompilerPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\WellKnownSchemaFilterPass; use Doctrine\Common\Util\ClassUtils; +use Doctrine\DBAL\Driver\Middleware; use Doctrine\ORM\EntityManagerInterface; use Doctrine\ORM\Proxy\Autoloader; use Symfony\Bridge\Doctrine\DependencyInjection\CompilerPass\DoctrineValidationPass; @@ -26,6 +28,7 @@ use function assert; use function class_exists; use function clearstatcache; +use function interface_exists; use function spl_autoload_unregister; class DoctrineBundle extends Bundle @@ -60,6 +63,11 @@ public function build(ContainerBuilder $container) $container->addCompilerPass(new CacheSchemaSubscriberPass(), PassConfig::TYPE_BEFORE_REMOVING, -10); $container->addCompilerPass(new RemoveProfilerControllerPass()); + /** @psalm-suppress UndefinedClass */ + if (interface_exists(Middleware::class)) { + $container->addCompilerPass(new MiddlewaresPass()); + } + if (! class_exists(RegisterUidTypePass::class)) { return; } diff --git a/Middleware/ConnectionNameAwareInterface.php b/Middleware/ConnectionNameAwareInterface.php new file mode 100644 index 000000000..5bb2e41da --- /dev/null +++ b/Middleware/ConnectionNameAwareInterface.php @@ -0,0 +1,8 @@ + - - - diff --git a/Resources/config/middlewares.xml b/Resources/config/middlewares.xml new file mode 100644 index 000000000..9fa2d797f --- /dev/null +++ b/Resources/config/middlewares.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + diff --git a/Tests/DependencyInjection/Compiler/MiddlewarePassTest.php b/Tests/DependencyInjection/Compiler/MiddlewarePassTest.php new file mode 100644 index 000000000..b2397af3d --- /dev/null +++ b/Tests/DependencyInjection/Compiler/MiddlewarePassTest.php @@ -0,0 +1,292 @@ + */ + public function provideAddMiddleware(): array + { + return [ + 'not connection name aware' => [PHP7Middleware::class, false], + 'connection name aware' => [ConnectionAwarePHP7Middleware::class, true], + ]; + } + + /** @dataProvider provideAddMiddleware */ + public function testAddMiddlewareWithExplicitTag(string $middlewareClass, bool $connectionNameAware): void + { + /** @psalm-suppress UndefinedClass */ + if (! interface_exists(Middleware::class)) { + $this->markTestSkipped(sprintf('%s needed to run this test', Middleware::class)); + } + + $container = $this->createContainer(static function (ContainerBuilder $container) use ($middlewareClass) { + $container + ->register('middleware', $middlewareClass) + ->setAbstract(true) + ->addTag('doctrine.middleware'); + + $container + ->setAlias('conf_conn1', 'doctrine.dbal.conn1_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + + $container + ->setAlias('conf_conn2', 'doctrine.dbal.conn2_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + }); + + $this->assertMiddlewareInjected($container, 'conn1', $middlewareClass, $connectionNameAware); + $this->assertMiddlewareInjected($container, 'conn2', $middlewareClass, $connectionNameAware); + } + + public function testAddMiddlewareWithExplicitTagsOnSpecificConnections(): void + { + /** @psalm-suppress UndefinedClass */ + if (! interface_exists(Middleware::class)) { + $this->markTestSkipped(sprintf('%s needed to run this test', Middleware::class)); + } + + $container = $this->createContainer(static function (ContainerBuilder $container) { + $container + ->register('middleware', PHP7Middleware::class) + ->setAbstract(true) + ->addTag('doctrine.middleware', ['connection' => 'conn1']); + + $container + ->setAlias('conf_conn1', 'doctrine.dbal.conn1_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + + $container + ->setAlias('conf_conn2', 'doctrine.dbal.conn2_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + }); + + $this->assertMiddlewareInjected($container, 'conn1', PHP7Middleware::class); + $this->assertMiddlewareNotInjected($container, 'conn2', PHP7Middleware::class); + } + + public function testAddMiddlewareWithAutoconfigure(): void + { + /** @psalm-suppress UndefinedClass */ + if (! interface_exists(Middleware::class)) { + $this->markTestSkipped(sprintf('%s needed to run this test', Middleware::class)); + } + + $container = $this->createContainer(static function (ContainerBuilder $container) { + /** @psalm-suppress UndefinedClass */ + $container + ->register('middleware', AutoconfiguredPHP7Middleware::class) + ->setAutoconfigured(true); + + $container + ->setAlias('conf_conn1', 'doctrine.dbal.conn1_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + + $container + ->setAlias('conf_conn2', 'doctrine.dbal.conn2_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + }); + + /** @psalm-suppress UndefinedClass */ + $this->assertMiddlewareInjected($container, 'conn1', AutoconfiguredPHP7Middleware::class); + /** @psalm-suppress UndefinedClass */ + $this->assertMiddlewareInjected($container, 'conn2', AutoconfiguredPHP7Middleware::class); + } + + /** @return array */ + public function provideAddMiddlewareWithAttributeForAutoconfiguration(): array + { + /** @psalm-suppress UndefinedClass */ + return [ + 'without specifying connection' => [AutoconfiguredMiddleware::class, true], + 'specifying connection' => [AutoconfiguredMiddlewareWithConnection::class, false], + ]; + } + + /** + * @param class-string $className + * + * @dataProvider provideAddMiddlewareWithAttributeForAutoconfiguration + */ + public function testAddMiddlewareWithAttributeForAutoconfiguration(string $className, bool $registeredOnConn1): void + { + /** @psalm-suppress UndefinedClass */ + if (! interface_exists(Middleware::class)) { + $this->markTestSkipped(sprintf('%s needed to run this test', Middleware::class)); + } + + if (PHP_VERSION_ID < 80000 || ! method_exists(ContainerBuilder::class, 'registerAttributeForAutoconfiguration')) { + $this->markTestSkipped(sprintf( + 'Testing attribute for autoconfiguration requires PHP 8 and %s::registerAttributeForAutoconfiguration', + ContainerBuilder::class + )); + } + + $container = $this->createContainer(static function (ContainerBuilder $container) use ($className) { + /** @psalm-suppress UndefinedClass */ + $container + ->register('middleware', $className) + ->setAutoconfigured(true); + + $container + ->setAlias('conf_conn1', 'doctrine.dbal.conn1_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + + $container + ->setAlias('conf_conn2', 'doctrine.dbal.conn2_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + }); + + if ($registeredOnConn1) { + /** @psalm-suppress UndefinedClass */ + $this->assertMiddlewareInjected($container, 'conn1', $className); + } else { + $this->assertMiddlewareNotInjected($container, 'conn1', $className); + } + + /** @psalm-suppress UndefinedClass */ + $this->assertMiddlewareInjected($container, 'conn2', $className); + } + + private function createContainer(callable $func): ContainerBuilder + { + $container = new ContainerBuilder(new ParameterBag(['kernel.debug' => false])); + + $container->registerExtension(new DoctrineExtension()); + $container->loadFromExtension('doctrine', [ + 'dbal' => [ + 'connections' => [ + 'conn1' => ['url' => 'mysql://user:pass@server1.tld:3306/db1'], + 'conn2' => ['url' => 'mysql://user:pass@server2.tld:3306/db2'], + ], + ], + ]); + + $container->addCompilerPass(new MiddlewaresPass()); + + $func($container); + + $container->compile(); + + return $container; + } + + private function assertMiddlewareInjected( + ContainerBuilder $container, + string $connName, + string $middlewareClass, + bool $connectionNameAware = false + ): void { + $middlewareFound = $this->getMiddlewaresForConn($container, $connName, $middlewareClass); + + $this->assertCount(1, $middlewareFound, sprintf( + 'Middleware %s not injected in doctrine.dbal.%s_connection.configuration', + $middlewareClass, + $connName + )); + + $callsFound = []; + foreach ($middlewareFound[0]->getMethodCalls() as $call) { + if ($call[0] !== 'setConnectionName') { + continue; + } + + $callsFound[] = $call; + } + + if ($connectionNameAware) { + $this->assertCount(1, $callsFound); + $this->assertSame($callsFound[0][1][0] ?? null, $connName); + } else { + $this->assertCount(0, $callsFound); + } + } + + private function assertMiddlewareNotInjected( + ContainerBuilder $container, + string $connName, + string $middlewareClass + ): void { + $middlewareFound = $this->getMiddlewaresForConn($container, $connName, $middlewareClass); + + $this->assertCount(0, $middlewareFound, sprintf( + 'Middleware %s injected in doctrine.dbal.%s_connection.configuration', + $middlewareClass, + $connName + )); + } + + /** @return Definition[] */ + private function getMiddlewaresForConn(ContainerBuilder $container, string $connName, string $middlewareClass): array + { + $calls = $container->getDefinition('conf_' . $connName)->getMethodCalls(); + $middlewaresFound = []; + foreach ($calls as $call) { + if ($call[0] !== 'setMiddlewares' || ! isset($call[1][0])) { + continue; + } + + foreach ($call[1][0] as $middlewareDef) { + if ($middlewareDef->getClass() !== $middlewareClass) { + continue; + } + + $middlewaresFound[] = $middlewareDef; + } + } + + return $middlewaresFound; + } +} + +class PHP7Middleware +{ +} + +class ConnectionAwarePHP7Middleware implements ConnectionNameAwareInterface +{ + public function setConnectionName(string $name): void + { + } +} + +/** @psalm-suppress UndefinedClass */ +if (interface_exists(Middleware::class)) { + class AutoconfiguredPHP7Middleware implements Middleware + { + public function wrap(Driver $driver): Driver + { + return $driver; + } + } + + if (PHP_VERSION_ID >= 80000) { + #[AsMiddleware] + class AutoconfiguredMiddleware + { + } + + #[AsMiddleware(connections: ['conn2'])] + class AutoconfiguredMiddlewareWithConnection + { + } + } +} diff --git a/Tests/DependencyInjection/DoctrineExtensionTest.php b/Tests/DependencyInjection/DoctrineExtensionTest.php index a641a08a7..e9da37289 100644 --- a/Tests/DependencyInjection/DoctrineExtensionTest.php +++ b/Tests/DependencyInjection/DoctrineExtensionTest.php @@ -16,7 +16,7 @@ use Doctrine\Common\Cache\XcacheCache; use Doctrine\DBAL\Connection; use Doctrine\DBAL\Driver\Connection as DriverConnection; -use Doctrine\DBAL\Driver\Middleware; +use Doctrine\DBAL\Logging\Middleware; use Doctrine\DBAL\Sharding\PoolingShardManager; use Doctrine\DBAL\Sharding\SQLAzure\SQLAzureShardManager; use Doctrine\ORM\Cache\CacheConfiguration; @@ -50,9 +50,9 @@ use Symfony\Component\HttpKernel\Kernel; use Symfony\Component\Messenger\MessageBusInterface; -use function array_filter; use function array_values; use function class_exists; +use function in_array; use function interface_exists; use function is_dir; use function method_exists; @@ -1158,10 +1158,24 @@ public function testAsEntityListenerAttribute() $this->assertSame([$expected], $definition->getTag('doctrine.orm.entity_listener')); } - public function testMiddlewaresAreNotAvailable(): void + /** + * @return array + */ + public function provideDefinitionsToLogQueries(): array + { + return [ + 'with middlewares' => [true, false, true], + 'without middlewares' => [false, true, false], + ]; + } + + /** + * @dataProvider provideDefinitionsToLogQueries + */ + public function testDefinitionsToLogQueries(bool $withMiddleware, bool $loggerInjected, bool $middlewareRegistered): void { /** @psalm-suppress UndefinedClass */ - if (interface_exists(Middleware::class)) { + if ($withMiddleware !== class_exists(Middleware::class)) { $this->markTestSkipped(sprintf('%s needs %s to not exist', __METHOD__, Middleware::class)); } @@ -1171,10 +1185,14 @@ public function testMiddlewaresAreNotAvailable(): void $config = BundleConfigurationBuilder::createBuilderWithBaseValues() ->addConnection([ 'connections' => [ - 'default' => [ + 'conn1' => [ 'password' => 'foo', 'logging' => true, ], + 'conn2' => [ + 'password' => 'bar', + 'logging' => false, + ], ], ]) ->addBaseEntityManager() @@ -1182,64 +1200,27 @@ public function testMiddlewaresAreNotAvailable(): void $extension->load([$config], $container); - $loggerDef = $container->getDefinition('doctrine.dbal.logger'); - $tags = $loggerDef->getTag('monolog.logger'); - $doctrineLoggerTags = array_filter($tags, static function (array $tag): bool { - return ($tag['channel'] ?? null) === 'doctrine'; - }); - $this->assertCount(1, $doctrineLoggerTags); - - $this->assertInstanceOf(Reference::class, $loggerDef->getArgument(0)); - $this->assertSame('logger', (string) $loggerDef->getArgument(0)); + $loggerDef = $container->getDefinition('doctrine.dbal.logger'); + $this->assertSame($loggerInjected, $loggerDef->getArgument(0) !== null); - $this->assertFalse($container->hasDefinition('doctrine.dbal.default_connection.logging_middleware')); - } + $this->assertSame($middlewareRegistered, $container->hasDefinition('doctrine.dbal.logging_middleware')); - public function testMiddlewaresAreAvailable(): void - { - /** @psalm-suppress UndefinedClass */ - if (! interface_exists(Middleware::class)) { - $this->markTestSkipped(sprintf('%s needs %s to exist', __METHOD__, Middleware::class)); + if (! $withMiddleware) { + return; } - $container = $this->getContainer(); - $extension = new DoctrineExtension(); - - $config = BundleConfigurationBuilder::createBuilderWithBaseValues() - ->addConnection([ - 'connections' => [ - 'default' => [ - 'password' => 'foo', - 'logging' => true, - ], - ], - ]) - ->addBaseEntityManager() - ->build(); + $abstractMiddlewareDefTags = $container->getDefinition('doctrine.dbal.logging_middleware')->getTags(); + $middleWareTagAttributes = []; + foreach ($abstractMiddlewareDefTags as $tag => $attributes) { + if ($tag !== 'doctrine.middleware') { + continue; + } - $extension->load([$config], $container); + $middleWareTagAttributes = $attributes; + } - $loggerDef = $container->getDefinition('doctrine.dbal.logger'); - $this->assertNull($loggerDef->getArgument(0)); - - $loggingMiddlewareDef = $container->getDefinition('doctrine.dbal.default_connection.logging_middleware'); - $this->assertInstanceOf(Reference::class, $loggingMiddlewareDef->getArgument(0)); - $this->assertSame('logger', (string) $loggingMiddlewareDef->getArgument(0)); - $tags = $loggingMiddlewareDef->getTag('monolog.logger'); - $doctrineLoggerTags = array_filter($tags, static function (array $tag): bool { - return ($tag['channel'] ?? null) === 'doctrine'; - }); - $this->assertCount(1, $doctrineLoggerTags); - - $connectionConfiguration = $container->getDefinition('doctrine.dbal.default_connection.configuration'); - $setMiddlewareCalls = array_filter($connectionConfiguration->getMethodCalls(), static function (array $call) { - return $call[0] === 'setMiddlewares'; - }); - $this->assertCount(1, $setMiddlewareCalls); - $callArgs = $setMiddlewareCalls[0][1]; - $this->assertCount(1, $callArgs); - $this->assertCount(1, $callArgs[0]); - $this->assertInstanceOf(Definition::class, $callArgs[0][0]); + $this->assertTrue(in_array(['connection' => 'conn1'], $middleWareTagAttributes, true), 'Tag with connection conn1 not found'); + $this->assertFalse(in_array(['connection' => 'conn2'], $middleWareTagAttributes, true), 'Tag with connection conn2 found'); } // phpcs:enable