Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example with custom authorizer lambda #154

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/apigateway-auth/Pulumi.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
name: apigateway-auth
runtime: nodejs
description: An API Gateway with a custom lambda authorizer
67 changes: 67 additions & 0 deletions examples/apigateway-auth/auth-lambda.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import * as awslambda from "aws-lambda";

type AuthorizerLambda = (event: awslambda.APIGatewayAuthorizerEvent) => Promise<awslambda.APIGatewayAuthorizerResult>

export function authorizerLambda(): AuthorizerLambda {
return async (event: awslambda.APIGatewayAuthorizerEvent) => {
try {
return await authenticate(event);
}
catch (err) {
console.log(err);
// Tells API Gateway to return a 401 Unauthorized response
throw new Error("Unauthorized");
}
}
}

// Extract and return the Bearer Token from the Lambda event parameters
function getToken(event: awslambda.APIGatewayAuthorizerEvent): string | undefined {
if (!event.type || event.type !== "TOKEN") {
throw new Error('Expected "event.type" parameter to have value "TOKEN"');
}

const tokenString = event.authorizationToken;
if (!tokenString) {
return undefined;
}

const match = tokenString.match(/^Bearer (.*)$/);
if (!match) {
// Invalid Authorization token - does not match "Bearer .*"
return undefined;
}
return match[1];
}

// Check the Token is valid
async function authenticate(event: awslambda.APIGatewayAuthorizerEvent): Promise<awslambda.APIGatewayAuthorizerResult> {
console.log(event);
const token = getToken(event);

// Dummy check for token, in a real-world scenario, you would verify the token
const effect = token ? "Allow" : "Deny";

const methodArn = getMethodArn(event);
console.log(`Method ARN: ${methodArn}`);
return {
principalId: "me",
policyDocument: {
Version: "2012-10-17",
Statement: [{
Action: "execute-api:Invoke",
Effect: effect,
Resource: methodArn,
}],
},
};
}

function getMethodArn(event: awslambda.APIGatewayAuthorizerEvent): string {
if (!event.methodArn) {
throw new Error('Expected "event.methodArn" parameter to be set');
}

const arnPartials = event.methodArn.split("/");
return arnPartials.slice(0, 2).join("/") + "/*";
}
38 changes: 38 additions & 0 deletions examples/apigateway-auth/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import * as aws from "@pulumi/aws";
import { authorizerLambda } from "./auth-lambda";
import * as apigateway from "@pulumi/aws-apigateway";

const f = new aws.lambda.CallbackFunction("f", {
callback: async (ev, ctx) => {
console.log(JSON.stringify(ev));
return {
statusCode: 200,
body: "Hello, World!",
};
},
});

const authorizer = {
authType: "custom",
authorizerName: "jwt-rsa-custom-authorizer",
parameterName: "Authorization",
identityValidationExpression: "^Bearer [-0-9a-zA-Z\._]*$",
type: "token",
parameterLocation: "header",
authorizerResultTtlInSeconds: 300,
handler: new aws.lambda.CallbackFunction("authorizer", {
callback: authorizerLambda(),
}),
}

const api = new apigateway.RestAPI("my-api", {
routes: [{
path: "/{proxy+}",
method: "ANY",
eventHandler: f,
authorizers: [authorizer]
}],
binaryMediaTypes: ["application/json"],
});

export const url = api.url;
14 changes: 14 additions & 0 deletions examples/apigateway-auth/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"name": "apigateway-auth",
"main": "index.ts",
"devDependencies": {
"@types/node": "^18",
"typescript": "^5.0.0",
"@types/aws-lambda": "^8.10.0"
},
"dependencies": {
"@pulumi/aws": "^6.0.0",
"@pulumi/pulumi": "^3.113.0",
"@pulumi/aws-apigateway": "^2.5.0"
}
}
18 changes: 18 additions & 0 deletions examples/apigateway-auth/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"compilerOptions": {
"strict": true,
"outDir": "bin",
"target": "es2020",
"module": "commonjs",
"moduleResolution": "node",
"sourceMap": true,
"experimentalDecorators": true,
"pretty": true,
"noFallthroughCasesInSwitch": true,
"noImplicitReturns": true,
"forceConsistentCasingInFileNames": true
},
"files": [
"index.ts"
]
}
25 changes: 24 additions & 1 deletion examples/examples_nodejs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"path"
"path/filepath"
"testing"
"time"

