From bca0e58ff85243c4d0918fcc1ea7d2fe533dd166 Mon Sep 17 00:00:00 2001 From: Laurent VOULLEMIER Date: Tue, 8 Feb 2022 21:25:42 +0100 Subject: [PATCH] Allow userland middlewares --- .../Compiler/MiddlewaresPass.php | 57 ++++++ DependencyInjection/DoctrineExtension.php | 30 ++- DoctrineBundle.php | 8 + Middleware/ConnectionNameAwareInterface.php | 8 + Resources/config/dbal.xml | 3 - Resources/config/middlewares.xml | 14 ++ .../Compiler/MiddlewarePassTest.php | 171 ++++++++++++++++++ .../DoctrineExtensionTest.php | 75 ++------ 8 files changed, 289 insertions(+), 77 deletions(-) create mode 100644 DependencyInjection/Compiler/MiddlewaresPass.php create mode 100644 Middleware/ConnectionNameAwareInterface.php create mode 100644 Resources/config/middlewares.xml create mode 100644 Tests/DependencyInjection/Compiler/MiddlewarePassTest.php diff --git a/DependencyInjection/Compiler/MiddlewaresPass.php b/DependencyInjection/Compiler/MiddlewaresPass.php new file mode 100644 index 000000000..6a7af59dd --- /dev/null +++ b/DependencyInjection/Compiler/MiddlewaresPass.php @@ -0,0 +1,57 @@ +connectionDefsParam = $connectionDefsParam; + $this->middlewareTag = $middlewareTag; + } + + public function process(ContainerBuilder $container): void + { + $middlewareAbstractDefs = []; + foreach (array_keys($container->findTaggedServiceIds($this->middlewareTag)) as $id) { + $middlewareAbstractDefs[$id] = $container->getDefinition($id); + } + + foreach ($container->getParameter($this->connectionDefsParam) as $name => $id) { + $middlewareDefs = []; + foreach ($middlewareAbstractDefs as $id => $abstractDef) { + $middlewareDefs[] = $childDef = $container->setDefinition( + sprintf('%s.%s', $id, $name), + new ChildDefinition($id) + ); + + if (! is_subclass_of($abstractDef->getClass(), ConnectionNameAwareInterface::class)) { + continue; + } + + $childDef->addMethodCall('setConnectionName', [$name]); + } + + $container + ->getDefinition(sprintf('doctrine.dbal.%s_connection.configuration', $name)) + ->addMethodCall('setMiddlewares', [$middlewareDefs]); + } + } +} diff --git a/DependencyInjection/DoctrineExtension.php b/DependencyInjection/DoctrineExtension.php index 6790ed0ec..0a54f7d53 100644 --- a/DependencyInjection/DoctrineExtension.php +++ b/DependencyInjection/DoctrineExtension.php @@ -44,7 +44,6 @@ use Symfony\Component\DependencyInjection\Alias; use Symfony\Component\DependencyInjection\ChildDefinition; use Symfony\Component\DependencyInjection\ContainerBuilder; -use Symfony\Component\DependencyInjection\ContainerInterface; use Symfony\Component\DependencyInjection\Definition; use Symfony\Component\DependencyInjection\Exception\InvalidArgumentException; use Symfony\Component\DependencyInjection\Loader\XmlFileLoader; @@ -85,6 +84,8 @@ public function load(array $configs, ContainerBuilder $container) $this->dbalLoad($config['dbal'], $container); $this->loadMessengerServices($container); + + $this->loadMiddlewares($container); } if (empty($config['orm'])) { @@ -143,9 +144,18 @@ protected function dbalLoad(array $config, ContainerBuilder $container) $container->setParameter('doctrine.connections', $connections); $container->setParameter('doctrine.default_connection', $this->defaultConnection); + /** @psalm-suppress UndefinedClass */ + if (interface_exists(Middleware::class)) { + $container + ->getDefinition('doctrine.dbal.logger') + ->replaceArgument(0, null); + } + foreach ($config['connections'] as $name => $connection) { $this->loadDbalConnection($name, $connection, $container); } + + $container->registerForAutoconfiguration(Middleware::class)->addTag('doctrine.middleware'); } /** @@ -160,7 +170,6 @@ protected function loadDbalConnection($name, array $connection, ContainerBuilder $configuration = $container->setDefinition(sprintf('doctrine.dbal.%s_connection.configuration', $name), new ChildDefinition('doctrine.dbal.connection.configuration')); $logger = null; if ($connection['logging']) { - $this->useMiddlewaresIfAvailable($connection, $container, $name, $configuration); $logger = new Reference('doctrine.dbal.logger'); } @@ -1073,25 +1082,14 @@ private function createArrayAdapterCachePool(ContainerBuilder $container, string return $id; } - /** @param array $connection */ - protected function useMiddlewaresIfAvailable(array $connection, ContainerBuilder $container, string $name, Definition $configuration): void + private function loadMiddlewares(ContainerBuilder $container): void { /** @psalm-suppress UndefinedClass */ if (! interface_exists(Middleware::class)) { return; } - $container - ->getDefinition('doctrine.dbal.logger') - ->replaceArgument(0, null); - - $loggingMiddlewareDef = $container->setDefinition( - sprintf('doctrine.dbal.%s_connection.logging_middleware', $name), - new ChildDefinition('doctrine.dbal.logging_middleware') - ); - $loggingMiddlewareDef->addArgument(new Reference('logger', ContainerInterface::NULL_ON_INVALID_REFERENCE)); - $loggingMiddlewareDef->addTag('monolog.logger', ['channel' => 'doctrine']); - - $configuration->addMethodCall('setMiddlewares', [[$loggingMiddlewareDef]]); + $loader = new XmlFileLoader($container, new FileLocator(__DIR__ . '/../Resources/config')); + $loader->load('middlewares.xml'); } } diff --git a/DoctrineBundle.php b/DoctrineBundle.php index 01daf7d12..9ec763a9f 100644 --- a/DoctrineBundle.php +++ b/DoctrineBundle.php @@ -7,10 +7,12 @@ use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\DbalSchemaFilterPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\EntityListenerPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\IdGeneratorPass; +use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\MiddlewaresPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\RemoveProfilerControllerPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\ServiceRepositoryCompilerPass; use Doctrine\Bundle\DoctrineBundle\DependencyInjection\Compiler\WellKnownSchemaFilterPass; use Doctrine\Common\Util\ClassUtils; +use Doctrine\DBAL\Driver\Middleware; use Doctrine\ORM\EntityManagerInterface; use Doctrine\ORM\Proxy\Autoloader; use Symfony\Bridge\Doctrine\DependencyInjection\CompilerPass\DoctrineValidationPass; @@ -26,6 +28,7 @@ use function assert; use function class_exists; use function clearstatcache; +use function interface_exists; use function spl_autoload_unregister; class DoctrineBundle extends Bundle @@ -60,6 +63,11 @@ public function build(ContainerBuilder $container) $container->addCompilerPass(new CacheSchemaSubscriberPass(), PassConfig::TYPE_BEFORE_REMOVING, -10); $container->addCompilerPass(new RemoveProfilerControllerPass()); + /** @psalm-suppress UndefinedClass */ + if (interface_exists(Middleware::class)) { + $container->addCompilerPass(new MiddlewaresPass()); + } + if (! class_exists(RegisterUidTypePass::class)) { return; } diff --git a/Middleware/ConnectionNameAwareInterface.php b/Middleware/ConnectionNameAwareInterface.php new file mode 100644 index 000000000..5bb2e41da --- /dev/null +++ b/Middleware/ConnectionNameAwareInterface.php @@ -0,0 +1,8 @@ + - - - diff --git a/Resources/config/middlewares.xml b/Resources/config/middlewares.xml new file mode 100644 index 000000000..9fa2d797f --- /dev/null +++ b/Resources/config/middlewares.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + diff --git a/Tests/DependencyInjection/Compiler/MiddlewarePassTest.php b/Tests/DependencyInjection/Compiler/MiddlewarePassTest.php new file mode 100644 index 000000000..1ddee2ae0 --- /dev/null +++ b/Tests/DependencyInjection/Compiler/MiddlewarePassTest.php @@ -0,0 +1,171 @@ + */ + public function provideAddMiddleware(): array + { + return [ + 'not connection name aware' => [Middleware1::class, false], + 'connection name aware' => [Middleware2::class, true], + ]; + } + + /** @dataProvider provideAddMiddleware */ + public function testAddMiddleware(string $middlewareClass, bool $connectionNameAware): void + { + /** @psalm-suppress UndefinedClass */ + if (! interface_exists(Middleware::class)) { + $this->markTestSkipped(sprintf('%s needed to run this test', Middleware::class)); + } + + $container = $this->createContainer(static function (ContainerBuilder $container) use ($middlewareClass) { + $container + ->register('middleware', $middlewareClass) + ->setAbstract(true) + ->addTag('doctrine.middleware'); + + $container + ->setAlias('conf_conn1', 'doctrine.dbal.conn1_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + + $container + ->setAlias('conf_conn2', 'doctrine.dbal.conn2_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + }); + + $this->assertMiddlewareInjected('conn1', $middlewareClass, $connectionNameAware, $container); + $this->assertMiddlewareInjected('conn2', $middlewareClass, $connectionNameAware, $container); + } + + public function testAddMiddlewareWithAutoconfigure(): void + { + /** @psalm-suppress UndefinedClass */ + if (! interface_exists(Middleware::class)) { + $this->markTestSkipped(sprintf('%s needed to run this test', Middleware::class)); + } + + $container = $this->createContainer(static function (ContainerBuilder $container) { + /** @psalm-suppress UndefinedClass */ + $container + ->register('middleware', Middleware3::class) + ->setAutoconfigured(true); + + $container + ->setAlias('conf_conn1', 'doctrine.dbal.conn1_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + + $container + ->setAlias('conf_conn2', 'doctrine.dbal.conn2_connection.configuration') + ->setPublic(true); // Avoid removal and inlining + }); + + /** @psalm-suppress UndefinedClass */ + $this->assertMiddlewareInjected('conn1', Middleware3::class, false, $container); + /** @psalm-suppress UndefinedClass */ + $this->assertMiddlewareInjected('conn2', Middleware3::class, false, $container); + } + + private function createContainer(callable $func): ContainerBuilder + { + $container = new ContainerBuilder(new ParameterBag(['kernel.debug' => false])); + + $container->registerExtension(new DoctrineExtension()); + $container->loadFromExtension('doctrine', [ + 'dbal' => [ + 'connections' => [ + 'conn1' => ['url' => 'mysql://user:pass@server1.tld:3306/db1'], + 'conn2' => ['url' => 'mysql://user:pass@server2.tld:3306/db2'], + ], + ], + ]); + + $container->addCompilerPass(new MiddlewaresPass()); + + $func($container); + + $container->compile(); + + return $container; + } + + private function assertMiddlewareInjected( + string $connName, + string $middlewareClass, + bool $connectionNameAware, + ContainerBuilder $container + ): void { + $calls = $container->getDefinition('conf_' . $connName)->getMethodCalls(); + $middlewareFound = []; + foreach ($calls as $call) { + if ($call[0] !== 'setMiddlewares' || ! isset($call[1][0])) { + continue; + } + + foreach ($call[1][0] as $middlewareDefs) { + if ($middlewareDefs->getClass() !== $middlewareClass) { + continue; + } + + $middlewareFound[] = $middlewareDefs; + } + } + + $this->assertCount(1, $middlewareFound, sprintf( + 'Middleware not injected in doctrine.dbal.%s_connection.configuration', + $connName + )); + + $callsFound = []; + foreach ($middlewareFound[0]->getMethodCalls() as $call) { + if ($call[0] !== 'setConnectionName') { + continue; + } + + $callsFound[] = $call; + } + + $this->assertCount($connectionNameAware ? 1 : 0, $callsFound); + if (! $connectionNameAware) { + return; + } + + $this->assertSame($call[1][0] ?? null, $connName); + } +} + +class Middleware1 +{ +} + +class Middleware2 implements ConnectionNameAwareInterface +{ + public function setConnectionName(string $name): void + { + } +} + +/** @psalm-suppress UndefinedClass */ +if (interface_exists(Middleware::class)) { + class Middleware3 implements Middleware + { + public function wrap(Driver $driver): Driver + { + return $driver; + } + } +} diff --git a/Tests/DependencyInjection/DoctrineExtensionTest.php b/Tests/DependencyInjection/DoctrineExtensionTest.php index a641a08a7..42a92689d 100644 --- a/Tests/DependencyInjection/DoctrineExtensionTest.php +++ b/Tests/DependencyInjection/DoctrineExtensionTest.php @@ -50,7 +50,6 @@ use Symfony\Component\HttpKernel\Kernel; use Symfony\Component\Messenger\MessageBusInterface; -use function array_filter; use function array_values; use function class_exists; use function interface_exists; @@ -1158,48 +1157,25 @@ public function testAsEntityListenerAttribute() $this->assertSame([$expected], $definition->getTag('doctrine.orm.entity_listener')); } - public function testMiddlewaresAreNotAvailable(): void + /** + * @return array + */ + public function provideLoggingMiddleware(): array { - /** @psalm-suppress UndefinedClass */ - if (interface_exists(Middleware::class)) { - $this->markTestSkipped(sprintf('%s needs %s to not exist', __METHOD__, Middleware::class)); - } - - $container = $this->getContainer(); - $extension = new DoctrineExtension(); - - $config = BundleConfigurationBuilder::createBuilderWithBaseValues() - ->addConnection([ - 'connections' => [ - 'default' => [ - 'password' => 'foo', - 'logging' => true, - ], - ], - ]) - ->addBaseEntityManager() - ->build(); - - $extension->load([$config], $container); - - $loggerDef = $container->getDefinition('doctrine.dbal.logger'); - $tags = $loggerDef->getTag('monolog.logger'); - $doctrineLoggerTags = array_filter($tags, static function (array $tag): bool { - return ($tag['channel'] ?? null) === 'doctrine'; - }); - $this->assertCount(1, $doctrineLoggerTags); - - $this->assertInstanceOf(Reference::class, $loggerDef->getArgument(0)); - $this->assertSame('logger', (string) $loggerDef->getArgument(0)); - - $this->assertFalse($container->hasDefinition('doctrine.dbal.default_connection.logging_middleware')); + return [ + 'with middlewares' => [true, false, true], + 'without middlewares' => [false, true, false], + ]; } - public function testMiddlewaresAreAvailable(): void + /** + * @dataProvider provideLoggingMiddleware + */ + public function testLoggingMiddleware(bool $withMiddleware, bool $loggerInjected, bool $middlewareRegistered): void { /** @psalm-suppress UndefinedClass */ - if (! interface_exists(Middleware::class)) { - $this->markTestSkipped(sprintf('%s needs %s to exist', __METHOD__, Middleware::class)); + if ($withMiddleware !== interface_exists(Middleware::class)) { + $this->markTestSkipped(sprintf('%s needs %s to not exist', __METHOD__, Middleware::class)); } $container = $this->getContainer(); @@ -1220,26 +1196,9 @@ public function testMiddlewaresAreAvailable(): void $extension->load([$config], $container); $loggerDef = $container->getDefinition('doctrine.dbal.logger'); - $this->assertNull($loggerDef->getArgument(0)); - - $loggingMiddlewareDef = $container->getDefinition('doctrine.dbal.default_connection.logging_middleware'); - $this->assertInstanceOf(Reference::class, $loggingMiddlewareDef->getArgument(0)); - $this->assertSame('logger', (string) $loggingMiddlewareDef->getArgument(0)); - $tags = $loggingMiddlewareDef->getTag('monolog.logger'); - $doctrineLoggerTags = array_filter($tags, static function (array $tag): bool { - return ($tag['channel'] ?? null) === 'doctrine'; - }); - $this->assertCount(1, $doctrineLoggerTags); - - $connectionConfiguration = $container->getDefinition('doctrine.dbal.default_connection.configuration'); - $setMiddlewareCalls = array_filter($connectionConfiguration->getMethodCalls(), static function (array $call) { - return $call[0] === 'setMiddlewares'; - }); - $this->assertCount(1, $setMiddlewareCalls); - $callArgs = $setMiddlewareCalls[0][1]; - $this->assertCount(1, $callArgs); - $this->assertCount(1, $callArgs[0]); - $this->assertInstanceOf(Definition::class, $callArgs[0][0]); + $this->assertSame($loggerInjected, $loggerDef->getArgument(0) !== null); + + $this->assertSame($middlewareRegistered, $container->hasDefinition('doctrine.dbal.logging_middleware')); } // phpcs:enable