diff --git a/apps/sim/lib/billing/webhooks/invoices.ts b/apps/sim/lib/billing/webhooks/invoices.ts index 37c17f7e31..5c95b4be81 100644 --- a/apps/sim/lib/billing/webhooks/invoices.ts +++ b/apps/sim/lib/billing/webhooks/invoices.ts @@ -1,7 +1,7 @@ import { render } from '@react-email/components' import { db } from '@sim/db' import { member, subscription as subscriptionTable, user, userStats } from '@sim/db/schema' -import { eq, inArray } from 'drizzle-orm' +import { and, eq, inArray } from 'drizzle-orm' import type Stripe from 'stripe' import PaymentFailedEmail from '@/components/emails/billing/payment-failed-email' import { calculateSubscriptionOverage } from '@/lib/billing/core/billing' @@ -226,42 +226,38 @@ export async function getBilledOverageForSubscription(sub: { plan: string | null referenceId: string }): Promise { - let billedOverage = 0 - if (sub.plan === 'team') { - const members = await db + const ownerRows = await db .select({ userId: member.userId }) .from(member) - .where(eq(member.organizationId, sub.referenceId)) - - const memberIds = members.map((m) => m.userId) + .where(and(eq(member.organizationId, sub.referenceId), eq(member.role, 'owner'))) + .limit(1) - if (memberIds.length > 0) { - const memberStatsRows = await db - .select({ - userId: userStats.userId, - billedOverageThisPeriod: userStats.billedOverageThisPeriod, - }) - .from(userStats) - .where(inArray(userStats.userId, memberIds)) + const ownerId = ownerRows[0]?.userId - for (const stats of memberStatsRows) { - billedOverage += parseDecimal(stats.billedOverageThisPeriod) - } + if (!ownerId) { + logger.warn('Organization has no owner when fetching billed overage', { + organizationId: sub.referenceId, + }) + return 0 } - } else { - const userStatsRecords = await db + + const ownerStats = await db .select({ billedOverageThisPeriod: userStats.billedOverageThisPeriod }) .from(userStats) - .where(eq(userStats.userId, sub.referenceId)) + .where(eq(userStats.userId, ownerId)) .limit(1) - if (userStatsRecords.length > 0) { - billedOverage = parseDecimal(userStatsRecords[0].billedOverageThisPeriod) - } + return ownerStats.length > 0 ? parseDecimal(ownerStats[0].billedOverageThisPeriod) : 0 } - return billedOverage + const userStatsRecords = await db + .select({ billedOverageThisPeriod: userStats.billedOverageThisPeriod }) + .from(userStats) + .where(eq(userStats.userId, sub.referenceId)) + .limit(1) + + return userStatsRecords.length > 0 ? parseDecimal(userStatsRecords[0].billedOverageThisPeriod) : 0 } export async function resetUsageForSubscription(sub: { plan: string | null; referenceId: string }) {