diff --git a/src/CachedKeySet.php b/src/CachedKeySet.php index 87f470d7..baf801f1 100644 --- a/src/CachedKeySet.php +++ b/src/CachedKeySet.php @@ -3,6 +3,7 @@ namespace Firebase\JWT; use ArrayAccess; +use InvalidArgumentException; use LogicException; use OutOfBoundsException; use Psr\Cache\CacheItemInterface; @@ -10,6 +11,7 @@ use Psr\Http\Client\ClientInterface; use Psr\Http\Message\RequestFactoryInterface; use RuntimeException; +use UnexpectedValueException; /** * @implements ArrayAccess @@ -41,7 +43,7 @@ class CachedKeySet implements ArrayAccess */ private $cacheItem; /** - * @var array + * @var array> */ private $keySet; /** @@ -101,7 +103,7 @@ public function offsetGet($keyId): Key if (!$this->keyIdExists($keyId)) { throw new OutOfBoundsException('Key ID not found'); } - return $this->keySet[$keyId]; + return JWK::parseKey($this->keySet[$keyId], $this->defaultAlg); } /** @@ -130,15 +132,43 @@ public function offsetUnset($offset): void throw new LogicException('Method not implemented'); } + /** + * @return array + */ + private function formatJwksForCache(string $jwks): array + { + $jwks = json_decode($jwks, true); + + if (!isset($jwks['keys'])) { + throw new UnexpectedValueException('"keys" member must exist in the JWK Set'); + } + + if (empty($jwks['keys'])) { + throw new InvalidArgumentException('JWK Set did not contain any keys'); + } + + $keys = []; + foreach ($jwks['keys'] as $k => $v) { + $kid = isset($v['kid']) ? $v['kid'] : $k; + $keys[(string) $kid] = $v; + } + + return $keys; + } + private function keyIdExists(string $keyId): bool { if (null === $this->keySet) { $item = $this->getCacheItem(); // Try to load keys from cache if ($item->isHit()) { - // item found! Return it - $jwks = $item->get(); - $this->keySet = JWK::parseKeySet(json_decode($jwks, true), $this->defaultAlg); + // item found! retrieve it + $this->keySet = $item->get(); + // If the cached item is a string, the JWKS response was cached (previous behavior). + // Parse this into expected format array instead. + if (\is_string($this->keySet)) { + $this->keySet = $this->formatJwksForCache($this->keySet); + } } } @@ -148,15 +178,14 @@ private function keyIdExists(string $keyId): bool } $request = $this->httpFactory->createRequest('GET', $this->jwksUri); $jwksResponse = $this->httpClient->sendRequest($request); - $jwks = (string) $jwksResponse->getBody(); - $this->keySet = JWK::parseKeySet(json_decode($jwks, true), $this->defaultAlg); + $this->keySet = $this->formatJwksForCache((string) $jwksResponse->getBody()); if (!isset($this->keySet[$keyId])) { return false; } $item = $this->getCacheItem(); - $item->set($jwks); + $item->set($this->keySet); if ($this->expiresAfter) { $item->expiresAfter($this->expiresAfter); } diff --git a/tests/CachedKeySetTest.php b/tests/CachedKeySetTest.php index 91d4a27c..04f7994c 100644 --- a/tests/CachedKeySetTest.php +++ b/tests/CachedKeySetTest.php @@ -17,11 +17,12 @@ class CachedKeySetTest extends TestCase private $testJwksUri = 'https://jwk.uri'; private $testJwksUriKey = 'jwkshttpsjwk.uri'; private $testJwks1 = '{"keys": [{"kid":"foo","kty":"RSA","alg":"foo","n":"","e":""}]}'; + private $testCachedJwks1 = ['foo' => ['kid' => 'foo', 'kty' => 'RSA', 'alg' => 'foo', 'n' => '', 'e' => '']]; private $testJwks2 = '{"keys": [{"kid":"bar","kty":"RSA","alg":"bar","n":"","e":""}]}'; private $testJwks3 = '{"keys": [{"kid":"baz","kty":"RSA","n":"","e":""}]}'; private $googleRsaUri = 'https://www.googleapis.com/oauth2/v3/certs'; - // private $googleEcUri = 'https://www.gstatic.com/iap/verify/public_key-jwk'; + private $googleEcUri = 'https://www.gstatic.com/iap/verify/public_key-jwk'; public function testEmptyUriThrowsException() { @@ -117,7 +118,7 @@ public function testKeyIdIsCached() $cacheItem->isHit() ->willReturn(true); $cacheItem->get() - ->willReturn($this->testJwks1); + ->willReturn($this->testCachedJwks1); $cache = $this->prophesize(CacheItemPoolInterface::class); $cache->getItem($this->testJwksUriKey) @@ -136,6 +137,66 @@ public function testKeyIdIsCached() } public function testCachedKeyIdRefresh() + { + $cacheItem = $this->prophesize(CacheItemInterface::class); + $cacheItem->isHit() + ->shouldBeCalledOnce() + ->willReturn(true); + $cacheItem->get() + ->shouldBeCalledOnce() + ->willReturn($this->testCachedJwks1); + $cacheItem->set(Argument::any()) + ->shouldBeCalledOnce() + ->will(function () { + return $this; + }); + + $cache = $this->prophesize(CacheItemPoolInterface::class); + $cache->getItem($this->testJwksUriKey) + ->shouldBeCalledOnce() + ->willReturn($cacheItem->reveal()); + $cache->save(Argument::any()) + ->shouldBeCalledOnce() + ->willReturn(true); + + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->getMockHttpClient($this->testJwks2), // updated JWK + $this->getMockHttpFactory(), + $cache->reveal() + ); + $this->assertInstanceOf(Key::class, $cachedKeySet['foo']); + $this->assertSame('foo', $cachedKeySet['foo']->getAlgorithm()); + + $this->assertInstanceOf(Key::class, $cachedKeySet['bar']); + $this->assertSame('bar', $cachedKeySet['bar']->getAlgorithm()); + } + + public function testKeyIdIsCachedFromPreviousFormat() + { + $cacheItem = $this->prophesize(CacheItemInterface::class); + $cacheItem->isHit() + ->willReturn(true); + $cacheItem->get() + ->willReturn($this->testJwks1); + + $cache = $this->prophesize(CacheItemPoolInterface::class); + $cache->getItem($this->testJwksUriKey) + ->willReturn($cacheItem->reveal()); + $cache->save(Argument::any()) + ->willReturn(true); + + $cachedKeySet = new CachedKeySet( + $this->testJwksUri, + $this->prophesize(ClientInterface::class)->reveal(), + $this->prophesize(RequestFactoryInterface::class)->reveal(), + $cache->reveal() + ); + $this->assertInstanceOf(Key::class, $cachedKeySet['foo']); + $this->assertSame('foo', $cachedKeySet['foo']->getAlgorithm()); + } + + public function testCachedKeyIdRefreshFromPreviousFormat() { $cacheItem = $this->prophesize(CacheItemInterface::class); $cacheItem->isHit() @@ -213,12 +274,18 @@ public function testJwtVerify() $payload = ['sub' => 'foo', 'exp' => strtotime('+10 seconds')]; $msg = JWT::encode($payload, $privKey1, 'RS256', 'jwk1'); + // format the cached value to match the expected format + $cachedJwks = []; + $rsaKeySet = file_get_contents(__DIR__ . '/data/rsa-jwkset.json'); + foreach (json_decode($rsaKeySet, true)['keys'] as $k => $v) { + $cachedJwks[$v['kid']] = $v; + } + $cacheItem = $this->prophesize(CacheItemInterface::class); $cacheItem->isHit() ->willReturn(true); $cacheItem->get() - ->willReturn(file_get_contents(__DIR__ . '/data/rsa-jwkset.json') - ); + ->willReturn($cachedJwks); $cache = $this->prophesize(CacheItemPoolInterface::class); $cache->getItem($this->testJwksUriKey) @@ -297,7 +364,7 @@ public function provideFullIntegration() { return [ [$this->googleRsaUri], - // [$this->googleEcUri, 'LYyP2g'] + [$this->googleEcUri, 'LYyP2g'] ]; }