diff --git a/providers/aws/dynamodb/tables.go b/providers/aws/dynamodb/tables.go index 284c5b262..b94a528b5 100644 --- a/providers/aws/dynamodb/tables.go +++ b/providers/aws/dynamodb/tables.go @@ -7,17 +7,54 @@ import ( log "github.com/sirupsen/logrus" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/pricing" + "github.com/aws/aws-sdk-go-v2/service/pricing/types" "github.com/aws/aws-sdk-go-v2/service/sts" + awsUtils "github.com/tailwarden/komiser/providers/aws/utils" . "github.com/tailwarden/komiser/models" . "github.com/tailwarden/komiser/providers" ) + func Tables(ctx context.Context, client ProviderClient) ([]Resource, error) { resources := make([]Resource, 0) var config dynamodb.ListTablesInput dynamodbClient := dynamodb.NewFromConfig(*client.AWSClient) + + var monthlyCost float64 = 0.0 + // there is something strange going on when using pricing client with regions other than us-east-1 + // https://discord.com/channels/932683789384183808/1117721764957536318/1162338171435090032 + oldRegion := client.AWSClient.Region + client.AWSClient.Region = "us-east-1" + pricingClient := pricing.NewFromConfig(*client.AWSClient) + client.AWSClient.Region = oldRegion + + pricingOutput, err := pricingClient.GetProducts(ctx, &pricing.GetProductsInput{ + ServiceCode: aws.String("AmazonDynamoDB"), + Filters: []types.Filter{ + { + Field: aws.String("regionCode"), + Value: aws.String(client.AWSClient.Region), + Type: types.FilterTypeTermMatch, + }, + }, + }) + + if err != nil { + log.Errorf("ERROR: Couldn't fetch pricing info for AWS DynamoDB: %v", err) + } + + priceMap, err := awsUtils.GetPriceMap(pricingOutput, "group") + + if err != nil { + log.Errorf("ERROR: Failed to fetch pricing map: %v", err) + } + + output, err := dynamodbClient.ListTables(ctx, &config) + if err != nil { return resources, err } @@ -47,6 +84,24 @@ func Tables(ctx context.Context, client ProviderClient) ([]Resource, error) { } } + tableDetails, err := dynamodbClient.DescribeTable(ctx, &dynamodb.DescribeTableInput{ + TableName: aws.String(table), + }) + + if err != nil { + log.Errorf("ERROR: Failed to query DynamoDB table details: %v", err) + } + + if tableDetails.Table != nil && tableDetails.Table.ProvisionedThroughput != nil { + provisionedRCUs := tableDetails.Table.ProvisionedThroughput.ReadCapacityUnits + provisionedWCUs := tableDetails.Table.ProvisionedThroughput.WriteCapacityUnits + + RCUCharges := awsUtils.GetCost(priceMap["DDB-ReadUnits"], awsUtils.Int64PtrToFloat64(provisionedRCUs)) + WCUCharges := awsUtils.GetCost(priceMap["DDB-WriteUnits"], awsUtils.Int64PtrToFloat64(provisionedWCUs)) + + monthlyCost = RCUCharges + WCUCharges + } + resources = append(resources, Resource{ Provider: "AWS", Account: client.Name, @@ -54,7 +109,7 @@ func Tables(ctx context.Context, client ProviderClient) ([]Resource, error) { ResourceId: resourceArn, Region: client.AWSClient.Region, Name: table, - Cost: 0, + Cost: monthlyCost, Tags: tags, FetchedAt: time.Now(), Link: fmt.Sprintf("https://%s.console.aws.amazon.com/dynamodbv2/home?region=%s#table?initialTagKey=&name=%s", client.AWSClient.Region, client.AWSClient.Region, table), diff --git a/providers/aws/utils/utils.go b/providers/aws/utils/utils.go index 529c00bcb..5ed566701 100644 --- a/providers/aws/utils/utils.go +++ b/providers/aws/utils/utils.go @@ -91,3 +91,10 @@ func GetPriceMap(pricingOutput *pricing.GetProductsOutput, field string) (map[st return priceMap, nil } + +func Int64PtrToFloat64(i *int64) float64 { + if i == nil { + return 0.0 // or any default value you prefer + } + return float64(*i) +} diff --git a/providers/aws/utils/utils_test.go b/providers/aws/utils/utils_test.go index bf2f6fb39..3b53d3268 100644 --- a/providers/aws/utils/utils_test.go +++ b/providers/aws/utils/utils_test.go @@ -234,3 +234,24 @@ func TestGetPriceMap_NoPricingOutput(t *testing.T) { t.Errorf("Expected an empty priceMap, but got %v", priceMap) } } + +func TestInt64PtrToFloat64_ValidInput(t *testing.T) { + var number int64 = 1 + pointer := &number + + returnValue := Int64PtrToFloat64(pointer) + var expected float64 = 1.0 + if returnValue != expected { + t.Errorf("Expected return value: %f, but got: %f", expected, returnValue) + } +} + +func TestInt64PtrToFloat64_NilInput(t *testing.T) { + // nil input + returnValue := Int64PtrToFloat64(nil) + var expected float64 = 0.0 + if returnValue != expected { + t.Errorf("Expected return value: %f, but got: %f", expected, returnValue) + } +} +