Skip to content

Commit

Permalink
feat: Implement basic support for websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoomen authored and jmcdo29 committed Jun 12, 2020
1 parent 671abbd commit 3a0cf2e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
7 changes: 7 additions & 0 deletions src/throttler.exception.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import { HttpException, HttpStatus } from '@nestjs/common';
import { WsException } from '@nestjs/websockets';

export class ThrottlerException extends HttpException {
constructor() {
super('ThrottlerException: Too Many Requests', HttpStatus.TOO_MANY_REQUESTS);
}
}

export class ThrottlerWsException extends WsException {
constructor() {
super('ThrottlerWsException: Too Many Requests');
}
}
37 changes: 32 additions & 5 deletions src/throttler.guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import {
THROTTLER_LIMIT,
THROTTLER_OPTIONS,
THROTTLER_SKIP,
THROTTLER_TTL,
THROTTLER_TTL
} from './throttler.constants';
import { ThrottlerException } from './throttler.exception';
import { ThrottlerException, ThrottlerWsException } from './throttler.exception';
import { ThrottlerOptions } from './throttler.interface';

@Injectable()
Expand All @@ -23,7 +23,6 @@ export class ThrottlerGuard implements CanActivate {
const req = context.switchToHttp().getRequest();
const handler = context.getHandler();
const classRef = context.getClass();
const headerPrefix = 'X-RateLimit';

// Return early if the current route should be skipped.
if (this.reflector.getAllAndOverride<boolean>(THROTTLER_SKIP, [handler, classRef])) {
Expand Down Expand Up @@ -53,11 +52,21 @@ export class ThrottlerGuard implements CanActivate {
const limit = routeOrClassLimit || this.options.limit;
const ttl = routeOrClassTtl || this.options.ttl;

switch (context.getType()) {
case 'http': return this.httpHandler(context, limit, ttl);
case 'ws': return this.websocketHandler(context, limit, ttl);
}
}

private httpHandler(context: ExecutionContext, limit: number, ttl: number): boolean {
const headerPrefix = 'X-RateLimit';

// Here we start to check the amount of requests being done against the ttl.
const res = context.switchToHttp().getResponse();
const key = md5(`${req.ip}-${classRef.name}-${handler.name}`);
const key = this.generateKey(context, req.ip);
const ttls = await this.storageService.getRecord(key);
const nearestExpiryTime = ttls.length > 0 ? Math.ceil((ttls[0] - Date.now()) / 1000) : 0;
const nearestExpiryTime =
ttls.length > 0 ? Math.ceil((ttls[0].getTime() - new Date().getTime()) / 1000) : 0;

// Throw an error when the user reached their limit.
if (ttls.length >= limit) {
Expand All @@ -74,4 +83,22 @@ export class ThrottlerGuard implements CanActivate {
await this.storageService.addRecord(key, ttl);
return true;
}

private websocketHandler(context: ExecutionContext, limit: number, ttl: number): boolean {
const client = context.switchToWs().getClient();
const key = this.generateKey(context, client.conn.remoteAddress);
const ttls = this.storageService.getRecord(key);

if (ttls.length >= limit) {
throw new ThrottlerWsException();
}

this.storageService.addRecord(key, ttl);
return true;
}

private generateKey(context: ExecutionContext, prefix: string): string {
const suffix = `${context.getClass().name}-${context.getHandler().name}`;
return md5(`${prefix}-${suffix}`)
}
}
9 changes: 7 additions & 2 deletions test/app/gateways/app.gateway.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import { WebSocketGateway, SubscribeMessage } from '@nestjs/websockets';
import { UseGuards } from '@nestjs/common';
import { SubscribeMessage, WebSocketGateway } from '@nestjs/websockets';
import { SkipThrottle, Throttle, ThrottlerGuard } from '../../../src';
import { AppService } from '../app.service';
import { SkipThrottle, Throttle } from '../../../src';

@Throttle(2, 10)
@WebSocketGateway({ path: '/' })
export class AppGateway {
constructor(private readonly appService: AppService) {}

@UseGuards(ThrottlerGuard)
@SubscribeMessage('throttle-regular')
pass() {
return this.appService.success();
}

@SkipThrottle()
@UseGuards(ThrottlerGuard)
@SubscribeMessage('ignore')
ignore() {
return this.appService.ignored();
}

@Throttle(5, 20)
@UseGuards(ThrottlerGuard)
@SubscribeMessage('throttle-override')
throttleOverride() {
return this.appService.success();
Expand Down
7 changes: 4 additions & 3 deletions test/app/main.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { NestFactory } from '@nestjs/core';
import { FastifyAdapter } from '@nestjs/platform-fastify';
import { ExpressAdapter } from '@nestjs/platform-express';
// import { FastifyAdapter } from '@nestjs/platform-fastify';
import { WsAdapter } from '@nestjs/platform-ws';
import { AppModule } from './app.module';

async function bootstrap() {
const app = await NestFactory.create(
AppModule,
// new ExpressAdapter(),
new FastifyAdapter(),
new ExpressAdapter(),
// new FastifyAdapter(),
);
app.useWebSocketAdapter(new WsAdapter(app));
await app.listen(3000);
Expand Down

0 comments on commit 3a0cf2e

Please sign in to comment.