diff --git a/README.md b/README.md index 2ea80fa0..0eafc1e2 100644 --- a/README.md +++ b/README.md @@ -4,15 +4,25 @@ A continuation-local storage module compatible with [NestJS](https://nestjs.com/ _Continuous-local storage allows to store state and propagate it throughout callbacks and promise chains. It allows storing data throughout the lifetime of a web request or any other asynchronous duration. It is similar to thread-local storage in other languages._ -> Note: For versions < 1.2, this package used [cls-hooked](https://www.npmjs.com/package/cls-hooked) as a peer dependency, now it uses [AsyncLocalStorage](https://nodejs.org/api/async_context.html#async_context_class_asynclocalstorage) from Node's `async_hooks` directly. The API stays the same for now but I'll consider making it more friendly for version 2. +Some common use cases for CLS include: + +- Request ID tracing for logging purposes +- Making the Tenant ID available everywhere in multi-tenant apps +- Globally setting the authentication level for the request + +Most of these are theoretically solvable using _request-scoped_ providers or passing the context as a parameter, but these solutions are often clunky and come with a whole lot of other issues. Thus this package was born. + +> **Note**: For versions < 1.2, this package used [cls-hooked](https://www.npmjs.com/package/cls-hooked) as a peer dependency, now it uses [AsyncLocalStorage](https://nodejs.org/api/async_context.html#async_context_class_asynclocalstorage) from Node's `async_hooks` directly. The API stays the same for now but I'll consider making it more friendly for version 2. # Outline - [Install](#install) - [Quick Start](#quick-start) - [How it works](#how-it-works) - - [HTTP](#http) - - [Non-HTTP](#non-http) +- [Setting up the CLS context](#setting-up-the-cls-context) + - [Using a Middleware](#using-a-middleware-http-only) + - [Using a Guard](#using-a-guard) + - [Using an Interceptor](#using-an-interceptor) - [API](#api) - [Options](#options) - [Request ID](#request-id) @@ -33,13 +43,13 @@ npm install nestjs-cls yarn add nestjs-cls ``` -> Note: This module requires additional peer deps, like the nestjs core and common libraries, but it is assumed those are already installed. +> **Note**: This module requires additional peer deps, like the nestjs core and common libraries, but it is assumed those are already installed. # Quick Start Below is an example of storing the client's IP address in an interceptor and retrieving it in a service without explicitly passing it along. -> Note: This example assumes you are using HTTP and therefore can use middleware. For usage with non-HTTP transports, keep reading. +> **Note**: This example assumes you are using HTTP and therefore can use middleware. For usage with non-HTTP transports, keep reading. ```ts // app.module.ts @@ -113,9 +123,15 @@ export class AppService { Continuation-local storage provides a common space for storing and retrieving data throughout the life of a function/callback call chain. In NestJS, this allows for sharing request data across the lifetime of a single request - without the need for request-scoped providers. It also makes it easy to track and log request ids throughout the whole application. -To make CLS work, it is required to set up the CLS context first. This is done by calling `cls.run()` (or `cls.enter()`) somewhere in the app. Once that is set up, anything that is called within the same callback chain has access to the same storage with `cls.set()` and `cls.get()`. +To make CLS work, it is required to set up the CLS context first. This is done by calling `cls.run()` (or `cls.enter()`, see [Security considerations](#security-considerations) for more info) somewhere in the app. Once that is set up, anything that is called within the same callback chain has access to the same storage with `cls.set()` and `cls.get()`. + +# Setting up the CLS context + +This package provides **three** methods of setting up the CLS context for incoming requests. This is mainly due to the fact that different underlying platforms are compatible with some of these methods - see [Compatibility considerations](#compatibility-considerations). + +For HTTP transports, the context can be preferably set up in a `ClsMiddleware`. For all other platforms, or cases where the `ClsMiddleware` is not applicable, this package also provides a `ClsGuard` and `ClsInterceptor`. While both of these also work with HTTP, they come with some caveats, see below. -## HTTP +## Using a Middleware (HTTP Only) Since in NestJS, HTTP **middleware** is the first thing to run when a request arrives, it is an ideal place to initialise the cls context. This package provides `ClsMidmidleware` that can be mounted to all (or selected) routes inside which the context is set up before the `next()` call. @@ -150,11 +166,11 @@ function bootstrap() { } ``` -> Please note: If you bind the middleware using `app.use()`, it will not respect middleware settings passed to `ClsModule.forRoot()`, so you will have to provide them yourself in the constructor. +> **Please note**: If you bind the middleware using `app.use()`, it will not respect middleware settings passed to `ClsModule.register()`, so you will have to provide them yourself in the constructor. -## Non-HTTP +## Using a Guard -For all other transports that don't use middleware, this package provides a `ClsGuard` to set up the CLS context. While it is not a "guard" per-se, it's the second best place to set up the CLS context, since it would be too late to do it in an interceptor. +The `ClsGuard` can be also used set up the CLS context. While it is not a "guard" per-se, it's the second best place to set up the CLS context, since after a middleware, it is the first piece of code that the request hits. To use it, pass its configuration to the `guard` property to the `ClsModule.register` options: @@ -179,9 +195,23 @@ If you need any other guards to use the `ClsService`, it's preferable mount `Cls export class AppModule {} ``` -> Please note: using the `ClsGuard` comes with some [security considerations](#security-considerations)! +> **Please note**: since the `ClsGuard` uses the `AsyncLocalStorage#enterWith` method, using the `ClsGuard` comes with some [security considerations](#security-considerations)! + +## Using an Interceptor + +Another place to initiate the CLS context is an `ClsInterceptor`, which, unlike the `ClsGuard` uses `AsyncLocalStorage#run` method to wrap the following code, which is considered safer than `enterWith`. + +To use it, pass its configuration to the `interceptor` property to the `ClsModule.register` options: + +```ts +ClsModule.register({ + interceptor: { generateId: true, mount: true } +}), +``` + +Or mount it manually as `APP_INTERCEPTOR`, should you need it. -> Note: A guard might not be the best place to initiate the CLS context for all transports. I'm looking into providing alternative options for specific platforms. +> **Please note**: Since Nest's _Interceptors_ run after _Guards_, that means using this method makes CLS **unavailable in Guards** (and in case of REST Controllers, also in **Exception Filters**). # API @@ -197,63 +227,57 @@ The injectable `ClsService` provides the following API to manipulate the cls con Retrieve the object containing all properties of the current CLS context. - **_`enter`_**`(): void;` Run any following code in a shared CLS context. +- **_`enterWith`_**`(store: any): void;` + Run any following code in a shared CLS context (while supplying the default contents). - **_`run`_**`(callback: () => T): T;` Run the callback in a shared CLS context. +- **_`runWith`_**`(store: any, callback: () => T): T;` + Run the callback in a shared CLS context (while supplying the default contents). - **_`isActive`_**`(): boolean` Whether the current code runs within an active CLS context. # Options -The `ClsModule.register()` method takes the following options: - -- **`ClsModuleOptions`** - - - **_`namespaceName`_: `string`** - The name of the cls namespace. This is the namespace that will be used by the ClsService and ClsMiddleware (most of the time you will not need to touch this setting) - - **_`global:`_ `boolean`** (default _`false`_) - Whether to make the module global, so you do not have to import `ClsModule.forFeature()` in other modules. - - **_`middleware:`_ `ClsMiddlewareOptions`** - An object with additional ClsMiddleware options, see below - - **_`guard:`_ `ClsGuardOptions`** - An object with additional ClsGuard options, see below (do not use together with ClsMiddleware) - -The `ClsMiddleware` takes the following options (either set up in `ClsModuleOptions` or directly when instantiating it manually): - -- **`ClsMiddlewareOptions`** - - - **_`mount`_: `boolean`** (default _`false`_) - Whether to automatically mount the middleware to every route (not applicable when instantiating manually) - - **_`generateId`_: `bolean`** (default _`false`_) - Whether to automatically generate request IDs. - - **_`idGenerator`_: `(req: Request) => string | Promise`** - An optional function for generating the request ID. It takes the `Request` object as an argument and (synchronously or asynchronously) returns a string. The default implementation uses `Math.random()` to generate a string of 8 characters. - - **_`setup`_: `(cls: ClsService, req: Request) => void | Promise;`** - Function that executes after the CLS context has been initialised. It can be used to put additional variables in the CLS context. - - **_`saveReq`_: `boolean`** (default _`true`_) - Whether to store the _Request_ object to the context. It will be available under the `CLS_REQ` key. - - **_`saveRes`_: `boolean`** (default _`false`_) - Whether to store the _Response_ object to the context. It will be available under the `CLS_RES` key - - **_`useEnterWith`_: `boolean`** (default _`false`_) - Set to `true` to set up the context using a call to [`AsyncLocalStorage#enterWith`](https://nodejs.org/api/async_context.html#async_context_asynclocalstorage_enterwith_store) instead of wrapping the `next()` call with the safer [`AsyncLocalStorage#run`](https://nodejs.org/api/async_context.html#async_context_asynclocalstorage_run_store_callback_args). Most of the time this should not be necessary, but [some frameworks](#graphql) are known to lose the context with `run`. - -The `ClsGuard` takes the following options: - -- **`ClsGuardOptions`** - - - **_`mount`_: `boolean`** (default _`false`_) - Whether to automatically mount the guard as APP_GUARD - - **_`generateId`_: `bolean`** (default _`false`_) - Whether to automatically generate request IDs. - - **_`idGenerator`_: `(context: ExecutionContext) => string | Promise`** - An optional function for generating the request ID. It takes the `ExecutionContext` object as an argument and (synchronously or asynchronously) returns a string. The default implementation uses `Math.random()` to generate a string of 8 characters. - - **_`setup`_: `(cls: ClsService, context: ExecutionContext) => void | Promise;`** - Function that executes after the CLS context has been initialised. It can be used to put additional variables in the CLS context. +The `ClsModule.register()` method takes the following `ClsModuleOptions`: + +- **_`namespaceName`_: `string`** + The name of the cls namespace. This is the namespace that will be used by the ClsService and ClsMiddleware (most of the time you will not need to touch this setting) +- **_`global:`_ `boolean`** (default _`false`_) + Whether to make the module global, so you do not have to import `ClsModule.forFeature()` in other modules. +- **_`middleware:`_ `ClsMiddlewareOptions`** + An object with additional options for the ClsMiddleware, see below +- **_`guard:`_ `ClsGuardOptions`** + An object with additional options for the ClsGuard, see below +- **_`interceptor:`_ `ClsInterceptorOptions`** + An object with additional options for the ClsInterceptor, see below + +> Important: the `middleware`, `guard` and `interceptor` options are _mutually exclusive_ - do not use more than one of them, otherwise the context will get overridden with the one that runs after. + +All of the `Cls{Middleware,Guard,Interceptor}Options` take the following parameters (either in `ClsModuleOptions` or directly when instantiating them manually): + +- **_`mount`_: `boolean`** (default _`false`_) + Whether to automatically mount the middleware/guard/interceptor to every route (not applicable when instantiating manually) +- **_`generateId`_: `bolean`** (default _`false`_) + Whether to automatically generate request IDs. +- **_`idGenerator`_: `(req: Request | ExecutionContext) => string | Promise`** + An optional function for generating the request ID. It takes the `Request` object (or the `ExecutionContext` in case of a Guard or Interceptor) as an argument and (synchronously or asynchronously) returns a string. The default implementation uses `Math.random()` to generate a string of 8 characters. +- **_`setup`_: `(cls: ClsService, req: Request) => void | Promise;`** + Function that executes after the CLS context has been initialised. It can be used to put additional variables in the CLS context. + +The `ClsMiddlewareOptions` additionally takes the following parameters: + +- **_`saveReq`_: `boolean`** (default _`true`_) + Whether to store the _Request_ object to the context. It will be available under the `CLS_REQ` key. +- **_`saveRes`_: `boolean`** (default _`false`_) + Whether to store the _Response_ object to the context. It will be available under the `CLS_RES` key +- **_`useEnterWith`_: `boolean`** (default _`false`_) + Set to `true` to set up the context using a call to [`AsyncLocalStorage#enterWith`](https://nodejs.org/api/async_context.html#async_context_asynclocalstorage_enterwith_store) instead of wrapping the `next()` call with the safer [`AsyncLocalStorage#run`](https://nodejs.org/api/async_context.html#async_context_asynclocalstorage_run_store_callback_args). Most of the time this should not be necessary, but [some frameworks](#graphql) are known to lose the context with `run`. # Request ID -Because of a shared storage, CLS is an ideal tool for tracking request (correlation) IDs for the purpose of logging. This package provides an option to automatically generate request IDs in the middleware/guard, if you pass `{ generateId: true }` to the middleware/guard options. By default, the generated ID is a string based on `Math.random()`, but you can provide a custom function in the `idGenerator` option. +Because of a shared storage, CLS is an ideal tool for tracking request (correlation) IDs for the purpose of logging. This package provides an option to automatically generate request IDs in the middleware/guard/interceptor, if you pass `{ generateId: true }` to its options. By default, the generated ID is a string based on `Math.random()`, but you can provide a custom function in the `idGenerator` option. -This function receives the `Request` (or `ExecutionContext` in case a `ClsGuard` is used) as the first parameter, which can be used in the generation process and should return a string ID that will be stored in the CLS for later use. +This function receives the `Request` (or `ExecutionContext` in case a `ClsGuard` is used) as the first parameter, which can be used in the generation process and should return (or resolve with) a string ID that will be stored in the CLS for later use. Below is an example of retrieving the request ID from the request header with a fallback to an autogenerated one. @@ -297,9 +321,9 @@ class MyService { # Additional CLS Setup -The CLS middleware/guard provide some default functionality, but sometimes you might want to store more things in the context by default. This can be of course done in a custom enhancer bound after, but for this scenario the `ClsMiddleware/ClsGuard` options expose the `setup` function, which will be executed in the middleware/guard after the CLS context is set up. +The CLS middleware/guard/interceptor provide some default functionality, but sometimes you might want to store more things in the context by default. This can be of course done in a custom enhancer bound after, but for this scenario the options expose the `setup` function, which will be executed in the middleware/guard after the CLS context is set up. -The function receives the `ClsService` and the `Request` (or `ExecutionContext`) object, and can be asynchronous. +The function receives the `ClsService` instance and the `Request` (or `ExecutionContext`) object, and can be asynchronous. ```ts ClsModule.register({ @@ -307,6 +331,7 @@ ClsModule.register({ mount: true, setup: (cls, req) => { // put some additional default info in the CLS + cls.set('TENANT_ID', req.params('tenant_id')); cls.set('AUTH', { authenticated: false }); }, }, @@ -325,7 +350,7 @@ function helper() { } ``` -> Please note: Only use this feature where absolutely necessary. Using this technique instead of dependency injection will make it difficult to mock the ClsService and your code will become harder to test. +> **Please note**: Only use this feature where absolutely necessary. Using this technique instead of dependency injection will make it difficult to mock the ClsService and your code will become harder to test. # Security considerations @@ -333,35 +358,51 @@ It is often discussed whether [`AsyncLocalStorage`](https://nodejs.org/api/async The `ClsMiddleware` by default uses the safe `run()` method, so it should not leak context, however, that only works for REST `Controllers`. -GraphQL `Resolvers`, cause the context to be lost and therefore require using the less safe `enterWith()` method. The same applies to using `ClsGuard` to set up the context, since there's no callback to wrap with the `run()` call (so the context would be not available outside of the guard otherwise). +GraphQL `Resolvers`, cause the context to be lost and therefore require using the less safe `enterWith()` method. The same applies to using `ClsGuard` to set up the context, since there's no callback to wrap with the `run()` call, the only way to set up context in a guard is to use `enterWith()` (the context would be not available outside of the guard otherwise). -**This has one consequence that should be taken into account:** +**This has a consequence that should be taken into account:** > When the `enterWith` method is used, any consequent requests _get access_ to the context of the previous one _until the request hits the `enterWith` call_. That means, when using `ClsMiddleware` with the `useEnterWith` option, or `ClsGuard` to set up context, be sure to mount them as early in the request lifetime as possible and do not use any other enhancers that rely on `ClsService` before them. For `ClsGuard`, that means you should probably manually mount it in `AppModule` if you require any other guard to run _after_ it. +The `ClsInterceptor` only uses the safe `run()` method. + # Compatibility considerations +The table below outlines the compatibility with some platforms: + +| | REST | GQL | Others | +| :----------------------------------------------------------: | :-------------------------------------------------: | :--------------------------------------------------------: | :----: | +| **ClsMiddleware** | ✔ | must be _mounted manually_
and use `useEnterWith: true` | ✖ | +| **ClsGuard**
(uses `enterWith`) | ✔ | ✔ | ? | +| **ClsInterceptor**
(context inaccessible
in _Guards_) | context also inaccessible
in _Exception Filters_ | ✔ | ? | + ## REST -This package is 100% compatible with Nest-supported REST controllers when you use the `ClsMiddleware` with the `mount` option. +This package is 100% compatible with Nest-supported REST controllers and the preferred way is to use the `ClsMiddleware` with the `mount` option. + +Tested with: - ✔ Express - ✔ Fastify ## GraphQL -For GraphQL, the `ClsMiddleware` needs to be [mounted manually](#manually-mounting-the-middleware) with `app.use(...)` in order to correctly set up the context for resolvers. Additionally, you have to pass `useEnterWith: true` to the `ClsMiddleware` options, because the context gets lost otherwise. +For GraphQL, the `ClsMiddleware` needs to be [mounted manually](#manually-mounting-the-middleware) with `app.use(...)` in order to correctly set up the context for resolvers. Additionally, you have to pass `useEnterWith: true` to the `ClsMiddleware` options, because the context gets lost otherwise due to [an issue with CLS and Apollo](https://github.com/apollographql/apollo-server/issues/2042) (sadly, the same is true for [Mercurius](https://github.com/Papooch/nestjs-cls/issues/1)). This method is functionally identical to just using the `ClsGuard`. + +Alternatively, you can use the `ClsInterceptor`, which uses the safer `AsyncLocalStorage#run` (thanks to [andreialecu](https://github.com/Papooch/nestjs-cls/issues/5)), but remember that using it makes CLS unavailable in _Guards_. -- ⚠ Apollo (Express) - - There's an [issue with CLS and Apollo](https://github.com/apollographql/apollo-server/issues/2042) talking about the context loss. -- ⚠ Mercurius (Fastify) - - The [same problem](https://github.com/Papooch/nestjs-cls/issues/1) applies here. +Tested with: + +- ✔ Apollo (Express) +- ✔ Mercurius (Fastify) ## Others -Use the `ClsGuard` to set up context with any other platform. This is still **experimental**, as there are no test and I can't guarantee it will work with your platform of choice. +Use the `ClsGuard` or `ClsInterceptor` to set up context with any other platform. This is still **experimental**, as there are no test and I can't guarantee it will work with your platform of choice. + +> If you decide to try this package with a platform that is not listed here, **please let me know** so I can add the compatibility notice. # Namespaces (experimental) @@ -369,7 +410,7 @@ Use the `ClsGuard` to set up context with any other platform. This is still **ex The default CLS namespace that the `ClsService` provides should be enough for most application, but should you need it, this package provides a way to use multiple CLS namespaces in order to be fully compatible with `cls-hooked`. -> Note: Since cls-hooked was ditched in version 1.2, it is no longer necessary to strive for compatibility with it. Still, the namespace support was there and there's no reason to remove it. +> **Note**: Since cls-hooked was ditched in version 1.2, it is no longer necessary to strive for compatibility with it. Still, the namespace support was there and there's no reason to remove it. To use custom namespace provider, use `ClsModule.forFeature('my-namespace')`. @@ -398,7 +439,7 @@ class HelloService { } ``` -> Note: `@InjectCls('x')` is equivalent to `@Inject(getNamespaceToken('x'))`. If you don't pass an argument to `@InjectCls()`, the default ClsService will be injected and is equivalent to omitting the decorator altogether. +> **Note**: `@InjectCls('x')` is equivalent to `@Inject(getNamespaceToken('x'))`. If you don't pass an argument to `@InjectCls()`, the default ClsService will be injected and is equivalent to omitting the decorator altogether. ```ts @Injectable() diff --git a/package.json b/package.json index 666157c5..c7259c0b 100644 --- a/package.json +++ b/package.json @@ -16,6 +16,7 @@ "cls", "continuation-local-storage", "als", + "AsyncLocalStorage", "async_hooks", "request context" ], diff --git a/src/index.ts b/src/index.ts index 0fd694e9..41a773e3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,7 @@ export * from './lib/cls-service-manager'; export * from './lib/cls.constants'; export * from './lib/cls.middleware'; +export * from './lib/cls.interceptor'; export * from './lib/cls.module'; export * from './lib/cls.service'; export * from './lib/cls.decorators'; diff --git a/src/lib/cls.constants.ts b/src/lib/cls.constants.ts index 6de1c869..19f66f05 100644 --- a/src/lib/cls.constants.ts +++ b/src/lib/cls.constants.ts @@ -4,3 +4,4 @@ export const CLS_ID = 'CLS_ID'; export const CLS_DEFAULT_NAMESPACE = 'CLS_DEFAULT_NAMESPACE'; export const CLS_MIDDLEWARE_OPTIONS = 'ClsMiddlewareOptions'; export const CLS_GUARD_OPTIONS = 'ClsGuardOptions'; +export const CLS_INTERCEPTOR_OPTIONS = 'ClsInterceptorOptions'; diff --git a/src/lib/cls.interceptor.ts b/src/lib/cls.interceptor.ts new file mode 100644 index 00000000..2380daa8 --- /dev/null +++ b/src/lib/cls.interceptor.ts @@ -0,0 +1,44 @@ +import { + CallHandler, + ExecutionContext, + Inject, + Injectable, + NestInterceptor, +} from '@nestjs/common'; +import { Observable } from 'rxjs'; +import { CLS_ID } from '..'; +import { ClsServiceManager } from './cls-service-manager'; +import { CLS_INTERCEPTOR_OPTIONS } from './cls.constants'; +import { ClsInterceptorOptions } from './cls.interfaces'; + +@Injectable() +export class ClsInterceptor implements NestInterceptor { + constructor( + @Inject(CLS_INTERCEPTOR_OPTIONS) + private readonly options?: Omit, + ) { + this.options = { ...new ClsInterceptorOptions(), ...options }; + } + + intercept(context: ExecutionContext, next: CallHandler): Observable { + const cls = ClsServiceManager.getClsService(this.options.namespaceName); + return new Observable((subscriber) => { + cls.run(async () => { + if (this.options.generateId) { + const id = await this.options.idGenerator(context); + cls.set(CLS_ID, id); + } + if (this.options.setup) { + await this.options.setup(cls, context); + } + next.handle() + .pipe() + .subscribe({ + next: (res) => subscriber.next(res), + error: (err) => subscriber.error(err), + complete: () => subscriber.complete(), + }); + }); + }); + } +} diff --git a/src/lib/cls.interfaces.ts b/src/lib/cls.interfaces.ts index cf3c161c..c78a4294 100644 --- a/src/lib/cls.interfaces.ts +++ b/src/lib/cls.interfaces.ts @@ -25,6 +25,11 @@ export class ClsModuleOptions { * Cls guard options */ guard?: ClsGuardOptions = null; + + /** + * Cls interceptor options + */ + interceptor?: ClsInterceptorOptions = null; } export class ClsMiddlewareOptions { @@ -103,3 +108,32 @@ export class ClsGuardOptions { readonly namespaceName?: string; } + +export class ClsInterceptorOptions { + /** + * whether to mount the interceptor globally + */ + mount?: boolean; // default false + + /** + * whether to automatically generate request ids + */ + generateId?: boolean; // default false + + /** + * the function to generate request ids inside the interceptor + */ + idGenerator?: (context: ExecutionContext) => string | Promise = + () => Math.random().toString(36).slice(-8); + + /** + * Function that executes after the CLS context has been initialised. + * It can be used to put additional variables in the CLS context. + */ + setup?: ( + cls: ClsService, + context: ExecutionContext, + ) => void | Promise; + + readonly namespaceName?: string; +} diff --git a/src/lib/cls.module.ts b/src/lib/cls.module.ts index ea21d851..913b30ef 100644 --- a/src/lib/cls.module.ts +++ b/src/lib/cls.module.ts @@ -6,12 +6,23 @@ import { NestModule, Provider, } from '@nestjs/common'; -import { APP_GUARD, HttpAdapterHost, ModuleRef } from '@nestjs/core'; +import { + APP_GUARD, + APP_INTERCEPTOR, + HttpAdapterHost, + ModuleRef, +} from '@nestjs/core'; +import { ClsInterceptor } from '..'; import { ClsServiceManager, getClsServiceToken } from './cls-service-manager'; -import { CLS_GUARD_OPTIONS, CLS_MIDDLEWARE_OPTIONS } from './cls.constants'; +import { + CLS_GUARD_OPTIONS, + CLS_INTERCEPTOR_OPTIONS, + CLS_MIDDLEWARE_OPTIONS, +} from './cls.constants'; import { ClsGuard } from './cls.guard'; import { ClsGuardOptions, + ClsInterceptorOptions, ClsMiddlewareOptions, ClsModuleOptions, } from './cls.interfaces'; @@ -80,6 +91,11 @@ export class ClsModule implements NestModule { ...options.guard, namespaceName: options.namespaceName, }; + const clsInterceptorOptions = { + ...new ClsInterceptorOptions(), + ...options.interceptor, + namespaceName: options.namespaceName, + }; const providers: Provider[] = [ ...ClsServiceManager.getClsServicesAsProviders(), { @@ -90,18 +106,28 @@ export class ClsModule implements NestModule { provide: CLS_GUARD_OPTIONS, useValue: clsGuardOptions, }, + { + provide: CLS_INTERCEPTOR_OPTIONS, + useValue: clsInterceptorOptions, + }, ]; - const guardArr = []; + const enhancerArr = []; if (clsGuardOptions.mount) { - guardArr.push({ + enhancerArr.push({ provide: APP_GUARD, useClass: ClsGuard, }); } + if (clsInterceptorOptions.mount) { + enhancerArr.push({ + provide: APP_INTERCEPTOR, + useClass: ClsInterceptor, + }); + } return { module: ClsModule, - providers: providers.concat(...guardArr), + providers: providers.concat(...enhancerArr), exports: providers, global: options.global, }; diff --git a/src/lib/cls.service.spec.ts b/src/lib/cls.service.spec.ts index 9d9733c5..cd953143 100644 --- a/src/lib/cls.service.spec.ts +++ b/src/lib/cls.service.spec.ts @@ -1,5 +1,6 @@ import { Test, TestingModule } from '@nestjs/testing'; -import { ClsServiceManager, CLS_DEFAULT_NAMESPACE } from '..'; +import { ClsServiceManager } from './cls-service-manager'; +import { CLS_DEFAULT_NAMESPACE } from './cls.constants'; import { ClsService } from './cls.service'; describe('ClsService', () => { diff --git a/src/lib/cls.service.ts b/src/lib/cls.service.ts index d008ca27..1ffa5daa 100644 --- a/src/lib/cls.service.ts +++ b/src/lib/cls.service.ts @@ -50,7 +50,7 @@ export class ClsService> { } /** - * Run the callback in a shared CLS context. + * Run the callback with a shared CLS context. * @param callback function to run * @returns whatever the callback returns */ @@ -59,12 +59,30 @@ export class ClsService> { } /** - * Run any following code in a shared CLS context. + * Run the callbacks with a shared CLS context. + * @param store the default context contents + * @param callback function to run + * @returns whatever the callback returns + */ + runWith(store: any, callback: () => T) { + return this.namespace.run(store ?? {}, callback); + } + + /** + * Run any following code with a shared CLS context. */ enter() { return this.namespace.enterWith({}); } + /** + * Run any following code with a shared ClS context + * @param store the default context contents + */ + enterWith(store: any = {}) { + return this.namespace.enterWith(store); + } + /** * Run the callback outside of a shared CLS context * @param callback function to run diff --git a/test/common/test.exception.ts b/test/common/test.exception.ts new file mode 100644 index 00000000..46eefc19 --- /dev/null +++ b/test/common/test.exception.ts @@ -0,0 +1,8 @@ +export class TestException extends Error { + public response: any; + public extensions: { a: 1 }; + constructor(response: any) { + super('TestException'); + this.response = response; + } +} diff --git a/test/common/test.guard.ts b/test/common/test.guard.ts index b3467b44..4668c1ff 100644 --- a/test/common/test.guard.ts +++ b/test/common/test.guard.ts @@ -9,7 +9,7 @@ export class TestGuard implements CanActivate { canActivate( context: ExecutionContext, ): boolean | Promise | Observable { - this.cls.set('FROM_GUARD', this.cls.getId()); - return this.cls.isActive(); + if (this.cls.isActive()) this.cls.set('FROM_GUARD', this.cls.getId()); + return true; } } diff --git a/test/gql/expect-ids-gql.ts b/test/gql/expect-ids-gql.ts index f6053bb0..8aec6c91 100644 --- a/test/gql/expect-ids-gql.ts +++ b/test/gql/expect-ids-gql.ts @@ -1,7 +1,10 @@ import { INestApplication } from '@nestjs/common'; import request from 'supertest'; -export const expectIdsGql = (app: INestApplication) => +export const expectOkIdsGql = ( + app: INestApplication, + options = { skipGuard: false }, +) => request(app.getHttpServer()) .post('/graphql') .send({ @@ -19,8 +22,36 @@ export const expectIdsGql = (app: INestApplication) => .expect(200) .then((r) => { const body = r.body.data?.items[0]; - const id = body.id; - expect(body.fromGuard).toEqual(id); + const id = body.id ?? 'no-id'; + if (!options.skipGuard) expect(body.fromGuard).toEqual(id); + expect(body.fromInterceptor).toEqual(id); + expect(body.fromInterceptorAfter).toEqual(id); + expect(body.fromResolver).toEqual(id); + expect(body.fromService).toEqual(id); + }); + +export const expectErrorIdsGql = ( + app: INestApplication, + options = { skipGuard: false }, +) => + request(app.getHttpServer()) + .post('/graphql') + .send({ + query: `query { + error { + id + fromGuard + fromInterceptor + fromInterceptorAfter + fromResolver + fromService + } + }`, + }) + .then((r) => { + const body = r.body.errors?.[0].extensions.exception?.response; + const id = body.id ?? 'no-id'; + if (!options.skipGuard) expect(body.fromGuard).toEqual(id); expect(body.fromInterceptor).toEqual(id); expect(body.fromInterceptorAfter).toEqual(id); expect(body.fromResolver).toEqual(id); diff --git a/test/gql/gql-apollo.spec.ts b/test/gql/gql-apollo.spec.ts index 2f067405..a8267771 100644 --- a/test/gql/gql-apollo.spec.ts +++ b/test/gql/gql-apollo.spec.ts @@ -1,7 +1,7 @@ import { INestApplication, Module } from '@nestjs/common'; import { Test, TestingModule } from '@nestjs/testing'; import { ClsMiddleware, ClsModule } from '../../src'; -import { expectIdsGql } from './expect-ids-gql'; +import { expectErrorIdsGql, expectOkIdsGql } from './expect-ids-gql'; import { ItemModule } from './item/item.module'; import { GraphQLModule } from '@nestjs/graphql'; @@ -29,18 +29,17 @@ describe('GQL Apollo App - Manually bound Middleware in Bootstrap', () => { await app.init(); }); - it('works with middleware', () => { - return expectIdsGql(app); + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('works with %s response', (_, func: any) => { + return func(app); }); - - it('does not leak context', () => { - return Promise.all([ - expectIdsGql(app), - expectIdsGql(app), - expectIdsGql(app), - expectIdsGql(app), - expectIdsGql(app), - ]); + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('does not leak context with %s response', (_, func: any) => { + return Promise.all(Array(10).fill(app).map(func)); }); }); @@ -67,17 +66,59 @@ describe('GQL Apollo App - Auto bound Guard', () => { await app.init(); }); - it('works with guard', () => { - return expectIdsGql(app); + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('works with %s response', (_, func: any) => { + return func(app); + }); + + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('does not leak context with %s response', (_, func: any) => { + return Promise.all(Array(10).fill(app).map(func)); + }); +}); + +describe('GQL Apollo App - Auto bound Interceptor', () => { + @Module({ + imports: [ + ClsModule.register({ + global: true, + interceptor: { mount: true, generateId: true }, + }), + ItemModule, + GraphQLModule.forRoot({ + autoSchemaFile: __dirname + 'schema.gql', + }), + ], + }) + class AppModule {} + + beforeAll(async () => { + const moduleFixture: TestingModule = await Test.createTestingModule({ + imports: [AppModule], + }).compile(); + app = moduleFixture.createNestApplication(); + await app.init(); + }); + + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('works with %s response', (_, func: any) => { + return func(app, { skipGuard: true }); }); - it('does not leak context', () => { - return Promise.all([ - expectIdsGql(app), - expectIdsGql(app), - expectIdsGql(app), - expectIdsGql(app), - expectIdsGql(app), - ]); + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('does not leak context with % response', (_, func: any) => { + return Promise.all( + Array(10) + .fill(0) + .map(() => func(app, { skipGuard: true })), + ); }); }); diff --git a/test/gql/gql-mercurius.spec.ts b/test/gql/gql-mercurius.spec.ts index ea8e3137..827f1d9b 100644 --- a/test/gql/gql-mercurius.spec.ts +++ b/test/gql/gql-mercurius.spec.ts @@ -4,7 +4,7 @@ import { NestFastifyApplication, } from '@nestjs/platform-fastify'; import { NestFactory } from '@nestjs/core'; -import { expectIdsGql } from './expect-ids-gql'; +import { expectErrorIdsGql, expectOkIdsGql } from './expect-ids-gql'; import { Module } from '@nestjs/common'; import { ItemModule } from './item/item.module'; import { MercuriusModule } from 'nestjs-mercurius'; @@ -35,17 +35,100 @@ describe('GQL Mercurius App - Manually bound Middleware in Bootstrap', () => { await app.getHttpAdapter().getInstance().ready(); }); - it('works with middleware', async () => { - return expectIdsGql(app); + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('works with %s response', (_, func: any) => { + return func(app); }); + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('does not leak context with %s response', (_, func: any) => { + return Promise.all(Array(10).fill(app).map(func)); + }); +}); + +describe('GQL Mercurius App - Auto bound Guard', () => { + @Module({ + imports: [ + ClsModule.register({ + global: true, + guard: { mount: true, generateId: true }, + }), + ItemModule, + MercuriusModule.forRoot({ + autoSchemaFile: __dirname + 'schema.gql', + }), + ], + }) + class AppModule {} + + beforeAll(async () => { + app = await NestFactory.create( + AppModule, + new FastifyAdapter(), + { logger: false }, + ); + await app.init(); + await app.getHttpAdapter().getInstance().ready(); + }); + + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('works with %s response', (_, func: any) => { + return func(app); + }); + + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('does not leak context with %s response', (_, func: any) => { + return Promise.all(Array(10).fill(app).map(func)); + }); +}); + +describe('GQL Mercurius App - Auto bound Interceptor', () => { + @Module({ + imports: [ + ClsModule.register({ + global: true, + interceptor: { mount: true, generateId: true }, + }), + ItemModule, + MercuriusModule.forRoot({ + autoSchemaFile: __dirname + 'schema.gql', + }), + ], + }) + class AppModule {} - it('does not leak context', () => { - return Promise.all([ - expectIdsGql(app), - expectIdsGql(app), - expectIdsGql(app), - expectIdsGql(app), - expectIdsGql(app), - ]); + beforeAll(async () => { + app = await NestFactory.create( + AppModule, + new FastifyAdapter(), + { logger: false }, + ); + await app.init(); + await app.getHttpAdapter().getInstance().ready(); + }); + + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('works with %s response', (_, func: any) => { + return func(app, { skipGuard: true }); + }); + + it.each([ + ['OK', expectOkIdsGql], + ['ERROR', expectErrorIdsGql], + ])('does not leak context with % response', (_, func: any) => { + return Promise.all( + Array(10) + .fill(0) + .map(() => func(app, { skipGuard: true })), + ); }); }); diff --git a/test/gql/item/item.model.ts b/test/gql/item/item.model.ts index aae9d15d..5f2fa336 100644 --- a/test/gql/item/item.model.ts +++ b/test/gql/item/item.model.ts @@ -2,21 +2,21 @@ import { Field, ID, ObjectType } from '@nestjs/graphql'; @ObjectType({ description: 'item ' }) export class Item { - @Field((type) => ID) + @Field(() => ID) id: string; - @Field() + @Field({ nullable: true }) fromGuard?: string; - @Field() + @Field({ nullable: true }) fromInterceptor?: string; - @Field() + @Field({ nullable: true }) fromInterceptorAfter?: string; - @Field() + @Field({ nullable: true }) fromResolver?: string; - @Field() + @Field({ nullable: true }) fromService?: string; } diff --git a/test/gql/item/item.resolver.ts b/test/gql/item/item.resolver.ts index f4efc05f..16f1b052 100644 --- a/test/gql/item/item.resolver.ts +++ b/test/gql/item/item.resolver.ts @@ -1,14 +1,17 @@ -import { NotFoundException, UseGuards, UseInterceptors } from '@nestjs/common'; -import { Args, Mutation, Query, Resolver, Subscription } from '@nestjs/graphql'; +import { UseFilters, UseGuards, UseInterceptors } from '@nestjs/common'; +import { Args, Query, Resolver } from '@nestjs/graphql'; import { ClsService } from '../../../src'; +import { TestException } from '../../common/test.exception'; import { TestGuard } from '../../common/test.guard'; import { TestInterceptor } from '../../common/test.interceptor'; import { RecipesArgs } from './dto/recipes.args'; import { Item } from './item.model'; import { ItemService } from './item.service'; +import { TestGqlExceptionFilter } from '../test-gql.filter'; +@UseFilters(TestGqlExceptionFilter) @UseGuards(TestGuard) -@Resolver((of) => Item) +@Resolver(() => Item) export class ItemResolver { constructor( private readonly recipesService: ItemService, @@ -21,4 +24,12 @@ export class ItemResolver { this.cls.set('FROM_RESOLVER', this.cls.getId()); return this.recipesService.findAll(recipesArgs); } + + @UseInterceptors(TestInterceptor) + @Query(() => [Item]) + async error(@Args() recipesArgs: RecipesArgs): Promise { + this.cls.set('FROM_RESOLVER', this.cls.getId()); + const response = await this.recipesService.findAll(recipesArgs); + throw new TestException(response[0]); + } } diff --git a/test/gql/item/item.service.ts b/test/gql/item/item.service.ts index 3ac63a40..772cfff4 100644 --- a/test/gql/item/item.service.ts +++ b/test/gql/item/item.service.ts @@ -8,7 +8,7 @@ export class ItemService { constructor(private readonly cls: ClsService) {} async findAll(recipesArgs: RecipesArgs): Promise { - return [ + const payload = [ { id: this.cls.getId(), fromGuard: this.cls.get('FROM_GUARD'), @@ -18,5 +18,6 @@ export class ItemService { fromService: this.cls.getId(), }, ]; + return payload; } } diff --git a/test/gql/test-gql.filter.ts b/test/gql/test-gql.filter.ts new file mode 100644 index 00000000..9746cc22 --- /dev/null +++ b/test/gql/test-gql.filter.ts @@ -0,0 +1,31 @@ +import { ArgumentsHost, Catch, ExceptionFilter } from '@nestjs/common'; +import { HttpAdapterHost } from '@nestjs/core'; +import mercurius from 'mercurius'; +import { ClsService } from '../../src'; +import { TestException } from '../common/test.exception'; + +@Catch(TestException) +export class TestGqlExceptionFilter implements ExceptionFilter { + constructor( + private readonly adapterHost: HttpAdapterHost, + private readonly cls: ClsService, + ) {} + + catch(exception: TestException, host: ArgumentsHost) { + const adapter = this.adapterHost.httpAdapter; + + exception.response.fromFilter = this.cls.getId(); + + if (adapter.constructor.name === 'FastifyAdapter') { + throw new mercurius.ErrorWithProps('AAA', { + exception: { + response: { + ...exception.response, + }, + }, + }); + } + + return exception; + } +} diff --git a/test/rest/expect-ids-rest.ts b/test/rest/expect-ids-rest.ts index 0a902e88..22903907 100644 --- a/test/rest/expect-ids-rest.ts +++ b/test/rest/expect-ids-rest.ts @@ -1,15 +1,28 @@ import { INestApplication } from '@nestjs/common'; import request from 'supertest'; -export const expectIdsRest = (app: INestApplication) => +export const expectOkIdsRest = (app: INestApplication) => request(app.getHttpServer()) .get('/hello') .expect(200) .then((r) => { const body = r.body; - const id = body.fromGuard; + const id = body.fromGuard ?? body.fromInterceptor; expect(body.fromInterceptor).toEqual(id); expect(body.fromInterceptorAfter).toEqual(id); expect(body.fromController).toEqual(id); expect(body.fromService).toEqual(id); }); + +export const expectErrorIdsRest = (app: INestApplication) => + request(app.getHttpServer()) + .get('/error') + .expect(500) + .then((r) => { + const body = r.body; + const id = body.fromGuard ?? body.fromInterceptor; + expect(body.fromInterceptor).toEqual(id); + expect(body.fromController).toEqual(id); + expect(body.fromService).toEqual(id); + expect(body.fromFilter).toEqual(id); + }); diff --git a/test/rest/http-express.spec.ts b/test/rest/http-express.spec.ts index 3823d981..41e53169 100644 --- a/test/rest/http-express.spec.ts +++ b/test/rest/http-express.spec.ts @@ -6,7 +6,7 @@ import { } from '@nestjs/common'; import { Test, TestingModule } from '@nestjs/testing'; import { ClsMiddleware, ClsModule } from '../../src'; -import { expectIdsRest } from './expect-ids-rest'; +import { expectErrorIdsRest, expectOkIdsRest } from './expect-ids-rest'; import { TestHttpController, TestHttpService } from './http.app'; let app: INestApplication; @@ -30,8 +30,11 @@ describe('Http Express App - Auto bound Middleware', () => { await app.init(); }); - it('works with middleware', () => { - return expectIdsRest(app); + it.each([ + ['OK', expectOkIdsRest], + ['ERROR', expectErrorIdsRest], + ])('works with %s response', (_, func: any) => { + return func(app); }); }); @@ -55,8 +58,11 @@ describe('Http Express App - Manually bound Middleware in AppModule', () => { await app.init(); }); - it('works with middleware', () => { - return expectIdsRest(app); + it.each([ + ['OK', expectOkIdsRest], + ['ERROR', expectErrorIdsRest], + ])('works with %s response', (_, func: any) => { + return func(app); }); }); describe('Http Express App - Manually bound Middleware in Bootstrap', () => { @@ -76,8 +82,11 @@ describe('Http Express App - Manually bound Middleware in Bootstrap', () => { await app.init(); }); - it('works with middleware', () => { - return expectIdsRest(app); + it.each([ + ['OK', expectOkIdsRest], + ['ERROR', expectErrorIdsRest], + ])('works with %s response', (_, func: any) => { + return func(app); }); }); @@ -101,17 +110,45 @@ describe('Http Express App - Auto bound Guard', () => { await app.init(); }); - it('works with guard', () => { - return expectIdsRest(app); + it.each([ + ['OK', expectOkIdsRest], + ['ERROR', expectErrorIdsRest], + ])('works with %s response', (_, func: any) => { + return func(app); + }); + + it.each([ + ['OK', expectOkIdsRest], + ['ERROR', expectErrorIdsRest], + ])('does not leak context with %s response', (_, func: any) => { + return Promise.all(Array(10).fill(app).map(func)); + }); +}); +describe('Http Express App - Auto bound Interceptor', () => { + @Module({ + imports: [ + ClsModule.register({ + interceptor: { mount: true, generateId: true }, + }), + ], + providers: [TestHttpService], + controllers: [TestHttpController], + }) + class TestAppWithAutoBoundInterceptor {} + + beforeAll(async () => { + const moduleFixture: TestingModule = await Test.createTestingModule({ + imports: [TestAppWithAutoBoundInterceptor], + }).compile(); + app = moduleFixture.createNestApplication(); + await app.init(); + }); + + it('works with OK response', () => { + return expectOkIdsRest(app); }); it('does not leak context', () => { - return Promise.all([ - expectIdsRest(app), - expectIdsRest(app), - expectIdsRest(app), - expectIdsRest(app), - expectIdsRest(app), - ]); + return Promise.all(Array(10).fill(app).map(expectOkIdsRest)); }); }); diff --git a/test/rest/http-fastify.spec.ts b/test/rest/http-fastify.spec.ts index 5988911f..115322ea 100644 --- a/test/rest/http-fastify.spec.ts +++ b/test/rest/http-fastify.spec.ts @@ -5,7 +5,7 @@ import { } from '@nestjs/platform-fastify'; import { Test, TestingModule } from '@nestjs/testing'; import { ClsModule } from '../../src'; -import { expectIdsRest } from './expect-ids-rest'; +import { expectOkIdsRest } from './expect-ids-rest'; import { TestHttpController, TestHttpService } from './http.app'; @Module({ @@ -35,6 +35,6 @@ describe('Http Fastify App', () => { }); it('works with Fastify', async () => { - return expectIdsRest(app); + return expectOkIdsRest(app); }); }); diff --git a/test/rest/http.app.ts b/test/rest/http.app.ts index 97c2541f..e76e19f9 100644 --- a/test/rest/http.app.ts +++ b/test/rest/http.app.ts @@ -2,11 +2,13 @@ import { Controller, Get, Injectable, + UseFilters, UseGuards, UseInterceptors, } from '@nestjs/common'; -import { identity } from 'rxjs'; import { ClsService, CLS_ID } from '../../src'; +import { TestException } from '../common/test.exception'; +import { TestRestExceptionFilter } from './test-rest.filter'; import { TestGuard } from '../common/test.guard'; import { TestInterceptor } from '../common/test.interceptor'; @@ -26,6 +28,7 @@ export class TestHttpService { } @UseGuards(TestGuard) +@UseFilters(TestRestExceptionFilter) @Controller('/') export class TestHttpController { constructor( @@ -39,4 +42,12 @@ export class TestHttpController { this.cls.set('FROM_CONTROLLER', this.cls.getId()); return this.service.hello(); } + + @UseInterceptors(TestInterceptor) + @Get('error') + async error() { + this.cls.set('FROM_CONTROLLER', this.cls.getId()); + const response = await this.service.hello(); + throw new TestException(response); + } } diff --git a/test/rest/main-express.ts b/test/rest/main-express.ts index a36224de..9e4128d1 100644 --- a/test/rest/main-express.ts +++ b/test/rest/main-express.ts @@ -6,7 +6,7 @@ import { TestHttpController, TestHttpService } from './http.app'; @Module({ imports: [ ClsModule.register({ - middleware: { mount: true, generateId: true, useEnterWith: true }, + middleware: { mount: true, generateId: true }, }), ], providers: [TestHttpService], diff --git a/test/rest/test-rest.filter.ts b/test/rest/test-rest.filter.ts new file mode 100644 index 00000000..e0f4380f --- /dev/null +++ b/test/rest/test-rest.filter.ts @@ -0,0 +1,19 @@ +import { ArgumentsHost, Catch, ExceptionFilter } from '@nestjs/common'; +import { Response } from 'express'; +import { ClsService } from '../../src'; +import { TestException } from '../common/test.exception'; + +@Catch(TestException) +export class TestRestExceptionFilter implements ExceptionFilter { + constructor(private readonly cls: ClsService) {} + + catch(exception: TestException, host: ArgumentsHost) { + const ctx = host.switchToHttp(); + const response = ctx.getResponse(); + + response.status(500).json({ + ...exception.response, + fromFilter: this.cls.getId(), + }); + } +}