"github.com/pulumi/providertest/pulumitest"
"github.com/pulumi/pulumi/pkg/v3/testing/integration"
Expand Down Expand Up @@ -73,7 +74,7 @@ func TestTagging(t *testing.T) {
ExtraRuntimeValidation: func(t *testing.T, stackInfo integration.RuntimeValidationStackInfo) {
expectedTags := map[string]interface{}{
"environment": "development",
"test": "test-tag",
"test": "test-tag",
}
assert.Equal(t, expectedTags, stackInfo.Outputs["apiTags"])
assert.Equal(t, expectedTags, stackInfo.Outputs["stageTags"])
Expand All @@ -83,6 +84,28 @@ func TestTagging(t *testing.T) {
integration.ProgramTest(t, &test)
}

func TestAuth(t *testing.T) {
test := getJSBaseOptions(t).
With(integration.ProgramTestOptions{
Dir: filepath.Join(getCwd(t), "apigateway-auth"),
ExtraRuntimeValidation: func(t *testing.T, stackInfo integration.RuntimeValidationStackInfo) {
url := stackInfo.Outputs["url"].(string) + "test"

validAuthHeaders := map[string]string{"Authorization": "Bearer DUMMY_TOKEN"}

// Make a request to the API Gateway endpoint with an auth token to verify it's working
integration.AssertHTTPResultWithRetry(t, url, validAuthHeaders, 60*time.Second, func(body string) bool {
return assert.Equal(t, "Hello, World!", body, "Body should equal 'Hello, World!', got %s", body)
})

// Make a request to the API Gateway endpoint without an auth token and expect a 401 to verify the authorizer is working
retryGETRequestUntil(t, url, nil, 401, 60*time.Second)
},
})

integration.ProgramTest(t, &test)
}

func getJSBaseOptions(t *testing.T) integration.ProgramTestOptions {
base := getBaseOptions(t)
baseJS := base.With(integration.ProgramTestOptions{
Expand Down
43 changes: 43 additions & 0 deletions examples/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
package examples

import (
"context"
"net/http"
"os"
"strings"
"testing"
"time"

"github.com/pulumi/pulumi/pkg/v3/testing/integration"
"github.com/pulumi/pulumi/sdk/v3/go/common/util/retry"
"github.com/stretchr/testify/assert"
)

func getRegion(t *testing.T) string {
Expand Down Expand Up @@ -46,3 +52,40 @@ func skipIfShort(t *testing.T) {
t.Skip("skipping long-running test in short mode")
}
}

func retryGETRequestUntil(t *testing.T, url string, headers map[string]string, expectedStatusCode int, timeout time.Duration) {
_, finalStatusCode, err := retry.UntilTimeout(context.TODO(), retry.Acceptor{
Accept: func(try int, delay time.Duration) (bool, interface{}, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return false, nil, err
}

for k, v := range headers {
// Host header cannot be set via req.Header.Set(), and must be set
// directly.
if strings.ToLower(k) == "host" {
req.Host = v
continue
}
req.Header.Set(k, v)
}

client := &http.Client{Timeout: time.Second * 10}
resp, err := client.Do(req)
assert.NoError(t, err, "error reading response: %v", err)
if resp.Body != nil {
defer resp.Body.Close()
}

if err != nil {
t.Logf("Http Error: %v\n", err)
return false, nil, nil
}

return resp.StatusCode == expectedStatusCode, resp.StatusCode, nil
},
}, timeout)
assert.NoError(t, err, "error retrying request: %v", err)
assert.Equal(t, expectedStatusCode, finalStatusCode, "expected status code %d, got %d", expectedStatusCode, finalStatusCode)
}
Loading