+ You have been added to the team {{ org.name }} on Gooey.AI.
+ Visit the teams page to see your team.
+
+
+
+ Your invite was automatically accepted because your email domain matches the organization's configured email domain.
+ If you think this shouldn't have happened, you can leave this organization from the
+ teams page.
+
+
+
+ Cheers,
+ Gooey.AI team
+
diff --git a/templates/org_invitation_email.html b/templates/org_invitation_email.html
new file mode 100644
index 000000000..c8e12dc87
--- /dev/null
+++ b/templates/org_invitation_email.html
@@ -0,0 +1,25 @@
+
+ Hi!
+
+
+
+ {{ invitation.inviter.display_name or invitation.inviter.first_name() }} has invited
+ you to join their team {{ invitation.org.name }} on Gooey.AI.
+
+
+
+ {% set invitation_url = invitation.get_url() %}
+ Visit this link to view the invitation:
+ {{ invitation_url }}.
+
+
+
+ The link will expire in {{ settings.ORG_INVITATION_EXPIRY_DAYS }} days.
+
+
+
+ Cheers,
+ The Gooey.AI team
+
+
+{{ "{{{ pm:unsubscribe }}}" }}
From 0821d22b821b1fead201ac1feaf70aad9bcfc3f9 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Tue, 23 Jul 2024 17:15:57 +0530
Subject: [PATCH 035/110] Use UniqueConstraint instead of unique_together for
membership
---
...e_domain_name_when_not_deleted_and_more.py | 36 +++++++++++++++++++
orgs/models.py | 8 ++++-
2 files changed, 43 insertions(+), 1 deletion(-)
create mode 100644 orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py
diff --git a/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py b/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py
new file mode 100644
index 000000000..6047919f1
--- /dev/null
+++ b/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py
@@ -0,0 +1,36 @@
+# Generated by Django 4.2.7 on 2024-07-23 11:45
+
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('app_users', '0019_alter_appusertransaction_reason'),
+ ('orgs', '0002_alter_org_unique_together_and_more'),
+ ]
+
+ operations = [
+ migrations.RemoveConstraint(
+ model_name='org',
+ name='unique_domain_name_when_not_deleted',
+ ),
+ migrations.AlterUniqueTogether(
+ name='orgmembership',
+ unique_together=set(),
+ ),
+ migrations.AlterField(
+ model_name='orginvitation',
+ name='status_changed_by',
+ field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='received_invitations', to='app_users.appuser'),
+ ),
+ migrations.AddConstraint(
+ model_name='org',
+ constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('domain_name',), name='unique_domain_name_when_not_deleted', violation_error_message='This domain name is already in use by another team. Contact Gooey.AI Support if you think this is a mistake.'),
+ ),
+ migrations.AddConstraint(
+ model_name='orgmembership',
+ constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('org', 'user'), name='unique_org_user'),
+ ),
+ ]
diff --git a/orgs/models.py b/orgs/models.py
index df98064ae..33219ead2 100644
--- a/orgs/models.py
+++ b/orgs/models.py
@@ -170,7 +170,13 @@ class OrgMembership(SafeDeleteModel):
objects = SafeDeleteManager()
class Meta:
- unique_together = ("org", "user", "deleted")
+ constraints = [
+ models.UniqueConstraint(
+ fields=["org", "user"],
+ condition=Q(deleted__isnull=True),
+ name="unique_org_user",
+ )
+ ]
def __str__(self):
return f"{self.get_role_display()} - {self.user} ({self.org})"
From bf4b4ffc3ddf203c47cdcb5e9400f15603ffa5d6 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Tue, 23 Jul 2024 17:30:34 +0530
Subject: [PATCH 036/110] rename get_route_url -> get_app_route_url in orgs/
---
orgs/models.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/orgs/models.py b/orgs/models.py
index 33219ead2..5a19dad78 100644
--- a/orgs/models.py
+++ b/orgs/models.py
@@ -12,7 +12,7 @@
from app_users.models import AppUser
from daras_ai_v2 import settings
-from daras_ai_v2.fastapi_tricks import get_route_url
+from daras_ai_v2.fastapi_tricks import get_app_route_url
from daras_ai_v2.crypto import get_random_doc_id
from orgs.tasks import send_auto_accepted_email, send_invitation_email
@@ -272,9 +272,9 @@ def auto_accept(self):
def get_url(self):
from routers.account import invitation_route
- return get_route_url(
+ return get_app_route_url(
invitation_route,
- params={"invite_id": self.invite_id, "org_slug": self.org.get_slug()},
+ path_params={"invite_id": self.invite_id, "org_slug": self.org.get_slug()},
)
def send_email(self):
From 49dff365492975654cf58d7bbdbc76c65bcbb2f4 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Tue, 23 Jul 2024 17:40:21 +0530
Subject: [PATCH 037/110] Add orgs/tasks.py
---
orgs/tasks.py | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 68 insertions(+)
create mode 100644 orgs/tasks.py
diff --git a/orgs/tasks.py b/orgs/tasks.py
new file mode 100644
index 000000000..09258c9ec
--- /dev/null
+++ b/orgs/tasks.py
@@ -0,0 +1,68 @@
+from django.utils import timezone
+from loguru import logger
+
+from celeryapp.celeryconfig import app
+from daras_ai_v2 import settings
+from daras_ai_v2.fastapi_tricks import get_app_route_url
+from daras_ai_v2.send_email import send_email_via_postmark
+from daras_ai_v2.settings import templates
+
+
+@app.task
+def send_invitation_email(invitation_pk: int):
+ from orgs.models import OrgInvitation
+
+ invitation = OrgInvitation.objects.get(pk=invitation_pk)
+
+ assert invitation.status == invitation.Status.PENDING
+
+ logger.info(
+ f"Sending inviation email to {invitation.invitee_email} for org {invitation.org}..."
+ )
+ send_email_via_postmark(
+ to_address=invitation.invitee_email,
+ from_address=settings.SUPPORT_EMAIL,
+ subject=f"[Gooey.AI] Invitation to join {invitation.org.name}",
+ html_body=templates.get_template("org_invitation_email.html").render(
+ settings=settings,
+ invitation=invitation,
+ ),
+ message_stream="outbound",
+ )
+
+ invitation.last_email_sent_at = timezone.now()
+ invitation.save()
+ logger.info("Invitation sent. Saved to DB")
+
+
+@app.task
+def send_auto_accepted_email(invitation_pk: int):
+ from orgs.models import OrgInvitation
+ from routers.account import orgs_route
+
+ invitation = OrgInvitation.objects.get(pk=invitation_pk)
+ assert invitation.auto_accepted and invitation.status == invitation.Status.ACCEPTED
+ assert invitation.status_changed_by
+
+ user = invitation.status_changed_by
+ if not user.email:
+ logger.warning(f"User {user} has no email. Skipping auto-accepted email.")
+ return
+
+ logger.info(
+ f"Sending auto-accepted email to {user.email} for org {invitation.org}..."
+ )
+ send_email_via_postmark(
+ to_address=user.email,
+ from_address=settings.SUPPORT_EMAIL,
+ subject=f"[Gooey.AI] You've been added to a new team!",
+ html_body=templates.get_template(
+ "org_invitation_auto_accepted_email.html"
+ ).render(
+ settings=settings,
+ user=user,
+ org=invitation.org,
+ orgs_url=get_app_route_url(orgs_route),
+ ),
+ message_stream="outbound",
+ )
From 10f28ec94cf6c1c7011bf99bc315d9fcd4d6f6bb Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 7 Aug 2024 20:36:45 +0530
Subject: [PATCH 038/110] gooey gui renaming
---
orgs/views.py | 27 +++++++++++++++------------
1 file changed, 15 insertions(+), 12 deletions(-)
diff --git a/orgs/views.py b/orgs/views.py
index 78a38bfc4..f324bb140 100644
--- a/orgs/views.py
+++ b/orgs/views.py
@@ -2,11 +2,10 @@
import html as html_lib
+import gooey_gui as gui
from django.core.exceptions import ValidationError
-import gooey_ui as st
from app_users.models import AppUser
-from gooey_ui.components.modal import Modal
from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole
from daras_ai_v2 import icons
from daras_ai_v2.fastapi_tricks import get_route_path
@@ -85,7 +84,7 @@ def render_org_by_membership(membership: OrgMembership):
):
with st.div(className="d-flex justify-content-center align-items-center"):
if membership.can_edit_org_metadata():
- org_edit_modal = Modal("Edit Org", key="edit-org-modal")
+ org_edit_modal = gui.Modal("Edit Org", key="edit-org-modal")
if org_edit_modal.is_open():
with org_edit_modal.container():
render_org_edit_view_by_membership(
@@ -113,7 +112,7 @@ def render_org_by_membership(membership: OrgMembership):
st.write("## Members")
if membership.can_invite():
- invite_modal = Modal("Invite Member", key="invite-member-modal")
+ invite_modal = gui.Modal("Invite Member", key="invite-member-modal")
if st.button(f"{icons.add_user} Invite"):
invite_modal.open()
@@ -129,7 +128,7 @@ def render_org_by_membership(membership: OrgMembership):
render_pending_invitations_list(org=org, current_member=membership)
with st.div(className="mt-4"):
- org_leave_modal = Modal("Leave Org", key="leave-org-modal")
+ org_leave_modal = gui.Modal("Leave Org", key="leave-org-modal")
if org_leave_modal.is_open():
with org_leave_modal.container():
render_org_leave_view_by_membership(membership, modal=org_leave_modal)
@@ -159,7 +158,7 @@ def render_org_creation_view(user: AppUser):
st.experimental_rerun()
-def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: Modal):
+def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: gui.Modal):
org = membership.org
render_org_create_or_edit_form(org=org)
@@ -182,7 +181,7 @@ def render_danger_zone_by_membership(membership: OrgMembership):
st.write("### Danger Zone", className="d-block my-2")
if membership.can_delete_org():
- org_deletion_modal = Modal("Delete Organization", key="delete-org-modal")
+ org_deletion_modal = gui.Modal("Delete Organization", key="delete-org-modal")
if org_deletion_modal.is_open():
with org_deletion_modal.container():
render_org_deletion_view_by_membership(
@@ -198,7 +197,9 @@ def render_danger_zone_by_membership(membership: OrgMembership):
org_deletion_modal.open()
-def render_org_deletion_view_by_membership(membership: OrgMembership, *, modal: Modal):
+def render_org_deletion_view_by_membership(
+ membership: OrgMembership, *, modal: gui.Modal
+):
st.write(
f"Are you sure you want to delete **{membership.org.name}**? This action is irreversible."
)
@@ -216,7 +217,9 @@ def render_org_deletion_view_by_membership(membership: OrgMembership, *, modal:
modal.close()
-def render_org_leave_view_by_membership(current_member: OrgMembership, *, modal: Modal):
+def render_org_leave_view_by_membership(
+ current_member: OrgMembership, *, modal: gui.Modal
+):
org = current_member.org
st.write("Are you sure you want to leave this organization?")
@@ -345,12 +348,12 @@ def button_with_confirmation_modal(
modal_key: str | None = None,
modal_className: str = "",
**btn_props,
-) -> tuple[Modal, bool]:
+) -> tuple[gui.Modal, bool]:
"""
Returns boolean for whether user confirmed the action or not.
"""
- modal = Modal(modal_title or btn_label, key=modal_key)
+ modal = gui.Modal(modal_title or btn_label, key=modal_key)
btn_classes = "btn btn-theme btn-sm my-0 py-0 " + btn_props.pop("className", "")
if st.button(btn_label, className=btn_classes, **btn_props):
@@ -447,7 +450,7 @@ def render_invitation_actions(invitation: OrgInvitation, current_member: OrgMemb
modal.close()
-def render_invite_creation_view(org: Org, inviter: AppUser, modal: Modal):
+def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal):
email = st.text_input("Email")
if org.domain_name:
st.caption(
From 0eab8eef13b6bf462f329a29ae94c489ff1ea8b5 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 8 Aug 2024 14:04:16 +0530
Subject: [PATCH 039/110] procfile: use && instead of ; between cd and npm run
---
Procfile | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Procfile b/Procfile
index 8711211c2..984315504 100644
--- a/Procfile
+++ b/Procfile
@@ -19,4 +19,4 @@ dashboard: poetry run streamlit run Home.py --server.port 8501 --server.headless
celery: poetry run celery -A celeryapp worker -P threads -c 16 -l DEBUG
-ui: cd ../gooey-gui/; PORT=3000 npm run dev
+ui: cd ../gooey-gui/ && env PORT=3000 npm run dev
From a349eeaad55b9f6d73d5d5b080ce6c1074288f8e Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Thu, 8 Aug 2024 14:05:41 +0530
Subject: [PATCH 040/110] rename st->gui in account.py
---
routers/account.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/routers/account.py b/routers/account.py
index e79cc75dc..3c3e9c881 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -143,7 +143,7 @@ def api_keys_route(request: Request):
@app.post("/orgs/")
-@st.route
+@gui.route
def orgs_route(request: Request):
with account_page_wrapper(request, AccountTabs.orgs):
orgs_tab(request)
@@ -161,7 +161,7 @@ def orgs_route(request: Request):
@app.post("/invitation/{org_slug}/{invite_id}/")
-@st.route
+@gui.route
def invitation_route(request: Request, org_slug: str, invite_id: str):
from routers.root import login
From 437025503b80a30b48cb1a28e083099f847e5e3c Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 12 Aug 2024 18:49:57 +0530
Subject: [PATCH 041/110] rename st -> gui
---
orgs/views.py | 206 ++++++++++++++++++++++-----------------------
routers/account.py | 8 +-
2 files changed, 106 insertions(+), 108 deletions(-)
diff --git a/orgs/views.py b/orgs/views.py
index f324bb140..ed864cb94 100644
--- a/orgs/views.py
+++ b/orgs/views.py
@@ -19,41 +19,41 @@ def invitation_page(user: AppUser, invitation: OrgInvitation):
orgs_page_path = get_route_path(orgs_route)
- with st.div(className="text-center my-5"):
- st.write(
+ with gui.div(className="text-center my-5"):
+ gui.write(
f"# Invitation to join {invitation.org.name}", className="d-block mb-5"
)
if invitation.org.memberships.filter(user=user).exists():
# redirect to org page
- raise st.RedirectException(orgs_page_path)
+ raise gui.RedirectException(orgs_page_path)
if invitation.status != OrgInvitation.Status.PENDING:
- st.write(f"This invitation has been {invitation.get_status_display()}.")
+ gui.write(f"This invitation has been {invitation.get_status_display()}.")
return
- st.write(
+ gui.write(
f"**{format_user_name(invitation.inviter)}** has invited you to join **{invitation.org.name}**."
)
if other_m := user.org_memberships.first():
- st.caption(
+ gui.caption(
f"You are currently a member of [{other_m.org.name}]({orgs_page_path}). You will be removed from that team if you accept this invitation."
)
accept_label = "Leave and Accept"
else:
accept_label = "Accept"
- with st.div(
+ with gui.div(
className="d-flex justify-content-center align-items-center mx-auto",
style={"max-width": "600px"},
):
- accept_button = st.button(accept_label, type="primary", className="w-50")
- reject_button = st.button("Decline", type="secondary", className="w-50")
+ accept_button = gui.button(accept_label, type="primary", className="w-50")
+ reject_button = gui.button("Decline", type="secondary", className="w-50")
if accept_button:
invitation.accept(user=user)
- raise st.RedirectException(orgs_page_path)
+ raise gui.RedirectException(orgs_page_path)
if reject_button:
invitation.reject(user=user)
@@ -61,7 +61,7 @@ def invitation_page(user: AppUser, invitation: OrgInvitation):
def orgs_page(user: AppUser):
memberships = user.org_memberships.all()
if not memberships:
- st.write("*You're not part of an organization yet... Create one?*")
+ gui.write("*You're not part of an organization yet... Create one?*")
render_org_creation_view(user)
else:
@@ -79,10 +79,10 @@ def render_org_by_membership(membership: OrgMembership):
org = membership.org
current_user = membership.user
- with st.div(
+ with gui.div(
className="d-xs-block d-sm-flex flex-row-reverse justify-content-between"
):
- with st.div(className="d-flex justify-content-center align-items-center"):
+ with gui.div(className="d-flex justify-content-center align-items-center"):
if membership.can_edit_org_metadata():
org_edit_modal = gui.Modal("Edit Org", key="edit-org-modal")
if org_edit_modal.is_open():
@@ -91,29 +91,29 @@ def render_org_by_membership(membership: OrgMembership):
membership, modal=org_edit_modal
)
- if st.button(f"{icons.edit} Edit", type="secondary"):
+ if gui.button(f"{icons.edit} Edit", type="secondary"):
org_edit_modal.open()
- with st.div(className="d-flex align-items-center"):
- st.image(
+ with gui.div(className="d-flex align-items-center"):
+ gui.image(
org.logo or DEFAULT_ORG_LOGO,
className="my-0 me-4 rounded",
style={"width": "128px", "height": "128px", "object-fit": "contain"},
)
- with st.div(className="d-flex flex-column justify-content-center"):
- st.write(f"# {org.name}")
+ with gui.div(className="d-flex flex-column justify-content-center"):
+ gui.write(f"# {org.name}")
if org.domain_name:
- st.write(
+ gui.write(
f"Org Domain: `@{org.domain_name}`", className="text-muted"
)
- with st.div(className="mt-4"):
- with st.div(className="d-flex justify-content-between align-items-center"):
- st.write("## Members")
+ with gui.div(className="mt-4"):
+ with gui.div(className="d-flex justify-content-between align-items-center"):
+ gui.write("## Members")
if membership.can_invite():
invite_modal = gui.Modal("Invite Member", key="invite-member-modal")
- if st.button(f"{icons.add_user} Invite"):
+ if gui.button(f"{icons.add_user} Invite"):
invite_modal.open()
if invite_modal.is_open():
@@ -124,17 +124,17 @@ def render_org_by_membership(membership: OrgMembership):
render_members_list(org=org, current_member=membership)
- with st.div(className="mt-4"):
+ with gui.div(className="mt-4"):
render_pending_invitations_list(org=org, current_member=membership)
- with st.div(className="mt-4"):
+ with gui.div(className="mt-4"):
org_leave_modal = gui.Modal("Leave Org", key="leave-org-modal")
if org_leave_modal.is_open():
with org_leave_modal.container():
render_org_leave_view_by_membership(membership, modal=org_leave_modal)
- with st.div(className="text-end"):
- leave_org = st.button(
+ with gui.div(className="text-end"):
+ leave_org = gui.button(
"Leave",
className="btn btn-theme bg-danger border-danger text-white",
)
@@ -143,42 +143,42 @@ def render_org_by_membership(membership: OrgMembership):
def render_org_creation_view(user: AppUser):
- st.write(f"# {icons.company} Create an Org", unsafe_allow_html=True)
+ gui.write(f"# {icons.company} Create an Org", unsafe_allow_html=True)
org_fields = render_org_create_or_edit_form()
- if st.button("Create"):
+ if gui.button("Create"):
try:
Org.objects.create_org(
created_by=user,
**org_fields,
)
except ValidationError as e:
- st.write(", ".join(e.messages), className="text-danger")
+ gui.write(", ".join(e.messages), className="text-danger")
else:
- st.experimental_rerun()
+ gui.experimental_rerun()
def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: gui.Modal):
org = membership.org
render_org_create_or_edit_form(org=org)
- if st.button("Save", className="w-100", type="primary"):
+ if gui.button("Save", className="w-100", type="primary"):
try:
org.full_clean()
except ValidationError as e:
# newlines in markdown
- st.write(" \n".join(e.messages), className="text-danger")
+ gui.write(" \n".join(e.messages), className="text-danger")
else:
org.save()
modal.close()
if membership.can_delete_org() or membership.can_transfer_ownership():
- st.write("---")
+ gui.write("---")
render_danger_zone_by_membership(membership)
def render_danger_zone_by_membership(membership: OrgMembership):
- st.write("### Danger Zone", className="d-block my-2")
+ gui.write("### Danger Zone", className="d-block my-2")
if membership.can_delete_org():
org_deletion_modal = gui.Modal("Delete Organization", key="delete-org-modal")
@@ -188,9 +188,9 @@ def render_danger_zone_by_membership(membership: OrgMembership):
membership, modal=org_deletion_modal
)
- with st.div(className="d-flex justify-content-between align-items-center"):
- st.write("Delete Organization")
- if st.button(
+ with gui.div(className="d-flex justify-content-between align-items-center"):
+ gui.write("Delete Organization")
+ if gui.button(
f"{icons.delete} Delete",
className="btn btn-theme py-2 bg-danger border-danger text-white",
):
@@ -200,17 +200,17 @@ def render_danger_zone_by_membership(membership: OrgMembership):
def render_org_deletion_view_by_membership(
membership: OrgMembership, *, modal: gui.Modal
):
- st.write(
+ gui.write(
f"Are you sure you want to delete **{membership.org.name}**? This action is irreversible."
)
- with st.div(className="d-flex"):
- if st.button(
+ with gui.div(className="d-flex"):
+ if gui.button(
"Cancel", type="secondary", className="border-danger text-danger w-50"
):
modal.close()
- if st.button(
+ if gui.button(
"Delete", className="btn btn-theme bg-danger border-danger text-light w-50"
):
membership.org.delete()
@@ -222,11 +222,11 @@ def render_org_leave_view_by_membership(
):
org = current_member.org
- st.write("Are you sure you want to leave this organization?")
+ gui.write("Are you sure you want to leave this organization?")
new_owner = None
if current_member.role == OrgRole.OWNER and org.memberships.count() == 1:
- st.caption(
+ gui.caption(
"You are the only member. You will lose access to this team if you leave."
)
elif (
@@ -239,23 +239,23 @@ def render_org_leave_view_by_membership(
if m != current_member
}
- st.caption(
+ gui.caption(
"You are the only owner of this organization. Please choose another member to promote to owner."
)
- new_owner_uid = st.selectbox(
+ new_owner_uid = gui.selectbox(
"New Owner",
options=list(members_by_uid),
format_func=lambda uid: format_user_name(members_by_uid[uid].user),
)
new_owner = members_by_uid[new_owner_uid]
- with st.div(className="d-flex"):
- if st.button(
+ with gui.div(className="d-flex"):
+ if gui.button(
"Cancel", type="secondary", className="border-danger text-danger w-50"
):
modal.close()
- if st.button(
+ if gui.button(
"Leave", className="btn btn-theme bg-danger border-danger text-light w-50"
):
if new_owner:
@@ -266,34 +266,34 @@ def render_org_leave_view_by_membership(
def render_members_list(org: Org, current_member: OrgMembership):
- with st.tag("table", className="table table-responsive"):
- with st.tag("thead"), st.tag("tr"):
- with st.tag("th", scope="col"):
- st.html("Name")
- with st.tag("th", scope="col"):
- st.html("Role")
- with st.tag("th", scope="col"):
- st.html(f"{icons.time} Since")
- with st.tag("th", scope="col"):
- st.html("")
-
- with st.tag("tbody"):
+ with gui.tag("table", className="table table-responsive"):
+ with gui.tag("thead"), gui.tag("tr"):
+ with gui.tag("th", scope="col"):
+ gui.html("Name")
+ with gui.tag("th", scope="col"):
+ gui.html("Role")
+ with gui.tag("th", scope="col"):
+ gui.html(f"{icons.time} Since")
+ with gui.tag("th", scope="col"):
+ gui.html("")
+
+ with gui.tag("tbody"):
for m in org.memberships.all().order_by("created_at"):
- with st.tag("tr"):
- with st.tag("td"):
+ with gui.tag("tr"):
+ with gui.tag("td"):
name = format_user_name(
m.user, current_user=current_member.user
)
if m.user.handle_id:
- with st.link(to=m.user.handle.get_app_url()):
- st.html(html_lib.escape(name))
+ with gui.link(to=m.user.handle.get_app_url()):
+ gui.html(html_lib.escape(name))
else:
- st.html(html_lib.escape(name))
- with st.tag("td"):
- st.html(m.get_role_display())
- with st.tag("td"):
- st.html(m.created_at.strftime("%b %d, %Y"))
- with st.tag("td", className="text-end"):
+ gui.html(html_lib.escape(name))
+ with gui.tag("td"):
+ gui.html(m.get_role_display())
+ with gui.tag("td"):
+ gui.html(m.created_at.strftime("%b %d, %Y"))
+ with gui.tag("td", className="text-end"):
render_membership_actions(m, current_member=current_member)
@@ -356,21 +356,21 @@ def button_with_confirmation_modal(
modal = gui.Modal(modal_title or btn_label, key=modal_key)
btn_classes = "btn btn-theme btn-sm my-0 py-0 " + btn_props.pop("className", "")
- if st.button(btn_label, className=btn_classes, **btn_props):
+ if gui.button(btn_label, className=btn_classes, **btn_props):
modal.open()
if modal.is_open():
with modal.container(className=modal_className):
- st.write(confirmation_text)
- with st.div(className="d-flex"):
- if st.button(
+ gui.write(confirmation_text)
+ with gui.div(className="d-flex"):
+ if gui.button(
"Cancel",
type="secondary",
className="border-danger text-danger w-50",
):
modal.close()
- confirmed = st.button(
+ confirmed = gui.button(
"Confirm",
className="btn btn-theme bg-danger border-danger text-light w-50",
)
@@ -384,35 +384,35 @@ def render_pending_invitations_list(org: Org, *, current_member: OrgMembership):
if not pending_invitations:
return
- st.write("## Pending")
- with st.tag("table", className="table table-responsive"):
- with st.tag("thead"), st.tag("tr"):
- with st.tag("th", scope="col"):
- st.html("Email")
- with st.tag("th", scope="col"):
- st.html("Invited By")
- with st.tag("th", scope="col"):
- st.html(f"{icons.time} Last invited on")
- with st.tag("th", scope="col"):
+ gui.write("## Pending")
+ with gui.tag("table", className="table table-responsive"):
+ with gui.tag("thead"), gui.tag("tr"):
+ with gui.tag("th", scope="col"):
+ gui.html("Email")
+ with gui.tag("th", scope="col"):
+ gui.html("Invited By")
+ with gui.tag("th", scope="col"):
+ gui.html(f"{icons.time} Last invited on")
+ with gui.tag("th", scope="col"):
pass
- with st.tag("tbody"):
+ with gui.tag("tbody"):
for invite in pending_invitations:
- with st.tag("tr", className="text-break"):
- with st.tag("td"):
- st.html(html_lib.escape(invite.invitee_email))
- with st.tag("td"):
- st.html(
+ with gui.tag("tr", className="text-break"):
+ with gui.tag("td"):
+ gui.html(html_lib.escape(invite.invitee_email))
+ with gui.tag("td"):
+ gui.html(
html_lib.escape(
format_user_name(
invite.inviter, current_user=current_member.user
)
)
)
- with st.tag("td"):
+ with gui.tag("td"):
last_invited_at = invite.last_email_sent_at or invite.created_at
- st.html(last_invited_at.strftime("%b %d, %Y"))
- with st.tag("td", className="text-end"):
+ gui.html(last_invited_at.strftime("%b %d, %Y"))
+ with gui.tag("td", className="text-end"):
render_invitation_actions(invite, current_member=current_member)
@@ -451,13 +451,13 @@ def render_invitation_actions(invitation: OrgInvitation, current_member: OrgMemb
def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal):
- email = st.text_input("Email")
+ email = gui.text_input("Email")
if org.domain_name:
- st.caption(
+ gui.caption(
f"Users with `@{org.domain_name}` email will be added automatically."
)
- if st.button(f"{icons.add_user} Invite", type="primary", unsafe_allow_html=True):
+ if gui.button(f"{icons.add_user} Invite", type="primary", unsafe_allow_html=True):
try:
org.invite_user(
invitee_email=email,
@@ -466,7 +466,7 @@ def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal):
auto_accept=org.domain_name.lower() == email.split("@")[1].lower(),
)
except ValidationError as e:
- st.write(", ".join(e.messages), className="text-danger")
+ gui.write(", ".join(e.messages), className="text-danger")
else:
modal.close()
@@ -474,17 +474,17 @@ def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal):
def render_org_create_or_edit_form(org: Org | None = None) -> AttrDict | Org:
org_proxy = org or AttrDict()
- org_proxy.name = st.text_input("Team Name", value=org and org.name or "")
- org_proxy.logo = st.file_uploader(
+ org_proxy.name = gui.text_input("Team Name", value=org and org.name or "")
+ org_proxy.logo = gui.file_uploader(
"Logo", accept=["image/*"], value=org and org.logo or ""
)
- org_proxy.domain_name = st.text_input(
+ org_proxy.domain_name = gui.text_input(
"Domain Name (Optional)",
placeholder="e.g. gooey.ai",
value=org and org.domain_name or "",
)
if org_proxy.domain_name:
- st.caption(
+ gui.caption(
f"Invite any user with `@{org_proxy.domain_name}` email to this organization."
)
diff --git a/routers/account.py b/routers/account.py
index 3c3e9c881..a898501db 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -3,9 +3,9 @@
from enum import Enum
import gooey_gui as gui
-from fastapi import APIRouter
from fastapi.requests import Request
from furl import furl
+from gooey_gui.core import RedirectException
from loguru import logger
from requests.models import HTTPError
from starlette.responses import Response
@@ -142,8 +142,7 @@ def api_keys_route(request: Request):
)
-@app.post("/orgs/")
-@gui.route
+@gui.route(app, "/orgs/")
def orgs_route(request: Request):
with account_page_wrapper(request, AccountTabs.orgs):
orgs_tab(request)
@@ -160,8 +159,7 @@ def orgs_route(request: Request):
)
-@app.post("/invitation/{org_slug}/{invite_id}/")
-@gui.route
+@gui.route(app, "/invitation/{org_slug}/{invite_id}/")
def invitation_route(request: Request, org_slug: str, invite_id: str):
from routers.root import login
From 3c06688ff79f925e09f3214e7fb4eab222027d64 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Fri, 30 Aug 2024 13:30:03 +0530
Subject: [PATCH 042/110] make org page only accessible to admins
---
routers/account.py | 23 ++++++++++++++++++++++-
1 file changed, 22 insertions(+), 1 deletion(-)
diff --git a/routers/account.py b/routers/account.py
index a898501db..b52239b2b 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -256,9 +256,30 @@ def api_keys_tab(request: Request):
def orgs_tab(request: Request):
+ """only accessible to admins"""
+ from daras_ai_v2.base import BasePage
+
+ if not BasePage.is_user_admin(request.user):
+ raise RedirectException(get_route_path(account_route))
+
orgs_page(request.user)
+def get_tabs(request: Request) -> list[AccountTabs]:
+ from daras_ai_v2.base import BasePage
+
+ tab_list = [
+ AccountTabs.billing,
+ AccountTabs.profile,
+ AccountTabs.saved,
+ AccountTabs.api_keys,
+ ]
+ if BasePage.is_user_admin(request.user):
+ tab_list.append(AccountTabs.orgs)
+
+ return tab_list
+
+
@contextmanager
def account_page_wrapper(request: Request, current_tab: TabData):
if not request.user or request.user.is_anonymous:
@@ -269,7 +290,7 @@ def account_page_wrapper(request: Request, current_tab: TabData):
with page_wrapper(request):
gui.div(className="mt-5")
with gui.nav_tabs():
- for tab in AccountTabs:
+ for tab in get_tabs(request):
with gui.nav_item(tab.url_path, active=tab == current_tab):
gui.html(tab.title)
From e5ddad713cdbde3fc3a83348e6b41cf0de0e988d Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Fri, 30 Aug 2024 13:35:26 +0530
Subject: [PATCH 043/110] Add org support with role and UI view
---
daras_ai_v2/icons.py | 1 -
gooey_ui/components/__init__.py | 1009 +++++++++++++++++++++++++++++++
2 files changed, 1009 insertions(+), 1 deletion(-)
create mode 100644 gooey_ui/components/__init__.py
diff --git a/daras_ai_v2/icons.py b/daras_ai_v2/icons.py
index 30dcc1e01..90334bbc3 100644
--- a/daras_ai_v2/icons.py
+++ b/daras_ai_v2/icons.py
@@ -24,7 +24,6 @@
admin = ''
remove_user = ''
add_user = ''
-transfer = ''
# brands
github = ''
diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py
new file mode 100644
index 000000000..2c27edd1d
--- /dev/null
+++ b/gooey_ui/components/__init__.py
@@ -0,0 +1,1009 @@
+import base64
+import html as html_lib
+import math
+import textwrap
+import typing
+from datetime import datetime, timezone
+
+import numpy as np
+from furl import furl
+
+from daras_ai.image_input import resize_img_scale
+from gooey_ui import state
+from gooey_ui.pubsub import md5_values
+
+T = typing.TypeVar("T")
+LabelVisibility = typing.Literal["visible", "collapsed"]
+
+BLANK_OPTION = "———"
+
+
+def _default_format(value: typing.Any) -> str:
+ if value is None:
+ return BLANK_OPTION
+ return str(value)
+
+
+def dummy(*args, **kwargs):
+ return state.NestingCtx()
+
+
+spinner = dummy
+set_page_config = dummy
+form = dummy
+dataframe = dummy
+
+
+def countdown_timer(
+ end_time: datetime,
+ delay_text: str,
+) -> state.NestingCtx:
+ return _node(
+ "countdown-timer",
+ endTime=end_time.astimezone(timezone.utc).isoformat(),
+ delayText=delay_text,
+ )
+
+
+def nav_tabs():
+ return _node("nav-tabs")
+
+
+def nav_item(href: str, *, active: bool):
+ return _node("nav-item", to=href, active="true" if active else None)
+
+
+def nav_tab_content():
+ return _node("nav-tab-content")
+
+
+def div(**props) -> state.NestingCtx:
+ return tag("div", **props)
+
+
+def link(*, to: str, **props) -> state.NestingCtx:
+ return _node("Link", to=to, **props)
+
+
+def tag(tag_name: str, **props) -> state.NestingCtx:
+ props["__reactjsxelement"] = tag_name
+ return _node("tag", **props)
+
+
+def html(body: str, **props):
+ props["className"] = props.get("className", "") + " gui-html-container"
+ return _node("html", body=body, **props)
+
+
+def write(*objs: typing.Any, line_clamp: int = None, unsafe_allow_html=False, **props):
+ for obj in objs:
+ markdown(
+ obj if isinstance(obj, str) else repr(obj),
+ line_clamp=line_clamp,
+ unsafe_allow_html=unsafe_allow_html,
+ **props,
+ )
+
+
+def center(direction="flex-column", className="") -> state.NestingCtx:
+ return div(
+ className=f"d-flex justify-content-center align-items-center text-center {direction} {className}"
+ )
+
+
+def newline():
+ html(" ")
+
+
+def markdown(
+ body: str | None, *, line_clamp: int = None, unsafe_allow_html=False, **props
+):
+ if body is None:
+ return _node("markdown", body="", **props)
+ if not unsafe_allow_html:
+ body = html_lib.escape(body)
+ props["className"] = (
+ props.get("className", "") + " gui-html-container gui-md-container"
+ )
+ return _node("markdown", body=dedent(body).strip(), lineClamp=line_clamp, **props)
+
+
+def _node(name: str, **props):
+ node = state.RenderTreeNode(name=name, props=props)
+ node.mount()
+ return state.NestingCtx(node)
+
+
+def text(body: str, **props):
+ state.RenderTreeNode(
+ name="pre",
+ props=dict(body=dedent(body), **props),
+ ).mount()
+
+
+def error(
+ body: str,
+ icon: str = "🔥",
+ *,
+ unsafe_allow_html=False,
+ color="rgba(255, 108, 108, 0.2)",
+ **props,
+):
+ if not isinstance(body, str):
+ body = repr(body)
+ with div(
+ style=dict(
+ backgroundColor=color,
+ padding="1rem",
+ paddingBottom="0",
+ marginBottom="0.5rem",
+ borderRadius="0.25rem",
+ display="flex",
+ gap="0.5rem",
+ )
+ ):
+ markdown(icon)
+ with div():
+ markdown(dedent(body), unsafe_allow_html=unsafe_allow_html, **props)
+
+
+def success(body: str, icon: str = "✅", *, unsafe_allow_html=False):
+ if not isinstance(body, str):
+ body = repr(body)
+ with div(
+ style=dict(
+ backgroundColor="rgba(108, 255, 108, 0.2)",
+ padding="1rem",
+ paddingBottom="0",
+ marginBottom="0.5rem",
+ borderRadius="0.25rem",
+ display="flex",
+ gap="0.5rem",
+ )
+ ):
+ markdown(icon)
+ markdown(dedent(body), unsafe_allow_html=unsafe_allow_html)
+
+
+def caption(body: str, className: str = None, **props):
+ className = className or "text-muted"
+ markdown(body, className=className, **props)
+
+
+def tabs(labels: list[str]) -> list[state.NestingCtx]:
+ parent = state.RenderTreeNode(
+ name="tabs",
+ children=[
+ state.RenderTreeNode(
+ name="tab",
+ props=dict(label=dedent(label)),
+ )
+ for label in labels
+ ],
+ ).mount()
+ return [state.NestingCtx(tab) for tab in parent.children]
+
+
+def controllable_tabs(
+ labels: list[str], key: str
+) -> tuple[list[state.NestingCtx], int]:
+ index = state.session_state.get(key, 0)
+ for i, label in enumerate(labels):
+ if button(
+ label,
+ key=f"tab-{i}",
+ type="primary",
+ className="replicate-nav",
+ style={
+ "background": "black" if i == index else "white",
+ "color": "white" if i == index else "black",
+ },
+ ):
+ state.session_state[key] = index = i
+ state.experimental_rerun()
+ ctxs = []
+ for i, label in enumerate(labels):
+ if i == index:
+ ctxs += [div(className="tab-content")]
+ else:
+ ctxs += [div(className="tab-content", style={"display": "none"})]
+ return ctxs, index
+
+
+def columns(
+ spec,
+ *,
+ gap: str = None,
+ responsive: bool = True,
+ column_props: dict = {},
+ **props,
+) -> tuple[state.NestingCtx, ...]:
+ if isinstance(spec, int):
+ spec = [1] * spec
+ total_weight = sum(spec)
+ props.setdefault("className", "row")
+ with div(**props):
+ return tuple(
+ div(
+ className=f"col-lg-{p} {'col-12' if responsive else f'col-{p}'}",
+ **column_props,
+ )
+ for w in spec
+ if (p := f"{round(w / total_weight * 12)}")
+ )
+
+
+def image(
+ src: str | np.ndarray,
+ caption: str = None,
+ alt: str = None,
+ href: str = None,
+ show_download_button: bool = False,
+ **props,
+):
+ if isinstance(src, np.ndarray):
+ from daras_ai.image_input import cv2_img_to_bytes
+
+ if not src.shape:
+ return
+ # ensure image is not too large
+ data = resize_img_scale(cv2_img_to_bytes(src), (128, 128))
+ # convert to base64
+ b64 = base64.b64encode(data).decode("utf-8")
+ src = "data:image/png;base64," + b64
+ if not src:
+ return
+ state.RenderTreeNode(
+ name="img",
+ props=dict(
+ src=src,
+ caption=dedent(caption),
+ alt=alt or caption,
+ href=href,
+ **props,
+ ),
+ ).mount()
+ if show_download_button:
+ download_button(
+ label=' Download', url=src
+ )
+
+
+def video(
+ src: str,
+ caption: str = None,
+ autoplay: bool = False,
+ show_download_button: bool = False,
+):
+ autoplay_props = {}
+ if autoplay:
+ autoplay_props = {
+ "preload": "auto",
+ "controls": True,
+ "autoPlay": True,
+ "loop": True,
+ "muted": True,
+ "playsInline": True,
+ }
+
+ if not src:
+ return
+ if isinstance(src, str):
+ # https://muffinman.io/blog/hack-for-ios-safari-to-display-html-video-thumbnail/
+ f = furl(src)
+ f.fragment.args["t"] = "0.001"
+ src = f.url
+ state.RenderTreeNode(
+ name="video",
+ props=dict(src=src, caption=dedent(caption), **autoplay_props),
+ ).mount()
+ if show_download_button:
+ download_button(
+ label=' Download', url=src
+ )
+
+
+def audio(src: str, caption: str = None, show_download_button: bool = False):
+ if not src:
+ return
+ state.RenderTreeNode(
+ name="audio",
+ props=dict(src=src, caption=dedent(caption)),
+ ).mount()
+ if show_download_button:
+ download_button(
+ label=' Download', url=src
+ )
+
+
+def text_area(
+ label: str,
+ value: str = "",
+ height: int = 500,
+ key: str = None,
+ help: str = None,
+ placeholder: str = None,
+ disabled: bool = False,
+ label_visibility: LabelVisibility = "visible",
+ **props,
+) -> str:
+ style = props.setdefault("style", {})
+ # if key:
+ # assert not value, "only one of value or key can be provided"
+ # else:
+ if not key:
+ key = md5_values(
+ "textarea",
+ label,
+ height,
+ help,
+ placeholder,
+ label_visibility,
+ not disabled or value,
+ )
+ value = str(state.session_state.setdefault(key, value) or "")
+ if label_visibility != "visible":
+ label = None
+ if disabled:
+ max_height = f"{height}px"
+ rows = nrows_for_text(value, height)
+ else:
+ max_height = "50vh"
+ rows = nrows_for_text(value, height)
+ style.setdefault("maxHeight", max_height)
+ props.setdefault("rows", rows)
+ state.RenderTreeNode(
+ name="textarea",
+ props=dict(
+ name=key,
+ label=dedent(label),
+ defaultValue=value,
+ help=help,
+ placeholder=placeholder,
+ disabled=disabled,
+ **props,
+ ),
+ ).mount()
+ return value or ""
+
+
+def nrows_for_text(
+ text: str,
+ max_height_px: int,
+ min_rows: int = 1,
+ row_height_px: int = 30,
+ row_width_px: int = 70,
+) -> int:
+ max_rows = max_height_px // row_height_px
+ nrows = math.ceil(
+ sum(
+ math.ceil(len(line) / row_width_px)
+ for line in (text or "").splitlines(keepends=True)
+ )
+ )
+ nrows = min(max(nrows, min_rows), max_rows)
+ return nrows
+
+
+def multiselect(
+ label: str,
+ options: typing.Sequence[T],
+ format_func: typing.Callable[[T], typing.Any] = _default_format,
+ key: str = None,
+ help: str = None,
+ allow_none: bool = False,
+ *,
+ disabled: bool = False,
+) -> list[T]:
+ if not options:
+ return []
+ options = list(options)
+ if not key:
+ key = md5_values("multiselect", label, options, help)
+ value = state.session_state.get(key) or []
+ if not isinstance(value, list):
+ value = [value]
+ value = [o for o in value if o in options]
+ if not allow_none and not value:
+ value = [options[0]]
+ state.session_state[key] = value
+ state.RenderTreeNode(
+ name="select",
+ props=dict(
+ name=key,
+ label=dedent(label),
+ help=help,
+ isDisabled=disabled,
+ isMulti=True,
+ defaultValue=value,
+ allow_none=allow_none,
+ options=[
+ {"value": option, "label": str(format_func(option))}
+ for option in options
+ ],
+ ),
+ ).mount()
+ return value
+
+
+def selectbox(
+ label: str,
+ options: typing.Iterable[T],
+ format_func: typing.Callable[[T], typing.Any] = _default_format,
+ key: str = None,
+ help: str = None,
+ *,
+ disabled: bool = False,
+ label_visibility: LabelVisibility = "visible",
+ value: T = None,
+ allow_none: bool = False,
+ **props,
+) -> T | None:
+ if not options:
+ return None
+ if label_visibility != "visible":
+ label = None
+ options = list(options)
+ if allow_none:
+ options.insert(0, None)
+ if not key:
+ key = md5_values("select", label, options, help, label_visibility)
+ value = state.session_state.setdefault(key, value)
+ if value not in options:
+ value = state.session_state[key] = options[0]
+ state.RenderTreeNode(
+ name="select",
+ props=dict(
+ name=key,
+ label=dedent(label),
+ help=help,
+ isDisabled=disabled,
+ defaultValue=value,
+ options=[
+ {"value": option, "label": str(format_func(option))}
+ for option in options
+ ],
+ **props,
+ ),
+ ).mount()
+ return value
+
+
+def download_button(
+ label: str,
+ url: str,
+ key: str = None,
+ help: str = None,
+ *,
+ type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary",
+ disabled: bool = False,
+ **props,
+) -> bool:
+ url = furl(url).remove(fragment=True).url
+ return button(
+ component="download-button",
+ url=url,
+ label=label,
+ key=key,
+ help=help,
+ type=type,
+ disabled=disabled,
+ **props,
+ )
+
+
+def button(
+ label: str,
+ key: str = None,
+ help: str = None,
+ *,
+ type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary",
+ disabled: bool = False,
+ component: typing.Literal["download-button", "gui-button"] = "gui-button",
+ **props,
+) -> bool:
+ """
+ Example:
+ st.button("Primary", key="test0", type="primary")
+ st.button("Secondary", key="test1")
+ st.button("Tertiary", key="test3", type="tertiary")
+ st.button("Link Button", key="test3", type="link")
+ """
+ if not key:
+ key = md5_values("button", label, help, type, props)
+ className = f"btn-{type} " + props.pop("className", "")
+ state.RenderTreeNode(
+ name=component,
+ props=dict(
+ type="submit",
+ value="yes",
+ name=key,
+ label=dedent(label),
+ help=help,
+ disabled=disabled,
+ className=className,
+ **props,
+ ),
+ ).mount()
+ return bool(state.session_state.pop(key, False))
+
+
+def anchor(
+ label: str,
+ href: str,
+ *,
+ type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary",
+ disabled: bool = False,
+ unsafe_allow_html: bool = False,
+ new_tab: bool = False,
+ **props,
+):
+ className = f"btn btn-theme btn-{type} " + props.pop("className", "")
+ style = props.pop("style", {})
+ if disabled:
+ style["pointerEvents"] = "none"
+ if new_tab:
+ props["target"] = "_blank"
+ with tag("a", href=href, className=className, style=style, **props):
+ markdown(dedent(label), unsafe_allow_html=unsafe_allow_html)
+
+
+form_submit_button = button
+
+
+def expander(label: str, *, expanded: bool = False, key: str = None, **props):
+ node = state.RenderTreeNode(
+ name="expander",
+ props=dict(
+ label=dedent(label),
+ open=expanded,
+ name=key or md5_values(label, expanded, props),
+ **props,
+ ),
+ )
+ node.mount()
+ return state.NestingCtx(node)
+
+
+def file_uploader(
+ label: str,
+ accept: list[str] = None,
+ accept_multiple_files=False,
+ key: str = None,
+ value: str | list[str] = None,
+ upload_key: str = None,
+ help: str = None,
+ *,
+ disabled: bool = False,
+ label_visibility: LabelVisibility = "visible",
+ upload_meta: dict = None,
+ optional: bool = False,
+) -> str | list[str] | None:
+ if label_visibility != "visible":
+ label = None
+ key = upload_key or key
+ if not key:
+ key = md5_values(
+ "file_uploader",
+ label,
+ accept,
+ accept_multiple_files,
+ help,
+ label_visibility,
+ )
+ if optional:
+ if not checkbox(
+ label, value=bool(state.session_state.get(key, value)), disabled=disabled
+ ):
+ state.session_state.pop(key, None)
+ return None
+ label = None
+ value = state.session_state.setdefault(key, value)
+ if not value:
+ if accept_multiple_files:
+ value = []
+ else:
+ value = None
+ state.session_state[key] = value
+ state.RenderTreeNode(
+ name="input",
+ props=dict(
+ type="file",
+ name=key,
+ label=dedent(label),
+ help=help,
+ disabled=disabled,
+ accept=accept,
+ multiple=accept_multiple_files,
+ defaultValue=value,
+ uploadMeta=upload_meta,
+ ),
+ ).mount()
+ return value
+
+
+def json(value: typing.Any, expanded: bool = False, depth: int = 1):
+ state.RenderTreeNode(
+ name="json",
+ props=dict(
+ value=value,
+ expanded=expanded,
+ defaultInspectDepth=3 if expanded else depth,
+ ),
+ ).mount()
+
+
+def data_table(file_url_or_cells: str | list):
+ if isinstance(file_url_or_cells, str):
+ file_url = file_url_or_cells
+ return _node("data-table", fileUrl=file_url)
+ else:
+ cells = file_url_or_cells
+ return _node("data-table-raw", cells=cells)
+
+
+def table(df: "pd.DataFrame"):
+ with tag("table", className="table table-striped table-sm"):
+ with tag("thead"):
+ with tag("tr"):
+ for col in df.columns:
+ with tag("th", scope="col"):
+ html(dedent(col))
+ with tag("tbody"):
+ for row in df.itertuples(index=False):
+ with tag("tr"):
+ for value in row:
+ with tag("td"):
+ html(dedent(str(value)))
+
+
+def raw_table(header: list[str], className: str = "", **props) -> state.NestingCtx:
+ className = "table " + className
+ with tag("table", className=className, **props):
+ if header:
+ with tag("thead"), tag("tr"):
+ for col in header:
+ with tag("th", scope="col"):
+ html(dedent(col))
+
+ return tag("tbody")
+
+
+def table_row(values: list[str], **props):
+ row = tag("tr", **props)
+ with row:
+ for v in values:
+ with tag("td"):
+ html(html_lib.escape(v))
+ return row
+
+
+def horizontal_radio(
+ label: str,
+ options: typing.Sequence[T],
+ format_func: typing.Callable[[T], typing.Any] = _default_format,
+ *,
+ key: str = None,
+ help: str = None,
+ value: T = None,
+ disabled: bool = False,
+ checked_by_default: bool = True,
+ label_visibility: LabelVisibility = "visible",
+ **button_props,
+) -> T | None:
+ if not options:
+ return None
+ options = list(options)
+ if not key:
+ key = md5_values("horizontal_radio", label, options, help, label_visibility)
+ value = state.session_state.setdefault(key, value)
+ if value not in options and checked_by_default:
+ value = state.session_state[key] = options[0]
+ if label_visibility != "visible":
+ label = None
+ markdown(label)
+ for option in options:
+ if button(
+ format_func(option),
+ key=f"tab-{key}-{option}",
+ type="primary",
+ className="replicate-nav " + ("active" if value == option else ""),
+ disabled=disabled,
+ **button_props,
+ ):
+ state.session_state[key] = value = option
+ state.experimental_rerun()
+ return value
+
+
+def radio(
+ label: str,
+ options: typing.Sequence[T],
+ format_func: typing.Callable[[T], typing.Any] = _default_format,
+ key: str = None,
+ value: T = None,
+ help: str = None,
+ *,
+ disabled: bool = False,
+ checked_by_default: bool = True,
+ label_visibility: LabelVisibility = "visible",
+) -> T | None:
+ if not options:
+ return None
+ options = list(options)
+ if not key:
+ key = md5_values("radio", label, options, help, label_visibility)
+ value = state.session_state.setdefault(key, value)
+ if value not in options and checked_by_default:
+ value = state.session_state[key] = options[0]
+ if label_visibility != "visible":
+ label = None
+ markdown(label)
+ for option in options:
+ state.RenderTreeNode(
+ name="input",
+ props=dict(
+ type="radio",
+ name=key,
+ label=dedent(str(format_func(option))),
+ value=option,
+ defaultChecked=bool(value == option),
+ help=help,
+ disabled=disabled,
+ ),
+ ).mount()
+ return value
+
+
+def text_input(
+ label: str,
+ value: str = "",
+ max_chars: str = None,
+ key: str = None,
+ help: str = None,
+ *,
+ placeholder: str = None,
+ disabled: bool = False,
+ label_visibility: LabelVisibility = "visible",
+ **props,
+) -> str:
+ value = _input_widget(
+ input_type="text",
+ label=label,
+ value=value,
+ key=key,
+ help=help,
+ disabled=disabled,
+ label_visibility=label_visibility,
+ maxLength=max_chars,
+ placeholder=placeholder,
+ **props,
+ )
+ return value or ""
+
+
+def date_input(
+ label: str,
+ value: str | None = None,
+ key: str = None,
+ help: str = None,
+ *,
+ disabled: bool = False,
+ label_visibility: LabelVisibility = "visible",
+ **props,
+) -> datetime | None:
+ value = _input_widget(
+ input_type="date",
+ label=label,
+ value=value,
+ key=key,
+ help=help,
+ disabled=disabled,
+ label_visibility=label_visibility,
+ style=dict(
+ border="1px solid hsl(0, 0%, 80%)",
+ padding="0.375rem 0.75rem",
+ borderRadius="0.25rem",
+ margin="0 0.5rem 0 0.5rem",
+ ),
+ **props,
+ )
+ try:
+ return datetime.strptime(value, "%Y-%m-%d") if value else None
+ except ValueError:
+ return None
+
+
+def password_input(
+ label: str,
+ value: str = "",
+ max_chars: str = None,
+ key: str = None,
+ help: str = None,
+ *,
+ placeholder: str = None,
+ disabled: bool = False,
+ label_visibility: LabelVisibility = "visible",
+ **props,
+) -> str:
+ value = _input_widget(
+ input_type="password",
+ label=label,
+ value=value,
+ key=key,
+ help=help,
+ disabled=disabled,
+ label_visibility=label_visibility,
+ maxLength=max_chars,
+ placeholder=placeholder,
+ **props,
+ )
+ return value or ""
+
+
+def slider(
+ label: str,
+ min_value: float = None,
+ max_value: float = None,
+ value: float = None,
+ step: float = None,
+ key: str = None,
+ help: str = None,
+ *,
+ disabled: bool = False,
+) -> float:
+ value = _input_widget(
+ input_type="range",
+ label=label,
+ value=value,
+ key=key,
+ help=help,
+ disabled=disabled,
+ min=min_value,
+ max=max_value,
+ step=_step_value(min_value, max_value, step),
+ )
+ return value or 0
+
+
+def number_input(
+ label: str,
+ min_value: float = None,
+ max_value: float = None,
+ value: float = None,
+ step: float = None,
+ key: str = None,
+ help: str = None,
+ *,
+ disabled: bool = False,
+) -> float:
+ value = _input_widget(
+ input_type="number",
+ inputMode="decimal",
+ label=label,
+ value=value,
+ key=key,
+ help=help,
+ disabled=disabled,
+ min=min_value,
+ max=max_value,
+ step=_step_value(min_value, max_value, step),
+ )
+ return value or 0
+
+
+def _step_value(
+ min_value: float | None, max_value: float | None, step: float | None
+) -> float:
+ if step:
+ return step
+ elif isinstance(min_value, float) or isinstance(max_value, float):
+ return 0.1
+ else:
+ return 1
+
+
+def checkbox(
+ label: str,
+ value: bool = False,
+ key: str = None,
+ help: str = None,
+ *,
+ disabled: bool = False,
+ label_visibility: LabelVisibility = "visible",
+ **props,
+) -> bool:
+ value = _input_widget(
+ input_type="checkbox",
+ label=label,
+ value=value,
+ key=key,
+ help=help,
+ disabled=disabled,
+ label_visibility=label_visibility,
+ default_value_attr="defaultChecked",
+ **props,
+ )
+ return bool(value)
+
+
+def _input_widget(
+ *,
+ input_type: str,
+ label: str,
+ value: typing.Any = None,
+ key: str = None,
+ help: str = None,
+ disabled: bool = False,
+ label_visibility: LabelVisibility = "visible",
+ default_value_attr: str = "defaultValue",
+ **kwargs,
+) -> typing.Any:
+ # if key:
+ # assert not value, "only one of value or key can be provided"
+ # else:
+ if not key:
+ key = md5_values("input", input_type, label, help, label_visibility)
+ value = state.session_state.setdefault(key, value)
+ if label_visibility != "visible":
+ label = None
+ state.RenderTreeNode(
+ name="input",
+ props={
+ "type": input_type,
+ "name": key,
+ "label": dedent(label),
+ default_value_attr: value,
+ "help": help,
+ "disabled": disabled,
+ **kwargs,
+ },
+ ).mount()
+ return value
+
+
+def breadcrumbs(divider: str = "/", **props) -> state.NestingCtx:
+ style = props.pop("style", {}) | {"--bs-breadcrumb-divider": f"'{divider}'"}
+ with tag("nav", style=style, **props):
+ return tag("ol", className="breadcrumb mb-0")
+
+
+def breadcrumb_item(inner_html: str, link_to: str | None = None, **props):
+ className = "breadcrumb-item " + props.pop("className", "")
+ with tag("li", className=className, **props):
+ if link_to:
+ with tag("a", href=link_to):
+ html(inner_html)
+ else:
+ html(inner_html)
+
+
+def plotly_chart(figure_or_data, **kwargs):
+ data = (
+ figure_or_data.to_plotly_json()
+ if hasattr(figure_or_data, "to_plotly_json")
+ else figure_or_data
+ )
+ state.RenderTreeNode(
+ name="plotly-chart",
+ props=dict(
+ chart=data,
+ args=kwargs,
+ ),
+ ).mount()
+
+
+def dedent(text: str | None) -> str | None:
+ if not text:
+ return text
+ return textwrap.dedent(text)
+
+
+def js(src: str, **kwargs):
+ state.RenderTreeNode(
+ name="script",
+ props=dict(
+ src=src,
+ args=kwargs,
+ ),
+ ).mount()
From 64fb8bda877980899ee114eed0939b3b76ea18bd Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 17 Jul 2024 18:45:30 +0530
Subject: [PATCH 044/110] Make modals rounded
---
gooey_ui/components/modal.py | 97 ++++++++++++++++++++++++++++++++++++
1 file changed, 97 insertions(+)
create mode 100644 gooey_ui/components/modal.py
diff --git a/gooey_ui/components/modal.py b/gooey_ui/components/modal.py
new file mode 100644
index 000000000..72e951fc8
--- /dev/null
+++ b/gooey_ui/components/modal.py
@@ -0,0 +1,97 @@
+from contextlib import contextmanager
+
+import gooey_ui as st
+from gooey_ui import experimental_rerun as rerun
+
+
+class Modal:
+ def __init__(self, title, key, padding=20, max_width=744):
+ """
+ :param title: title of the Modal shown in the h1
+ :param key: unique key identifying this modal instance
+ :param padding: padding of the content within the modal
+ :param max_width: maximum width this modal should use
+ """
+ self.title = title
+ self.padding = padding
+ self.max_width = str(max_width) + "px"
+ self.key = key
+
+ self._container = None
+
+ def is_open(self):
+ return st.session_state.get(f"{self.key}-opened", False)
+
+ def open(self):
+ st.session_state[f"{self.key}-opened"] = True
+ rerun()
+
+ def close(self, rerun_condition=True):
+ st.session_state[f"{self.key}-opened"] = False
+ if rerun_condition:
+ rerun()
+
+ def empty(self):
+ if self._container:
+ self._container.empty()
+
+ @contextmanager
+ def container(self, **props):
+ st.html(
+ f"""
+
+ """
+ )
+
+ with st.div(className="blur-background"):
+ with st.div(className="modal-parent"):
+ container_class = "modal-container " + props.pop("className", "")
+ self._container = st.div(className=container_class, **props)
+
+ with self._container:
+ with st.div(className="d-flex justify-content-between align-items-center"):
+ if self.title:
+ st.markdown(f"### {self.title}")
+ else:
+ st.div()
+
+ close_ = st.button(
+ "✖",
+ type="tertiary",
+ key=f"{self.key}-close",
+ style={"padding": "0.375rem 0.75rem"},
+ )
+ if close_:
+ self.close()
+ yield self._container
From 4c72b4690aef1784a683bbf5789d1da126523767 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Tue, 23 Jul 2024 17:08:16 +0530
Subject: [PATCH 045/110] Add invitation page
---
gooey_ui/components/__init__.py | 1009 -------------------------------
gooey_ui/components/modal.py | 97 ---
2 files changed, 1106 deletions(-)
delete mode 100644 gooey_ui/components/__init__.py
delete mode 100644 gooey_ui/components/modal.py
diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py
deleted file mode 100644
index 2c27edd1d..000000000
--- a/gooey_ui/components/__init__.py
+++ /dev/null
@@ -1,1009 +0,0 @@
-import base64
-import html as html_lib
-import math
-import textwrap
-import typing
-from datetime import datetime, timezone
-
-import numpy as np
-from furl import furl
-
-from daras_ai.image_input import resize_img_scale
-from gooey_ui import state
-from gooey_ui.pubsub import md5_values
-
-T = typing.TypeVar("T")
-LabelVisibility = typing.Literal["visible", "collapsed"]
-
-BLANK_OPTION = "———"
-
-
-def _default_format(value: typing.Any) -> str:
- if value is None:
- return BLANK_OPTION
- return str(value)
-
-
-def dummy(*args, **kwargs):
- return state.NestingCtx()
-
-
-spinner = dummy
-set_page_config = dummy
-form = dummy
-dataframe = dummy
-
-
-def countdown_timer(
- end_time: datetime,
- delay_text: str,
-) -> state.NestingCtx:
- return _node(
- "countdown-timer",
- endTime=end_time.astimezone(timezone.utc).isoformat(),
- delayText=delay_text,
- )
-
-
-def nav_tabs():
- return _node("nav-tabs")
-
-
-def nav_item(href: str, *, active: bool):
- return _node("nav-item", to=href, active="true" if active else None)
-
-
-def nav_tab_content():
- return _node("nav-tab-content")
-
-
-def div(**props) -> state.NestingCtx:
- return tag("div", **props)
-
-
-def link(*, to: str, **props) -> state.NestingCtx:
- return _node("Link", to=to, **props)
-
-
-def tag(tag_name: str, **props) -> state.NestingCtx:
- props["__reactjsxelement"] = tag_name
- return _node("tag", **props)
-
-
-def html(body: str, **props):
- props["className"] = props.get("className", "") + " gui-html-container"
- return _node("html", body=body, **props)
-
-
-def write(*objs: typing.Any, line_clamp: int = None, unsafe_allow_html=False, **props):
- for obj in objs:
- markdown(
- obj if isinstance(obj, str) else repr(obj),
- line_clamp=line_clamp,
- unsafe_allow_html=unsafe_allow_html,
- **props,
- )
-
-
-def center(direction="flex-column", className="") -> state.NestingCtx:
- return div(
- className=f"d-flex justify-content-center align-items-center text-center {direction} {className}"
- )
-
-
-def newline():
- html(" ")
-
-
-def markdown(
- body: str | None, *, line_clamp: int = None, unsafe_allow_html=False, **props
-):
- if body is None:
- return _node("markdown", body="", **props)
- if not unsafe_allow_html:
- body = html_lib.escape(body)
- props["className"] = (
- props.get("className", "") + " gui-html-container gui-md-container"
- )
- return _node("markdown", body=dedent(body).strip(), lineClamp=line_clamp, **props)
-
-
-def _node(name: str, **props):
- node = state.RenderTreeNode(name=name, props=props)
- node.mount()
- return state.NestingCtx(node)
-
-
-def text(body: str, **props):
- state.RenderTreeNode(
- name="pre",
- props=dict(body=dedent(body), **props),
- ).mount()
-
-
-def error(
- body: str,
- icon: str = "🔥",
- *,
- unsafe_allow_html=False,
- color="rgba(255, 108, 108, 0.2)",
- **props,
-):
- if not isinstance(body, str):
- body = repr(body)
- with div(
- style=dict(
- backgroundColor=color,
- padding="1rem",
- paddingBottom="0",
- marginBottom="0.5rem",
- borderRadius="0.25rem",
- display="flex",
- gap="0.5rem",
- )
- ):
- markdown(icon)
- with div():
- markdown(dedent(body), unsafe_allow_html=unsafe_allow_html, **props)
-
-
-def success(body: str, icon: str = "✅", *, unsafe_allow_html=False):
- if not isinstance(body, str):
- body = repr(body)
- with div(
- style=dict(
- backgroundColor="rgba(108, 255, 108, 0.2)",
- padding="1rem",
- paddingBottom="0",
- marginBottom="0.5rem",
- borderRadius="0.25rem",
- display="flex",
- gap="0.5rem",
- )
- ):
- markdown(icon)
- markdown(dedent(body), unsafe_allow_html=unsafe_allow_html)
-
-
-def caption(body: str, className: str = None, **props):
- className = className or "text-muted"
- markdown(body, className=className, **props)
-
-
-def tabs(labels: list[str]) -> list[state.NestingCtx]:
- parent = state.RenderTreeNode(
- name="tabs",
- children=[
- state.RenderTreeNode(
- name="tab",
- props=dict(label=dedent(label)),
- )
- for label in labels
- ],
- ).mount()
- return [state.NestingCtx(tab) for tab in parent.children]
-
-
-def controllable_tabs(
- labels: list[str], key: str
-) -> tuple[list[state.NestingCtx], int]:
- index = state.session_state.get(key, 0)
- for i, label in enumerate(labels):
- if button(
- label,
- key=f"tab-{i}",
- type="primary",
- className="replicate-nav",
- style={
- "background": "black" if i == index else "white",
- "color": "white" if i == index else "black",
- },
- ):
- state.session_state[key] = index = i
- state.experimental_rerun()
- ctxs = []
- for i, label in enumerate(labels):
- if i == index:
- ctxs += [div(className="tab-content")]
- else:
- ctxs += [div(className="tab-content", style={"display": "none"})]
- return ctxs, index
-
-
-def columns(
- spec,
- *,
- gap: str = None,
- responsive: bool = True,
- column_props: dict = {},
- **props,
-) -> tuple[state.NestingCtx, ...]:
- if isinstance(spec, int):
- spec = [1] * spec
- total_weight = sum(spec)
- props.setdefault("className", "row")
- with div(**props):
- return tuple(
- div(
- className=f"col-lg-{p} {'col-12' if responsive else f'col-{p}'}",
- **column_props,
- )
- for w in spec
- if (p := f"{round(w / total_weight * 12)}")
- )
-
-
-def image(
- src: str | np.ndarray,
- caption: str = None,
- alt: str = None,
- href: str = None,
- show_download_button: bool = False,
- **props,
-):
- if isinstance(src, np.ndarray):
- from daras_ai.image_input import cv2_img_to_bytes
-
- if not src.shape:
- return
- # ensure image is not too large
- data = resize_img_scale(cv2_img_to_bytes(src), (128, 128))
- # convert to base64
- b64 = base64.b64encode(data).decode("utf-8")
- src = "data:image/png;base64," + b64
- if not src:
- return
- state.RenderTreeNode(
- name="img",
- props=dict(
- src=src,
- caption=dedent(caption),
- alt=alt or caption,
- href=href,
- **props,
- ),
- ).mount()
- if show_download_button:
- download_button(
- label=' Download', url=src
- )
-
-
-def video(
- src: str,
- caption: str = None,
- autoplay: bool = False,
- show_download_button: bool = False,
-):
- autoplay_props = {}
- if autoplay:
- autoplay_props = {
- "preload": "auto",
- "controls": True,
- "autoPlay": True,
- "loop": True,
- "muted": True,
- "playsInline": True,
- }
-
- if not src:
- return
- if isinstance(src, str):
- # https://muffinman.io/blog/hack-for-ios-safari-to-display-html-video-thumbnail/
- f = furl(src)
- f.fragment.args["t"] = "0.001"
- src = f.url
- state.RenderTreeNode(
- name="video",
- props=dict(src=src, caption=dedent(caption), **autoplay_props),
- ).mount()
- if show_download_button:
- download_button(
- label=' Download', url=src
- )
-
-
-def audio(src: str, caption: str = None, show_download_button: bool = False):
- if not src:
- return
- state.RenderTreeNode(
- name="audio",
- props=dict(src=src, caption=dedent(caption)),
- ).mount()
- if show_download_button:
- download_button(
- label=' Download', url=src
- )
-
-
-def text_area(
- label: str,
- value: str = "",
- height: int = 500,
- key: str = None,
- help: str = None,
- placeholder: str = None,
- disabled: bool = False,
- label_visibility: LabelVisibility = "visible",
- **props,
-) -> str:
- style = props.setdefault("style", {})
- # if key:
- # assert not value, "only one of value or key can be provided"
- # else:
- if not key:
- key = md5_values(
- "textarea",
- label,
- height,
- help,
- placeholder,
- label_visibility,
- not disabled or value,
- )
- value = str(state.session_state.setdefault(key, value) or "")
- if label_visibility != "visible":
- label = None
- if disabled:
- max_height = f"{height}px"
- rows = nrows_for_text(value, height)
- else:
- max_height = "50vh"
- rows = nrows_for_text(value, height)
- style.setdefault("maxHeight", max_height)
- props.setdefault("rows", rows)
- state.RenderTreeNode(
- name="textarea",
- props=dict(
- name=key,
- label=dedent(label),
- defaultValue=value,
- help=help,
- placeholder=placeholder,
- disabled=disabled,
- **props,
- ),
- ).mount()
- return value or ""
-
-
-def nrows_for_text(
- text: str,
- max_height_px: int,
- min_rows: int = 1,
- row_height_px: int = 30,
- row_width_px: int = 70,
-) -> int:
- max_rows = max_height_px // row_height_px
- nrows = math.ceil(
- sum(
- math.ceil(len(line) / row_width_px)
- for line in (text or "").splitlines(keepends=True)
- )
- )
- nrows = min(max(nrows, min_rows), max_rows)
- return nrows
-
-
-def multiselect(
- label: str,
- options: typing.Sequence[T],
- format_func: typing.Callable[[T], typing.Any] = _default_format,
- key: str = None,
- help: str = None,
- allow_none: bool = False,
- *,
- disabled: bool = False,
-) -> list[T]:
- if not options:
- return []
- options = list(options)
- if not key:
- key = md5_values("multiselect", label, options, help)
- value = state.session_state.get(key) or []
- if not isinstance(value, list):
- value = [value]
- value = [o for o in value if o in options]
- if not allow_none and not value:
- value = [options[0]]
- state.session_state[key] = value
- state.RenderTreeNode(
- name="select",
- props=dict(
- name=key,
- label=dedent(label),
- help=help,
- isDisabled=disabled,
- isMulti=True,
- defaultValue=value,
- allow_none=allow_none,
- options=[
- {"value": option, "label": str(format_func(option))}
- for option in options
- ],
- ),
- ).mount()
- return value
-
-
-def selectbox(
- label: str,
- options: typing.Iterable[T],
- format_func: typing.Callable[[T], typing.Any] = _default_format,
- key: str = None,
- help: str = None,
- *,
- disabled: bool = False,
- label_visibility: LabelVisibility = "visible",
- value: T = None,
- allow_none: bool = False,
- **props,
-) -> T | None:
- if not options:
- return None
- if label_visibility != "visible":
- label = None
- options = list(options)
- if allow_none:
- options.insert(0, None)
- if not key:
- key = md5_values("select", label, options, help, label_visibility)
- value = state.session_state.setdefault(key, value)
- if value not in options:
- value = state.session_state[key] = options[0]
- state.RenderTreeNode(
- name="select",
- props=dict(
- name=key,
- label=dedent(label),
- help=help,
- isDisabled=disabled,
- defaultValue=value,
- options=[
- {"value": option, "label": str(format_func(option))}
- for option in options
- ],
- **props,
- ),
- ).mount()
- return value
-
-
-def download_button(
- label: str,
- url: str,
- key: str = None,
- help: str = None,
- *,
- type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary",
- disabled: bool = False,
- **props,
-) -> bool:
- url = furl(url).remove(fragment=True).url
- return button(
- component="download-button",
- url=url,
- label=label,
- key=key,
- help=help,
- type=type,
- disabled=disabled,
- **props,
- )
-
-
-def button(
- label: str,
- key: str = None,
- help: str = None,
- *,
- type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary",
- disabled: bool = False,
- component: typing.Literal["download-button", "gui-button"] = "gui-button",
- **props,
-) -> bool:
- """
- Example:
- st.button("Primary", key="test0", type="primary")
- st.button("Secondary", key="test1")
- st.button("Tertiary", key="test3", type="tertiary")
- st.button("Link Button", key="test3", type="link")
- """
- if not key:
- key = md5_values("button", label, help, type, props)
- className = f"btn-{type} " + props.pop("className", "")
- state.RenderTreeNode(
- name=component,
- props=dict(
- type="submit",
- value="yes",
- name=key,
- label=dedent(label),
- help=help,
- disabled=disabled,
- className=className,
- **props,
- ),
- ).mount()
- return bool(state.session_state.pop(key, False))
-
-
-def anchor(
- label: str,
- href: str,
- *,
- type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary",
- disabled: bool = False,
- unsafe_allow_html: bool = False,
- new_tab: bool = False,
- **props,
-):
- className = f"btn btn-theme btn-{type} " + props.pop("className", "")
- style = props.pop("style", {})
- if disabled:
- style["pointerEvents"] = "none"
- if new_tab:
- props["target"] = "_blank"
- with tag("a", href=href, className=className, style=style, **props):
- markdown(dedent(label), unsafe_allow_html=unsafe_allow_html)
-
-
-form_submit_button = button
-
-
-def expander(label: str, *, expanded: bool = False, key: str = None, **props):
- node = state.RenderTreeNode(
- name="expander",
- props=dict(
- label=dedent(label),
- open=expanded,
- name=key or md5_values(label, expanded, props),
- **props,
- ),
- )
- node.mount()
- return state.NestingCtx(node)
-
-
-def file_uploader(
- label: str,
- accept: list[str] = None,
- accept_multiple_files=False,
- key: str = None,
- value: str | list[str] = None,
- upload_key: str = None,
- help: str = None,
- *,
- disabled: bool = False,
- label_visibility: LabelVisibility = "visible",
- upload_meta: dict = None,
- optional: bool = False,
-) -> str | list[str] | None:
- if label_visibility != "visible":
- label = None
- key = upload_key or key
- if not key:
- key = md5_values(
- "file_uploader",
- label,
- accept,
- accept_multiple_files,
- help,
- label_visibility,
- )
- if optional:
- if not checkbox(
- label, value=bool(state.session_state.get(key, value)), disabled=disabled
- ):
- state.session_state.pop(key, None)
- return None
- label = None
- value = state.session_state.setdefault(key, value)
- if not value:
- if accept_multiple_files:
- value = []
- else:
- value = None
- state.session_state[key] = value
- state.RenderTreeNode(
- name="input",
- props=dict(
- type="file",
- name=key,
- label=dedent(label),
- help=help,
- disabled=disabled,
- accept=accept,
- multiple=accept_multiple_files,
- defaultValue=value,
- uploadMeta=upload_meta,
- ),
- ).mount()
- return value
-
-
-def json(value: typing.Any, expanded: bool = False, depth: int = 1):
- state.RenderTreeNode(
- name="json",
- props=dict(
- value=value,
- expanded=expanded,
- defaultInspectDepth=3 if expanded else depth,
- ),
- ).mount()
-
-
-def data_table(file_url_or_cells: str | list):
- if isinstance(file_url_or_cells, str):
- file_url = file_url_or_cells
- return _node("data-table", fileUrl=file_url)
- else:
- cells = file_url_or_cells
- return _node("data-table-raw", cells=cells)
-
-
-def table(df: "pd.DataFrame"):
- with tag("table", className="table table-striped table-sm"):
- with tag("thead"):
- with tag("tr"):
- for col in df.columns:
- with tag("th", scope="col"):
- html(dedent(col))
- with tag("tbody"):
- for row in df.itertuples(index=False):
- with tag("tr"):
- for value in row:
- with tag("td"):
- html(dedent(str(value)))
-
-
-def raw_table(header: list[str], className: str = "", **props) -> state.NestingCtx:
- className = "table " + className
- with tag("table", className=className, **props):
- if header:
- with tag("thead"), tag("tr"):
- for col in header:
- with tag("th", scope="col"):
- html(dedent(col))
-
- return tag("tbody")
-
-
-def table_row(values: list[str], **props):
- row = tag("tr", **props)
- with row:
- for v in values:
- with tag("td"):
- html(html_lib.escape(v))
- return row
-
-
-def horizontal_radio(
- label: str,
- options: typing.Sequence[T],
- format_func: typing.Callable[[T], typing.Any] = _default_format,
- *,
- key: str = None,
- help: str = None,
- value: T = None,
- disabled: bool = False,
- checked_by_default: bool = True,
- label_visibility: LabelVisibility = "visible",
- **button_props,
-) -> T | None:
- if not options:
- return None
- options = list(options)
- if not key:
- key = md5_values("horizontal_radio", label, options, help, label_visibility)
- value = state.session_state.setdefault(key, value)
- if value not in options and checked_by_default:
- value = state.session_state[key] = options[0]
- if label_visibility != "visible":
- label = None
- markdown(label)
- for option in options:
- if button(
- format_func(option),
- key=f"tab-{key}-{option}",
- type="primary",
- className="replicate-nav " + ("active" if value == option else ""),
- disabled=disabled,
- **button_props,
- ):
- state.session_state[key] = value = option
- state.experimental_rerun()
- return value
-
-
-def radio(
- label: str,
- options: typing.Sequence[T],
- format_func: typing.Callable[[T], typing.Any] = _default_format,
- key: str = None,
- value: T = None,
- help: str = None,
- *,
- disabled: bool = False,
- checked_by_default: bool = True,
- label_visibility: LabelVisibility = "visible",
-) -> T | None:
- if not options:
- return None
- options = list(options)
- if not key:
- key = md5_values("radio", label, options, help, label_visibility)
- value = state.session_state.setdefault(key, value)
- if value not in options and checked_by_default:
- value = state.session_state[key] = options[0]
- if label_visibility != "visible":
- label = None
- markdown(label)
- for option in options:
- state.RenderTreeNode(
- name="input",
- props=dict(
- type="radio",
- name=key,
- label=dedent(str(format_func(option))),
- value=option,
- defaultChecked=bool(value == option),
- help=help,
- disabled=disabled,
- ),
- ).mount()
- return value
-
-
-def text_input(
- label: str,
- value: str = "",
- max_chars: str = None,
- key: str = None,
- help: str = None,
- *,
- placeholder: str = None,
- disabled: bool = False,
- label_visibility: LabelVisibility = "visible",
- **props,
-) -> str:
- value = _input_widget(
- input_type="text",
- label=label,
- value=value,
- key=key,
- help=help,
- disabled=disabled,
- label_visibility=label_visibility,
- maxLength=max_chars,
- placeholder=placeholder,
- **props,
- )
- return value or ""
-
-
-def date_input(
- label: str,
- value: str | None = None,
- key: str = None,
- help: str = None,
- *,
- disabled: bool = False,
- label_visibility: LabelVisibility = "visible",
- **props,
-) -> datetime | None:
- value = _input_widget(
- input_type="date",
- label=label,
- value=value,
- key=key,
- help=help,
- disabled=disabled,
- label_visibility=label_visibility,
- style=dict(
- border="1px solid hsl(0, 0%, 80%)",
- padding="0.375rem 0.75rem",
- borderRadius="0.25rem",
- margin="0 0.5rem 0 0.5rem",
- ),
- **props,
- )
- try:
- return datetime.strptime(value, "%Y-%m-%d") if value else None
- except ValueError:
- return None
-
-
-def password_input(
- label: str,
- value: str = "",
- max_chars: str = None,
- key: str = None,
- help: str = None,
- *,
- placeholder: str = None,
- disabled: bool = False,
- label_visibility: LabelVisibility = "visible",
- **props,
-) -> str:
- value = _input_widget(
- input_type="password",
- label=label,
- value=value,
- key=key,
- help=help,
- disabled=disabled,
- label_visibility=label_visibility,
- maxLength=max_chars,
- placeholder=placeholder,
- **props,
- )
- return value or ""
-
-
-def slider(
- label: str,
- min_value: float = None,
- max_value: float = None,
- value: float = None,
- step: float = None,
- key: str = None,
- help: str = None,
- *,
- disabled: bool = False,
-) -> float:
- value = _input_widget(
- input_type="range",
- label=label,
- value=value,
- key=key,
- help=help,
- disabled=disabled,
- min=min_value,
- max=max_value,
- step=_step_value(min_value, max_value, step),
- )
- return value or 0
-
-
-def number_input(
- label: str,
- min_value: float = None,
- max_value: float = None,
- value: float = None,
- step: float = None,
- key: str = None,
- help: str = None,
- *,
- disabled: bool = False,
-) -> float:
- value = _input_widget(
- input_type="number",
- inputMode="decimal",
- label=label,
- value=value,
- key=key,
- help=help,
- disabled=disabled,
- min=min_value,
- max=max_value,
- step=_step_value(min_value, max_value, step),
- )
- return value or 0
-
-
-def _step_value(
- min_value: float | None, max_value: float | None, step: float | None
-) -> float:
- if step:
- return step
- elif isinstance(min_value, float) or isinstance(max_value, float):
- return 0.1
- else:
- return 1
-
-
-def checkbox(
- label: str,
- value: bool = False,
- key: str = None,
- help: str = None,
- *,
- disabled: bool = False,
- label_visibility: LabelVisibility = "visible",
- **props,
-) -> bool:
- value = _input_widget(
- input_type="checkbox",
- label=label,
- value=value,
- key=key,
- help=help,
- disabled=disabled,
- label_visibility=label_visibility,
- default_value_attr="defaultChecked",
- **props,
- )
- return bool(value)
-
-
-def _input_widget(
- *,
- input_type: str,
- label: str,
- value: typing.Any = None,
- key: str = None,
- help: str = None,
- disabled: bool = False,
- label_visibility: LabelVisibility = "visible",
- default_value_attr: str = "defaultValue",
- **kwargs,
-) -> typing.Any:
- # if key:
- # assert not value, "only one of value or key can be provided"
- # else:
- if not key:
- key = md5_values("input", input_type, label, help, label_visibility)
- value = state.session_state.setdefault(key, value)
- if label_visibility != "visible":
- label = None
- state.RenderTreeNode(
- name="input",
- props={
- "type": input_type,
- "name": key,
- "label": dedent(label),
- default_value_attr: value,
- "help": help,
- "disabled": disabled,
- **kwargs,
- },
- ).mount()
- return value
-
-
-def breadcrumbs(divider: str = "/", **props) -> state.NestingCtx:
- style = props.pop("style", {}) | {"--bs-breadcrumb-divider": f"'{divider}'"}
- with tag("nav", style=style, **props):
- return tag("ol", className="breadcrumb mb-0")
-
-
-def breadcrumb_item(inner_html: str, link_to: str | None = None, **props):
- className = "breadcrumb-item " + props.pop("className", "")
- with tag("li", className=className, **props):
- if link_to:
- with tag("a", href=link_to):
- html(inner_html)
- else:
- html(inner_html)
-
-
-def plotly_chart(figure_or_data, **kwargs):
- data = (
- figure_or_data.to_plotly_json()
- if hasattr(figure_or_data, "to_plotly_json")
- else figure_or_data
- )
- state.RenderTreeNode(
- name="plotly-chart",
- props=dict(
- chart=data,
- args=kwargs,
- ),
- ).mount()
-
-
-def dedent(text: str | None) -> str | None:
- if not text:
- return text
- return textwrap.dedent(text)
-
-
-def js(src: str, **kwargs):
- state.RenderTreeNode(
- name="script",
- props=dict(
- src=src,
- args=kwargs,
- ),
- ).mount()
diff --git a/gooey_ui/components/modal.py b/gooey_ui/components/modal.py
deleted file mode 100644
index 72e951fc8..000000000
--- a/gooey_ui/components/modal.py
+++ /dev/null
@@ -1,97 +0,0 @@
-from contextlib import contextmanager
-
-import gooey_ui as st
-from gooey_ui import experimental_rerun as rerun
-
-
-class Modal:
- def __init__(self, title, key, padding=20, max_width=744):
- """
- :param title: title of the Modal shown in the h1
- :param key: unique key identifying this modal instance
- :param padding: padding of the content within the modal
- :param max_width: maximum width this modal should use
- """
- self.title = title
- self.padding = padding
- self.max_width = str(max_width) + "px"
- self.key = key
-
- self._container = None
-
- def is_open(self):
- return st.session_state.get(f"{self.key}-opened", False)
-
- def open(self):
- st.session_state[f"{self.key}-opened"] = True
- rerun()
-
- def close(self, rerun_condition=True):
- st.session_state[f"{self.key}-opened"] = False
- if rerun_condition:
- rerun()
-
- def empty(self):
- if self._container:
- self._container.empty()
-
- @contextmanager
- def container(self, **props):
- st.html(
- f"""
-
- """
- )
-
- with st.div(className="blur-background"):
- with st.div(className="modal-parent"):
- container_class = "modal-container " + props.pop("className", "")
- self._container = st.div(className=container_class, **props)
-
- with self._container:
- with st.div(className="d-flex justify-content-between align-items-center"):
- if self.title:
- st.markdown(f"### {self.title}")
- else:
- st.div()
-
- close_ = st.button(
- "✖",
- type="tertiary",
- key=f"{self.key}-close",
- style={"padding": "0.375rem 0.75rem"},
- )
- if close_:
- self.close()
- yield self._container
From c748d70a160f6dc1e631c33fb1be4bfb868cd008 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 28 Aug 2024 14:56:50 +0530
Subject: [PATCH 046/110] feat: add billing support for orgs (db + ux)
---
Procfile | 2 +-
...ction_org_alter_appusertransaction_user.py | 25 ++
app_users/models.py | 50 ++-
app_users/tasks.py | 10 +-
bots/models.py | 7 +
daras_ai_v2/base.py | 3 +-
daras_ai_v2/billing.py | 7 +-
daras_ai_v2/send_email.py | 15 +-
orgs/admin.py | 11 +-
..._org_is_paying_org_is_personal_and_more.py | 45 +++
.../0005_org_unique_personal_org_per_user.py | 17 +
orgs/models.py | 146 ++++++-
orgs/views.py | 382 +++++++++++++++++-
payments/models.py | 19 +-
payments/tasks.py | 51 +--
payments/webhooks.py | 68 ++--
scripts/migrate_orgs_from_appusers.py | 26 ++
17 files changed, 780 insertions(+), 104 deletions(-)
create mode 100644 app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py
create mode 100644 orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py
create mode 100644 orgs/migrations/0005_org_unique_personal_org_per_user.py
create mode 100644 scripts/migrate_orgs_from_appusers.py
diff --git a/Procfile b/Procfile
index 984315504..1766991c6 100644
--- a/Procfile
+++ b/Procfile
@@ -19,4 +19,4 @@ dashboard: poetry run streamlit run Home.py --server.port 8501 --server.headless
celery: poetry run celery -A celeryapp worker -P threads -c 16 -l DEBUG
-ui: cd ../gooey-gui/ && env PORT=3000 npm run dev
+ui: cd ../gooey-gui/ && env PORT=3000 REDIS_URL=redis://localhost:6379 pnpm run dev
diff --git a/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py b/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py
new file mode 100644
index 000000000..b3e80c708
--- /dev/null
+++ b/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py
@@ -0,0 +1,25 @@
+# Generated by Django 4.2.7 on 2024-08-13 14:34
+
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('orgs', '0005_org_unique_personal_org_per_user'),
+ ('app_users', '0019_alter_appusertransaction_reason'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='appusertransaction',
+ name='org',
+ field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='transactions', to='orgs.org'),
+ ),
+ migrations.AlterField(
+ model_name='appusertransaction',
+ name='user',
+ field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='transactions', to='app_users.appuser'),
+ ),
+ ]
diff --git a/app_users/models.py b/app_users/models.py
index 09832cebc..739ab3bd3 100644
--- a/app_users/models.py
+++ b/app_users/models.py
@@ -90,23 +90,10 @@ class AppUser(models.Model):
display_name = models.TextField("name", blank=True)
email = models.EmailField(null=True, blank=True)
phone_number = PhoneNumberField(null=True, blank=True)
- balance = models.IntegerField("bal")
is_anonymous = models.BooleanField()
is_disabled = models.BooleanField(default=False)
photo_url = CustomURLField(default="", blank=True)
- stripe_customer_id = models.CharField(max_length=255, default="", blank=True)
- is_paying = models.BooleanField("paid", default=False)
-
- low_balance_email_sent_at = models.DateTimeField(null=True, blank=True)
- subscription = models.OneToOneField(
- "payments.Subscription",
- on_delete=models.SET_NULL,
- related_name="user",
- null=True,
- blank=True,
- )
-
created_at = models.DateTimeField(
"created", editable=False, blank=True, default=timezone.now
)
@@ -129,6 +116,18 @@ class AppUser(models.Model):
github_username = models.CharField(max_length=255, blank=True, default="")
website_url = CustomURLField(blank=True, default="")
+ balance = models.IntegerField("bal")
+ is_paying = models.BooleanField("paid", default=False)
+ stripe_customer_id = models.CharField(max_length=255, default="", blank=True)
+ subscription = models.OneToOneField(
+ "payments.Subscription",
+ on_delete=models.SET_NULL,
+ related_name="user",
+ null=True,
+ blank=True,
+ )
+ low_balance_email_sent_at = models.DateTimeField(null=True, blank=True)
+
disable_rate_limits = models.BooleanField(default=False)
objects = AppUserQuerySet.as_manager()
@@ -159,6 +158,9 @@ def first_name_possesive(self) -> str:
else:
return name + "'s"
+ def get_personal_org(self) -> "Org | None":
+ return self.orgs.filter(is_personal=True).first()
+
@db_middleware
@transaction.atomic
def add_balance(
@@ -246,6 +248,17 @@ def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser":
return self
+ def get_or_create_personal_org(self) -> tuple["Org", bool]:
+ from orgs.models import Org
+
+ org_membership = self.org_memberships.filter(
+ org__is_personal=True, org__created_by=self
+ ).first()
+ if org_membership:
+ return org_membership, False
+ else:
+ return Org.objects.migrate_from_appuser(self), True
+
def get_or_create_stripe_customer(self) -> stripe.Customer:
customer = self.search_stripe_customer()
if not customer:
@@ -303,7 +316,16 @@ class TransactionReason(models.IntegerChoices):
class AppUserTransaction(models.Model):
user = models.ForeignKey(
- "AppUser", on_delete=models.CASCADE, related_name="transactions"
+ "AppUser",
+ on_delete=models.SET_NULL,
+ related_name="transactions",
+ null=True,
+ )
+ org = models.ForeignKey(
+ "orgs.Org",
+ on_delete=models.SET_NULL,
+ related_name="transactions",
+ null=True,
)
invoice_id = models.CharField(
max_length=255,
diff --git a/app_users/tasks.py b/app_users/tasks.py
index 0327ac423..b1d893196 100644
--- a/app_users/tasks.py
+++ b/app_users/tasks.py
@@ -5,14 +5,14 @@
from celeryapp.celeryconfig import app
from payments.models import Subscription
from payments.plans import PricingPlan
-from payments.webhooks import set_user_subscription
+from payments.webhooks import set_org_subscription
@app.task
def save_stripe_default_payment_method(
*,
payment_intent_id: str,
- uid: str,
+ org_id: str,
amount: int,
charged_amount: int,
reason: TransactionReason,
@@ -41,11 +41,11 @@ def save_stripe_default_payment_method(
if (
reason == TransactionReason.ADDON
and not Subscription.objects.filter(
- user__uid=uid, payment_provider__isnull=False
+ org__org_id=org_id, payment_provider__isnull=False
).exists()
):
- set_user_subscription(
- uid=uid,
+ set_org_subscription(
+ org_id=org_id,
plan=PricingPlan.STARTER,
provider=PaymentProvider.STRIPE,
external_id=None,
diff --git a/bots/models.py b/bots/models.py
index e997e8f8a..a6163ee1c 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -212,6 +212,13 @@ class SavedRun(models.Model):
)
run_id = models.CharField(max_length=128, default=None, null=True, blank=True)
uid = models.CharField(max_length=128, default=None, null=True, blank=True)
+ billed_org = models.ForeignKey(
+ "orgs.Org",
+ on_delete=models.SET_NULL,
+ null=True,
+ blank=True,
+ related_name="billed_runs",
+ )
state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder)
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index 2233a0803..f37a284bb 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -2106,7 +2106,8 @@ def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]:
), "request.user must be set to deduct credits"
amount = self.get_price_roundoff(state)
- txn = self.request.user.add_balance(-amount, f"gooey_in_{uuid.uuid1()}")
+ org, _ = self.request.user.get_or_create_personal_org()
+ txn = org.add_balance(-amount, f"gooey_in_{uuid.uuid1()}")
return txn, amount
def get_price_roundoff(self, state: dict) -> int:
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index 639412464..adc500015 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -9,9 +9,10 @@
from daras_ai_v2.gui_confirm import confirm_modal
from daras_ai_v2.settings import templates
from daras_ai_v2.user_date_widgets import render_local_date_attrs
+from orgs.models import Org
from payments.models import PaymentMethodSummary
from payments.plans import PricingPlan
-from payments.webhooks import StripeWebhookHandler, set_user_subscription
+from payments.webhooks import StripeWebhookHandler, set_org_subscription
from scripts.migrate_existing_subscriptions import available_subscriptions
rounded_border = "w-100 border shadow-sm rounded py-4 px-3"
@@ -635,8 +636,8 @@ def render_payment_information(user: AppUser):
):
modal.open()
if confirmed:
- set_user_subscription(
- uid=user.uid,
+ set_org_subscription(
+ org_id=user.get_personal_org().org_id,
plan=PricingPlan.STARTER,
provider=None,
external_id=None,
diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py
index a9ff1934d..3c679c6fb 100644
--- a/daras_ai_v2/send_email.py
+++ b/daras_ai_v2/send_email.py
@@ -3,16 +3,19 @@
import requests
-from app_users.models import AppUser
from daras_ai_v2 import settings
from daras_ai_v2.exceptions import raise_for_status
from daras_ai_v2.fastapi_tricks import get_app_route_url
from daras_ai_v2.settings import templates
+if typing.TYPE_CHECKING:
+ from app_users.models import AppUser
+
+
def send_reported_run_email(
*,
- user: AppUser,
+ user: "AppUser",
run_uid: str,
url: str,
recipe_name: str,
@@ -41,7 +44,7 @@ def send_reported_run_email(
def send_low_balance_email(
*,
- user: AppUser,
+ user: "AppUser",
total_credits_consumed: int,
):
from routers.account import account_route
@@ -70,8 +73,8 @@ def send_email_via_postmark(
*,
from_address: str,
to_address: str,
- cc: str = None,
- bcc: str = None,
+ cc: str | None = None,
+ bcc: str | None = None,
subject: str = "",
html_body: str = "",
text_body: str = "",
@@ -79,7 +82,7 @@ def send_email_via_postmark(
"outbound", "gooey-ai-workflows", "announcements"
] = "outbound",
):
- if is_running_pytest:
+ if is_running_pytest or not settings.POSTMARK_API_TOKEN:
pytest_outbox.append(
dict(
from_address=from_address,
diff --git a/orgs/admin.py b/orgs/admin.py
index 969866f41..370ca4c4e 100644
--- a/orgs/admin.py
+++ b/orgs/admin.py
@@ -43,9 +43,16 @@ class OrgAdmin(SafeDeleteAdmin):
"updated_at",
] + list(SafeDeleteAdmin.list_display)
list_filter = [SafeDeleteAdminFilter] + list(SafeDeleteAdmin.list_filter)
- fields = ["name", "domain_name", "created_by", "created_at", "updated_at"]
+ fields = [
+ "name",
+ "domain_name",
+ "created_by",
+ "is_personal",
+ "created_at",
+ "updated_at",
+ ]
search_fields = ["name", "domain_name"]
- readonly_fields = ["created_at", "updated_at"]
+ readonly_fields = ["is_personal", "created_at", "updated_at"]
inlines = [OrgMembershipInline, OrgInvitationInline]
ordering = ["-created_at"]
diff --git a/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py b/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py
new file mode 100644
index 000000000..9d9fdfc5d
--- /dev/null
+++ b/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py
@@ -0,0 +1,45 @@
+# Generated by Django 4.2.7 on 2024-08-12 14:23
+
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('payments', '0005_alter_subscription_plan'),
+ ('orgs', '0003_remove_org_unique_domain_name_when_not_deleted_and_more'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='org',
+ name='balance',
+ field=models.IntegerField(default=0, verbose_name='bal'),
+ ),
+ migrations.AddField(
+ model_name='org',
+ name='is_paying',
+ field=models.BooleanField(default=False, verbose_name='paid'),
+ ),
+ migrations.AddField(
+ model_name='org',
+ name='is_personal',
+ field=models.BooleanField(default=False),
+ ),
+ migrations.AddField(
+ model_name='org',
+ name='low_balance_email_sent_at',
+ field=models.DateTimeField(blank=True, null=True),
+ ),
+ migrations.AddField(
+ model_name='org',
+ name='stripe_customer_id',
+ field=models.CharField(blank=True, default='', max_length=255),
+ ),
+ migrations.AddField(
+ model_name='org',
+ name='subscription',
+ field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='org', to='payments.subscription'),
+ ),
+ ]
diff --git a/orgs/migrations/0005_org_unique_personal_org_per_user.py b/orgs/migrations/0005_org_unique_personal_org_per_user.py
new file mode 100644
index 000000000..aaaa1cc4d
--- /dev/null
+++ b/orgs/migrations/0005_org_unique_personal_org_per_user.py
@@ -0,0 +1,17 @@
+# Generated by Django 4.2.7 on 2024-08-13 14:34
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('orgs', '0004_org_balance_org_is_paying_org_is_personal_and_more'),
+ ]
+
+ operations = [
+ migrations.AddConstraint(
+ model_name='org',
+ constraint=models.UniqueConstraint(models.F('created_by'), condition=models.Q(('deleted__isnull', True), ('is_personal', True)), name='unique_personal_org_per_user'),
+ ),
+ ]
diff --git a/orgs/models.py b/orgs/models.py
index 5a19dad78..0c39312c0 100644
--- a/orgs/models.py
+++ b/orgs/models.py
@@ -1,6 +1,10 @@
+from __future__ import annotations
+
import re
from datetime import timedelta
+from django.db.models.aggregates import Sum
+import stripe
from django.db import models, transaction
from django.core.exceptions import ValidationError
from django.db.backends.base.schema import logger
@@ -10,10 +14,10 @@
from safedelete.managers import SafeDeleteManager
from safedelete.models import SafeDeleteModel, SOFT_DELETE_CASCADE
-from app_users.models import AppUser
from daras_ai_v2 import settings
from daras_ai_v2.fastapi_tricks import get_app_route_url
from daras_ai_v2.crypto import get_random_doc_id
+from gooeysite.bg_db_conn import db_middleware
from orgs.tasks import send_auto_accepted_email, send_invitation_email
@@ -37,7 +41,9 @@ class OrgRole(models.IntegerChoices):
class OrgManager(SafeDeleteManager):
- def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwargs):
+ def create_org(
+ self, *, created_by: "AppUser", org_id: str | None = None, **kwargs
+ ) -> Org:
org = self.model(
org_id=org_id or get_random_doc_id(), created_by=created_by, **kwargs
)
@@ -49,6 +55,28 @@ def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwar
)
return org
+ def get_or_create_from_org_id(self, org_id: str) -> tuple[Org, bool]:
+ from app_users.models import AppUser
+
+ try:
+ return self.get(org_id=org_id), False
+ except self.model.DoesNotExist:
+ user = AppUser.objects.get_or_create_from_uid(org_id)[0]
+ return self.migrate_from_appuser(user), True
+
+ def migrate_from_appuser(self, user: "AppUser") -> Org:
+ return self.create_org(
+ name=f"{user.first_name()}'s Personal Workspace",
+ org_id=user.uid or get_random_doc_id(),
+ created_by=user,
+ is_personal=True,
+ balance=user.balance,
+ stripe_customer_id=user.stripe_customer_id,
+ subscription=user.subscription,
+ low_balance_email_sent_at=user.low_balance_email_sent_at,
+ is_paying=user.is_paying,
+ )
+
class Org(SafeDeleteModel):
_safedelete_policy = SOFT_DELETE_CASCADE
@@ -71,6 +99,21 @@ class Org(SafeDeleteModel):
],
)
+ # billing
+ balance = models.IntegerField("bal", default=0)
+ is_paying = models.BooleanField("paid", default=False)
+ stripe_customer_id = models.CharField(max_length=255, default="", blank=True)
+ subscription = models.OneToOneField(
+ "payments.Subscription",
+ on_delete=models.SET_NULL,
+ related_name="org",
+ null=True,
+ blank=True,
+ )
+ low_balance_email_sent_at = models.DateTimeField(null=True, blank=True)
+
+ is_personal = models.BooleanField(default=False)
+
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
@@ -83,7 +126,12 @@ class Meta:
condition=Q(deleted__isnull=True),
name="unique_domain_name_when_not_deleted",
violation_error_message=f"This domain name is already in use by another team. Contact {settings.SUPPORT_EMAIL} if you think this is a mistake.",
- )
+ ),
+ models.UniqueConstraint(
+ "created_by",
+ condition=Q(deleted__isnull=True, is_personal=True),
+ name="unique_personal_org_per_user",
+ ),
]
def __str__(self):
@@ -147,6 +195,90 @@ def invite_user(
return invitation
+ def get_owners(self) -> list[OrgMembership]:
+ return self.memberships.filter(role=OrgRole.OWNER)
+
+ @db_middleware
+ @transaction.atomic
+ def add_balance(
+ self, amount: int, invoice_id: str, **kwargs
+ ) -> "AppUserTransaction":
+ """
+ Used to add/deduct credits when they are bought or consumed.
+
+ When credits are bought with stripe -- invoice_id is the stripe
+ invoice ID.
+ When credits are deducted due to a run -- invoice_id is of the
+ form "gooey_in_{uuid}"
+ """
+ from app_users.models import AppUserTransaction
+
+ # if an invoice entry exists
+ try:
+ # avoid updating twice for same invoice
+ return AppUserTransaction.objects.get(invoice_id=invoice_id)
+ except AppUserTransaction.DoesNotExist:
+ pass
+
+ # select_for_update() is very important here
+ # transaction.atomic alone is not enough!
+ # It won't lock this row for reads, and multiple threads can update the same row leading incorrect balance
+ #
+ # Also we're not using .update() here because it won't give back the updated end balance
+ org: Org = Org.objects.select_for_update().get(pk=self.pk)
+ org.balance += amount
+ org.save(update_fields=["balance"])
+ kwargs.setdefault("plan", org.subscription and org.subscription.plan)
+ return AppUserTransaction.objects.create(
+ org=org,
+ invoice_id=invoice_id,
+ amount=amount,
+ end_balance=org.balance,
+ **kwargs,
+ )
+
+ def get_or_create_stripe_customer(self) -> stripe.Customer:
+ customer = self.search_stripe_customer()
+ if not customer:
+ customer = stripe.Customer.create(
+ name=self.created_by.display_name,
+ email=self.created_by.email,
+ phone=self.created_by.phone,
+ metadata={"uid": self.org_id, "org_id": self.org_id, "id": self.pk},
+ )
+ self.stripe_customer_id = customer.id
+ self.save()
+ return customer
+
+ def search_stripe_customer(self) -> stripe.Customer | None:
+ if not self.org_id:
+ return None
+ if self.stripe_customer_id:
+ try:
+ return stripe.Customer.retrieve(self.stripe_customer_id)
+ except stripe.error.InvalidRequestError as e:
+ if e.http_status != 404:
+ raise
+ try:
+ customer = stripe.Customer.search(
+ query=f'metadata["uid"]:"{self.org_id}"'
+ ).data[0]
+ except IndexError:
+ return None
+ else:
+ self.stripe_customer_id = customer.id
+ self.save()
+ return customer
+
+ def get_dollars_spent_this_month(self) -> float:
+ today = timezone.now()
+ cents_spent = self.transactions.filter(
+ created_at__month=today.month,
+ created_at__year=today.year,
+ amount__gt=0,
+ ).aggregate(total=Sum("charged_amount"))["total"]
+ return (cents_spent or 0) / 100
+
class OrgMembership(SafeDeleteModel):
org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="memberships")
@@ -260,6 +392,8 @@ def auto_accept(self):
Raises: ValidationError
"""
+ from app_users.models import AppUser
+
assert self.status == self.Status.PENDING
invitee = AppUser.objects.get(email=self.invitee_email)
@@ -287,7 +421,7 @@ def send_email(self):
send_invitation_email.delay(invitation_pk=self.pk)
- def accept(self, user: AppUser, *, auto_accepted: bool = False):
+ def accept(self, user: "AppUser", *, auto_accepted: bool = False):
"""
Raises: ValidationError
"""
@@ -323,13 +457,13 @@ def accept(self, user: AppUser, *, auto_accepted: bool = False):
)
self.save()
- def reject(self, user: AppUser):
+ def reject(self, user: "AppUser"):
self.status = self.Status.REJECTED
self.status_changed_at = timezone.now()
self.status_changed_by = user
self.save()
- def cancel(self, user: AppUser):
+ def cancel(self, user: "AppUser"):
self.status = self.Status.CANCELED
self.status_changed_at = timezone.now()
self.status_changed_by = user
diff --git a/orgs/views.py b/orgs/views.py
index ed864cb94..2d6f3c27c 100644
--- a/orgs/views.py
+++ b/orgs/views.py
@@ -2,18 +2,29 @@
import html as html_lib
+import stripe
import gooey_gui as gui
from django.core.exceptions import ValidationError
-from app_users.models import AppUser
+from app_users.models import AppUser, PaymentProvider
+from daras_ai_v2.billing import format_card_brand, payment_provider_radio
+from daras_ai_v2.grid_layout_widget import grid_layout
from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole
-from daras_ai_v2 import icons
-from daras_ai_v2.fastapi_tricks import get_route_path
+from daras_ai_v2 import icons, settings
+from daras_ai_v2.fastapi_tricks import get_route_path, get_app_route_url
+from daras_ai_v2.settings import templates
+from daras_ai_v2.user_date_widgets import render_local_date_attrs
+from payments.models import PaymentMethodSummary
+from payments.plans import PricingPlan
+from scripts.migrate_existing_subscriptions import available_subscriptions
DEFAULT_ORG_LOGO = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/74a37c52-8260-11ee-a297-02420a0001ee/gooey.ai%20-%20A%20pop%20art%20illustration%20of%20robots%20taki...y%20Liechtenstein%20mint%20colour%20is%20main%20city%20Seattle.png"
+rounded_border = "w-100 border shadow-sm rounded py-4 px-3"
+
+
def invitation_page(user: AppUser, invitation: OrgInvitation):
from routers.account import orgs_route
@@ -107,6 +118,10 @@ def render_org_by_membership(membership: OrgMembership):
f"Org Domain: `@{org.domain_name}`", className="text-muted"
)
+ with gui.div(className="mt-4"):
+ gui.write("# Billing")
+ billing_section(org=org, current_member=membership)
+
with gui.div(className="mt-4"):
with gui.div(className="d-flex justify-content-between align-items-center"):
gui.write("## Members")
@@ -142,6 +157,361 @@ def render_org_by_membership(membership: OrgMembership):
org_leave_modal.open()
+def billing_section(*, org: Org, current_member: OrgMembership):
+ render_payments_setup()
+
+ if org.subscription and org.subscription.external_id:
+ render_current_plan(org)
+
+ with gui.div(className="my-5"):
+ render_credit_balance(org)
+
+ with gui.div(className="my-5"):
+ selected_payment_provider = render_all_plans(org)
+
+ with gui.div(className="my-5"):
+ render_addon_section(org, selected_payment_provider)
+
+ if org.subscription and org.subscription.external_id:
+ # if org.subscription.payment_provider == PaymentProvider.STRIPE:
+ # with gui.div(className="my-5"):
+ # render_auto_recharge_section(user)
+ with gui.div(className="my-5"):
+ render_payment_information(org)
+
+ with gui.div(className="my-5"):
+ render_billing_history(org)
+
+
+def render_payments_setup():
+ from routers.account import payment_processing_route
+
+ gui.html(
+ templates.get_template("payment_setup.html").render(
+ settings=settings,
+ payment_processing_url=get_app_route_url(payment_processing_route),
+ )
+ )
+
+
+def render_current_plan(org: Org):
+ plan = PricingPlan.from_sub(org.subscription)
+ provider = (
+ PaymentProvider(org.subscription.payment_provider)
+ if org.subscription.payment_provider
+ else None
+ )
+
+ with gui.div(className=f"{rounded_border} border-dark"):
+ # ROW 1: Plan title and next invoice date
+ left, right = left_and_right()
+ with left:
+ gui.write(f"#### Gooey.AI {plan.title}")
+
+ if provider:
+ gui.write(
+ f"[{icons.edit} Manage Subscription](#payment-information)",
+ unsafe_allow_html=True,
+ )
+ with right, gui.div(className="d-flex align-items-center gap-1"):
+ if provider and (
+ next_invoice_ts := gui.run_in_thread(
+ org.subscription.get_next_invoice_timestamp, cache=True
+ )
+ ):
+ gui.html("Next invoice on ")
+ gui.pill(
+ "...",
+ text_bg="dark",
+ **render_local_date_attrs(
+ next_invoice_ts,
+ date_options={"day": "numeric", "month": "long"},
+ ),
+ )
+
+ if plan is PricingPlan.ENTERPRISE:
+ # charge details are not relevant for Enterprise customers
+ return
+
+ # ROW 2: Plan pricing details
+ left, right = left_and_right(className="mt-5")
+ with left:
+ gui.write(f"# {plan.pricing_title()}", className="no-margin")
+ if plan.monthly_charge:
+ provider_text = f" **via {provider.label}**" if provider else ""
+ gui.caption("per month" + provider_text)
+
+ with right, gui.div(className="text-end"):
+ gui.write(f"# {plan.credits:,} credits", className="no-margin")
+ if plan.monthly_charge:
+ gui.write(
+ f"**${plan.monthly_charge:,}** monthly renewal for {plan.credits:,} credits"
+ )
+
+
+def render_credit_balance(org: Org):
+ gui.write(f"## Credit Balance: {org.balance:,}")
+ gui.caption(
+ "Every time you submit a workflow or make an API call, we deduct credits from your account."
+ )
+
+
+def render_all_plans(org: Org) -> PaymentProvider | None:
+ current_plan = (
+ PricingPlan.from_sub(org.subscription)
+ if org.subscription
+ else PricingPlan.STARTER
+ )
+ all_plans = [plan for plan in PricingPlan if not plan.deprecated]
+
+ gui.write("## All Plans")
+ plans_div = gui.div(className="mb-1")
+
+ if org.subscription and org.subscription.payment_provider:
+ selected_payment_provider = None
+ else:
+ with gui.div():
+ selected_payment_provider = PaymentProvider[
+ payment_provider_radio() or PaymentProvider.STRIPE.name
+ ]
+
+ def _render_plan(plan: PricingPlan):
+ if plan == current_plan:
+ extra_class = "border-dark"
+ else:
+ extra_class = "bg-light"
+ with gui.div(className="d-flex flex-column h-100"):
+ with gui.div(
+ className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}"
+ ):
+ _render_plan_details(plan)
+ # _render_plan_action_button(
+ # user, plan, current_plan, selected_payment_provider
+ # )
+
+ with plans_div:
+ grid_layout(4, all_plans, _render_plan, separator=False)
+
+ with gui.div(className="my-2 d-flex justify-content-center"):
+ gui.caption(
+ f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**"
+ )
+
+ return selected_payment_provider
+
+
+def _render_plan_details(plan: PricingPlan):
+ with gui.div(className="flex-grow-1"):
+ with gui.div(className="mb-4"):
+ with gui.tag("h4", className="mb-0"):
+ gui.html(plan.title)
+ gui.caption(
+ plan.description,
+ style={
+ "minHeight": "calc(var(--bs-body-line-height) * 2em)",
+ "display": "block",
+ },
+ )
+ with gui.div(className="my-3 w-100"):
+ with gui.tag("h4", className="my-0 d-inline me-2"):
+ gui.html(plan.pricing_title())
+ with gui.tag("span", className="text-muted my-0"):
+ gui.html(plan.pricing_caption())
+ gui.write(plan.long_description, unsafe_allow_html=True)
+
+
+def render_payment_information(org: Org):
+ assert org.subscription
+
+ gui.write("## Payment Information", id="payment-information", className="d-block")
+ col1, col2, col3 = gui.columns(3, responsive=False)
+ with col1:
+ gui.write("**Pay via**")
+ with col2:
+ provider = PaymentProvider(org.subscription.payment_provider)
+ gui.write(provider.label)
+ with col3:
+ if gui.button(f"{icons.edit} Edit", type="link", key="manage-payment-provider"):
+ raise gui.RedirectException(org.subscription.get_external_management_url())
+
+ pm_summary = gui.run_in_thread(
+ org.subscription.get_payment_method_summary, cache=True
+ )
+ if not pm_summary:
+ return
+ pm_summary = PaymentMethodSummary(*pm_summary)
+ if pm_summary.card_brand and pm_summary.card_last4:
+ col1, col2, col3 = gui.columns(3, responsive=False)
+ with col1:
+ gui.write("**Payment Method**")
+ with col2:
+ gui.write(
+ f"{format_card_brand(pm_summary.card_brand)} ending in {pm_summary.card_last4}",
+ unsafe_allow_html=True,
+ )
+ with col3:
+ if gui.button(f"{icons.edit} Edit", type="link", key="edit-payment-method"):
+ change_payment_method(org)
+
+ if pm_summary.billing_email:
+ col1, col2, _ = gui.columns(3, responsive=False)
+ with col1:
+ gui.write("**Billing Email**")
+ with col2:
+ gui.html(pm_summary.billing_email)
+
+
+def change_payment_method(org: Org):
+ from routers.account import payment_processing_route
+ from routers.account import account_route
+
+ match org.subscription.payment_provider:
+ case PaymentProvider.STRIPE:
+ session = stripe.checkout.Session.create(
+ mode="setup",
+ currency="usd",
+ customer=org.get_or_create_stripe_customer().id,
+ setup_intent_data={
+ "metadata": {"subscription_id": org.subscription.external_id},
+ },
+ success_url=get_app_route_url(payment_processing_route),
+ cancel_url=get_app_route_url(account_route),
+ )
+ raise gui.RedirectException(session.url, status_code=303)
+ case _:
+ gui.error("Not implemented for this payment provider")
+
+
+def render_billing_history(org: Org, limit: int = 50):
+ import pandas as pd
+
+ txns = org.transactions.filter(amount__gt=0).order_by("-created_at")
+ if not txns:
+ return
+
+ gui.write("## Billing History", className="d-block")
+ gui.table(
+ pd.DataFrame.from_records(
+ [
+ {
+ "Date": txn.created_at.strftime("%m/%d/%Y"),
+ "Description": txn.reason_note(),
+ "Amount": f"-${txn.charged_amount / 100:,.2f}",
+ "Credits": f"+{txn.amount:,}",
+ "Balance": f"{txn.end_balance:,}",
+ }
+ for txn in txns[:limit]
+ ]
+ ),
+ )
+ if txns.count() > limit:
+ gui.caption(f"Showing only the most recent {limit} transactions.")
+
+
+def render_addon_section(org: Org, selected_payment_provider: PaymentProvider):
+ if org.subscription:
+ gui.write("# Purchase More Credits")
+ else:
+ gui.write("# Purchase Credits")
+ gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits")
+
+ if org.subscription and org.subscription.payment_provider:
+ provider = PaymentProvider(org.subscription.payment_provider)
+ else:
+ provider = selected_payment_provider
+ match provider:
+ case PaymentProvider.STRIPE | None:
+ render_stripe_addon_buttons(org)
+ case PaymentProvider.PAYPAL:
+ render_paypal_addon_buttons()
+
+
+def render_paypal_addon_buttons():
+ selected_amt = gui.horizontal_radio(
+ "",
+ settings.ADDON_AMOUNT_CHOICES,
+ format_func=lambda amt: f"${amt:,}",
+ checked_by_default=False,
+ )
+ if selected_amt:
+ gui.js(
+ f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})"
+ )
+ gui.div(
+ id="paypal-addon-buttons",
+ className="mt-2",
+ style={"width": "fit-content"},
+ )
+ gui.div(id="paypal-result-message")
+
+
+def render_stripe_addon_buttons(org: Org):
+ for dollar_amt in settings.ADDON_AMOUNT_CHOICES:
+ render_stripe_addon_button(dollar_amt, org)
+
+
+def render_stripe_addon_button(dollar_amt: int, org: Org):
+ confirm_purchase_modal = gui.Modal(
+ "Confirm Purchase", key=f"confirm-purchase-{dollar_amt}"
+ )
+ if gui.button(f"${dollar_amt:,}", type="primary"):
+ if org.subscription and org.subscription.external_id:
+ confirm_purchase_modal.open()
+ else:
+ stripe_addon_checkout_redirect(org, dollar_amt)
+
+ if not confirm_purchase_modal.is_open():
+ return
+ with confirm_purchase_modal.container():
+ gui.write(
+ f"""
+ Please confirm your purchase:
+ **{dollar_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollar_amt}**.
+ """,
+ className="py-4 d-block text-center",
+ )
+ with gui.div(className="d-flex w-100 justify-content-end"):
+ if gui.session_state.get("--confirm-purchase"):
+ success = gui.run_in_thread(
+ org.subscription.stripe_attempt_addon_purchase,
+ args=[dollar_amt],
+ placeholder="Processing payment...",
+ )
+ if success is None:
+ return
+ gui.session_state.pop("--confirm-purchase")
+ if success:
+ confirm_purchase_modal.close()
+ else:
+ gui.error("Payment failed... Please try again.")
+ return
+
+ if gui.button("Cancel", className="border border-danger text-danger me-2"):
+ confirm_purchase_modal.close()
+ gui.button("Buy", type="primary", key="--confirm-purchase")
+
+
+def stripe_addon_checkout_redirect(org: Org, dollar_amt: int):
+ from routers.account import account_route
+ from routers.account import payment_processing_route
+
+ line_item = available_subscriptions["addon"]["stripe"].copy()
+ line_item["quantity"] = dollar_amt * settings.ADDON_CREDITS_PER_DOLLAR
+ checkout_session = stripe.checkout.Session.create(
+ line_items=[line_item],
+ mode="payment",
+ success_url=get_app_route_url(payment_processing_route),
+ cancel_url=get_app_route_url(account_route),
+ customer=org.get_or_create_stripe_customer().id,
+ invoice_creation={"enabled": True},
+ allow_promotion_codes=True,
+ saved_payment_method_options={
+ "payment_method_save": "enabled",
+ },
+ )
+ raise gui.RedirectException(checkout_session.url, status_code=303)
+
+
def render_org_creation_view(user: AppUser):
gui.write(f"# {icons.company} Create an Org", unsafe_allow_html=True)
org_fields = render_org_create_or_edit_form()
@@ -502,3 +872,9 @@ class AttrDict(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__dict__ = self
+
+
+def left_and_right(*, className: str = "", **props):
+ className += " d-flex flex-row justify-content-between align-items-center"
+ with gui.div(className=className, **props):
+ return gui.div(), gui.div()
diff --git a/payments/models.py b/payments/models.py
index fe280247d..f647bd5a6 100644
--- a/payments/models.py
+++ b/payments/models.py
@@ -80,8 +80,10 @@ class Meta:
def __str__(self):
ret = f"{self.get_plan_display()} | {self.get_payment_provider_display()}"
- if self.has_user:
- ret = f"{ret} | {self.user}"
+ # if self.has_user:
+ # ret = f"{ret} | {self.user}"
+ if self.has_org:
+ ret = f"{ret} | {self.org}"
if self.auto_recharge_enabled:
ret = f"Auto | {ret}"
return ret
@@ -131,6 +133,15 @@ def has_user(self) -> bool:
def is_paid(self) -> bool:
return PricingPlan.from_sub(self).monthly_charge > 0 and self.external_id
+ @property
+ def has_org(self) -> bool:
+ try:
+ self.org
+ except Subscription.org.RelatedObjectDoesNotExist:
+ return False
+ else:
+ return True
+
def cancel(self):
from payments.webhooks import StripeWebhookHandler
@@ -361,12 +372,12 @@ def has_sent_monthly_budget_email_this_month(self) -> bool:
)
def should_send_monthly_spending_notification(self) -> bool:
- assert self.has_user
+ assert self.has_org
return bool(
self.monthly_spending_notification_threshold
and not self.has_sent_monthly_spending_notification_this_month()
- and self.user.get_dollars_spent_this_month()
+ and self.org.get_dollars_spent_this_month()
>= self.monthly_spending_notification_threshold
)
diff --git a/payments/tasks.py b/payments/tasks.py
index 252064541..2070db714 100644
--- a/payments/tasks.py
+++ b/payments/tasks.py
@@ -2,6 +2,7 @@
from loguru import logger
from app_users.models import AppUser
+from orgs.models import Org
from celeryapp import app
from daras_ai_v2 import settings
from daras_ai_v2.fastapi_tricks import get_app_route_url
@@ -10,33 +11,33 @@
@app.task
-def send_monthly_spending_notification_email(user_id: int):
+def send_monthly_spending_notification_email(id: int):
from routers.account import account_route
- user = AppUser.objects.get(id=user_id)
- if not user.email:
- logger.error(f"User doesn't have an email: {user=}")
- return
-
- threshold = user.subscription.monthly_spending_notification_threshold
-
- send_email_via_postmark(
- from_address=settings.SUPPORT_EMAIL,
- to_address=user.email,
- subject=f"[Gooey.AI] Monthly spending has exceeded ${threshold}",
- html_body=templates.get_template(
- "monthly_spending_notification_threshold_email.html"
- ).render(
- user=user,
- account_url=get_app_route_url(account_route),
- ),
- )
-
- # IMPORTANT: always use update_fields=... / select_for_update when updating
- # subscription info. We don't want to overwrite other changes made to
- # subscription during the same time
- user.subscription.monthly_spending_notification_sent_at = timezone.now()
- user.subscription.save(update_fields=["monthly_spending_notification_sent_at"])
+ org = Org.objects.get(id=id)
+ threshold = org.subscription.monthly_spending_notification_threshold
+ for owner in org.get_owners():
+ if not owner.user.email:
+ logger.error(f"Org Owner doesn't have an email: {owner=}")
+ return
+
+ send_email_via_postmark(
+ from_address=settings.SUPPORT_EMAIL,
+ to_address=owner.user.email,
+ subject=f"[Gooey.AI] Monthly spending has exceeded ${threshold}",
+ html_body=templates.get_template(
+ "monthly_spending_notification_threshold_email.html"
+ ).render(
+ user=owner.user,
+ account_url=get_app_route_url(account_route),
+ ),
+ )
+
+ # IMPORTANT: always use update_fields=... / select_for_update when updating
+ # subscription info. We don't want to overwrite other changes made to
+ # subscription during the same time
+ org.subscription.monthly_spending_notification_sent_at = timezone.now()
+ org.subscription.save(update_fields=["monthly_spending_notification_sent_at"])
def send_monthly_budget_reached_email(user: AppUser):
diff --git a/payments/webhooks.py b/payments/webhooks.py
index 0b822cfe7..c280e129f 100644
--- a/payments/webhooks.py
+++ b/payments/webhooks.py
@@ -10,6 +10,7 @@
TransactionReason,
)
from daras_ai_v2 import paypal
+from orgs.models import Org
from .models import Subscription
from .plans import PricingPlan
from .tasks import send_monthly_spending_notification_email
@@ -25,7 +26,7 @@ def handle_sale_completed(cls, sale: paypal.Sale):
return
pp_sub = paypal.Subscription.retrieve(sale.billing_agreement_id)
- assert pp_sub.custom_id, "pp_sub is missing uid"
+ assert pp_sub.custom_id, "pp_sub is missing org_id"
assert pp_sub.plan_id, "pp_sub is missing plan ID"
plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id)
@@ -38,9 +39,9 @@ def handle_sale_completed(cls, sale: paypal.Sale):
f"paypal: charged amount ${charged_dollars} does not match plan's monthly charge ${plan.monthly_charge}"
)
- uid = pp_sub.custom_id
+ org_id = pp_sub.custom_id
add_balance_for_payment(
- uid=uid,
+ org_id=org_id,
amount=plan.credits,
invoice_id=sale.id,
payment_provider=cls.PROVIDER,
@@ -53,7 +54,7 @@ def handle_sale_completed(cls, sale: paypal.Sale):
def handle_subscription_updated(cls, pp_sub: paypal.Subscription):
logger.info(f"Paypal subscription updated {pp_sub.id}")
- assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid"
+ assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing org_id"
assert pp_sub.plan_id, f"PayPal subscription {pp_sub.id} is missing plan ID"
plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id)
@@ -65,17 +66,17 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription):
)
return
- set_user_subscription(
+ set_org_subscription(
provider=cls.PROVIDER,
plan=plan,
- uid=pp_sub.custom_id,
+ org_id=pp_sub.custom_id,
external_id=pp_sub.id,
)
@classmethod
def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription):
assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid"
- set_user_subscription(
+ set_org_subscription(
uid=pp_sub.custom_id,
plan=PricingPlan.STARTER,
provider=None,
@@ -87,11 +88,9 @@ class StripeWebhookHandler:
PROVIDER = PaymentProvider.STRIPE
@classmethod
- def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice):
- from app_users.tasks import save_stripe_default_payment_method
-
+ def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice):
kwargs = {}
- if invoice.subscription:
+ if invoice.subscription and invoice.subscription_details:
kwargs["plan"] = PricingPlan.get_by_key(
invoice.subscription_details.metadata.get("subscription_key")
).db_value
@@ -112,7 +111,7 @@ def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice):
amount = invoice.lines.data[0].quantity
charged_amount = invoice.lines.data[0].amount
add_balance_for_payment(
- uid=uid,
+ org_id=org_id,
amount=amount,
invoice_id=invoice.id,
payment_provider=cls.PROVIDER,
@@ -130,7 +129,7 @@ def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice):
)
@classmethod
- def handle_checkout_session_completed(cls, uid: str, session_data):
+ def handle_checkout_session_completed(cls, org_id: str, session_data):
setup_intent_id = session_data.get("setup_intent")
if not setup_intent_id:
# not a setup mode checkout -- do nothing
@@ -152,7 +151,7 @@ def handle_checkout_session_completed(cls, uid: str, session_data):
)
@classmethod
- def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription):
+ def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscription):
logger.info(f"Stripe subscription updated: {stripe_sub.id}")
assert stripe_sub.plan, f"Stripe subscription {stripe_sub.id} is missing plan"
@@ -173,17 +172,18 @@ def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription):
)
return
- set_user_subscription(
+ set_org_subscription(
provider=cls.PROVIDER,
plan=plan,
- uid=uid,
+ org_id=org_id,
external_id=stripe_sub.id,
)
@classmethod
- def handle_subscription_cancelled(cls, uid: str):
- set_user_subscription(
- uid=uid,
+ def handle_subscription_cancelled(cls, org_id: str):
+ logger.info(f"Stripe subscription cancelled: {stripe_sub.id}")
+ set_org_subscription(
+ org_id=org_id,
plan=PricingPlan.STARTER,
provider=PaymentProvider.STRIPE,
external_id=None,
@@ -192,15 +192,15 @@ def handle_subscription_cancelled(cls, uid: str):
def add_balance_for_payment(
*,
- uid: str,
+ org_id: str,
amount: int,
invoice_id: str,
payment_provider: PaymentProvider,
charged_amount: int,
**kwargs,
):
- user = AppUser.objects.get_or_create_from_uid(uid)[0]
- user.add_balance(
+ org = Org.objects.get_or_create_from_org_id(org_id)[0]
+ org.add_balance(
amount=amount,
invoice_id=invoice_id,
charged_amount=charged_amount,
@@ -208,20 +208,20 @@ def add_balance_for_payment(
**kwargs,
)
- if not user.is_paying:
- user.is_paying = True
- user.save(update_fields=["is_paying"])
+ if not org.is_paying:
+ org.is_paying = True
+ org.save(update_fields=["is_paying"])
if (
- user.subscription
- and user.subscription.should_send_monthly_spending_notification()
+ org.subscription
+ and org.subscription.should_send_monthly_spending_notification()
):
- send_monthly_spending_notification_email.delay(user.id)
+ send_monthly_spending_notification_email.delay(org.id)
-def set_user_subscription(
+def set_org_subscription(
*,
- uid: str,
+ org_id: str,
plan: PricingPlan,
provider: PaymentProvider | None,
external_id: str | None,
@@ -229,9 +229,9 @@ def set_user_subscription(
charged_amount: int = None,
) -> Subscription:
with transaction.atomic():
- user = AppUser.objects.get_or_create_from_uid(uid)[0]
+ org = Org.objects.get_or_create_from_org_id(org_id)[0]
- old_sub = user.subscription
+ old_sub = org.subscription
if old_sub:
new_sub = copy(old_sub)
else:
@@ -245,8 +245,8 @@ def set_user_subscription(
new_sub.save()
if not old_sub:
- user.subscription = new_sub
- user.save(update_fields=["subscription"])
+ org.subscription = new_sub
+ org.save(update_fields=["subscription"])
# cancel previous subscription if it's not the same as the new one
if old_sub and old_sub.external_id != external_id:
diff --git a/scripts/migrate_orgs_from_appusers.py b/scripts/migrate_orgs_from_appusers.py
new file mode 100644
index 000000000..d4e868e30
--- /dev/null
+++ b/scripts/migrate_orgs_from_appusers.py
@@ -0,0 +1,26 @@
+from django.db import IntegrityError
+from loguru import logger
+
+from app_users.models import AppUser
+from orgs.models import Org
+
+
+def run():
+ users_without_personal_org = AppUser.objects.exclude(
+ id__in=Org.objects.filter(is_personal=True).values_list("created_by", flat=True)
+ )
+
+ done_count = 0
+
+ for appuser in users_without_personal_org:
+ try:
+ Org.objects.migrate_from_appuser(appuser)
+ except IntegrityError as e:
+ logger.warning(f"IntegrityError: {e}")
+ else:
+ done_count += 1
+
+ if done_count % 100 == 0:
+ logger.info(f"Running... {done_count} migrated")
+
+ logger.info(f"Done... {done_count} migrated")
From 0a3593fa86fb0a7d6bcf4311a61e90d1ea55f0f7 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 28 Aug 2024 15:17:34 +0530
Subject: [PATCH 047/110] feat: set initial credit balance for first org
created by user
---
daras_ai_v2/billing.py | 1 -
daras_ai_v2/settings.py | 3 +--
orgs/models.py | 20 ++++++++++++++++++--
payments/webhooks.py | 1 -
4 files changed, 19 insertions(+), 6 deletions(-)
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index adc500015..7722fa5bf 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -9,7 +9,6 @@
from daras_ai_v2.gui_confirm import confirm_modal
from daras_ai_v2.settings import templates
from daras_ai_v2.user_date_widgets import render_local_date_attrs
-from orgs.models import Org
from payments.models import PaymentMethodSummary
from payments.plans import PricingPlan
from payments.webhooks import StripeWebhookHandler, set_org_subscription
diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py
index 3cdd88dc8..05a79d4c8 100644
--- a/daras_ai_v2/settings.py
+++ b/daras_ai_v2/settings.py
@@ -11,7 +11,6 @@
"""
import os
-import json
from pathlib import Path
import sentry_sdk
@@ -289,9 +288,9 @@
EMAIL_USER_FREE_CREDITS = config("EMAIL_USER_FREE_CREDITS", 0, cast=int)
ANON_USER_FREE_CREDITS = config("ANON_USER_FREE_CREDITS", 25, cast=int)
LOGIN_USER_FREE_CREDITS = config("LOGIN_USER_FREE_CREDITS", 500, cast=int)
+FIRST_ORG_FREE_CREDITS = config("ORG_FREE_CREDITS", 500, cast=int)
ADDON_CREDITS_PER_DOLLAR = config("ADDON_CREDITS_PER_DOLLAR", 100, cast=int)
-
ADDON_AMOUNT_CHOICES = [10, 30, 50, 100, 300, 500, 1000] # USD
AUTO_RECHARGE_BALANCE_THRESHOLD_CHOICES = [300, 1000, 3000, 10000] # Credit balance
AUTO_RECHARGE_COOLDOWN_SECONDS = config("AUTO_RECHARGE_COOLDOWN_SECONDS", 60, cast=int)
diff --git a/orgs/models.py b/orgs/models.py
index 0c39312c0..fa1b471b9 100644
--- a/orgs/models.py
+++ b/orgs/models.py
@@ -42,11 +42,27 @@ class OrgRole(models.IntegerChoices):
class OrgManager(SafeDeleteManager):
def create_org(
- self, *, created_by: "AppUser", org_id: str | None = None, **kwargs
+ self,
+ *,
+ created_by: "AppUser",
+ org_id: str | None = None,
+ balance: int | None = None,
+ **kwargs,
) -> Org:
org = self.model(
- org_id=org_id or get_random_doc_id(), created_by=created_by, **kwargs
+ org_id=org_id or get_random_doc_id(),
+ created_by=created_by,
+ balance=balance,
+ **kwargs,
)
+ if (
+ balance is None
+ and Org.all_objects.filter(created_by=created_by).count() <= 1
+ ):
+ # set some balance for first team created by user
+ # Org.all_objects is important to include deleted orgs
+ org.balance = settings.FIRST_ORG_FREE_CREDITS
+
org.full_clean()
org.save()
org.add_member(
diff --git a/payments/webhooks.py b/payments/webhooks.py
index c280e129f..a00466bbc 100644
--- a/payments/webhooks.py
+++ b/payments/webhooks.py
@@ -181,7 +181,6 @@ def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscriptio
@classmethod
def handle_subscription_cancelled(cls, org_id: str):
- logger.info(f"Stripe subscription cancelled: {stripe_sub.id}")
set_org_subscription(
org_id=org_id,
plan=PricingPlan.STARTER,
From d6b5bc1de6566019082694d226bad5393206dd72 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 28 Aug 2024 15:33:41 +0530
Subject: [PATCH 048/110] feat: add billed org to saved run & script to migrate
org_id for existing saved runs
---
bots/migrations/0082_savedrun_billed_org.py | 20 ++++++++++++++++++++
1 file changed, 20 insertions(+)
create mode 100644 bots/migrations/0082_savedrun_billed_org.py
diff --git a/bots/migrations/0082_savedrun_billed_org.py b/bots/migrations/0082_savedrun_billed_org.py
new file mode 100644
index 000000000..9dbe6170d
--- /dev/null
+++ b/bots/migrations/0082_savedrun_billed_org.py
@@ -0,0 +1,20 @@
+# Generated by Django 4.2.7 on 2024-08-28 09:49
+
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('orgs', '0005_org_unique_personal_org_per_user'),
+ ('bots', '0081_alter_botintegration_streaming_enabled'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='savedrun',
+ name='billed_org',
+ field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='billed_runs', to='orgs.org'),
+ ),
+ ]
From e0b94cb34a0d9193e1442183924a9d1e9dbb314d Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 28 Aug 2024 15:34:37 +0530
Subject: [PATCH 049/110] fix: add filter condition in billed_org migration
script to only run on historical data
---
scripts/migrate_billed_org_for_saved_runs.py | 18 ++++++++++++++++++
1 file changed, 18 insertions(+)
create mode 100644 scripts/migrate_billed_org_for_saved_runs.py
diff --git a/scripts/migrate_billed_org_for_saved_runs.py b/scripts/migrate_billed_org_for_saved_runs.py
new file mode 100644
index 000000000..52b86e932
--- /dev/null
+++ b/scripts/migrate_billed_org_for_saved_runs.py
@@ -0,0 +1,18 @@
+from django.db.models import F, Subquery, OuterRef
+from django.db import transaction
+
+from bots.models import SavedRun
+from orgs.models import Org
+
+
+def run():
+ # Start a transaction to ensure atomicity
+ with transaction.atomic():
+ # Perform the update where 'uid' matches a valid 'org_id' in the 'Org' table
+ SavedRun.objects.filter(
+ billed_org_id__isnull=True, uid__in=Org.objects.values("org_id")
+ ).update(
+ billed_org_id=Subquery(
+ Org.objects.filter(org_id=OuterRef("uid")).values("id")[:1]
+ )
+ )
From 026bc27ae5baf8a4b6b8020a98384cf917fed200 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sun, 1 Sep 2024 19:01:36 +0530
Subject: [PATCH 050/110] fix: sync migrations in bots app with master
---
bots/migrations/0082_savedrun_billed_org.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/bots/migrations/0082_savedrun_billed_org.py b/bots/migrations/0082_savedrun_billed_org.py
index 9dbe6170d..208f46dcc 100644
--- a/bots/migrations/0082_savedrun_billed_org.py
+++ b/bots/migrations/0082_savedrun_billed_org.py
@@ -1,4 +1,4 @@
-# Generated by Django 4.2.7 on 2024-08-28 09:49
+# Generated by Django 4.2.7 on 2024-08-30 08:10
from django.db import migrations, models
import django.db.models.deletion
@@ -8,7 +8,7 @@ class Migration(migrations.Migration):
dependencies = [
('orgs', '0005_org_unique_personal_org_per_user'),
- ('bots', '0081_alter_botintegration_streaming_enabled'),
+ ('bots', '0081_remove_conversation_bots_conver_bot_int_73ac7b_idx_and_more'),
]
operations = [
From 322dbbd8840c0d022e8d50988d963fa5afacbee5 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sun, 1 Sep 2024 19:02:12 +0530
Subject: [PATCH 051/110] fix: type check for user.get_or_create_personal_org
---
app_users/models.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/app_users/models.py b/app_users/models.py
index 739ab3bd3..46803c1a8 100644
--- a/app_users/models.py
+++ b/app_users/models.py
@@ -1,5 +1,6 @@
import requests
import stripe
+import typing
from django.db import models, IntegrityError, transaction
from django.db.models import Sum
from django.utils import timezone
@@ -14,6 +15,9 @@
from handles.models import Handle
from payments.plans import PricingPlan
+if typing.TYPE_CHECKING:
+ from orgs.models import Org
+
class AppUserQuerySet(models.QuerySet):
def get_or_create_from_uid(
@@ -249,13 +253,13 @@ def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser":
return self
def get_or_create_personal_org(self) -> tuple["Org", bool]:
- from orgs.models import Org
+ from orgs.models import Org, OrgMembership
- org_membership = self.org_memberships.filter(
+ org_membership: OrgMembership | None = self.org_memberships.filter(
org__is_personal=True, org__created_by=self
).first()
if org_membership:
- return org_membership, False
+ return org_membership.org, False
else:
return Org.objects.migrate_from_appuser(self), True
From 0bf1ee98578985c6e1b489458ca5e74e30d6185e Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sun, 1 Sep 2024 19:05:16 +0530
Subject: [PATCH 052/110] add: make billing tab work with org instead of
AppUser
---
daras_ai_v2/billing.py | 264 +++++++++++++++++++++--------------------
routers/account.py | 3 +-
2 files changed, 137 insertions(+), 130 deletions(-)
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index 7722fa5bf..e5a5fa27e 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -1,8 +1,10 @@
+import typing
+
import gooey_gui as gui
import stripe
from django.core.exceptions import ValidationError
-from app_users.models import AppUser, PaymentProvider
+from app_users.models import AppUserTransaction, PaymentProvider
from daras_ai_v2 import icons, settings, paypal
from daras_ai_v2.fastapi_tricks import get_app_route_url
from daras_ai_v2.grid_layout_widget import grid_layout
@@ -14,34 +16,38 @@
from payments.webhooks import StripeWebhookHandler, set_org_subscription
from scripts.migrate_existing_subscriptions import available_subscriptions
+if typing.TYPE_CHECKING:
+ from orgs.models import Org
+
+
rounded_border = "w-100 border shadow-sm rounded py-4 px-3"
-def billing_page(user: AppUser):
+def billing_page(org: "Org"):
render_payments_setup()
- if user.subscription and user.subscription.is_paid():
- render_current_plan(user)
+ if org.subscription and org.subscription.is_paid():
+ render_current_plan(org)
with gui.div(className="my-5"):
- render_credit_balance(user)
+ render_credit_balance(org)
with gui.div(className="my-5"):
- selected_payment_provider = render_all_plans(user)
+ selected_payment_provider = render_all_plans(org)
with gui.div(className="my-5"):
- render_addon_section(user, selected_payment_provider)
+ render_addon_section(org, selected_payment_provider)
- if user.subscription:
- if user.subscription.payment_provider == PaymentProvider.STRIPE:
+ if org.subscription:
+ if org.subscription.payment_provider == PaymentProvider.STRIPE:
with gui.div(className="my-5"):
- render_auto_recharge_section(user)
+ render_auto_recharge_section(org)
with gui.div(className="my-5"):
- render_payment_information(user)
+ render_payment_information(org)
with gui.div(className="my-5"):
- render_billing_history(user)
+ render_billing_history(org)
def render_payments_setup():
@@ -55,10 +61,10 @@ def render_payments_setup():
)
-def render_current_plan(user: AppUser):
- plan = PricingPlan.from_sub(user.subscription)
- if user.subscription.payment_provider:
- provider = PaymentProvider(user.subscription.payment_provider)
+def render_current_plan(org: "Org"):
+ plan = PricingPlan.from_sub(org.subscription)
+ if org.subscription.payment_provider:
+ provider = PaymentProvider(org.subscription.payment_provider)
else:
provider = None
@@ -76,7 +82,7 @@ def render_current_plan(user: AppUser):
with right, gui.div(className="d-flex align-items-center gap-1"):
if provider and (
next_invoice_ts := gui.run_in_thread(
- user.subscription.get_next_invoice_timestamp, cache=True
+ org.subscription.get_next_invoice_timestamp, cache=True
)
):
gui.html("Next invoice on ")
@@ -112,17 +118,17 @@ def render_current_plan(user: AppUser):
)
-def render_credit_balance(user: AppUser):
- gui.write(f"## Credit Balance: {user.balance:,}")
+def render_credit_balance(org: "Org"):
+ gui.write(f"## Credit Balance: {org.balance:,}")
gui.caption(
"Every time you submit a workflow or make an API call, we deduct credits from your account."
)
-def render_all_plans(user: AppUser) -> PaymentProvider:
+def render_all_plans(org: "Org") -> PaymentProvider:
current_plan = (
- PricingPlan.from_sub(user.subscription)
- if user.subscription
+ PricingPlan.from_sub(org.subscription)
+ if org.subscription
else PricingPlan.STARTER
)
all_plans = [plan for plan in PricingPlan if not plan.deprecated]
@@ -130,8 +136,8 @@ def render_all_plans(user: AppUser) -> PaymentProvider:
gui.write("## All Plans")
plans_div = gui.div(className="mb-1")
- if user.subscription and user.subscription.payment_provider:
- selected_payment_provider = user.subscription.payment_provider
+ if org.subscription and org.subscription.payment_provider:
+ selected_payment_provider = org.subscription.payment_provider
else:
with gui.div():
selected_payment_provider = PaymentProvider[
@@ -149,7 +155,7 @@ def _render_plan(plan: PricingPlan):
):
_render_plan_details(plan)
_render_plan_action_button(
- user=user,
+ org=org,
plan=plan,
current_plan=current_plan,
payment_provider=selected_payment_provider,
@@ -187,7 +193,7 @@ def _render_plan_details(plan: PricingPlan):
def _render_plan_action_button(
- user: AppUser,
+ org: "Org",
plan: PricingPlan,
current_plan: PricingPlan,
payment_provider: PaymentProvider | None,
@@ -201,75 +207,72 @@ def _render_plan_action_button(
className=btn_classes + " btn btn-theme btn-primary",
):
gui.html("Contact Us")
- elif (
- user.subscription and user.subscription.plan == PricingPlan.ENTERPRISE.db_value
- ):
+ elif org.subscription and org.subscription.plan == PricingPlan.ENTERPRISE.db_value:
# don't show upgrade/downgrade buttons for enterprise customers
return
- else:
- if user.subscription and user.subscription.is_paid():
- # subscription exists, show upgrade/downgrade button
- if plan.credits > current_plan.credits:
- modal, confirmed = confirm_modal(
- title="Upgrade Plan",
- key=f"--modal-{plan.key}",
- text=f"""
+ elif org.subscription and org.subscription.is_paid():
+ # subscription exists, show upgrade/downgrade button
+ if plan.credits > current_plan.credits:
+ modal, confirmed = confirm_modal(
+ title="Upgrade Plan",
+ key=f"--modal-{plan.key}",
+ text=f"""
Are you sure you want to upgrade from **{current_plan.title} @ {fmt_price(current_plan)}** to **{plan.title} @ {fmt_price(plan)}**?
Your payment method will be charged ${plan.monthly_charge:,} today and again every month until you cancel.
**{plan.credits:,} Credits** will be added to your account today and with subsequent payments, your account balance
will be refreshed to {plan.credits:,} Credits.
- """,
- button_label="Upgrade",
+ """,
+ button_label="Upgrade",
+ )
+ if gui.button(
+ "Upgrade", className="primary", key=f"--change-sub-{plan.key}"
+ ):
+ modal.open()
+ if confirmed:
+ change_subscription(
+ org,
+ plan,
+ # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time
+ billing_cycle_anchor="now",
)
- if gui.button(
- "Upgrade", className="primary", key=f"--change-sub-{plan.key}"
- ):
- modal.open()
- if confirmed:
- change_subscription(
- user,
- plan,
- # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time
- billing_cycle_anchor="now",
- )
- else:
- modal, confirmed = confirm_modal(
- title="Downgrade Plan",
- key=f"--modal-{plan.key}",
- text=f"""
+ else:
+ modal, confirmed = confirm_modal(
+ title="Downgrade Plan",
+ key=f"--modal-{plan.key}",
+ text=f"""
Are you sure you want to downgrade from: **{current_plan.title} @ {fmt_price(current_plan)}** to **{plan.title} @ {fmt_price(plan)}**?
This will take effect from the next billing cycle.
- """,
- button_label="Downgrade",
- button_class="border-danger bg-danger text-white",
- )
- if gui.button(
- "Downgrade", className="secondary", key=f"--change-sub-{plan.key}"
- ):
- modal.open()
- if confirmed:
- change_subscription(user, plan)
- else:
- assert payment_provider is not None # for sanity
- _render_create_subscription_button(
- user=user,
- plan=plan,
- payment_provider=payment_provider,
+ """,
+ button_label="Downgrade",
+ button_class="border-danger bg-danger text-white",
)
+ if gui.button(
+ "Downgrade", className="secondary", key=f"--change-sub-{plan.key}"
+ ):
+ modal.open()
+ if confirmed:
+ change_subscription(org, plan)
+ else:
+ assert payment_provider is not None # for sanity
+ _render_create_subscription_button(
+ org=org,
+ plan=plan,
+ payment_provider=payment_provider,
+ )
def _render_create_subscription_button(
*,
- user: AppUser,
+ org: "Org",
plan: PricingPlan,
payment_provider: PaymentProvider,
):
match payment_provider:
case PaymentProvider.STRIPE:
- render_stripe_subscription_button(user=user, plan=plan)
+ render_stripe_subscription_button(org=org, plan=plan)
case PaymentProvider.PAYPAL:
render_paypal_subscription_button(plan=plan)
@@ -281,27 +284,27 @@ def fmt_price(plan: PricingPlan) -> str:
return "Free"
-def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs):
+def change_subscription(org: "Org", new_plan: PricingPlan, **kwargs):
from routers.account import account_route
from routers.account import payment_processing_route
- current_plan = PricingPlan.from_sub(user.subscription)
+ current_plan = PricingPlan.from_sub(org.subscription)
if new_plan == current_plan:
raise gui.RedirectException(get_app_route_url(account_route), status_code=303)
if new_plan == PricingPlan.STARTER:
- user.subscription.cancel()
+ org.subscription.cancel()
raise gui.RedirectException(
get_app_route_url(payment_processing_route), status_code=303
)
- match user.subscription.payment_provider:
+ match org.subscription.payment_provider:
case PaymentProvider.STRIPE:
if not new_plan.supports_stripe():
gui.error(f"Stripe subscription not available for {new_plan}")
- subscription = stripe.Subscription.retrieve(user.subscription.external_id)
+ subscription = stripe.Subscription.retrieve(org.subscription.external_id)
stripe.Subscription.modify(
subscription.id,
items=[
@@ -345,20 +348,20 @@ def payment_provider_radio(**props) -> str | None:
)
-def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvider):
- if user.subscription:
+def render_addon_section(org: "Org", selected_payment_provider: PaymentProvider):
+ if org.subscription:
gui.write("# Purchase More Credits")
else:
gui.write("# Purchase Credits")
gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits")
- if user.subscription and user.subscription.payment_provider:
- provider = PaymentProvider(user.subscription.payment_provider)
+ if org.subscription and org.subscription.payment_provider:
+ provider = PaymentProvider(org.subscription.payment_provider)
else:
provider = selected_payment_provider
match provider:
case PaymentProvider.STRIPE:
- render_stripe_addon_buttons(user)
+ render_stripe_addon_buttons(org)
case PaymentProvider.PAYPAL:
render_paypal_addon_buttons()
@@ -382,8 +385,8 @@ def render_paypal_addon_buttons():
gui.div(id="paypal-result-message")
-def render_stripe_addon_buttons(user: AppUser):
- if not (user.subscription and user.subscription.payment_provider):
+def render_stripe_addon_buttons(org: "Org"):
+ if not (org.subscription and org.subscription.payment_provider):
save_pm = gui.checkbox(
"Save payment method for future purchases & auto-recharge", value=True
)
@@ -391,10 +394,10 @@ def render_stripe_addon_buttons(user: AppUser):
save_pm = True
for dollat_amt in settings.ADDON_AMOUNT_CHOICES:
- render_stripe_addon_button(dollat_amt, user, save_pm)
+ render_stripe_addon_button(dollat_amt, org, save_pm)
-def render_stripe_addon_button(dollat_amt: int, user: AppUser, save_pm: bool):
+def render_stripe_addon_button(dollat_amt: int, org: "Org", save_pm: bool):
modal, confirmed = confirm_modal(
title="Purchase Credits",
key=f"--addon-modal-{dollat_amt}",
@@ -408,14 +411,14 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser, save_pm: bool):
)
if gui.button(f"${dollat_amt:,}", type="primary"):
- if user.subscription and user.subscription.stripe_get_default_payment_method():
+ if org.subscription and org.subscription.stripe_get_default_payment_method():
modal.open()
else:
- stripe_addon_checkout_redirect(user, dollat_amt, save_pm)
+ stripe_addon_checkout_redirect(org, dollat_amt, save_pm)
if confirmed:
success = gui.run_in_thread(
- user.subscription.stripe_attempt_addon_purchase,
+ org.subscription.stripe_attempt_addon_purchase,
args=[dollat_amt],
placeholder="",
)
@@ -426,10 +429,10 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser, save_pm: bool):
modal.close()
else:
# fallback to stripe checkout flow if the auto payment failed
- stripe_addon_checkout_redirect(user, dollat_amt, save_pm)
+ stripe_addon_checkout_redirect(org, dollat_amt, save_pm)
-def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int, save_pm: bool):
+def stripe_addon_checkout_redirect(org: "Org", dollat_amt: int, save_pm: bool):
from routers.account import account_route
from routers.account import payment_processing_route
@@ -445,7 +448,7 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int, save_pm: bool
mode="payment",
success_url=get_app_route_url(payment_processing_route),
cancel_url=get_app_route_url(account_route),
- customer=user.get_or_create_stripe_customer(),
+ customer=org.get_or_create_stripe_customer(),
invoice_creation={"enabled": True},
allow_promotion_codes=True,
**kwargs,
@@ -455,7 +458,7 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int, save_pm: bool
def render_stripe_subscription_button(
*,
- user: AppUser,
+ org: "Org",
plan: PricingPlan,
):
if not plan.supports_stripe():
@@ -483,36 +486,38 @@ def render_stripe_subscription_button(
key=f"--change-sub-{plan.key}",
type="primary",
):
- if user.subscription and user.subscription.stripe_get_default_payment_method():
+ if org.subscription and org.subscription.stripe_get_default_payment_method():
modal.open()
else:
- stripe_subscription_create(user=user, plan=plan)
+ stripe_subscription_create(org=org, plan=plan)
if confirmed:
- stripe_subscription_create(user=user, plan=plan)
+ stripe_subscription_create(org=org, plan=plan)
-def stripe_subscription_create(user: AppUser, plan: PricingPlan):
+def stripe_subscription_create(org: "Org", plan: PricingPlan):
from routers.account import account_route
from routers.account import payment_processing_route
- if user.subscription and user.subscription.plan == plan.db_value:
+ if org.subscription and org.subscription.is_paid():
# sanity check: already subscribed to some plan
- return
+ gui.rerun()
# check for existing subscriptions on stripe
- customer = user.get_or_create_stripe_customer()
+ customer = org.get_or_create_stripe_customer()
for sub in stripe.Subscription.list(
customer=customer, status="active", limit=1
).data:
- StripeWebhookHandler.handle_subscription_updated(uid=user.uid, stripe_sub=sub)
+ StripeWebhookHandler.handle_subscription_updated(
+ org_id=org.org_id, stripe_sub=sub
+ )
raise gui.RedirectException(
get_app_route_url(payment_processing_route), status_code=303
)
# try to directly create the subscription without checkout
- pm = user.subscription and user.subscription.stripe_get_default_payment_method()
metadata = {settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: plan.key}
+ pm = org.subscription and org.subscription.stripe_get_default_payment_method()
line_items = [plan.get_stripe_line_item()]
if pm:
sub = stripe.Subscription.create(
@@ -562,12 +567,12 @@ def render_paypal_subscription_button(
)
-def render_payment_information(user: AppUser):
- if not user.subscription:
+def render_payment_information(org: "Org"):
+ if not org.subscription:
return
pm_summary = gui.run_in_thread(
- user.subscription.get_payment_method_summary, cache=True
+ org.subscription.get_payment_method_summary, cache=True
)
if not pm_summary:
return
@@ -579,7 +584,7 @@ def render_payment_information(user: AppUser):
gui.write("**Pay via**")
with col2:
provider = PaymentProvider(
- user.subscription.payment_provider or PaymentProvider.STRIPE
+ org.subscription.payment_provider or PaymentProvider.STRIPE
)
gui.write(provider.label)
with col3:
@@ -587,7 +592,7 @@ def render_payment_information(user: AppUser):
f"{icons.edit} Edit", type="link", key="manage-payment-provider"
):
raise gui.RedirectException(
- user.subscription.get_external_management_url()
+ org.subscription.get_external_management_url()
)
pm_summary = PaymentMethodSummary(*pm_summary)
@@ -607,7 +612,7 @@ def render_payment_information(user: AppUser):
if gui.button(
f"{icons.edit} Edit", type="link", key="edit-payment-method"
):
- change_payment_method(user)
+ change_payment_method(org)
if pm_summary.billing_email:
col1, col2, _ = gui.columns(3, responsive=False)
@@ -636,12 +641,12 @@ def render_payment_information(user: AppUser):
modal.open()
if confirmed:
set_org_subscription(
- org_id=user.get_personal_org().org_id,
+ org_id=org.org_id,
plan=PricingPlan.STARTER,
provider=None,
external_id=None,
)
- pm = user.subscription and user.subscription.stripe_get_default_payment_method()
+ pm = org.subscription and org.subscription.stripe_get_default_payment_method()
if pm:
pm.detach()
raise gui.RedirectException(
@@ -649,18 +654,18 @@ def render_payment_information(user: AppUser):
)
-def change_payment_method(user: AppUser):
+def change_payment_method(org: "Org"):
from routers.account import payment_processing_route
from routers.account import account_route
- match user.subscription.payment_provider:
+ match org.subscription.payment_provider:
case PaymentProvider.STRIPE:
session = stripe.checkout.Session.create(
mode="setup",
currency="usd",
- customer=user.get_or_create_stripe_customer(),
+ customer=org.get_or_create_stripe_customer(),
setup_intent_data={
- "metadata": {"subscription_id": user.subscription.external_id},
+ "metadata": {"subscription_id": org.subscription.external_id},
},
success_url=get_app_route_url(payment_processing_route),
cancel_url=get_app_route_url(account_route),
@@ -674,10 +679,13 @@ def format_card_brand(brand: str) -> str:
return icons.card_icons.get(brand.lower(), brand.capitalize())
-def render_billing_history(user: AppUser, limit: int = 50):
+def render_billing_history(org: "Org", limit: int = 50):
import pandas as pd
- txns = user.transactions.filter(amount__gt=0).order_by("-created_at")
+ txns = AppUserTransaction.objects.filter(
+ org=org,
+ amount__gt=0,
+ ).order_by("-created_at")
if not txns:
return
@@ -700,9 +708,9 @@ def render_billing_history(user: AppUser, limit: int = 50):
gui.caption(f"Showing only the most recent {limit} transactions.")
-def render_auto_recharge_section(user: AppUser):
- assert user.subscription
- subscription = user.subscription
+def render_auto_recharge_section(org: "Org"):
+ assert org.subscription
+ subscription = org.subscription
gui.write("## Auto Recharge & Limits")
with gui.div(className="h4"):
@@ -746,10 +754,10 @@ def render_auto_recharge_section(user: AppUser):
""",
)
with gui.div(className="d-flex align-items-center"):
- user.subscription.monthly_spending_budget = gui.number_input(
+ subscription.monthly_spending_budget = gui.number_input(
"",
min_value=10,
- value=user.subscription.monthly_spending_budget,
+ value=subscription.monthly_spending_budget,
key="monthly-spending-budget",
)
gui.write("USD", className="d-block ms-2")
@@ -762,13 +770,11 @@ def render_auto_recharge_section(user: AppUser):
"""
)
with gui.div(className="d-flex align-items-center"):
- user.subscription.monthly_spending_notification_threshold = (
- gui.number_input(
- "",
- min_value=10,
- value=user.subscription.monthly_spending_notification_threshold,
- key="monthly-spending-notification-threshold",
- )
+ subscription.monthly_spending_notification_threshold = gui.number_input(
+ "",
+ min_value=10,
+ value=subscription.monthly_spending_notification_threshold,
+ key="monthly-spending-notification-threshold",
)
gui.write("USD", className="d-block ms-2")
diff --git a/routers/account.py b/routers/account.py
index b52239b2b..f9194589b 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -203,7 +203,8 @@ def url_path(self) -> str:
def billing_tab(request: Request):
- return billing_page(request.user)
+ org, _ = request.user.get_or_create_personal_org()
+ return billing_page(org)
def profile_tab(request: Request):
From 33359c53589b39e822ebeba293f148ff5763e68b Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sun, 1 Sep 2024 19:05:50 +0530
Subject: [PATCH 053/110] fix: remove billing from org page
---
orgs/views.py | 377 +-------------------------------------------------
1 file changed, 5 insertions(+), 372 deletions(-)
diff --git a/orgs/views.py b/orgs/views.py
index 2d6f3c27c..494bac72a 100644
--- a/orgs/views.py
+++ b/orgs/views.py
@@ -2,21 +2,13 @@
import html as html_lib
-import stripe
import gooey_gui as gui
from django.core.exceptions import ValidationError
-from app_users.models import AppUser, PaymentProvider
-from daras_ai_v2.billing import format_card_brand, payment_provider_radio
-from daras_ai_v2.grid_layout_widget import grid_layout
+from app_users.models import AppUser
from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole
-from daras_ai_v2 import icons, settings
-from daras_ai_v2.fastapi_tricks import get_route_path, get_app_route_url
-from daras_ai_v2.settings import templates
-from daras_ai_v2.user_date_widgets import render_local_date_attrs
-from payments.models import PaymentMethodSummary
-from payments.plans import PricingPlan
-from scripts.migrate_existing_subscriptions import available_subscriptions
+from daras_ai_v2 import icons
+from daras_ai_v2.fastapi_tricks import get_route_path
DEFAULT_ORG_LOGO = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/74a37c52-8260-11ee-a297-02420a0001ee/gooey.ai%20-%20A%20pop%20art%20illustration%20of%20robots%20taki...y%20Liechtenstein%20mint%20colour%20is%20main%20city%20Seattle.png"
@@ -70,7 +62,7 @@ def invitation_page(user: AppUser, invitation: OrgInvitation):
def orgs_page(user: AppUser):
- memberships = user.org_memberships.all()
+ memberships = user.org_memberships.filter()
if not memberships:
gui.write("*You're not part of an organization yet... Create one?*")
@@ -118,10 +110,6 @@ def render_org_by_membership(membership: OrgMembership):
f"Org Domain: `@{org.domain_name}`", className="text-muted"
)
- with gui.div(className="mt-4"):
- gui.write("# Billing")
- billing_section(org=org, current_member=membership)
-
with gui.div(className="mt-4"):
with gui.div(className="d-flex justify-content-between align-items-center"):
gui.write("## Members")
@@ -157,361 +145,6 @@ def render_org_by_membership(membership: OrgMembership):
org_leave_modal.open()
-def billing_section(*, org: Org, current_member: OrgMembership):
- render_payments_setup()
-
- if org.subscription and org.subscription.external_id:
- render_current_plan(org)
-
- with gui.div(className="my-5"):
- render_credit_balance(org)
-
- with gui.div(className="my-5"):
- selected_payment_provider = render_all_plans(org)
-
- with gui.div(className="my-5"):
- render_addon_section(org, selected_payment_provider)
-
- if org.subscription and org.subscription.external_id:
- # if org.subscription.payment_provider == PaymentProvider.STRIPE:
- # with gui.div(className="my-5"):
- # render_auto_recharge_section(user)
- with gui.div(className="my-5"):
- render_payment_information(org)
-
- with gui.div(className="my-5"):
- render_billing_history(org)
-
-
-def render_payments_setup():
- from routers.account import payment_processing_route
-
- gui.html(
- templates.get_template("payment_setup.html").render(
- settings=settings,
- payment_processing_url=get_app_route_url(payment_processing_route),
- )
- )
-
-
-def render_current_plan(org: Org):
- plan = PricingPlan.from_sub(org.subscription)
- provider = (
- PaymentProvider(org.subscription.payment_provider)
- if org.subscription.payment_provider
- else None
- )
-
- with gui.div(className=f"{rounded_border} border-dark"):
- # ROW 1: Plan title and next invoice date
- left, right = left_and_right()
- with left:
- gui.write(f"#### Gooey.AI {plan.title}")
-
- if provider:
- gui.write(
- f"[{icons.edit} Manage Subscription](#payment-information)",
- unsafe_allow_html=True,
- )
- with right, gui.div(className="d-flex align-items-center gap-1"):
- if provider and (
- next_invoice_ts := gui.run_in_thread(
- org.subscription.get_next_invoice_timestamp, cache=True
- )
- ):
- gui.html("Next invoice on ")
- gui.pill(
- "...",
- text_bg="dark",
- **render_local_date_attrs(
- next_invoice_ts,
- date_options={"day": "numeric", "month": "long"},
- ),
- )
-
- if plan is PricingPlan.ENTERPRISE:
- # charge details are not relevant for Enterprise customers
- return
-
- # ROW 2: Plan pricing details
- left, right = left_and_right(className="mt-5")
- with left:
- gui.write(f"# {plan.pricing_title()}", className="no-margin")
- if plan.monthly_charge:
- provider_text = f" **via {provider.label}**" if provider else ""
- gui.caption("per month" + provider_text)
-
- with right, gui.div(className="text-end"):
- gui.write(f"# {plan.credits:,} credits", className="no-margin")
- if plan.monthly_charge:
- gui.write(
- f"**${plan.monthly_charge:,}** monthly renewal for {plan.credits:,} credits"
- )
-
-
-def render_credit_balance(org: Org):
- gui.write(f"## Credit Balance: {org.balance:,}")
- gui.caption(
- "Every time you submit a workflow or make an API call, we deduct credits from your account."
- )
-
-
-def render_all_plans(org: Org) -> PaymentProvider | None:
- current_plan = (
- PricingPlan.from_sub(org.subscription)
- if org.subscription
- else PricingPlan.STARTER
- )
- all_plans = [plan for plan in PricingPlan if not plan.deprecated]
-
- gui.write("## All Plans")
- plans_div = gui.div(className="mb-1")
-
- if org.subscription and org.subscription.payment_provider:
- selected_payment_provider = None
- else:
- with gui.div():
- selected_payment_provider = PaymentProvider[
- payment_provider_radio() or PaymentProvider.STRIPE.name
- ]
-
- def _render_plan(plan: PricingPlan):
- if plan == current_plan:
- extra_class = "border-dark"
- else:
- extra_class = "bg-light"
- with gui.div(className="d-flex flex-column h-100"):
- with gui.div(
- className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}"
- ):
- _render_plan_details(plan)
- # _render_plan_action_button(
- # user, plan, current_plan, selected_payment_provider
- # )
-
- with plans_div:
- grid_layout(4, all_plans, _render_plan, separator=False)
-
- with gui.div(className="my-2 d-flex justify-content-center"):
- gui.caption(
- f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**"
- )
-
- return selected_payment_provider
-
-
-def _render_plan_details(plan: PricingPlan):
- with gui.div(className="flex-grow-1"):
- with gui.div(className="mb-4"):
- with gui.tag("h4", className="mb-0"):
- gui.html(plan.title)
- gui.caption(
- plan.description,
- style={
- "minHeight": "calc(var(--bs-body-line-height) * 2em)",
- "display": "block",
- },
- )
- with gui.div(className="my-3 w-100"):
- with gui.tag("h4", className="my-0 d-inline me-2"):
- gui.html(plan.pricing_title())
- with gui.tag("span", className="text-muted my-0"):
- gui.html(plan.pricing_caption())
- gui.write(plan.long_description, unsafe_allow_html=True)
-
-
-def render_payment_information(org: Org):
- assert org.subscription
-
- gui.write("## Payment Information", id="payment-information", className="d-block")
- col1, col2, col3 = gui.columns(3, responsive=False)
- with col1:
- gui.write("**Pay via**")
- with col2:
- provider = PaymentProvider(org.subscription.payment_provider)
- gui.write(provider.label)
- with col3:
- if gui.button(f"{icons.edit} Edit", type="link", key="manage-payment-provider"):
- raise gui.RedirectException(org.subscription.get_external_management_url())
-
- pm_summary = gui.run_in_thread(
- org.subscription.get_payment_method_summary, cache=True
- )
- if not pm_summary:
- return
- pm_summary = PaymentMethodSummary(*pm_summary)
- if pm_summary.card_brand and pm_summary.card_last4:
- col1, col2, col3 = gui.columns(3, responsive=False)
- with col1:
- gui.write("**Payment Method**")
- with col2:
- gui.write(
- f"{format_card_brand(pm_summary.card_brand)} ending in {pm_summary.card_last4}",
- unsafe_allow_html=True,
- )
- with col3:
- if gui.button(f"{icons.edit} Edit", type="link", key="edit-payment-method"):
- change_payment_method(org)
-
- if pm_summary.billing_email:
- col1, col2, _ = gui.columns(3, responsive=False)
- with col1:
- gui.write("**Billing Email**")
- with col2:
- gui.html(pm_summary.billing_email)
-
-
-def change_payment_method(org: Org):
- from routers.account import payment_processing_route
- from routers.account import account_route
-
- match org.subscription.payment_provider:
- case PaymentProvider.STRIPE:
- session = stripe.checkout.Session.create(
- mode="setup",
- currency="usd",
- customer=org.get_or_create_stripe_customer().id,
- setup_intent_data={
- "metadata": {"subscription_id": org.subscription.external_id},
- },
- success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(account_route),
- )
- raise gui.RedirectException(session.url, status_code=303)
- case _:
- gui.error("Not implemented for this payment provider")
-
-
-def render_billing_history(org: Org, limit: int = 50):
- import pandas as pd
-
- txns = org.transactions.filter(amount__gt=0).order_by("-created_at")
- if not txns:
- return
-
- gui.write("## Billing History", className="d-block")
- gui.table(
- pd.DataFrame.from_records(
- [
- {
- "Date": txn.created_at.strftime("%m/%d/%Y"),
- "Description": txn.reason_note(),
- "Amount": f"-${txn.charged_amount / 100:,.2f}",
- "Credits": f"+{txn.amount:,}",
- "Balance": f"{txn.end_balance:,}",
- }
- for txn in txns[:limit]
- ]
- ),
- )
- if txns.count() > limit:
- gui.caption(f"Showing only the most recent {limit} transactions.")
-
-
-def render_addon_section(org: Org, selected_payment_provider: PaymentProvider):
- if org.subscription:
- gui.write("# Purchase More Credits")
- else:
- gui.write("# Purchase Credits")
- gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits")
-
- if org.subscription and org.subscription.payment_provider:
- provider = PaymentProvider(org.subscription.payment_provider)
- else:
- provider = selected_payment_provider
- match provider:
- case PaymentProvider.STRIPE | None:
- render_stripe_addon_buttons(org)
- case PaymentProvider.PAYPAL:
- render_paypal_addon_buttons()
-
-
-def render_paypal_addon_buttons():
- selected_amt = gui.horizontal_radio(
- "",
- settings.ADDON_AMOUNT_CHOICES,
- format_func=lambda amt: f"${amt:,}",
- checked_by_default=False,
- )
- if selected_amt:
- gui.js(
- f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})"
- )
- gui.div(
- id="paypal-addon-buttons",
- className="mt-2",
- style={"width": "fit-content"},
- )
- gui.div(id="paypal-result-message")
-
-
-def render_stripe_addon_buttons(org: Org):
- for dollar_amt in settings.ADDON_AMOUNT_CHOICES:
- render_stripe_addon_button(dollar_amt, org)
-
-
-def render_stripe_addon_button(dollar_amt: int, org: Org):
- confirm_purchase_modal = gui.Modal(
- "Confirm Purchase", key=f"confirm-purchase-{dollar_amt}"
- )
- if gui.button(f"${dollar_amt:,}", type="primary"):
- if org.subscription and org.subscription.external_id:
- confirm_purchase_modal.open()
- else:
- stripe_addon_checkout_redirect(org, dollar_amt)
-
- if not confirm_purchase_modal.is_open():
- return
- with confirm_purchase_modal.container():
- gui.write(
- f"""
- Please confirm your purchase:
- **{dollar_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollar_amt}**.
- """,
- className="py-4 d-block text-center",
- )
- with gui.div(className="d-flex w-100 justify-content-end"):
- if gui.session_state.get("--confirm-purchase"):
- success = gui.run_in_thread(
- org.subscription.stripe_attempt_addon_purchase,
- args=[dollar_amt],
- placeholder="Processing payment...",
- )
- if success is None:
- return
- gui.session_state.pop("--confirm-purchase")
- if success:
- confirm_purchase_modal.close()
- else:
- gui.error("Payment failed... Please try again.")
- return
-
- if gui.button("Cancel", className="border border-danger text-danger me-2"):
- confirm_purchase_modal.close()
- gui.button("Buy", type="primary", key="--confirm-purchase")
-
-
-def stripe_addon_checkout_redirect(org: Org, dollar_amt: int):
- from routers.account import account_route
- from routers.account import payment_processing_route
-
- line_item = available_subscriptions["addon"]["stripe"].copy()
- line_item["quantity"] = dollar_amt * settings.ADDON_CREDITS_PER_DOLLAR
- checkout_session = stripe.checkout.Session.create(
- line_items=[line_item],
- mode="payment",
- success_url=get_app_route_url(payment_processing_route),
- cancel_url=get_app_route_url(account_route),
- customer=org.get_or_create_stripe_customer().id,
- invoice_creation={"enabled": True},
- allow_promotion_codes=True,
- saved_payment_method_options={
- "payment_method_save": "enabled",
- },
- )
- raise gui.RedirectException(checkout_session.url, status_code=303)
-
-
def render_org_creation_view(user: AppUser):
gui.write(f"# {icons.company} Create an Org", unsafe_allow_html=True)
org_fields = render_org_create_or_edit_form()
@@ -525,7 +158,7 @@ def render_org_creation_view(user: AppUser):
except ValidationError as e:
gui.write(", ".join(e.messages), className="text-danger")
else:
- gui.experimental_rerun()
+ gui.rerun()
def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: gui.Modal):
From 19a24efb2a5c89d0e52252acfa7d409c18828cd5 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sun, 1 Sep 2024 19:06:26 +0530
Subject: [PATCH 054/110] feat: use set_org_subscription instead of
set_user_subscription
---
payments/webhooks.py | 20 +++++++++-----------
1 file changed, 9 insertions(+), 11 deletions(-)
diff --git a/payments/webhooks.py b/payments/webhooks.py
index a00466bbc..cedd2b0b3 100644
--- a/payments/webhooks.py
+++ b/payments/webhooks.py
@@ -4,11 +4,7 @@
from django.db import transaction
from loguru import logger
-from app_users.models import (
- AppUser,
- PaymentProvider,
- TransactionReason,
-)
+from app_users.models import PaymentProvider, TransactionReason
from daras_ai_v2 import paypal
from orgs.models import Org
from .models import Subscription
@@ -67,9 +63,9 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription):
return
set_org_subscription(
- provider=cls.PROVIDER,
- plan=plan,
org_id=pp_sub.custom_id,
+ plan=plan,
+ provider=cls.PROVIDER,
external_id=pp_sub.id,
)
@@ -77,7 +73,7 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription):
def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription):
assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid"
set_org_subscription(
- uid=pp_sub.custom_id,
+ org_id=pp_sub.custom_id,
plan=PricingPlan.STARTER,
provider=None,
external_id=None,
@@ -89,6 +85,8 @@ class StripeWebhookHandler:
@classmethod
def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice):
+ from app_users.tasks import save_stripe_default_payment_method
+
kwargs = {}
if invoice.subscription and invoice.subscription_details:
kwargs["plan"] = PricingPlan.get_by_key(
@@ -122,7 +120,7 @@ def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice):
save_stripe_default_payment_method.delay(
payment_intent_id=invoice.payment_intent,
- uid=uid,
+ org_id=org_id,
amount=amount,
charged_amount=charged_amount,
reason=reason,
@@ -173,9 +171,9 @@ def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscriptio
return
set_org_subscription(
- provider=cls.PROVIDER,
- plan=plan,
org_id=org_id,
+ plan=plan,
+ provider=cls.PROVIDER,
external_id=stripe_sub.id,
)
From 984602985c4d088df1c6b2eccf7f59224a513ae2 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sun, 1 Sep 2024 23:43:31 +0530
Subject: [PATCH 055/110] fix: phone number field in
org.get_or_create_stripe_customer
---
orgs/models.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/orgs/models.py b/orgs/models.py
index fa1b471b9..6038d99c9 100644
--- a/orgs/models.py
+++ b/orgs/models.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import re
+import typing
from datetime import timedelta
from django.db.models.aggregates import Sum
@@ -20,6 +21,9 @@
from gooeysite.bg_db_conn import db_middleware
from orgs.tasks import send_auto_accepted_email, send_invitation_email
+if typing.TYPE_CHECKING:
+ from app_users.models import AppUser
+
ORG_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$")
@@ -259,7 +263,7 @@ def get_or_create_stripe_customer(self) -> stripe.Customer:
customer = stripe.Customer.create(
name=self.created_by.display_name,
email=self.created_by.email,
- phone=self.created_by.phone,
+ phone=self.created_by.phone_number,
metadata={"uid": self.org_id, "org_id": self.org_id, "id": self.pk},
)
self.stripe_customer_id = customer.id
From f99e7231bb32711ddb455324e60b3f771c29751f Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sun, 1 Sep 2024 23:44:16 +0530
Subject: [PATCH 056/110] add org to list view in transactions admin
---
app_users/admin.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/app_users/admin.py b/app_users/admin.py
index 56f325fca..f433f86d2 100644
--- a/app_users/admin.py
+++ b/app_users/admin.py
@@ -216,6 +216,7 @@ class AppUserTransactionAdmin(admin.ModelAdmin):
autocomplete_fields = ["user"]
list_display = [
"invoice_id",
+ "org",
"user",
"amount",
"dollar_amount",
From 761d5e618f7031212ef22cff630a31f9320ffc30 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sun, 1 Sep 2024 23:44:45 +0530
Subject: [PATCH 057/110] fix: types in orgs.models
---
orgs/models.py | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/orgs/models.py b/orgs/models.py
index 6038d99c9..0b0362503 100644
--- a/orgs/models.py
+++ b/orgs/models.py
@@ -4,8 +4,8 @@
import typing
from datetime import timedelta
-from django.db.models.aggregates import Sum
import stripe
+from django.db.models.aggregates import Sum
from django.db import models, transaction
from django.core.exceptions import ValidationError
from django.db.backends.base.schema import logger
@@ -21,8 +21,9 @@
from gooeysite.bg_db_conn import db_middleware
from orgs.tasks import send_auto_accepted_email, send_invitation_email
+
if typing.TYPE_CHECKING:
- from app_users.models import AppUser
+ from app_users.models import AppUser, AppUserTransaction
ORG_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$")
@@ -164,7 +165,7 @@ def get_slug(self):
return slugify(self.name)
def add_member(
- self, user: AppUser, role: OrgRole, invitation: "OrgInvitation | None" = None
+ self, user: "AppUser", role: OrgRole, invitation: "OrgInvitation | None" = None
):
OrgMembership(
org=self,
@@ -177,7 +178,7 @@ def invite_user(
self,
*,
invitee_email: str,
- inviter: AppUser,
+ inviter: "AppUser",
role: OrgRole,
auto_accept: bool = False,
) -> "OrgInvitation":
From b62a128ac2feffb30ce5bd02e284cccca1de46e2 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Sun, 1 Sep 2024 23:45:16 +0530
Subject: [PATCH 058/110] add transaction migration to org billing migration
script
---
scripts/migrate_orgs_from_appusers.py | 28 +++++++++++++++++++++++++--
1 file changed, 26 insertions(+), 2 deletions(-)
diff --git a/scripts/migrate_orgs_from_appusers.py b/scripts/migrate_orgs_from_appusers.py
index d4e868e30..f4cbc7ec9 100644
--- a/scripts/migrate_orgs_from_appusers.py
+++ b/scripts/migrate_orgs_from_appusers.py
@@ -1,4 +1,4 @@
-from django.db import IntegrityError
+from django.db import IntegrityError, connection
from loguru import logger
from app_users.models import AppUser
@@ -6,12 +6,18 @@
def run():
+ migrate_personal_orgs()
+ migrate_txns()
+
+
+def migrate_personal_orgs():
users_without_personal_org = AppUser.objects.exclude(
id__in=Org.objects.filter(is_personal=True).values_list("created_by", flat=True)
)
done_count = 0
+ logger.info("Creating personal orgs...")
for appuser in users_without_personal_org:
try:
Org.objects.migrate_from_appuser(appuser)
@@ -23,4 +29,22 @@ def run():
if done_count % 100 == 0:
logger.info(f"Running... {done_count} migrated")
- logger.info(f"Done... {done_count} migrated")
+ logger.info(f"Migrated {done_count} personal orgs...")
+
+
+def migrate_txns():
+ with connection.cursor() as cursor:
+ cursor.execute(
+ """
+ UPDATE app_users_appusertransaction AS txn
+ SET org_id = orgs_org.id
+ FROM
+ app_users_appuser
+ INNER JOIN orgs_org ON app_users_appuser.id = orgs_org.created_by_id
+ WHERE
+ txn.user_id = app_users_appuser.id
+ AND txn.org_id IS NULL
+ AND orgs_org.is_personal = true
+ """
+ )
+ logger.info(f"Updated {cursor.rowcount} txns with personal orgs")
From 60f69380ac418091fbc29a9343bd3e923bd8d66c Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 2 Sep 2024 12:54:59 +0530
Subject: [PATCH 059/110] revert accidental changes to Procfile
---
Procfile | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Procfile b/Procfile
index 1766991c6..8711211c2 100644
--- a/Procfile
+++ b/Procfile
@@ -19,4 +19,4 @@ dashboard: poetry run streamlit run Home.py --server.port 8501 --server.headless
celery: poetry run celery -A celeryapp worker -P threads -c 16 -l DEBUG
-ui: cd ../gooey-gui/ && env PORT=3000 REDIS_URL=redis://localhost:6379 pnpm run dev
+ui: cd ../gooey-gui/; PORT=3000 npm run dev
From 782859540e7ccd9cd27319b32a68e8db95553dbd Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 2 Sep 2024 12:59:06 +0530
Subject: [PATCH 060/110] remove unused appuser.get_personal_org
---
app_users/models.py | 3 ---
1 file changed, 3 deletions(-)
diff --git a/app_users/models.py b/app_users/models.py
index 46803c1a8..68193c441 100644
--- a/app_users/models.py
+++ b/app_users/models.py
@@ -162,9 +162,6 @@ def first_name_possesive(self) -> str:
else:
return name + "'s"
- def get_personal_org(self) -> "Org | None":
- return self.orgs.filter(is_personal=True).first()
-
@db_middleware
@transaction.atomic
def add_balance(
From 68df010dda8fb67a57e734a3441bf23a0c58102d Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 2 Sep 2024 13:01:01 +0530
Subject: [PATCH 061/110] set user on txn if org is personal
works for now, until we introduce team members for all
---
orgs/models.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/orgs/models.py b/orgs/models.py
index 0b0362503..33f0b35de 100644
--- a/orgs/models.py
+++ b/orgs/models.py
@@ -252,6 +252,7 @@ def add_balance(
kwargs.setdefault("plan", org.subscription and org.subscription.plan)
return AppUserTransaction.objects.create(
org=org,
+ user=org.created_by if org.is_personal else None,
invoice_id=invoice_id,
amount=amount,
end_balance=org.balance,
From 811854b9d063f076553cfb220635c09e5505dc9b Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 2 Sep 2024 13:02:34 +0530
Subject: [PATCH 062/110] remove debug change
---
daras_ai_v2/send_email.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py
index 3c679c6fb..2262624e7 100644
--- a/daras_ai_v2/send_email.py
+++ b/daras_ai_v2/send_email.py
@@ -82,7 +82,7 @@ def send_email_via_postmark(
"outbound", "gooey-ai-workflows", "announcements"
] = "outbound",
):
- if is_running_pytest or not settings.POSTMARK_API_TOKEN:
+ if is_running_pytest:
pytest_outbox.append(
dict(
from_address=from_address,
From 87252748f4659db395457cfd06b7c9dd9ae01d89 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 2 Sep 2024 13:33:16 +0530
Subject: [PATCH 063/110] use org.subscription instead of user.subscription
---
orgs/models.py | 9 +++
payments/auto_recharge.py | 59 +++++++++----------
payments/tasks.py | 51 ++++++----------
routers/paypal.py | 5 +-
templates/auto_recharge_failed_email.html | 15 -----
templates/monthly_budget_reached_email.html | 8 +--
...spending_notification_threshold_email.html | 6 +-
7 files changed, 66 insertions(+), 87 deletions(-)
delete mode 100644 templates/auto_recharge_failed_email.html
diff --git a/orgs/models.py b/orgs/models.py
index 33f0b35de..4c9b2c8e2 100644
--- a/orgs/models.py
+++ b/orgs/models.py
@@ -98,6 +98,15 @@ def migrate_from_appuser(self, user: "AppUser") -> Org:
is_paying=user.is_paying,
)
+ def get_dollars_spent_this_month(self) -> float:
+ today = timezone.now()
+ cents_spent = self.transactions.filter(
+ created_at__month=today.month,
+ created_at__year=today.year,
+ amount__gt=0,
+ ).aggregate(total=Sum("charged_amount"))["total"]
+ return (cents_spent or 0) / 100
+
class Org(SafeDeleteModel):
_safedelete_policy = SOFT_DELETE_CASCADE
diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py
index 14d6ba49d..3d07493b5 100644
--- a/payments/auto_recharge.py
+++ b/payments/auto_recharge.py
@@ -3,12 +3,10 @@
import sentry_sdk
from loguru import logger
-from app_users.models import AppUser, PaymentProvider
+from app_users.models import PaymentProvider
from daras_ai_v2.redis_cache import redis_lock
-from payments.tasks import (
- send_monthly_budget_reached_email,
- send_auto_recharge_failed_email,
-)
+from orgs.models import Org
+from payments.tasks import send_monthly_budget_reached_email
class AutoRechargeException(Exception):
@@ -30,18 +28,18 @@ class AutoRechargeCooldownException(AutoRechargeException):
pass
-def should_attempt_auto_recharge(user: AppUser):
+def should_attempt_auto_recharge(org: Org):
return (
- user.subscription
- and user.subscription.auto_recharge_enabled
- and user.subscription.payment_provider
- and user.balance < user.subscription.auto_recharge_balance_threshold
+ org.subscription
+ and org.subscription.auto_recharge_enabled
+ and org.subscription.payment_provider
+ and org.balance < org.subscription.auto_recharge_balance_threshold
)
-def run_auto_recharge_gracefully(user: AppUser):
+def run_auto_recharge_gracefully(org: Org):
"""
- Wrapper over _auto_recharge_user, that handles exceptions so that it can:
+ Wrapper over _auto_recharge_org, that handles exceptions so that it can:
- log exceptions
- send emails when auto-recharge fails
- not retry if this is run as a background task
@@ -49,50 +47,49 @@ def run_auto_recharge_gracefully(user: AppUser):
Meant to be used in conjunction with should_attempt_auto_recharge
"""
try:
- with redis_lock(f"gooey/auto_recharge_user/v1/{user.uid}"):
- _auto_recharge_user(user)
+ with redis_lock(f"gooey/auto_recharge_user/v1/{org.org_id}"):
+ _auto_recharge_org(org)
except AutoRechargeCooldownException as e:
logger.info(
- f"Rejected auto-recharge because auto-recharge is in cooldown period for user"
- f"{user=}, {e=}"
+ f"Rejected auto-recharge because auto-recharge is in cooldown period for org"
+ f"{org=}, {e=}"
)
except MonthlyBudgetReachedException as e:
- send_monthly_budget_reached_email(user)
+ send_monthly_budget_reached_email(org)
logger.info(
f"Rejected auto-recharge because user has reached monthly budget"
- f"{user=}, spending=${e.spending}, budget=${e.budget}"
+ f"{org=}, spending=${e.spending}, budget=${e.budget}"
)
except Exception as e:
traceback.print_exc()
sentry_sdk.capture_exception(e)
- send_auto_recharge_failed_email(user)
-def _auto_recharge_user(user: AppUser):
+def _auto_recharge_org(org: Org):
"""
Returns whether a charge was attempted
"""
from payments.webhooks import StripeWebhookHandler
assert (
- user.subscription.payment_provider == PaymentProvider.STRIPE
+ org.subscription.payment_provider == PaymentProvider.STRIPE
), "Auto recharge is only supported with Stripe"
# check for monthly budget
- dollars_spent = user.get_dollars_spent_this_month()
+ dollars_spent = org.get_dollars_spent_this_month()
if (
- dollars_spent + user.subscription.auto_recharge_topup_amount
- > user.subscription.monthly_spending_budget
+ dollars_spent + org.subscription.auto_recharge_topup_amount
+ > org.subscription.monthly_spending_budget
):
raise MonthlyBudgetReachedException(
"Performing this top-up would exceed your monthly recharge budget",
- budget=user.subscription.monthly_spending_budget,
+ budget=org.subscription.monthly_spending_budget,
spending=dollars_spent,
)
try:
- invoice = user.subscription.stripe_get_or_create_auto_invoice(
- amount_in_dollars=user.subscription.auto_recharge_topup_amount,
+ invoice = org.subscription.stripe_get_or_create_auto_invoice(
+ amount_in_dollars=org.subscription.auto_recharge_topup_amount,
metadata_key="auto_recharge",
)
except Exception as e:
@@ -106,9 +103,9 @@ def _auto_recharge_user(user: AppUser):
# get default payment method and attempt payment
assert invoice.status == "open" # sanity check
- pm = user.subscription.stripe_get_default_payment_method()
+ pm = org.subscription.stripe_get_default_payment_method()
if not pm:
- logger.warning(f"{user} has no default payment method, cannot auto-recharge")
+ logger.warning(f"{org} has no default payment method, cannot auto-recharge")
return
try:
@@ -119,4 +116,6 @@ def _auto_recharge_user(user: AppUser):
) from e
else:
assert invoice_data.paid
- StripeWebhookHandler.handle_invoice_paid(uid=user.uid, invoice=invoice_data)
+ StripeWebhookHandler.handle_invoice_paid(
+ org_id=org.org_id, invoice=invoice_data
+ )
diff --git a/payments/tasks.py b/payments/tasks.py
index 2070db714..c98b8c12e 100644
--- a/payments/tasks.py
+++ b/payments/tasks.py
@@ -29,6 +29,7 @@ def send_monthly_spending_notification_email(id: int):
"monthly_spending_notification_threshold_email.html"
).render(
user=owner.user,
+ org=org,
account_url=get_app_route_url(account_route),
),
)
@@ -40,43 +41,27 @@ def send_monthly_spending_notification_email(id: int):
org.subscription.save(update_fields=["monthly_spending_notification_sent_at"])
-def send_monthly_budget_reached_email(user: AppUser):
+def send_monthly_budget_reached_email(org: Org):
from routers.account import account_route
- if not user.email:
- return
+ for owner in org.get_owners():
+ if not owner.user.email:
+ continue
- email_body = templates.get_template("monthly_budget_reached_email.html").render(
- user=user,
- account_url=get_app_route_url(account_route),
- )
- send_email_via_postmark(
- from_address=settings.SUPPORT_EMAIL,
- to_address=user.email,
- subject="[Gooey.AI] Monthly Budget Reached",
- html_body=email_body,
- )
+ email_body = templates.get_template("monthly_budget_reached_email.html").render(
+ user=owner.user,
+ org=org,
+ account_url=get_app_route_url(account_route),
+ )
+ send_email_via_postmark(
+ from_address=settings.SUPPORT_EMAIL,
+ to_address=owner.user.email,
+ subject="[Gooey.AI] Monthly Budget Reached",
+ html_body=email_body,
+ )
# IMPORTANT: always use update_fields=... when updating subscription
# info. We don't want to overwrite other changes made to subscription
# during the same time
- user.subscription.monthly_budget_email_sent_at = timezone.now()
- user.subscription.save(update_fields=["monthly_budget_email_sent_at"])
-
-
-def send_auto_recharge_failed_email(user: AppUser):
- from routers.account import account_route
-
- if not user.email:
- return
-
- email_body = templates.get_template("auto_recharge_failed_email.html").render(
- user=user,
- account_url=get_app_route_url(account_route),
- )
- send_email_via_postmark(
- from_address=settings.SUPPORT_EMAIL,
- to_address=user.email,
- subject="[Gooey.AI] Auto-Recharge failed",
- html_body=email_body,
- )
+ org.subscription.monthly_budget_email_sent_at = timezone.now()
+ org.subscription.save(update_fields=["monthly_budget_email_sent_at"])
diff --git a/routers/paypal.py b/routers/paypal.py
index 48e65a623..86f93ce48 100644
--- a/routers/paypal.py
+++ b/routers/paypal.py
@@ -126,7 +126,8 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json):
if plan.deprecated:
return JSONResponse({"error": "Deprecated plan"}, status_code=400)
- if request.user.subscription and request.user.subscription.is_paid():
+ org, _ = request.user.get_or_create_personal_org()
+ if org.subscription and org.subscription.is_paid():
return JSONResponse(
{"error": "User already has an active subscription"}, status_code=400
)
@@ -134,7 +135,7 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json):
paypal_plan_info = plan.get_paypal_plan()
pp_subscription = paypal.Subscription.create(
plan_id=paypal_plan_info["plan_id"],
- custom_id=request.user.uid,
+ custom_id=org.org_id,
plan=paypal_plan_info.get("plan", {}),
application_context={
"brand_name": "Gooey.AI",
diff --git a/templates/auto_recharge_failed_email.html b/templates/auto_recharge_failed_email.html
deleted file mode 100644
index 601fab5d8..000000000
--- a/templates/auto_recharge_failed_email.html
+++ /dev/null
@@ -1,15 +0,0 @@
-
- Hey, {{ user.first_name() }}!
-
-
-
- Your Gooey.AI account balance is below your threshold.
- An auto-recharge was attempted but failed because {{ reason }}.
- Please visit your billing settings.
-
-
-
- Best Wishes,
-
- Gooey.AI Team
-
diff --git a/templates/monthly_budget_reached_email.html b/templates/monthly_budget_reached_email.html
index 0171e320d..6e467a086 100644
--- a/templates/monthly_budget_reached_email.html
+++ b/templates/monthly_budget_reached_email.html
@@ -1,6 +1,6 @@
-{% set dollars_spent = user.get_dollars_spent_this_month() %}
-{% set monthly_budget = user.subscription.monthly_spending_budget %}
-{% set threshold = user.subscription.auto_recharge_balance_threshold %}
+{% set dollars_spent = org.get_dollars_spent_this_month() %}
+{% set monthly_budget = org.subscription.monthly_spending_budget %}
+{% set threshold = org.subscription.auto_recharge_balance_threshold %}
Hey, {{ user.first_name() }}!
@@ -18,7 +18,7 @@
-
Credit Balance: {{ user.balance }} credits
+
Credit Balance: {{ org.balance }} credits
Monthly Budget: ${{ monthly_budget }}
Spending this month: ${{ dollars_spent }}
diff --git a/templates/monthly_spending_notification_threshold_email.html b/templates/monthly_spending_notification_threshold_email.html
index ddf54e223..13be0fae5 100644
--- a/templates/monthly_spending_notification_threshold_email.html
+++ b/templates/monthly_spending_notification_threshold_email.html
@@ -1,4 +1,4 @@
-{% set dollars_spent = user.get_dollars_spent_this_month() %}
+{% set dollars_spent = org.get_dollars_spent_this_month() %}
Hi, {{ user.first_name() }}!
@@ -6,11 +6,11 @@
Your spend on Gooey.AI so far this month is ${{ dollars_spent }}, exceeding your notification threshold
- of ${{ user.subscription.monthly_spending_notification_threshold }}.
+ of ${{ org.subscription.monthly_spending_notification_threshold }}.
- Your monthly budget is ${{ user.subscription.monthly_spending_budget }}, after which auto-recharge will be
+ Your monthly budget is ${{ org.subscription.monthly_spending_budget }}, after which auto-recharge will be
paused and all runs / API calls will be rejected.
From 4a598a1cb3028641d509f4bc6e0dc06af18efd00 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 2 Sep 2024 18:37:56 +0530
Subject: [PATCH 064/110] fix: s/user.subscription/org.subscription in
billing.py
---
daras_ai_v2/billing.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index e5a5fa27e..7c99a53bf 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -325,7 +325,7 @@ def change_subscription(org: "Org", new_plan: PricingPlan, **kwargs):
if not new_plan.supports_paypal():
gui.error(f"Paypal subscription not available for {new_plan}")
- subscription = paypal.Subscription.retrieve(user.subscription.external_id)
+ subscription = paypal.Subscription.retrieve(org.subscription.external_id)
paypal_plan_info = new_plan.get_paypal_plan()
approval_url = subscription.update_plan(
plan_id=paypal_plan_info["plan_id"],
From 5039bea0a1854f75ad78fae815fd05b7ecb20df2 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 2 Sep 2024 19:08:19 +0530
Subject: [PATCH 065/110] fix type for set_org_subscription
---
payments/models.py | 6 +++++-
payments/webhooks.py | 4 ++--
2 files changed, 7 insertions(+), 3 deletions(-)
diff --git a/payments/models.py b/payments/models.py
index f647bd5a6..ff5be4f69 100644
--- a/payments/models.py
+++ b/payments/models.py
@@ -89,7 +89,11 @@ def __str__(self):
return ret
def full_clean(
- self, amount: int = None, charged_amount: int = None, *args, **kwargs
+ self,
+ amount: int | None = None,
+ charged_amount: int | None = None,
+ *args,
+ **kwargs,
):
if self.auto_recharge_enabled:
if amount is None:
diff --git a/payments/webhooks.py b/payments/webhooks.py
index cedd2b0b3..36f0499c7 100644
--- a/payments/webhooks.py
+++ b/payments/webhooks.py
@@ -222,8 +222,8 @@ def set_org_subscription(
plan: PricingPlan,
provider: PaymentProvider | None,
external_id: str | None,
- amount: int = None,
- charged_amount: int = None,
+ amount: int | None = None,
+ charged_amount: int | None = None,
) -> Subscription:
with transaction.atomic():
org = Org.objects.get_or_create_from_org_id(org_id)[0]
From e5385742fd8d5c7f08ee98e6ac0e5605b06395c5 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Mon, 2 Sep 2024 19:14:47 +0530
Subject: [PATCH 066/110] fix paypal handle_invoice_paid: uid -> org_id
---
routers/paypal.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/routers/paypal.py b/routers/paypal.py
index 86f93ce48..3771481cf 100644
--- a/routers/paypal.py
+++ b/routers/paypal.py
@@ -177,7 +177,7 @@ def _handle_invoice_paid(order_id: str):
purchase_unit = order["purchase_units"][0]
payment_capture = purchase_unit["payments"]["captures"][0]
add_balance_for_payment(
- uid=payment_capture["custom_id"],
+ org_id=payment_capture["custom_id"],
amount=int(purchase_unit["items"][0]["quantity"]),
invoice_id=payment_capture["id"],
payment_provider=PaymentProvider.PAYPAL,
From c12f3e8a9187d509307cab1c2fbd4ebf4ee7f5ec Mon Sep 17 00:00:00 2001
From: Dev Aggarwal
Date: Tue, 3 Sep 2024 01:10:18 +0530
Subject: [PATCH 067/110] update yt-dlp
---
poetry.lock | 12 ++++++------
pyproject.toml | 2 +-
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/poetry.lock b/poetry.lock
index 9950129f5..a7eb2ca4c 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -6491,13 +6491,13 @@ multidict = ">=4.0"
[[package]]
name = "yt-dlp"
-version = "2024.7.2"
+version = "2024.8.6"
description = "A feature-rich command-line audio/video downloader"
optional = false
python-versions = ">=3.8"
files = [
- {file = "yt_dlp-2024.7.2-py3-none-any.whl", hash = "sha256:4f76b48244c783e6ac06e8d7627bcf62cbeb4f6d79ba7e3cfc8249e680d4e691"},
- {file = "yt_dlp-2024.7.2.tar.gz", hash = "sha256:2b0c86b579d4a044eaf3c4b00e3d7b24d82e6e26869fa11c288ea4395b387f41"},
+ {file = "yt_dlp-2024.8.6-py3-none-any.whl", hash = "sha256:ab507ff600bd9269ad4d654e309646976778f0e243eaa2f6c3c3214278bb2922"},
+ {file = "yt_dlp-2024.8.6.tar.gz", hash = "sha256:e8551f26bc8bf67b99c12373cc87ed2073436c3437e53290878d0f4b4bb1f663"},
]
[package.dependencies]
@@ -6511,8 +6511,8 @@ urllib3 = ">=1.26.17,<3"
websockets = ">=12.0"
[package.extras]
-build = ["build", "hatchling", "pip", "setuptools", "wheel"]
-curl-cffi = ["curl-cffi (==0.5.10)"]
+build = ["build", "hatchling", "pip", "setuptools (>=71.0.2)", "wheel"]
+curl-cffi = ["curl-cffi (==0.5.10)", "curl-cffi (>=0.5.10,<0.6.dev0 || ==0.7.*)"]
dev = ["autopep8 (>=2.0,<3.0)", "pre-commit", "pytest (>=8.1,<9.0)", "ruff (>=0.5.0,<0.6.0)"]
py2exe = ["py2exe (>=0.12)"]
pyinstaller = ["pyinstaller (>=6.7.0)"]
@@ -6538,4 +6538,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
-content-hash = "5834cb5e676e83b492e8aec5d9efa15bed653848f6d356c139917e1a1b01e872"
+content-hash = "ac4c7f52c5bb619909f5c1ed8c653aeeb3aa0275e542d716724c5e6ebada2f37"
diff --git a/pyproject.toml b/pyproject.toml
index 90859e035..1d66681df 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,7 +47,7 @@ oauth2client = "^4.1.3"
tiktoken = "^0.7.0"
google-cloud-translate = "^3.12.0"
google-cloud-speech = "^2.21.0"
-yt-dlp = "^2024.7.2"
+yt-dlp = "^2024.8.6"
Jinja2 = "^3.1.2"
Django = "^4.2"
django-phonenumber-field = { extras = ["phonenumberslite"], version = "^7.0.2" }
From 34cb49775e8d6684d1b1c5a5a5f5d50f40763733 Mon Sep 17 00:00:00 2001
From: Dev Aggarwal
Date: Wed, 4 Sep 2024 17:55:12 +0530
Subject: [PATCH 068/110] Add logging for InvalidRequestError in billing module
---
daras_ai_v2/billing.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index 0254a89fc..c96bff8c7 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -2,6 +2,7 @@
import sentry_sdk
import stripe
from django.core.exceptions import ValidationError
+from loguru import logger
from app_users.models import AppUser, PaymentProvider
from daras_ai_v2 import icons, settings, paypal
@@ -239,6 +240,7 @@ def _render_plan_action_button(
except (stripe.CardError, stripe.InvalidRequestError) as e:
if isinstance(e, stripe.InvalidRequestError):
sentry_sdk.capture_exception(e)
+ logger.warning(e)
# only handle error if it's related to mandates
# cancel current subscription & redirect user to new subscription page
From ed05dad3bbe5050474f90fd04ef92e621a16f9b6 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 4 Sep 2024 18:41:18 +0530
Subject: [PATCH 069/110] Rename org -> workspace, remove workspace.id
---
app_users/admin.py | 2 +-
..._appusertransaction_workspace_and_more.py} | 8 +-
app_users/models.py | 21 +-
app_users/tasks.py | 34 ++-
...g.py => 0082_savedrun_billed_workspace.py} | 8 +-
bots/models.py | 4 +-
celeryapp/tasks.py | 39 ++-
daras_ai_v2/base.py | 36 ++-
daras_ai_v2/billing.py | 187 ++++++++------
daras_ai_v2/send_email.py | 30 ++-
daras_ai_v2/settings.py | 12 +-
orgs/admin.py | 111 --------
...0002_alter_org_unique_together_and_more.py | 35 ---
...e_domain_name_when_not_deleted_and_more.py | 36 ---
..._org_is_paying_org_is_personal_and_more.py | 45 ----
.../0005_org_unique_personal_org_per_user.py | 17 --
orgs/signals.py | 49 ----
payments/auto_recharge.py | 54 ++--
payments/models.py | 14 +-
payments/tasks.py | 29 ++-
payments/webhooks.py | 82 +++---
routers/account.py | 36 +--
routers/api.py | 1 +
routers/paypal.py | 8 +-
scripts/migrate_billed_org_for_saved_runs.py | 18 --
...migrate_billed_workspace_for_saved_runs.py | 23 ++
scripts/migrate_orgs_from_appusers.py | 50 ----
scripts/migrate_workspace_from_appusers.py | 52 ++++
templates/monthly_budget_reached_email.html | 8 +-
...spending_notification_threshold_email.html | 6 +-
.../org_invitation_auto_accepted_email.html | 10 +-
templates/org_invitation_email.html | 4 +-
{orgs => workspaces}/__init__.py | 0
workspaces/admin.py | 155 +++++++++++
{orgs => workspaces}/apps.py | 4 +-
.../migrations/0001_initial.py | 55 ++--
.../migrations/0002_alter_workspace_logo.py | 19 ++
{orgs => workspaces}/migrations/__init__.py | 0
{orgs => workspaces}/models.py | 223 +++++++++-------
workspaces/signals.py | 50 ++++
{orgs => workspaces}/tasks.py | 24 +-
{orgs => workspaces}/tests.py | 0
{orgs => workspaces}/views.py | 244 ++++++++++--------
43 files changed, 958 insertions(+), 885 deletions(-)
rename app_users/migrations/{0020_appusertransaction_org_alter_appusertransaction_user.py => 0020_appusertransaction_workspace_and_more.py} (75%)
rename bots/migrations/{0082_savedrun_billed_org.py => 0082_savedrun_billed_workspace.py} (74%)
delete mode 100644 orgs/admin.py
delete mode 100644 orgs/migrations/0002_alter_org_unique_together_and_more.py
delete mode 100644 orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py
delete mode 100644 orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py
delete mode 100644 orgs/migrations/0005_org_unique_personal_org_per_user.py
delete mode 100644 orgs/signals.py
delete mode 100644 scripts/migrate_billed_org_for_saved_runs.py
create mode 100644 scripts/migrate_billed_workspace_for_saved_runs.py
delete mode 100644 scripts/migrate_orgs_from_appusers.py
create mode 100644 scripts/migrate_workspace_from_appusers.py
rename {orgs => workspaces}/__init__.py (100%)
create mode 100644 workspaces/admin.py
rename {orgs => workspaces}/apps.py (74%)
rename {orgs => workspaces}/migrations/0001_initial.py (56%)
create mode 100644 workspaces/migrations/0002_alter_workspace_logo.py
rename {orgs => workspaces}/migrations/__init__.py (100%)
rename {orgs => workspaces}/models.py (69%)
create mode 100644 workspaces/signals.py
rename {orgs => workspaces}/tasks.py (67%)
rename {orgs => workspaces}/tests.py (100%)
rename {orgs => workspaces}/views.py (64%)
diff --git a/app_users/admin.py b/app_users/admin.py
index f433f86d2..ba05b10e1 100644
--- a/app_users/admin.py
+++ b/app_users/admin.py
@@ -216,7 +216,7 @@ class AppUserTransactionAdmin(admin.ModelAdmin):
autocomplete_fields = ["user"]
list_display = [
"invoice_id",
- "org",
+ "workspace",
"user",
"amount",
"dollar_amount",
diff --git a/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py b/app_users/migrations/0020_appusertransaction_workspace_and_more.py
similarity index 75%
rename from app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py
rename to app_users/migrations/0020_appusertransaction_workspace_and_more.py
index b3e80c708..43b2d32d2 100644
--- a/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py
+++ b/app_users/migrations/0020_appusertransaction_workspace_and_more.py
@@ -1,4 +1,4 @@
-# Generated by Django 4.2.7 on 2024-08-13 14:34
+# Generated by Django 4.2.7 on 2024-09-02 14:07
from django.db import migrations, models
import django.db.models.deletion
@@ -7,15 +7,15 @@
class Migration(migrations.Migration):
dependencies = [
- ('orgs', '0005_org_unique_personal_org_per_user'),
+ ('workspaces', '0001_initial'),
('app_users', '0019_alter_appusertransaction_reason'),
]
operations = [
migrations.AddField(
model_name='appusertransaction',
- name='org',
- field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='transactions', to='orgs.org'),
+ name='workspace',
+ field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='transactions', to='workspaces.workspace'),
),
migrations.AlterField(
model_name='appusertransaction',
diff --git a/app_users/models.py b/app_users/models.py
index 68193c441..66ce232f3 100644
--- a/app_users/models.py
+++ b/app_users/models.py
@@ -16,7 +16,7 @@
from payments.plans import PricingPlan
if typing.TYPE_CHECKING:
- from orgs.models import Org
+ from workspaces.models import Workspace
class AppUserQuerySet(models.QuerySet):
@@ -249,16 +249,13 @@ def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser":
return self
- def get_or_create_personal_org(self) -> tuple["Org", bool]:
- from orgs.models import Org, OrgMembership
+ def get_or_create_personal_workspace(self) -> tuple["Workspace", bool]:
+ from workspaces.models import Workspace
- org_membership: OrgMembership | None = self.org_memberships.filter(
- org__is_personal=True, org__created_by=self
- ).first()
- if org_membership:
- return org_membership.org, False
- else:
- return Org.objects.migrate_from_appuser(self), True
+ try:
+ return Workspace.objects.get(is_personal=True, created_by=self), False
+ except Workspace.DoesNotExist:
+ return Workspace.objects.migrate_from_appuser(self), True
def get_or_create_stripe_customer(self) -> stripe.Customer:
customer = self.search_stripe_customer()
@@ -322,8 +319,8 @@ class AppUserTransaction(models.Model):
related_name="transactions",
null=True,
)
- org = models.ForeignKey(
- "orgs.Org",
+ workspace = models.ForeignKey(
+ "workspaces.Workspace",
on_delete=models.SET_NULL,
related_name="transactions",
null=True,
diff --git a/app_users/tasks.py b/app_users/tasks.py
index b1d893196..9bd85eb7a 100644
--- a/app_users/tasks.py
+++ b/app_users/tasks.py
@@ -3,16 +3,16 @@
from app_users.models import PaymentProvider, TransactionReason
from celeryapp.celeryconfig import app
-from payments.models import Subscription
from payments.plans import PricingPlan
-from payments.webhooks import set_org_subscription
+from payments.webhooks import set_workspace_subscription
+from workspaces.models import Workspace
@app.task
def save_stripe_default_payment_method(
*,
+ workspace_id_or_uid: int | str,
payment_intent_id: str,
- org_id: str,
amount: int,
charged_amount: int,
reason: TransactionReason,
@@ -36,16 +36,24 @@ def save_stripe_default_payment_method(
invoice_settings=dict(default_payment_method=pm),
)
- # if user doesn't already have a active billing/autorecharge info, so we don't need to do anything
- # set user's subscription to the free plan
- if (
- reason == TransactionReason.ADDON
- and not Subscription.objects.filter(
- org__org_id=org_id, payment_provider__isnull=False
- ).exists()
- ):
- set_org_subscription(
- org_id=org_id,
+ # if user already has a subscription with payment info, we do nothing
+ # otherwise, we set the user's subscription to the free plan
+ if reason == TransactionReason.ADDON:
+ try:
+ workspace = Workspace.objects.select_related("subscription").get(
+ int(workspace_id_or_uid)
+ )
+ except (ValueError, Workspace.DoesNotExist):
+ workspace, _ = Workspace.objects.get_or_create_from_uid(workspace_id_or_uid)
+
+ if workspace.subscription and (
+ workspace.subscription.is_paid() or workspace.subscription.payment_provider
+ ):
+ # already has a subscription
+ return
+
+ set_workspace_subscription(
+ workspace_id_or_uid=workspace.id,
plan=PricingPlan.STARTER,
provider=PaymentProvider.STRIPE,
external_id=None,
diff --git a/bots/migrations/0082_savedrun_billed_org.py b/bots/migrations/0082_savedrun_billed_workspace.py
similarity index 74%
rename from bots/migrations/0082_savedrun_billed_org.py
rename to bots/migrations/0082_savedrun_billed_workspace.py
index 208f46dcc..502c15269 100644
--- a/bots/migrations/0082_savedrun_billed_org.py
+++ b/bots/migrations/0082_savedrun_billed_workspace.py
@@ -1,4 +1,4 @@
-# Generated by Django 4.2.7 on 2024-08-30 08:10
+# Generated by Django 4.2.7 on 2024-09-02 14:08
from django.db import migrations, models
import django.db.models.deletion
@@ -7,14 +7,14 @@
class Migration(migrations.Migration):
dependencies = [
- ('orgs', '0005_org_unique_personal_org_per_user'),
+ ('workspaces', '0001_initial'),
('bots', '0081_remove_conversation_bots_conver_bot_int_73ac7b_idx_and_more'),
]
operations = [
migrations.AddField(
model_name='savedrun',
- name='billed_org',
- field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='billed_runs', to='orgs.org'),
+ name='billed_workspace',
+ field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='billed_runs', to='workspaces.workspace'),
),
]
diff --git a/bots/models.py b/bots/models.py
index a6163ee1c..fcdd345cc 100644
--- a/bots/models.py
+++ b/bots/models.py
@@ -212,8 +212,8 @@ class SavedRun(models.Model):
)
run_id = models.CharField(max_length=128, default=None, null=True, blank=True)
uid = models.CharField(max_length=128, default=None, null=True, blank=True)
- billed_org = models.ForeignKey(
- "orgs.Org",
+ billed_workspace = models.ForeignKey(
+ "workspaces.Workspace",
on_delete=models.SET_NULL,
null=True,
blank=True,
diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py
index c3ae75b52..d3a54549b 100644
--- a/celeryapp/tasks.py
+++ b/celeryapp/tasks.py
@@ -29,6 +29,10 @@
run_auto_recharge_gracefully,
)
+if typing.TYPE_CHECKING:
+ from workspaces.models import Workspace
+
+
DEFAULT_RUN_STATUS = "Running..."
@@ -121,15 +125,14 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False
@app.task
def post_runner_tasks(saved_run_id: int):
sr = SavedRun.objects.get(id=saved_run_id)
- user = AppUser.objects.get(uid=sr.uid)
if not sr.is_api_call:
send_email_on_completion(sr)
- if should_attempt_auto_recharge(user):
- run_auto_recharge_gracefully(user)
+ if should_attempt_auto_recharge(sr.billed_workspace):
+ run_auto_recharge_gracefully(sr.billed_workspace)
- run_low_balance_email_check(user)
+ run_low_balance_email_check(sr.billed_workspace)
def err_msg_for_exc(e: Exception):
@@ -158,15 +161,18 @@ def err_msg_for_exc(e: Exception):
return f"{type(e).__name__}: {e}"
-def run_low_balance_email_check(user: AppUser):
+def run_low_balance_email_check(workspace: Workspace):
# don't send email if feature is disabled
if not settings.LOW_BALANCE_EMAIL_ENABLED:
return
# don't send email if user is not paying or has enough balance
- if not user.is_paying or user.balance > settings.LOW_BALANCE_EMAIL_CREDITS:
+ if (
+ not workspace.is_paying
+ or workspace.balance > settings.LOW_BALANCE_EMAIL_CREDITS
+ ):
return
last_purchase = (
- AppUserTransaction.objects.filter(user=user, amount__gt=0)
+ AppUserTransaction.objects.filter(workspace=workspace, amount__gt=0)
.order_by("-created_at")
.first()
)
@@ -176,22 +182,27 @@ def run_low_balance_email_check(user: AppUser):
# send email if user has not been sent email in last X days or last purchase was after last email sent
if (
# user has not been sent any email
- not user.low_balance_email_sent_at
+ not workspace.low_balance_email_sent_at
# user was sent email before X days
- or (user.low_balance_email_sent_at < email_date_cutoff)
+ or (workspace.low_balance_email_sent_at < email_date_cutoff)
# user has made a purchase after last email sent
- or (last_purchase and last_purchase.created_at > user.low_balance_email_sent_at)
+ or (
+ last_purchase
+ and last_purchase.created_at > workspace.low_balance_email_sent_at
+ )
):
# calculate total credits consumed in last X days
total_credits_consumed = abs(
AppUserTransaction.objects.filter(
- user=user, amount__lt=0, created_at__gte=email_date_cutoff
+ workspace=workspace, amount__lt=0, created_at__gte=email_date_cutoff
).aggregate(Sum("amount"))["amount__sum"]
or 0
)
- send_low_balance_email(user=user, total_credits_consumed=total_credits_consumed)
- user.low_balance_email_sent_at = timezone.now()
- user.save(update_fields=["low_balance_email_sent_at"])
+ send_low_balance_email(
+ workspace=workspace, total_credits_consumed=total_credits_consumed
+ )
+ workspace.low_balance_email_sent_at = timezone.now()
+ workspace.save(update_fields=["low_balance_email_sent_at"])
def send_email_on_completion(sr: SavedRun):
diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py
index f37a284bb..a7c61367d 100644
--- a/daras_ai_v2/base.py
+++ b/daras_ai_v2/base.py
@@ -83,6 +83,10 @@
from routers.account import AccountTabs
from routers.root import RecipeTabs
+if typing.TYPE_CHECKING:
+ from workspaces.models import Workspace
+
+
DEFAULT_META_IMG = (
# Small
"https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ec2100aa-1f6e-11ef-ba0b-02420a000159/Main.jpg"
@@ -571,7 +575,7 @@ def _render_publish_modal(
gui.radio(
"",
options=[
- 'Anyone at my org (coming soon)'
+ 'Anyone at my workspace (coming soon)'
],
disabled=True,
checked_by_default=False,
@@ -1204,6 +1208,12 @@ def create_published_run(
visibility=visibility,
)
+ def get_current_workspace(self) -> "Workspace":
+ assert self.request.user
+
+ workspace, _ = self.request.user.get_or_create_personal_workspace()
+ return workspace
+
def duplicate_published_run(
self,
published_run: PublishedRun,
@@ -1599,7 +1609,9 @@ def submit_and_redirect(self):
def on_submit(self):
try:
- sr = self.create_new_run(enable_rate_limits=True)
+ sr = self.create_new_run(
+ enable_rate_limits=True, billed_workspace=self.get_current_workspace()
+ )
except ValidationError as e:
gui.session_state[StateKeys.run_status] = None
gui.session_state[StateKeys.error_msg] = str(e)
@@ -1612,7 +1624,7 @@ def on_submit(self):
return sr
def should_submit_after_login(self) -> bool:
- return (
+ return bool(
gui.get_query_params().get(SUBMIT_AFTER_LOGIN_Q)
and self.request
and self.request.user
@@ -2084,18 +2096,18 @@ def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict):
assert self.request, "request must be set to check credits"
assert self.request.user, "request.user must be set to check credits"
- user = self.request.user
price = self.get_price_roundoff(state)
+ workspace, _ = self.request.user.get_or_create_personal_workspace()
- if user.balance >= price:
+ if workspace.balance >= price:
return
- if should_attempt_auto_recharge(user):
+ if should_attempt_auto_recharge(workspace):
yield "Low balance detected. Recharging..."
- run_auto_recharge_gracefully(user)
- user.refresh_from_db()
+ run_auto_recharge_gracefully(workspace)
+ workspace.refresh_from_db()
- if user.balance >= price:
+ if workspace.balance >= price:
return
raise InsufficientCredits(self.request.user, sr)
@@ -2106,8 +2118,8 @@ def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]:
), "request.user must be set to deduct credits"
amount = self.get_price_roundoff(state)
- org, _ = self.request.user.get_or_create_personal_org()
- txn = org.add_balance(-amount, f"gooey_in_{uuid.uuid1()}")
+ workspace, _ = self.request.user.get_or_create_personal_workspace()
+ txn = workspace.add_balance(-amount, f"gooey_in_{uuid.uuid1()}")
return txn, amount
def get_price_roundoff(self, state: dict) -> int:
@@ -2204,7 +2216,7 @@ def get_cost_note(self) -> str | None:
@classmethod
def is_user_admin(cls, user: AppUser) -> bool:
email = user.email
- return email and email in settings.ADMIN_EMAILS
+ return bool(email and email in settings.ADMIN_EMAILS)
def is_current_user_admin(self) -> bool:
if not self.request or not self.request.user:
diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py
index 7c99a53bf..d8b3bd44b 100644
--- a/daras_ai_v2/billing.py
+++ b/daras_ai_v2/billing.py
@@ -13,41 +13,41 @@
from daras_ai_v2.user_date_widgets import render_local_date_attrs
from payments.models import PaymentMethodSummary
from payments.plans import PricingPlan
-from payments.webhooks import StripeWebhookHandler, set_org_subscription
+from payments.webhooks import StripeWebhookHandler, set_workspace_subscription
from scripts.migrate_existing_subscriptions import available_subscriptions
if typing.TYPE_CHECKING:
- from orgs.models import Org
+ from workspaces.models import Workspace
rounded_border = "w-100 border shadow-sm rounded py-4 px-3"
-def billing_page(org: "Org"):
+def billing_page(workspace: "Workspace"):
render_payments_setup()
- if org.subscription and org.subscription.is_paid():
- render_current_plan(org)
+ if workspace.subscription and workspace.subscription.is_paid():
+ render_current_plan(workspace)
with gui.div(className="my-5"):
- render_credit_balance(org)
+ render_credit_balance(workspace)
with gui.div(className="my-5"):
- selected_payment_provider = render_all_plans(org)
+ selected_payment_provider = render_all_plans(workspace)
with gui.div(className="my-5"):
- render_addon_section(org, selected_payment_provider)
+ render_addon_section(workspace, selected_payment_provider)
- if org.subscription:
- if org.subscription.payment_provider == PaymentProvider.STRIPE:
+ if workspace.subscription:
+ if workspace.subscription.payment_provider == PaymentProvider.STRIPE:
with gui.div(className="my-5"):
- render_auto_recharge_section(org)
+ render_auto_recharge_section(workspace)
with gui.div(className="my-5"):
- render_payment_information(org)
+ render_payment_information(workspace)
with gui.div(className="my-5"):
- render_billing_history(org)
+ render_billing_history(workspace)
def render_payments_setup():
@@ -61,10 +61,10 @@ def render_payments_setup():
)
-def render_current_plan(org: "Org"):
- plan = PricingPlan.from_sub(org.subscription)
- if org.subscription.payment_provider:
- provider = PaymentProvider(org.subscription.payment_provider)
+def render_current_plan(workspace: "Workspace"):
+ plan = PricingPlan.from_sub(workspace.subscription)
+ if workspace.subscription.payment_provider:
+ provider = PaymentProvider(workspace.subscription.payment_provider)
else:
provider = None
@@ -82,7 +82,7 @@ def render_current_plan(org: "Org"):
with right, gui.div(className="d-flex align-items-center gap-1"):
if provider and (
next_invoice_ts := gui.run_in_thread(
- org.subscription.get_next_invoice_timestamp, cache=True
+ workspace.subscription.get_next_invoice_timestamp, cache=True
)
):
gui.html("Next invoice on ")
@@ -118,17 +118,17 @@ def render_current_plan(org: "Org"):
)
-def render_credit_balance(org: "Org"):
- gui.write(f"## Credit Balance: {org.balance:,}")
+def render_credit_balance(workspace: "Workspace"):
+ gui.write(f"## Credit Balance: {workspace.balance:,}")
gui.caption(
"Every time you submit a workflow or make an API call, we deduct credits from your account."
)
-def render_all_plans(org: "Org") -> PaymentProvider:
+def render_all_plans(workspace: "Workspace") -> PaymentProvider:
current_plan = (
- PricingPlan.from_sub(org.subscription)
- if org.subscription
+ PricingPlan.from_sub(workspace.subscription)
+ if workspace.subscription
else PricingPlan.STARTER
)
all_plans = [plan for plan in PricingPlan if not plan.deprecated]
@@ -136,8 +136,8 @@ def render_all_plans(org: "Org") -> PaymentProvider:
gui.write("## All Plans")
plans_div = gui.div(className="mb-1")
- if org.subscription and org.subscription.payment_provider:
- selected_payment_provider = org.subscription.payment_provider
+ if workspace.subscription and workspace.subscription.payment_provider:
+ selected_payment_provider = workspace.subscription.payment_provider
else:
with gui.div():
selected_payment_provider = PaymentProvider[
@@ -155,7 +155,7 @@ def _render_plan(plan: PricingPlan):
):
_render_plan_details(plan)
_render_plan_action_button(
- org=org,
+ workspace=workspace,
plan=plan,
current_plan=current_plan,
payment_provider=selected_payment_provider,
@@ -193,7 +193,7 @@ def _render_plan_details(plan: PricingPlan):
def _render_plan_action_button(
- org: "Org",
+ workspace: "Workspace",
plan: PricingPlan,
current_plan: PricingPlan,
payment_provider: PaymentProvider | None,
@@ -207,10 +207,13 @@ def _render_plan_action_button(
className=btn_classes + " btn btn-theme btn-primary",
):
gui.html("Contact Us")
- elif org.subscription and org.subscription.plan == PricingPlan.ENTERPRISE.db_value:
+ elif (
+ workspace.subscription
+ and workspace.subscription.plan == PricingPlan.ENTERPRISE.db_value
+ ):
# don't show upgrade/downgrade buttons for enterprise customers
return
- elif org.subscription and org.subscription.is_paid():
+ elif workspace.subscription and workspace.subscription.is_paid():
# subscription exists, show upgrade/downgrade button
if plan.credits > current_plan.credits:
modal, confirmed = confirm_modal(
@@ -232,7 +235,7 @@ def _render_plan_action_button(
modal.open()
if confirmed:
change_subscription(
- org,
+ workspace,
plan,
# when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time
billing_cycle_anchor="now",
@@ -254,11 +257,11 @@ def _render_plan_action_button(
):
modal.open()
if confirmed:
- change_subscription(org, plan)
+ change_subscription(workspace, plan)
else:
assert payment_provider is not None # for sanity
_render_create_subscription_button(
- org=org,
+ workspace=workspace,
plan=plan,
payment_provider=payment_provider,
)
@@ -266,13 +269,13 @@ def _render_plan_action_button(
def _render_create_subscription_button(
*,
- org: "Org",
+ workspace: "Workspace",
plan: PricingPlan,
payment_provider: PaymentProvider,
):
match payment_provider:
case PaymentProvider.STRIPE:
- render_stripe_subscription_button(org=org, plan=plan)
+ render_stripe_subscription_button(workspace=workspace, plan=plan)
case PaymentProvider.PAYPAL:
render_paypal_subscription_button(plan=plan)
@@ -284,27 +287,29 @@ def fmt_price(plan: PricingPlan) -> str:
return "Free"
-def change_subscription(org: "Org", new_plan: PricingPlan, **kwargs):
+def change_subscription(workspace: "Workspace", new_plan: PricingPlan, **kwargs):
from routers.account import account_route
from routers.account import payment_processing_route
- current_plan = PricingPlan.from_sub(org.subscription)
+ current_plan = PricingPlan.from_sub(workspace.subscription)
if new_plan == current_plan:
raise gui.RedirectException(get_app_route_url(account_route), status_code=303)
if new_plan == PricingPlan.STARTER:
- org.subscription.cancel()
+ workspace.subscription.cancel()
raise gui.RedirectException(
get_app_route_url(payment_processing_route), status_code=303
)
- match org.subscription.payment_provider:
+ match workspace.subscription.payment_provider:
case PaymentProvider.STRIPE:
if not new_plan.supports_stripe():
gui.error(f"Stripe subscription not available for {new_plan}")
- subscription = stripe.Subscription.retrieve(org.subscription.external_id)
+ subscription = stripe.Subscription.retrieve(
+ workspace.subscription.external_id
+ )
stripe.Subscription.modify(
subscription.id,
items=[
@@ -325,7 +330,9 @@ def change_subscription(org: "Org", new_plan: PricingPlan, **kwargs):
if not new_plan.supports_paypal():
gui.error(f"Paypal subscription not available for {new_plan}")
- subscription = paypal.Subscription.retrieve(org.subscription.external_id)
+ subscription = paypal.Subscription.retrieve(
+ workspace.subscription.external_id
+ )
paypal_plan_info = new_plan.get_paypal_plan()
approval_url = subscription.update_plan(
plan_id=paypal_plan_info["plan_id"],
@@ -348,20 +355,22 @@ def payment_provider_radio(**props) -> str | None:
)
-def render_addon_section(org: "Org", selected_payment_provider: PaymentProvider):
- if org.subscription:
+def render_addon_section(
+ workspace: "Workspace", selected_payment_provider: PaymentProvider
+):
+ if workspace.subscription:
gui.write("# Purchase More Credits")
else:
gui.write("# Purchase Credits")
gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits")
- if org.subscription and org.subscription.payment_provider:
- provider = PaymentProvider(org.subscription.payment_provider)
+ if workspace.subscription and workspace.subscription.payment_provider:
+ provider = PaymentProvider(workspace.subscription.payment_provider)
else:
provider = selected_payment_provider
match provider:
case PaymentProvider.STRIPE:
- render_stripe_addon_buttons(org)
+ render_stripe_addon_buttons(workspace)
case PaymentProvider.PAYPAL:
render_paypal_addon_buttons()
@@ -385,8 +394,8 @@ def render_paypal_addon_buttons():
gui.div(id="paypal-result-message")
-def render_stripe_addon_buttons(org: "Org"):
- if not (org.subscription and org.subscription.payment_provider):
+def render_stripe_addon_buttons(workspace: "Workspace"):
+ if not (workspace.subscription and workspace.subscription.payment_provider):
save_pm = gui.checkbox(
"Save payment method for future purchases & auto-recharge", value=True
)
@@ -394,10 +403,10 @@ def render_stripe_addon_buttons(org: "Org"):
save_pm = True
for dollat_amt in settings.ADDON_AMOUNT_CHOICES:
- render_stripe_addon_button(dollat_amt, org, save_pm)
+ render_stripe_addon_button(dollat_amt, workspace, save_pm)
-def render_stripe_addon_button(dollat_amt: int, org: "Org", save_pm: bool):
+def render_stripe_addon_button(dollat_amt: int, workspace: "Workspace", save_pm: bool):
modal, confirmed = confirm_modal(
title="Purchase Credits",
key=f"--addon-modal-{dollat_amt}",
@@ -411,14 +420,17 @@ def render_stripe_addon_button(dollat_amt: int, org: "Org", save_pm: bool):
)
if gui.button(f"${dollat_amt:,}", type="primary"):
- if org.subscription and org.subscription.stripe_get_default_payment_method():
+ if (
+ workspace.subscription
+ and workspace.subscription.stripe_get_default_payment_method()
+ ):
modal.open()
else:
- stripe_addon_checkout_redirect(org, dollat_amt, save_pm)
+ stripe_addon_checkout_redirect(workspace, dollat_amt, save_pm)
if confirmed:
success = gui.run_in_thread(
- org.subscription.stripe_attempt_addon_purchase,
+ workspace.subscription.stripe_attempt_addon_purchase,
args=[dollat_amt],
placeholder="",
)
@@ -429,10 +441,12 @@ def render_stripe_addon_button(dollat_amt: int, org: "Org", save_pm: bool):
modal.close()
else:
# fallback to stripe checkout flow if the auto payment failed
- stripe_addon_checkout_redirect(org, dollat_amt, save_pm)
+ stripe_addon_checkout_redirect(workspace, dollat_amt, save_pm)
-def stripe_addon_checkout_redirect(org: "Org", dollat_amt: int, save_pm: bool):
+def stripe_addon_checkout_redirect(
+ workspace: "Workspace", dollat_amt: int, save_pm: bool
+):
from routers.account import account_route
from routers.account import payment_processing_route
@@ -448,7 +462,7 @@ def stripe_addon_checkout_redirect(org: "Org", dollat_amt: int, save_pm: bool):
mode="payment",
success_url=get_app_route_url(payment_processing_route),
cancel_url=get_app_route_url(account_route),
- customer=org.get_or_create_stripe_customer(),
+ customer=workspace.get_or_create_stripe_customer(),
invoice_creation={"enabled": True},
allow_promotion_codes=True,
**kwargs,
@@ -458,7 +472,7 @@ def stripe_addon_checkout_redirect(org: "Org", dollat_amt: int, save_pm: bool):
def render_stripe_subscription_button(
*,
- org: "Org",
+ workspace: "Workspace",
plan: PricingPlan,
):
if not plan.supports_stripe():
@@ -486,30 +500,33 @@ def render_stripe_subscription_button(
key=f"--change-sub-{plan.key}",
type="primary",
):
- if org.subscription and org.subscription.stripe_get_default_payment_method():
+ if (
+ workspace.subscription
+ and workspace.subscription.stripe_get_default_payment_method()
+ ):
modal.open()
else:
- stripe_subscription_create(org=org, plan=plan)
+ stripe_subscription_create(workspace=workspace, plan=plan)
if confirmed:
- stripe_subscription_create(org=org, plan=plan)
+ stripe_subscription_create(workspace=workspace, plan=plan)
-def stripe_subscription_create(org: "Org", plan: PricingPlan):
+def stripe_subscription_create(workspace: "Workspace", plan: PricingPlan):
from routers.account import account_route
from routers.account import payment_processing_route
- if org.subscription and org.subscription.is_paid():
+ if workspace.subscription and workspace.subscription.is_paid():
# sanity check: already subscribed to some plan
gui.rerun()
# check for existing subscriptions on stripe
- customer = org.get_or_create_stripe_customer()
+ customer = workspace.get_or_create_stripe_customer()
for sub in stripe.Subscription.list(
customer=customer, status="active", limit=1
).data:
StripeWebhookHandler.handle_subscription_updated(
- org_id=org.org_id, stripe_sub=sub
+ workspace_id_or_uid=workspace.id, stripe_sub=sub
)
raise gui.RedirectException(
get_app_route_url(payment_processing_route), status_code=303
@@ -517,7 +534,10 @@ def stripe_subscription_create(org: "Org", plan: PricingPlan):
# try to directly create the subscription without checkout
metadata = {settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: plan.key}
- pm = org.subscription and org.subscription.stripe_get_default_payment_method()
+ pm = (
+ workspace.subscription
+ and workspace.subscription.stripe_get_default_payment_method()
+ )
line_items = [plan.get_stripe_line_item()]
if pm:
sub = stripe.Subscription.create(
@@ -567,12 +587,12 @@ def render_paypal_subscription_button(
)
-def render_payment_information(org: "Org"):
- if not org.subscription:
+def render_payment_information(workspace: "Workspace"):
+ if not workspace.subscription:
return
pm_summary = gui.run_in_thread(
- org.subscription.get_payment_method_summary, cache=True
+ workspace.subscription.get_payment_method_summary, cache=True
)
if not pm_summary:
return
@@ -584,7 +604,7 @@ def render_payment_information(org: "Org"):
gui.write("**Pay via**")
with col2:
provider = PaymentProvider(
- org.subscription.payment_provider or PaymentProvider.STRIPE
+ workspace.subscription.payment_provider or PaymentProvider.STRIPE
)
gui.write(provider.label)
with col3:
@@ -592,7 +612,7 @@ def render_payment_information(org: "Org"):
f"{icons.edit} Edit", type="link", key="manage-payment-provider"
):
raise gui.RedirectException(
- org.subscription.get_external_management_url()
+ workspace.subscription.get_external_management_url()
)
pm_summary = PaymentMethodSummary(*pm_summary)
@@ -612,7 +632,7 @@ def render_payment_information(org: "Org"):
if gui.button(
f"{icons.edit} Edit", type="link", key="edit-payment-method"
):
- change_payment_method(org)
+ change_payment_method(workspace)
if pm_summary.billing_email:
col1, col2, _ = gui.columns(3, responsive=False)
@@ -640,13 +660,16 @@ def render_payment_information(org: "Org"):
):
modal.open()
if confirmed:
- set_org_subscription(
- org_id=org.org_id,
+ set_workspace_subscription(
+ workspace_id_or_uid=workspace.id,
plan=PricingPlan.STARTER,
provider=None,
external_id=None,
)
- pm = org.subscription and org.subscription.stripe_get_default_payment_method()
+ pm = (
+ workspace.subscription
+ and workspace.subscription.stripe_get_default_payment_method()
+ )
if pm:
pm.detach()
raise gui.RedirectException(
@@ -654,18 +677,18 @@ def render_payment_information(org: "Org"):
)
-def change_payment_method(org: "Org"):
+def change_payment_method(workspace: "Workspace"):
from routers.account import payment_processing_route
from routers.account import account_route
- match org.subscription.payment_provider:
+ match workspace.subscription.payment_provider:
case PaymentProvider.STRIPE:
session = stripe.checkout.Session.create(
mode="setup",
currency="usd",
- customer=org.get_or_create_stripe_customer(),
+ customer=workspace.get_or_create_stripe_customer(),
setup_intent_data={
- "metadata": {"subscription_id": org.subscription.external_id},
+ "metadata": {"subscription_id": workspace.subscription.external_id},
},
success_url=get_app_route_url(payment_processing_route),
cancel_url=get_app_route_url(account_route),
@@ -679,11 +702,11 @@ def format_card_brand(brand: str) -> str:
return icons.card_icons.get(brand.lower(), brand.capitalize())
-def render_billing_history(org: "Org", limit: int = 50):
+def render_billing_history(workspace: "Workspace", limit: int = 50):
import pandas as pd
txns = AppUserTransaction.objects.filter(
- org=org,
+ workspace=workspace,
amount__gt=0,
).order_by("-created_at")
if not txns:
@@ -708,9 +731,9 @@ def render_billing_history(org: "Org", limit: int = 50):
gui.caption(f"Showing only the most recent {limit} transactions.")
-def render_auto_recharge_section(org: "Org"):
- assert org.subscription
- subscription = org.subscription
+def render_auto_recharge_section(workspace: "Workspace"):
+ assert workspace.subscription
+ subscription = workspace.subscription
gui.write("## Auto Recharge & Limits")
with gui.div(className="h4"):
diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py
index 2262624e7..11799f86a 100644
--- a/daras_ai_v2/send_email.py
+++ b/daras_ai_v2/send_email.py
@@ -11,6 +11,7 @@
if typing.TYPE_CHECKING:
from app_users.models import AppUser
+ from workspaces.models import Workspace
def send_reported_run_email(
@@ -44,25 +45,26 @@ def send_reported_run_email(
def send_low_balance_email(
*,
- user: "AppUser",
+ workspace: "Workspace",
total_credits_consumed: int,
):
from routers.account import account_route
recipeints = "support@gooey.ai, devs@gooey.ai"
- html_body = templates.get_template("low_balance_email.html").render(
- user=user,
- url=get_app_route_url(account_route),
- total_credits_consumed=total_credits_consumed,
- settings=settings,
- )
- send_email_via_postmark(
- from_address=settings.SUPPORT_EMAIL,
- to_address=user.email or recipeints,
- bcc=recipeints,
- subject="Your Gooey.AI credit balance is low",
- html_body=html_body,
- )
+ for owner in workspace.get_owners():
+ html_body = templates.get_template("low_balance_email.html").render(
+ user=owner.user,
+ url=get_app_route_url(account_route),
+ total_credits_consumed=total_credits_consumed,
+ settings=settings,
+ )
+ send_email_via_postmark(
+ from_address=settings.SUPPORT_EMAIL,
+ to_address=owner.user.email or recipeints,
+ bcc=recipeints,
+ subject="Your Gooey.AI credit balance is low",
+ html_body=html_body,
+ )
is_running_pytest = "pytest" in sys.modules
diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py
index c73de7223..9bd1d419d 100644
--- a/daras_ai_v2/settings.py
+++ b/daras_ai_v2/settings.py
@@ -63,7 +63,7 @@
"handles",
"payments",
"functions",
- "orgs",
+ "workspaces",
]
MIDDLEWARE = [
@@ -288,7 +288,7 @@
EMAIL_USER_FREE_CREDITS = config("EMAIL_USER_FREE_CREDITS", 0, cast=int)
ANON_USER_FREE_CREDITS = config("ANON_USER_FREE_CREDITS", 25, cast=int)
LOGIN_USER_FREE_CREDITS = config("LOGIN_USER_FREE_CREDITS", 500, cast=int)
-FIRST_ORG_FREE_CREDITS = config("ORG_FREE_CREDITS", 500, cast=int)
+FIRST_WORKSPACE_FREE_CREDITS = config("WORKSPACE_FREE_CREDITS", 500, cast=int)
ADDON_CREDITS_PER_DOLLAR = config("ADDON_CREDITS_PER_DOLLAR", 100, cast=int)
ADDON_AMOUNT_CHOICES = [10, 30, 50, 100, 300, 500, 1000] # USD
@@ -399,9 +399,11 @@
TWILIO_API_KEY_SID = config("TWILIO_API_KEY_SID", "")
TWILIO_API_KEY_SECRET = config("TWILIO_API_KEY_SECRET", "")
-ORG_INVITATION_EXPIRY_DAYS = config("ORG_INVITATIONS_EXPIRY_IN_DAYS", 10, cast=int)
-ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL = config(
- "ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL", 60 * 60 * 24, cast=int # 24 hours
+WORKSPACE_INVITATION_EXPIRY_DAYS = config(
+ "WORKSPACE_INVITATIONS_EXPIRY_IN_DAYS", 10, cast=int
+)
+WORKSPACE_INVITATION_EMAIL_COOLDOWN_INTERVAL = config(
+ "WORKSPACE_INVITATION_EMAIL_COOLDOWN_INTERVAL", 60 * 60 * 24, cast=int # 24 hours
)
SCRAPING_PROXY_HOST = config("SCRAPING_PROXY_HOST", "")
diff --git a/orgs/admin.py b/orgs/admin.py
deleted file mode 100644
index 370ca4c4e..000000000
--- a/orgs/admin.py
+++ /dev/null
@@ -1,111 +0,0 @@
-from django.contrib import admin
-from safedelete.admin import SafeDeleteAdmin, SafeDeleteAdminFilter
-
-from bots.admin_links import change_obj_url
-from orgs.models import Org, OrgMembership, OrgInvitation
-
-
-class OrgMembershipInline(admin.TabularInline):
- model = OrgMembership
- extra = 0
- show_change_link = True
- fields = ["user", "role", "created_at", "updated_at"]
- readonly_fields = ["created_at", "updated_at"]
- ordering = ["-created_at"]
- can_delete = False
- show_change_link = True
-
-
-class OrgInvitationInline(admin.TabularInline):
- model = OrgInvitation
- extra = 0
- show_change_link = True
- fields = [
- "invitee_email",
- "inviter",
- "status",
- "auto_accepted",
- "created_at",
- "updated_at",
- ]
- readonly_fields = ["auto_accepted", "created_at", "updated_at"]
- ordering = ["status", "-created_at"]
- can_delete = False
- show_change_link = True
-
-
-@admin.register(Org)
-class OrgAdmin(SafeDeleteAdmin):
- list_display = [
- "name",
- "domain_name",
- "created_at",
- "updated_at",
- ] + list(SafeDeleteAdmin.list_display)
- list_filter = [SafeDeleteAdminFilter] + list(SafeDeleteAdmin.list_filter)
- fields = [
- "name",
- "domain_name",
- "created_by",
- "is_personal",
- "created_at",
- "updated_at",
- ]
- search_fields = ["name", "domain_name"]
- readonly_fields = ["is_personal", "created_at", "updated_at"]
- inlines = [OrgMembershipInline, OrgInvitationInline]
- ordering = ["-created_at"]
-
-
-@admin.register(OrgMembership)
-class OrgMembershipAdmin(SafeDeleteAdmin):
- list_display = [
- "user",
- "org",
- "role",
- "created_at",
- "updated_at",
- ] + list(SafeDeleteAdmin.list_display)
- list_filter = ["org", "role", SafeDeleteAdminFilter] + list(
- SafeDeleteAdmin.list_filter
- )
-
- def get_readonly_fields(
- self, request: "HttpRequest", obj: OrgMembership | None = None
- ) -> list[str]:
- readonly_fields = list(super().get_readonly_fields(request, obj))
- if obj and obj.org and obj.org.deleted:
- return readonly_fields + ["deleted_org"]
- else:
- return readonly_fields
-
- @admin.display
- def deleted_org(self, obj):
- org = Org.deleted_objects.get(pk=obj.org_id)
- return change_obj_url(org)
-
-
-@admin.register(OrgInvitation)
-class OrgInvitationAdmin(SafeDeleteAdmin):
- fields = [
- "org",
- "invitee_email",
- "inviter",
- "role",
- "status",
- "auto_accepted",
- "created_at",
- "updated_at",
- ]
- list_display = [
- "org",
- "invitee_email",
- "inviter",
- "status",
- "created_at",
- "updated_at",
- ] + list(SafeDeleteAdmin.list_display)
- list_filter = ["org", "inviter", "role", SafeDeleteAdminFilter] + list(
- SafeDeleteAdmin.list_filter
- )
- readonly_fields = ["auto_accepted"]
diff --git a/orgs/migrations/0002_alter_org_unique_together_and_more.py b/orgs/migrations/0002_alter_org_unique_together_and_more.py
deleted file mode 100644
index 2c5384d67..000000000
--- a/orgs/migrations/0002_alter_org_unique_together_and_more.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Generated by Django 4.2.7 on 2024-07-22 14:45
-
-from django.db import migrations, models
-
-
-class Migration(migrations.Migration):
-
- dependencies = [
- ('orgs', '0001_initial'),
- ]
-
- operations = [
- migrations.AlterUniqueTogether(
- name='org',
- unique_together=set(),
- ),
- migrations.AlterField(
- model_name='orginvitation',
- name='last_email_sent_at',
- field=models.DateTimeField(blank=True, default=None, null=True),
- ),
- migrations.AlterField(
- model_name='orginvitation',
- name='status_changed_at',
- field=models.DateTimeField(blank=True, default=None, null=True),
- ),
- migrations.AddConstraint(
- model_name='org',
- constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('domain_name',), name='unique_domain_name_when_not_deleted'),
- ),
- migrations.RemoveField(
- model_name='org',
- name='members',
- ),
- ]
diff --git a/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py b/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py
deleted file mode 100644
index 6047919f1..000000000
--- a/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Generated by Django 4.2.7 on 2024-07-23 11:45
-
-from django.db import migrations, models
-import django.db.models.deletion
-
-
-class Migration(migrations.Migration):
-
- dependencies = [
- ('app_users', '0019_alter_appusertransaction_reason'),
- ('orgs', '0002_alter_org_unique_together_and_more'),
- ]
-
- operations = [
- migrations.RemoveConstraint(
- model_name='org',
- name='unique_domain_name_when_not_deleted',
- ),
- migrations.AlterUniqueTogether(
- name='orgmembership',
- unique_together=set(),
- ),
- migrations.AlterField(
- model_name='orginvitation',
- name='status_changed_by',
- field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='received_invitations', to='app_users.appuser'),
- ),
- migrations.AddConstraint(
- model_name='org',
- constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('domain_name',), name='unique_domain_name_when_not_deleted', violation_error_message='This domain name is already in use by another team. Contact Gooey.AI Support if you think this is a mistake.'),
- ),
- migrations.AddConstraint(
- model_name='orgmembership',
- constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('org', 'user'), name='unique_org_user'),
- ),
- ]
diff --git a/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py b/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py
deleted file mode 100644
index 9d9fdfc5d..000000000
--- a/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Generated by Django 4.2.7 on 2024-08-12 14:23
-
-from django.db import migrations, models
-import django.db.models.deletion
-
-
-class Migration(migrations.Migration):
-
- dependencies = [
- ('payments', '0005_alter_subscription_plan'),
- ('orgs', '0003_remove_org_unique_domain_name_when_not_deleted_and_more'),
- ]
-
- operations = [
- migrations.AddField(
- model_name='org',
- name='balance',
- field=models.IntegerField(default=0, verbose_name='bal'),
- ),
- migrations.AddField(
- model_name='org',
- name='is_paying',
- field=models.BooleanField(default=False, verbose_name='paid'),
- ),
- migrations.AddField(
- model_name='org',
- name='is_personal',
- field=models.BooleanField(default=False),
- ),
- migrations.AddField(
- model_name='org',
- name='low_balance_email_sent_at',
- field=models.DateTimeField(blank=True, null=True),
- ),
- migrations.AddField(
- model_name='org',
- name='stripe_customer_id',
- field=models.CharField(blank=True, default='', max_length=255),
- ),
- migrations.AddField(
- model_name='org',
- name='subscription',
- field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='org', to='payments.subscription'),
- ),
- ]
diff --git a/orgs/migrations/0005_org_unique_personal_org_per_user.py b/orgs/migrations/0005_org_unique_personal_org_per_user.py
deleted file mode 100644
index aaaa1cc4d..000000000
--- a/orgs/migrations/0005_org_unique_personal_org_per_user.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Generated by Django 4.2.7 on 2024-08-13 14:34
-
-from django.db import migrations, models
-
-
-class Migration(migrations.Migration):
-
- dependencies = [
- ('orgs', '0004_org_balance_org_is_paying_org_is_personal_and_more'),
- ]
-
- operations = [
- migrations.AddConstraint(
- model_name='org',
- constraint=models.UniqueConstraint(models.F('created_by'), condition=models.Q(('deleted__isnull', True), ('is_personal', True)), name='unique_personal_org_per_user'),
- ),
- ]
diff --git a/orgs/signals.py b/orgs/signals.py
deleted file mode 100644
index bb23b7e06..000000000
--- a/orgs/signals.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from django.db.models.signals import post_save
-from django.dispatch import receiver
-from loguru import logger
-from safedelete.signals import post_softdelete
-
-from app_users.models import AppUser
-from orgs.models import Org, OrgMembership, OrgRole
-from orgs.tasks import send_auto_accepted_email
-
-
-@receiver(post_save, sender=AppUser)
-def add_user_existing_org(instance: AppUser, **kwargs):
- """
- if the domain name matches
- """
- if not instance.email:
- return
-
- email_domain = instance.email.split("@")[1]
- org = Org.objects.filter(domain_name=email_domain).first()
- if not org:
- return
-
- if instance.received_invitations.exists():
- # user has some existing invitations
- return
-
- org_owner = org.memberships.filter(role=OrgRole.OWNER).first()
- if not org_owner:
- logger.warning(
- f"Org {org} has no owner. Skipping auto-accept for user {instance}"
- )
- return
-
- invitation = org.invite_user(
- invitee_email=instance.email,
- inviter=org_owner.user,
- role=OrgRole.MEMBER,
- auto_accept=not instance.org_memberships.exists(), # auto-accept only if user has no existing memberships
- )
-
-
-@receiver(post_softdelete, sender=OrgMembership)
-def delete_org_if_no_members_left(instance: OrgMembership, **kwargs):
- if instance.org.memberships.exists():
- return
-
- logger.info(f"Deleting org {instance.org} because it has no members left")
- instance.org.delete()
diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py
index 3d07493b5..bc7934311 100644
--- a/payments/auto_recharge.py
+++ b/payments/auto_recharge.py
@@ -5,7 +5,7 @@
from app_users.models import PaymentProvider
from daras_ai_v2.redis_cache import redis_lock
-from orgs.models import Org
+from workspaces.models import Workspace
from payments.tasks import send_monthly_budget_reached_email
@@ -28,18 +28,18 @@ class AutoRechargeCooldownException(AutoRechargeException):
pass
-def should_attempt_auto_recharge(org: Org):
- return (
- org.subscription
- and org.subscription.auto_recharge_enabled
- and org.subscription.payment_provider
- and org.balance < org.subscription.auto_recharge_balance_threshold
+def should_attempt_auto_recharge(workspace: Workspace) -> bool:
+ return bool(
+ workspace.subscription
+ and workspace.subscription.auto_recharge_enabled
+ and workspace.subscription.payment_provider
+ and workspace.balance < workspace.subscription.auto_recharge_balance_threshold
)
-def run_auto_recharge_gracefully(org: Org):
+def run_auto_recharge_gracefully(workspace: Workspace):
"""
- Wrapper over _auto_recharge_org, that handles exceptions so that it can:
+ Wrapper over _auto_recharge_workspace, that handles exceptions so that it can:
- log exceptions
- send emails when auto-recharge fails
- not retry if this is run as a background task
@@ -47,49 +47,49 @@ def run_auto_recharge_gracefully(org: Org):
Meant to be used in conjunction with should_attempt_auto_recharge
"""
try:
- with redis_lock(f"gooey/auto_recharge_user/v1/{org.org_id}"):
- _auto_recharge_org(org)
+ with redis_lock(f"gooey/auto_recharge_user/v1/{workspace.id}"):
+ _auto_recharge_workspace(workspace)
except AutoRechargeCooldownException as e:
logger.info(
- f"Rejected auto-recharge because auto-recharge is in cooldown period for org"
- f"{org=}, {e=}"
+ f"Rejected auto-recharge because auto-recharge is in cooldown period for workspace"
+ f"{workspace=}, {e=}"
)
except MonthlyBudgetReachedException as e:
- send_monthly_budget_reached_email(org)
+ send_monthly_budget_reached_email(workspace)
logger.info(
f"Rejected auto-recharge because user has reached monthly budget"
- f"{org=}, spending=${e.spending}, budget=${e.budget}"
+ f"{workspace=}, spending=${e.spending}, budget=${e.budget}"
)
except Exception as e:
traceback.print_exc()
sentry_sdk.capture_exception(e)
-def _auto_recharge_org(org: Org):
+def _auto_recharge_workspace(workspace: Workspace):
"""
Returns whether a charge was attempted
"""
from payments.webhooks import StripeWebhookHandler
assert (
- org.subscription.payment_provider == PaymentProvider.STRIPE
+ workspace.subscription.payment_provider == PaymentProvider.STRIPE
), "Auto recharge is only supported with Stripe"
# check for monthly budget
- dollars_spent = org.get_dollars_spent_this_month()
+ dollars_spent = workspace.get_dollars_spent_this_month()
if (
- dollars_spent + org.subscription.auto_recharge_topup_amount
- > org.subscription.monthly_spending_budget
+ dollars_spent + workspace.subscription.auto_recharge_topup_amount
+ > workspace.subscription.monthly_spending_budget
):
raise MonthlyBudgetReachedException(
"Performing this top-up would exceed your monthly recharge budget",
- budget=org.subscription.monthly_spending_budget,
+ budget=workspace.subscription.monthly_spending_budget,
spending=dollars_spent,
)
try:
- invoice = org.subscription.stripe_get_or_create_auto_invoice(
- amount_in_dollars=org.subscription.auto_recharge_topup_amount,
+ invoice = workspace.subscription.stripe_get_or_create_auto_invoice(
+ amount_in_dollars=workspace.subscription.auto_recharge_topup_amount,
metadata_key="auto_recharge",
)
except Exception as e:
@@ -103,9 +103,11 @@ def _auto_recharge_org(org: Org):
# get default payment method and attempt payment
assert invoice.status == "open" # sanity check
- pm = org.subscription.stripe_get_default_payment_method()
+ pm = workspace.subscription.stripe_get_default_payment_method()
if not pm:
- logger.warning(f"{org} has no default payment method, cannot auto-recharge")
+ logger.warning(
+ f"{workspace} has no default payment method, cannot auto-recharge"
+ )
return
try:
@@ -117,5 +119,5 @@ def _auto_recharge_org(org: Org):
else:
assert invoice_data.paid
StripeWebhookHandler.handle_invoice_paid(
- org_id=org.org_id, invoice=invoice_data
+ workspace_id_or_uid=workspace.id, invoice=invoice_data
)
diff --git a/payments/models.py b/payments/models.py
index ff5be4f69..cebfeda70 100644
--- a/payments/models.py
+++ b/payments/models.py
@@ -82,8 +82,8 @@ def __str__(self):
ret = f"{self.get_plan_display()} | {self.get_payment_provider_display()}"
# if self.has_user:
# ret = f"{ret} | {self.user}"
- if self.has_org:
- ret = f"{ret} | {self.org}"
+ if self.has_workspace:
+ ret = f"{ret} | {self.workspace}"
if self.auto_recharge_enabled:
ret = f"Auto | {ret}"
return ret
@@ -138,10 +138,10 @@ def is_paid(self) -> bool:
return PricingPlan.from_sub(self).monthly_charge > 0 and self.external_id
@property
- def has_org(self) -> bool:
+ def has_workspace(self) -> bool:
try:
- self.org
- except Subscription.org.RelatedObjectDoesNotExist:
+ self.workspace
+ except Subscription.workspace.RelatedObjectDoesNotExist:
return False
else:
return True
@@ -376,12 +376,12 @@ def has_sent_monthly_budget_email_this_month(self) -> bool:
)
def should_send_monthly_spending_notification(self) -> bool:
- assert self.has_org
+ assert self.has_workspace
return bool(
self.monthly_spending_notification_threshold
and not self.has_sent_monthly_spending_notification_this_month()
- and self.org.get_dollars_spent_this_month()
+ and self.workspace.get_dollars_spent_this_month()
>= self.monthly_spending_notification_threshold
)
diff --git a/payments/tasks.py b/payments/tasks.py
index c98b8c12e..d84f1c748 100644
--- a/payments/tasks.py
+++ b/payments/tasks.py
@@ -1,8 +1,7 @@
from django.utils import timezone
from loguru import logger
-from app_users.models import AppUser
-from orgs.models import Org
+from workspaces.models import Workspace
from celeryapp import app
from daras_ai_v2 import settings
from daras_ai_v2.fastapi_tricks import get_app_route_url
@@ -14,11 +13,11 @@
def send_monthly_spending_notification_email(id: int):
from routers.account import account_route
- org = Org.objects.get(id=id)
- threshold = org.subscription.monthly_spending_notification_threshold
- for owner in org.get_owners():
+ workspace = Workspace.objects.get(id=id)
+ threshold = workspace.subscription.monthly_spending_notification_threshold
+ for owner in workspace.get_owners():
if not owner.user.email:
- logger.error(f"Org Owner doesn't have an email: {owner=}")
+ logger.error(f"Workspace Owner doesn't have an email: {owner=}")
return
send_email_via_postmark(
@@ -29,7 +28,7 @@ def send_monthly_spending_notification_email(id: int):
"monthly_spending_notification_threshold_email.html"
).render(
user=owner.user,
- org=org,
+ workspace=workspace,
account_url=get_app_route_url(account_route),
),
)
@@ -37,20 +36,22 @@ def send_monthly_spending_notification_email(id: int):
# IMPORTANT: always use update_fields=... / select_for_update when updating
# subscription info. We don't want to overwrite other changes made to
# subscription during the same time
- org.subscription.monthly_spending_notification_sent_at = timezone.now()
- org.subscription.save(update_fields=["monthly_spending_notification_sent_at"])
+ workspace.subscription.monthly_spending_notification_sent_at = timezone.now()
+ workspace.subscription.save(
+ update_fields=["monthly_spending_notification_sent_at"]
+ )
-def send_monthly_budget_reached_email(org: Org):
+def send_monthly_budget_reached_email(workspace: Workspace):
from routers.account import account_route
- for owner in org.get_owners():
+ for owner in workspace.get_owners():
if not owner.user.email:
continue
email_body = templates.get_template("monthly_budget_reached_email.html").render(
user=owner.user,
- org=org,
+ workspace=workspace,
account_url=get_app_route_url(account_route),
)
send_email_via_postmark(
@@ -63,5 +64,5 @@ def send_monthly_budget_reached_email(org: Org):
# IMPORTANT: always use update_fields=... when updating subscription
# info. We don't want to overwrite other changes made to subscription
# during the same time
- org.subscription.monthly_budget_email_sent_at = timezone.now()
- org.subscription.save(update_fields=["monthly_budget_email_sent_at"])
+ workspace.subscription.monthly_budget_email_sent_at = timezone.now()
+ workspace.subscription.save(update_fields=["monthly_budget_email_sent_at"])
diff --git a/payments/webhooks.py b/payments/webhooks.py
index 36f0499c7..3d7c3f202 100644
--- a/payments/webhooks.py
+++ b/payments/webhooks.py
@@ -6,7 +6,7 @@
from app_users.models import PaymentProvider, TransactionReason
from daras_ai_v2 import paypal
-from orgs.models import Org
+from workspaces.models import Workspace
from .models import Subscription
from .plans import PricingPlan
from .tasks import send_monthly_spending_notification_email
@@ -22,7 +22,7 @@ def handle_sale_completed(cls, sale: paypal.Sale):
return
pp_sub = paypal.Subscription.retrieve(sale.billing_agreement_id)
- assert pp_sub.custom_id, "pp_sub is missing org_id"
+ assert pp_sub.custom_id, "pp_sub is missing workspace_id"
assert pp_sub.plan_id, "pp_sub is missing plan ID"
plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id)
@@ -35,9 +35,8 @@ def handle_sale_completed(cls, sale: paypal.Sale):
f"paypal: charged amount ${charged_dollars} does not match plan's monthly charge ${plan.monthly_charge}"
)
- org_id = pp_sub.custom_id
add_balance_for_payment(
- org_id=org_id,
+ workspace_id_or_uid=pp_sub.custom_id,
amount=plan.credits,
invoice_id=sale.id,
payment_provider=cls.PROVIDER,
@@ -50,7 +49,9 @@ def handle_sale_completed(cls, sale: paypal.Sale):
def handle_subscription_updated(cls, pp_sub: paypal.Subscription):
logger.info(f"Paypal subscription updated {pp_sub.id}")
- assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing org_id"
+ assert (
+ pp_sub.custom_id
+ ), f"PayPal subscription {pp_sub.id} is missing workspace_id"
assert pp_sub.plan_id, f"PayPal subscription {pp_sub.id} is missing plan ID"
plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id)
@@ -62,8 +63,8 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription):
)
return
- set_org_subscription(
- org_id=pp_sub.custom_id,
+ set_workspace_subscription(
+ workspace_id_or_uid=pp_sub.custom_id,
plan=plan,
provider=cls.PROVIDER,
external_id=pp_sub.id,
@@ -72,8 +73,8 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription):
@classmethod
def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription):
assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid"
- set_org_subscription(
- org_id=pp_sub.custom_id,
+ set_workspace_subscription(
+ workspace_id_or_uid=pp_sub.custom_id,
plan=PricingPlan.STARTER,
provider=None,
external_id=None,
@@ -84,7 +85,9 @@ class StripeWebhookHandler:
PROVIDER = PaymentProvider.STRIPE
@classmethod
- def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice):
+ def handle_invoice_paid(
+ cls, workspace_id_or_uid: str | int, invoice: stripe.Invoice
+ ):
from app_users.tasks import save_stripe_default_payment_method
kwargs = {}
@@ -109,7 +112,7 @@ def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice):
amount = invoice.lines.data[0].quantity
charged_amount = invoice.lines.data[0].amount
add_balance_for_payment(
- org_id=org_id,
+ workspace_id_or_uid=workspace_id_or_uid,
amount=amount,
invoice_id=invoice.id,
payment_provider=cls.PROVIDER,
@@ -119,15 +122,15 @@ def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice):
)
save_stripe_default_payment_method.delay(
+ workspace_id_or_uid=workspace_id_or_uid,
payment_intent_id=invoice.payment_intent,
- org_id=org_id,
amount=amount,
charged_amount=charged_amount,
reason=reason,
)
@classmethod
- def handle_checkout_session_completed(cls, org_id: str, session_data):
+ def handle_checkout_session_completed(cls, workspace_id_or_uid: str, session_data):
setup_intent_id = session_data.get("setup_intent")
if not setup_intent_id:
# not a setup mode checkout -- do nothing
@@ -149,7 +152,9 @@ def handle_checkout_session_completed(cls, org_id: str, session_data):
)
@classmethod
- def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscription):
+ def handle_subscription_updated(
+ cls, workspace_id_or_uid: int | str, stripe_sub: stripe.Subscription
+ ):
logger.info(f"Stripe subscription updated: {stripe_sub.id}")
assert stripe_sub.plan, f"Stripe subscription {stripe_sub.id} is missing plan"
@@ -170,17 +175,17 @@ def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscriptio
)
return
- set_org_subscription(
- org_id=org_id,
+ set_workspace_subscription(
+ workspace_id_or_uid=workspace_id_or_uid,
plan=plan,
provider=cls.PROVIDER,
external_id=stripe_sub.id,
)
@classmethod
- def handle_subscription_cancelled(cls, org_id: str):
- set_org_subscription(
- org_id=org_id,
+ def handle_subscription_cancelled(cls, workspace_id_or_uid: int | str):
+ set_workspace_subscription(
+ workspace_id_or_uid=workspace_id_or_uid,
plan=PricingPlan.STARTER,
provider=PaymentProvider.STRIPE,
external_id=None,
@@ -189,15 +194,19 @@ def handle_subscription_cancelled(cls, org_id: str):
def add_balance_for_payment(
*,
- org_id: str,
+ workspace_id_or_uid: int | str,
amount: int,
invoice_id: str,
payment_provider: PaymentProvider,
charged_amount: int,
**kwargs,
):
- org = Org.objects.get_or_create_from_org_id(org_id)[0]
- org.add_balance(
+ try:
+ workspace = Workspace.objects.get(id=int(workspace_id_or_uid))
+ except (ValueError, Workspace.DoesNotExist):
+ workspace, _ = Workspace.objects.get_or_create_from_uid(workspace_id_or_uid)
+
+ workspace.add_balance(
amount=amount,
invoice_id=invoice_id,
charged_amount=charged_amount,
@@ -205,30 +214,33 @@ def add_balance_for_payment(
**kwargs,
)
- if not org.is_paying:
- org.is_paying = True
- org.save(update_fields=["is_paying"])
+ if not workspace.is_paying:
+ workspace.is_paying = True
+ workspace.save(update_fields=["is_paying"])
if (
- org.subscription
- and org.subscription.should_send_monthly_spending_notification()
+ workspace.subscription
+ and workspace.subscription.should_send_monthly_spending_notification()
):
- send_monthly_spending_notification_email.delay(org.id)
+ send_monthly_spending_notification_email.delay(workspace.id)
-def set_org_subscription(
+def set_workspace_subscription(
*,
- org_id: str,
+ workspace_id_or_uid: int | str,
plan: PricingPlan,
provider: PaymentProvider | None,
external_id: str | None,
amount: int | None = None,
charged_amount: int | None = None,
) -> Subscription:
- with transaction.atomic():
- org = Org.objects.get_or_create_from_org_id(org_id)[0]
+ try:
+ workspace = Workspace.objects.get(id=int(workspace_id_or_uid))
+ except (ValueError, Workspace.DoesNotExist):
+ workspace, _ = Workspace.objects.get_or_create_from_uid(workspace_id_or_uid)
- old_sub = org.subscription
+ with transaction.atomic():
+ old_sub = workspace.subscription
if old_sub:
new_sub = copy(old_sub)
else:
@@ -242,8 +254,8 @@ def set_org_subscription(
new_sub.save()
if not old_sub:
- org.subscription = new_sub
- org.save(update_fields=["subscription"])
+ workspace.subscription = new_sub
+ workspace.save(update_fields=["subscription"])
# cancel previous subscription if it's not the same as the new one
if old_sub and old_sub.external_id != external_id:
diff --git a/routers/account.py b/routers/account.py
index f9194589b..b2612d55c 100644
--- a/routers/account.py
+++ b/routers/account.py
@@ -18,10 +18,10 @@
from daras_ai_v2.manage_api_keys_widget import manage_api_keys
from daras_ai_v2.meta_content import raw_build_meta_tags
from daras_ai_v2.profiles import edit_user_profile_page
-from orgs.models import OrgInvitation
+from workspaces.models import WorkspaceInvitation
from payments.webhooks import PaypalWebhookHandler
from routers.root import page_wrapper, get_og_url_path
-from orgs.views import invitation_page, orgs_page
+from workspaces.views import invitation_page, workspaces_page
from routers.custom_api_router import CustomAPIRouter
@@ -142,10 +142,10 @@ def api_keys_route(request: Request):
)
-@gui.route(app, "/orgs/")
-def orgs_route(request: Request):
- with account_page_wrapper(request, AccountTabs.orgs):
- orgs_tab(request)
+@gui.route(app, "/workspaces/")
+def workspaces_route(request: Request):
+ with account_page_wrapper(request, AccountTabs.workspaces):
+ workspaces_tab(request)
url = get_og_url_path(request)
return dict(
@@ -159,8 +159,8 @@ def orgs_route(request: Request):
)
-@gui.route(app, "/invitation/{org_slug}/{invite_id}/")
-def invitation_route(request: Request, org_slug: str, invite_id: str):
+@gui.route(app, "/invitation/{workspace_slug}/{invite_id}/")
+def invitation_route(request: Request, workspace_slug: str, invite_id: str):
from routers.root import login
if not request.user or request.user.is_anonymous:
@@ -169,8 +169,8 @@ def invitation_route(request: Request, org_slug: str, invite_id: str):
raise RedirectException(redirect_url)
try:
- invitation = OrgInvitation.objects.get(invite_id=invite_id)
- except OrgInvitation.DoesNotExist:
+ invitation = WorkspaceInvitation.objects.get(invite_id=invite_id)
+ except WorkspaceInvitation.DoesNotExist:
return Response(status_code=404)
with page_wrapper(request):
@@ -178,8 +178,8 @@ def invitation_route(request: Request, org_slug: str, invite_id: str):
return dict(
meta=raw_build_meta_tags(
url=str(request.url),
- title=f"Join {invitation.org.name} • Gooey.AI",
- description=f"Invitation to join {invitation.org.name}",
+ title=f"Join {invitation.workspace.name} • Gooey.AI",
+ description=f"Invitation to join {invitation.workspace.name}",
robots="noindex,nofollow",
)
)
@@ -195,7 +195,7 @@ class AccountTabs(TabData, Enum):
profile = TabData(title=f"{icons.profile} Profile", route=profile_route)
saved = TabData(title=f"{icons.save} Saved", route=saved_route)
api_keys = TabData(title=f"{icons.api} API Keys", route=api_keys_route)
- orgs = TabData(title=f"{icons.company} Teams", route=orgs_route)
+ workspaces = TabData(title=f"{icons.company} Teams", route=workspaces_route)
@property
def url_path(self) -> str:
@@ -203,8 +203,8 @@ def url_path(self) -> str:
def billing_tab(request: Request):
- org, _ = request.user.get_or_create_personal_org()
- return billing_page(org)
+ workspace, _ = request.user.get_or_create_personal_workspace()
+ return billing_page(workspace)
def profile_tab(request: Request):
@@ -256,14 +256,14 @@ def api_keys_tab(request: Request):
manage_api_keys(request.user)
-def orgs_tab(request: Request):
+def workspaces_tab(request: Request):
"""only accessible to admins"""
from daras_ai_v2.base import BasePage
if not BasePage.is_user_admin(request.user):
raise RedirectException(get_route_path(account_route))
- orgs_page(request.user)
+ workspaces_page(request.user)
def get_tabs(request: Request) -> list[AccountTabs]:
@@ -276,7 +276,7 @@ def get_tabs(request: Request) -> list[AccountTabs]:
AccountTabs.api_keys,
]
if BasePage.is_user_admin(request.user):
- tab_list.append(AccountTabs.orgs)
+ tab_list.append(AccountTabs.workspaces)
return tab_list
diff --git a/routers/api.py b/routers/api.py
index 9b795d426..5d2b4e42d 100644
--- a/routers/api.py
+++ b/routers/api.py
@@ -354,6 +354,7 @@ def submit_api_call(
enable_rate_limits=enable_rate_limits,
is_api_call=True,
retention_policy=retention_policy or RetentionPolicy.keep,
+ billed_workspace=self.get_current_workspace(),
)
except ValidationError as e:
raise RequestValidationError(e.raw_errors, body=gui.session_state) from e
diff --git a/routers/paypal.py b/routers/paypal.py
index 3771481cf..84d8a3b83 100644
--- a/routers/paypal.py
+++ b/routers/paypal.py
@@ -126,8 +126,8 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json):
if plan.deprecated:
return JSONResponse({"error": "Deprecated plan"}, status_code=400)
- org, _ = request.user.get_or_create_personal_org()
- if org.subscription and org.subscription.is_paid():
+ workspace, _ = request.user.get_or_create_personal_worksace()
+ if workspace.subscription and workspace.subscription.is_paid():
return JSONResponse(
{"error": "User already has an active subscription"}, status_code=400
)
@@ -135,7 +135,7 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json):
paypal_plan_info = plan.get_paypal_plan()
pp_subscription = paypal.Subscription.create(
plan_id=paypal_plan_info["plan_id"],
- custom_id=org.org_id,
+ custom_id=str(workspace.id),
plan=paypal_plan_info.get("plan", {}),
application_context={
"brand_name": "Gooey.AI",
@@ -177,7 +177,7 @@ def _handle_invoice_paid(order_id: str):
purchase_unit = order["purchase_units"][0]
payment_capture = purchase_unit["payments"]["captures"][0]
add_balance_for_payment(
- org_id=payment_capture["custom_id"],
+ workspace_id_or_uid=payment_capture["custom_id"],
amount=int(purchase_unit["items"][0]["quantity"]),
invoice_id=payment_capture["id"],
payment_provider=PaymentProvider.PAYPAL,
diff --git a/scripts/migrate_billed_org_for_saved_runs.py b/scripts/migrate_billed_org_for_saved_runs.py
deleted file mode 100644
index 52b86e932..000000000
--- a/scripts/migrate_billed_org_for_saved_runs.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from django.db.models import F, Subquery, OuterRef
-from django.db import transaction
-
-from bots.models import SavedRun
-from orgs.models import Org
-
-
-def run():
- # Start a transaction to ensure atomicity
- with transaction.atomic():
- # Perform the update where 'uid' matches a valid 'org_id' in the 'Org' table
- SavedRun.objects.filter(
- billed_org_id__isnull=True, uid__in=Org.objects.values("org_id")
- ).update(
- billed_org_id=Subquery(
- Org.objects.filter(org_id=OuterRef("uid")).values("id")[:1]
- )
- )
diff --git a/scripts/migrate_billed_workspace_for_saved_runs.py b/scripts/migrate_billed_workspace_for_saved_runs.py
new file mode 100644
index 000000000..39f19ae4a
--- /dev/null
+++ b/scripts/migrate_billed_workspace_for_saved_runs.py
@@ -0,0 +1,23 @@
+from django.db import connection
+from loguru import logger
+
+
+def run():
+ with connection.cursor() as cursor:
+ cursor.execute(
+ """
+ UPDATE bots_savedrun
+ SET billed_workspace_id = workspaces_workspace.id
+ FROM
+ workspaces_workspace INNER JOIN
+ app_users_appuser ON workspaces_workspace.created_by_id = app_users_appuser.id
+ WHERE
+ bots_savedrun.billed_workspace_id IS NULL AND
+ bots_savedrun.uid IS NOT NULL AND
+ bots_savedrun.uid = app_users_appuser.uid AND
+ workspaces_workspace.is_personal = true
+ """
+ )
+ rows_updated = cursor.rowcount
+
+ logger.info(f"Updated {rows_updated} saved runs with billed workspace")
diff --git a/scripts/migrate_orgs_from_appusers.py b/scripts/migrate_orgs_from_appusers.py
deleted file mode 100644
index f4cbc7ec9..000000000
--- a/scripts/migrate_orgs_from_appusers.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from django.db import IntegrityError, connection
-from loguru import logger
-
-from app_users.models import AppUser
-from orgs.models import Org
-
-
-def run():
- migrate_personal_orgs()
- migrate_txns()
-
-
-def migrate_personal_orgs():
- users_without_personal_org = AppUser.objects.exclude(
- id__in=Org.objects.filter(is_personal=True).values_list("created_by", flat=True)
- )
-
- done_count = 0
-
- logger.info("Creating personal orgs...")
- for appuser in users_without_personal_org:
- try:
- Org.objects.migrate_from_appuser(appuser)
- except IntegrityError as e:
- logger.warning(f"IntegrityError: {e}")
- else:
- done_count += 1
-
- if done_count % 100 == 0:
- logger.info(f"Running... {done_count} migrated")
-
- logger.info(f"Migrated {done_count} personal orgs...")
-
-
-def migrate_txns():
- with connection.cursor() as cursor:
- cursor.execute(
- """
- UPDATE app_users_appusertransaction AS txn
- SET org_id = orgs_org.id
- FROM
- app_users_appuser
- INNER JOIN orgs_org ON app_users_appuser.id = orgs_org.created_by_id
- WHERE
- txn.user_id = app_users_appuser.id
- AND txn.org_id IS NULL
- AND orgs_org.is_personal = true
- """
- )
- logger.info(f"Updated {cursor.rowcount} txns with personal orgs")
diff --git a/scripts/migrate_workspace_from_appusers.py b/scripts/migrate_workspace_from_appusers.py
new file mode 100644
index 000000000..f58c0935d
--- /dev/null
+++ b/scripts/migrate_workspace_from_appusers.py
@@ -0,0 +1,52 @@
+from django.db import IntegrityError, connection
+from loguru import logger
+
+from app_users.models import AppUser
+from workspaces.models import Workspace
+
+
+def run():
+ migrate_personal_workspaces()
+ migrate_txns()
+
+
+def migrate_personal_workspaces():
+ users_without_personal_workspace = AppUser.objects.exclude(
+ id__in=Workspace.objects.filter(is_personal=True).values_list(
+ "created_by", flat=True
+ )
+ )
+
+ done_count = 0
+
+ logger.info("Creating personal workspaces...")
+ for appuser in users_without_personal_workspace:
+ try:
+ Workspace.objects.migrate_from_appuser(appuser)
+ except IntegrityError as e:
+ logger.warning(f"IntegrityError: {e}")
+ else:
+ done_count += 1
+
+ if done_count % 100 == 0:
+ logger.info(f"Running... {done_count} migrated")
+
+ logger.info(f"Migrated {done_count} personal workspaces...")
+
+
+def migrate_txns():
+ with connection.cursor() as cursor:
+ cursor.execute(
+ """
+ UPDATE app_users_appusertransaction AS txn
+ SET workspace_id = workspaces_workspace.id
+ FROM
+ app_users_appuser
+ INNER JOIN workspaces_workspace ON app_users_appuser.id = workspaces_workspace.created_by_id
+ WHERE
+ txn.user_id = app_users_appuser.id
+ AND txn.workspace_id IS NULL
+ AND workspaces_workspace.is_personal = true
+ """
+ )
+ logger.info(f"Updated {cursor.rowcount} txns with personal workspaces")
diff --git a/templates/monthly_budget_reached_email.html b/templates/monthly_budget_reached_email.html
index 6e467a086..861a3c3c4 100644
--- a/templates/monthly_budget_reached_email.html
+++ b/templates/monthly_budget_reached_email.html
@@ -1,6 +1,6 @@
-{% set dollars_spent = org.get_dollars_spent_this_month() %}
-{% set monthly_budget = org.subscription.monthly_spending_budget %}
-{% set threshold = org.subscription.auto_recharge_balance_threshold %}
+{% set dollars_spent = workspace.get_dollars_spent_this_month() %}
+{% set monthly_budget = workspace.subscription.monthly_spending_budget %}
+{% set threshold = workspace.subscription.auto_recharge_balance_threshold %}
Hey, {{ user.first_name() }}!
@@ -18,7 +18,7 @@
-
Credit Balance: {{ org.balance }} credits
+
Credit Balance: {{ workspace.balance }} credits
Monthly Budget: ${{ monthly_budget }}
Spending this month: ${{ dollars_spent }}
diff --git a/templates/monthly_spending_notification_threshold_email.html b/templates/monthly_spending_notification_threshold_email.html
index 13be0fae5..c8a6394d1 100644
--- a/templates/monthly_spending_notification_threshold_email.html
+++ b/templates/monthly_spending_notification_threshold_email.html
@@ -1,4 +1,4 @@
-{% set dollars_spent = org.get_dollars_spent_this_month() %}
+{% set dollars_spent = workspace.get_dollars_spent_this_month() %}
Hi, {{ user.first_name() }}!
@@ -6,11 +6,11 @@
Your spend on Gooey.AI so far this month is ${{ dollars_spent }}, exceeding your notification threshold
- of ${{ org.subscription.monthly_spending_notification_threshold }}.
+ of ${{ workspace.subscription.monthly_spending_notification_threshold }}.
- Your monthly budget is ${{ org.subscription.monthly_spending_budget }}, after which auto-recharge will be
+ Your monthly budget is ${{ workspace.subscription.monthly_spending_budget }}, after which auto-recharge will be
paused and all runs / API calls will be rejected.
- You have been added to the team {{ org.name }} on Gooey.AI.
- Visit the teams page to see your team.
+ You have been added to the team {{ workspace.name }} on Gooey.AI.
+ Visit the teams page to see your team.
- Your invite was automatically accepted because your email domain matches the organization's configured email domain.
- If you think this shouldn't have happened, you can leave this organization from the
- teams page.
+ Your invite was automatically accepted because your email domain matches the workspaceanization's configured email domain.
+ If you think this shouldn't have happened, you can leave this workspaceanization from the
+ teams page.
{{ invitation.inviter.display_name or invitation.inviter.first_name() }} has invited
- you to join their team {{ invitation.org.name }} on Gooey.AI.
+ you to join their team {{ invitation.workspace.name }} on Gooey.AI.
@@ -14,7 +14,7 @@
- The link will expire in {{ settings.ORG_INVITATION_EXPIRY_DAYS }} days.
+ The link will expire in {{ settings.WORKSPACE_INVITATION_EXPIRY_DAYS }} days.
diff --git a/orgs/__init__.py b/workspaces/__init__.py
similarity index 100%
rename from orgs/__init__.py
rename to workspaces/__init__.py
diff --git a/workspaces/admin.py b/workspaces/admin.py
new file mode 100644
index 000000000..3c0e74de7
--- /dev/null
+++ b/workspaces/admin.py
@@ -0,0 +1,155 @@
+from django.contrib import admin
+from django.db.models import Sum
+from safedelete.admin import SafeDeleteAdmin, SafeDeleteAdminFilter
+
+from bots.admin_links import change_obj_url
+from usage_costs.models import UsageCost
+from .models import Workspace, WorkspaceMembership, WorkspaceInvitation
+
+
+class WorkspaceMembershipInline(admin.TabularInline):
+ model = WorkspaceMembership
+ extra = 0
+ show_change_link = True
+ fields = ["user", "role", "created_at", "updated_at"]
+ readonly_fields = ["created_at", "updated_at"]
+ ordering = ["-created_at"]
+ can_delete = False
+ show_change_link = True
+
+
+class WorkspaceInvitationInline(admin.TabularInline):
+ model = WorkspaceInvitation
+ extra = 0
+ show_change_link = True
+ fields = [
+ "invitee_email",
+ "inviter",
+ "status",
+ "auto_accepted",
+ "created_at",
+ "updated_at",
+ ]
+ readonly_fields = ["auto_accepted", "created_at", "updated_at"]
+ ordering = ["status", "-created_at"]
+ can_delete = False
+ show_change_link = True
+
+
+@admin.register(Workspace)
+class WorkspaceAdmin(SafeDeleteAdmin):
+ list_display = [
+ "name",
+ "domain_name",
+ "created_at",
+ "updated_at",
+ ] + list(SafeDeleteAdmin.list_display)
+ list_filter = [SafeDeleteAdminFilter] + list(SafeDeleteAdmin.list_filter)
+ fields = [
+ "name",
+ "domain_name",
+ "created_by",
+ "is_personal",
+ "is_paying",
+ ("balance", "subscription"),
+ ("total_payments", "total_charged", "total_usage_cost"),
+ "created_at",
+ "updated_at",
+ ]
+ search_fields = ["name", "domain_name"]
+ readonly_fields = [
+ "is_personal",
+ "created_at",
+ "updated_at",
+ "total_payments",
+ "total_charged",
+ "total_usage_cost",
+ ]
+ inlines = [WorkspaceMembershipInline, WorkspaceInvitationInline]
+ ordering = ["-created_at"]
+
+ @admin.display(description="Total Payments")
+ def total_payments(self, workspace: Workspace):
+ return "$" + str(
+ (
+ workspace.transactions.aggregate(Sum("charged_amount"))[
+ "charged_amount__sum"
+ ]
+ or 0
+ )
+ / 100
+ )
+
+ @admin.display(description="Total Charged")
+ def total_charged(self, workspace: Workspace):
+ credits_charged = -1 * (
+ workspace.transactions.filter(amount__lt=0).aggregate(Sum("amount"))[
+ "amount__sum"
+ ]
+ or 0
+ )
+ return f"{credits_charged} Credits"
+
+ @admin.display(description="Total Usage Cost")
+ def total_usage_cost(self, workspace: Workspace):
+ total_cost = (
+ UsageCost.objects.filter(
+ saved_run__billed_workspace_id=workspace.id
+ ).aggregate(Sum("dollar_amount"))["dollar_amount__sum"]
+ or 0
+ )
+ return round(total_cost, 2)
+
+
+@admin.register(WorkspaceMembership)
+class WorkspaceMembershipAdmin(SafeDeleteAdmin):
+ list_display = [
+ "user",
+ "workspace",
+ "role",
+ "created_at",
+ "updated_at",
+ ] + list(SafeDeleteAdmin.list_display)
+ list_filter = ["workspace", "role", SafeDeleteAdminFilter] + list(
+ SafeDeleteAdmin.list_filter
+ )
+
+ def get_readonly_fields(
+ self, request: "HttpRequest", obj: WorkspaceMembership | None = None
+ ) -> list[str]:
+ readonly_fields = list(super().get_readonly_fields(request, obj))
+ if obj and obj.workspace and obj.workspace.deleted:
+ return readonly_fields + ["deleted_workspace"]
+ else:
+ return readonly_fields
+
+ @admin.display
+ def deleted_workspace(self, obj):
+ workspace = Workspace.deleted_objects.get(pk=obj.workspace_id)
+ return change_obj_url(workspace)
+
+
+@admin.register(WorkspaceInvitation)
+class WorkspaceInvitationAdmin(SafeDeleteAdmin):
+ fields = [
+ "workspace",
+ "invitee_email",
+ "inviter",
+ "role",
+ "status",
+ "auto_accepted",
+ "created_at",
+ "updated_at",
+ ]
+ list_display = [
+ "workspace",
+ "invitee_email",
+ "inviter",
+ "status",
+ "created_at",
+ "updated_at",
+ ] + list(SafeDeleteAdmin.list_display)
+ list_filter = ["workspace", "inviter", "role", SafeDeleteAdminFilter] + list(
+ SafeDeleteAdmin.list_filter
+ )
+ readonly_fields = ["auto_accepted"]
diff --git a/orgs/apps.py b/workspaces/apps.py
similarity index 74%
rename from orgs/apps.py
rename to workspaces/apps.py
index a75310666..dfc799939 100644
--- a/orgs/apps.py
+++ b/workspaces/apps.py
@@ -1,9 +1,9 @@
from django.apps import AppConfig
-class OrgsConfig(AppConfig):
+class WorkspacesConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
- name = "orgs"
+ name = "workspaces"
def ready(self):
from . import signals
diff --git a/orgs/migrations/0001_initial.py b/workspaces/migrations/0001_initial.py
similarity index 56%
rename from orgs/migrations/0001_initial.py
rename to workspaces/migrations/0001_initial.py
index 7de84737d..b7183be45 100644
--- a/orgs/migrations/0001_initial.py
+++ b/workspaces/migrations/0001_initial.py
@@ -1,8 +1,8 @@
-# Generated by Django 4.2.7 on 2024-07-18 15:41
+# Generated by Django 4.2.7 on 2024-09-02 14:07
from django.db import migrations, models
import django.db.models.deletion
-import orgs.models
+import workspaces.models
class Migration(migrations.Migration):
@@ -10,27 +10,34 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
+ ('payments', '0005_alter_subscription_plan'),
('app_users', '0019_alter_appusertransaction_reason'),
]
operations = [
migrations.CreateModel(
- name='Org',
+ name='Workspace',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('deleted', models.DateTimeField(db_index=True, editable=False, null=True)),
('deleted_by_cascade', models.BooleanField(default=False, editable=False)),
- ('org_id', models.CharField(blank=True, max_length=100, null=True, unique=True)),
+ ('workspace_id', models.CharField(blank=True, max_length=100, null=True, unique=True)),
('name', models.CharField(max_length=100)),
('logo', models.URLField(blank=True, null=True)),
- ('domain_name', models.CharField(blank=True, max_length=30, null=True, validators=[orgs.models.validate_org_domain_name])),
+ ('domain_name', models.CharField(blank=True, max_length=30, null=True, validators=[workspaces.models.validate_workspace_domain_name])),
+ ('balance', models.IntegerField(default=0, verbose_name='bal')),
+ ('is_paying', models.BooleanField(default=False, verbose_name='paid')),
+ ('stripe_customer_id', models.CharField(blank=True, default='', max_length=255)),
+ ('low_balance_email_sent_at', models.DateTimeField(blank=True, null=True)),
+ ('is_personal', models.BooleanField(default=False)),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='app_users.appuser')),
+ ('subscription', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='workspace', to='payments.subscription')),
],
),
migrations.CreateModel(
- name='OrgInvitation',
+ name='WorkspaceInvitation',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('deleted', models.DateTimeField(db_index=True, editable=False, null=True)),
@@ -40,20 +47,20 @@ class Migration(migrations.Migration):
('status', models.IntegerField(choices=[(1, 'Pending'), (2, 'Accepted'), (3, 'Rejected'), (4, 'Canceled'), (5, 'Expired')], default=1)),
('auto_accepted', models.BooleanField(default=False)),
('role', models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3)),
- ('last_email_sent_at', models.DateTimeField(blank=True, default=False, null=True)),
- ('status_changed_at', models.DateTimeField(blank=True, default=False, null=True)),
+ ('last_email_sent_at', models.DateTimeField(blank=True, default=None, null=True)),
+ ('status_changed_at', models.DateTimeField(blank=True, default=None, null=True)),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
('inviter', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='sent_invitations', to='app_users.appuser')),
- ('org', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='invitations', to='orgs.org')),
- ('status_changed_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='app_users.appuser')),
+ ('status_changed_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='received_invitations', to='app_users.appuser')),
+ ('workspace', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='invitations', to='workspaces.workspace')),
],
options={
'abstract': False,
},
),
migrations.CreateModel(
- name='OrgMembership',
+ name='WorkspaceMembership',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('deleted', models.DateTimeField(db_index=True, editable=False, null=True)),
@@ -61,21 +68,21 @@ class Migration(migrations.Migration):
('role', models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3)),
('created_at', models.DateTimeField(auto_now_add=True)),
('updated_at', models.DateTimeField(auto_now=True)),
- ('invitation', models.OneToOneField(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='membership', to='orgs.orginvitation')),
- ('org', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='memberships', to='orgs.org')),
- ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='org_memberships', to='app_users.appuser')),
+ ('invitation', models.OneToOneField(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='membership', to='workspaces.workspaceinvitation')),
+ ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='workspace_memberships', to='app_users.appuser')),
+ ('workspace', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='memberships', to='workspaces.workspace')),
],
- options={
- 'unique_together': {('org', 'user', 'deleted')},
- },
),
- migrations.AddField(
- model_name='org',
- name='members',
- field=models.ManyToManyField(related_name='orgs', through='orgs.OrgMembership', to='app_users.appuser'),
+ migrations.AddConstraint(
+ model_name='workspacemembership',
+ constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('workspace', 'user'), name='unique_workspace_user'),
+ ),
+ migrations.AddConstraint(
+ model_name='workspace',
+ constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('domain_name',), name='unique_domain_name_when_not_deleted', violation_error_message='This domain name is already in use by another team. Contact Gooey.AI Support if you think this is a mistake.'),
),
- migrations.AlterUniqueTogether(
- name='org',
- unique_together={('domain_name', 'deleted')},
+ migrations.AddConstraint(
+ model_name='workspace',
+ constraint=models.UniqueConstraint(models.F('created_by'), condition=models.Q(('deleted__isnull', True), ('is_personal', True)), name='unique_personal_workspace_per_user'),
),
]
diff --git a/workspaces/migrations/0002_alter_workspace_logo.py b/workspaces/migrations/0002_alter_workspace_logo.py
new file mode 100644
index 000000000..c28aba367
--- /dev/null
+++ b/workspaces/migrations/0002_alter_workspace_logo.py
@@ -0,0 +1,19 @@
+# Generated by Django 4.2.7 on 2024-09-03 12:59
+
+import bots.custom_fields
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('workspaces', '0001_initial'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='workspace',
+ name='logo',
+ field=bots.custom_fields.CustomURLField(blank=True, max_length=2048, null=True),
+ ),
+ ]
diff --git a/orgs/migrations/__init__.py b/workspaces/migrations/__init__.py
similarity index 100%
rename from orgs/migrations/__init__.py
rename to workspaces/migrations/__init__.py
diff --git a/orgs/models.py b/workspaces/models.py
similarity index 69%
rename from orgs/models.py
rename to workspaces/models.py
index 4c9b2c8e2..56020aef4 100644
--- a/orgs/models.py
+++ b/workspaces/models.py
@@ -15,80 +15,80 @@
from safedelete.managers import SafeDeleteManager
from safedelete.models import SafeDeleteModel, SOFT_DELETE_CASCADE
+from bots.custom_fields import CustomURLField
from daras_ai_v2 import settings
from daras_ai_v2.fastapi_tricks import get_app_route_url
from daras_ai_v2.crypto import get_random_doc_id
from gooeysite.bg_db_conn import db_middleware
-from orgs.tasks import send_auto_accepted_email, send_invitation_email
+from .tasks import send_auto_accepted_email, send_invitation_email
if typing.TYPE_CHECKING:
from app_users.models import AppUser, AppUserTransaction
-ORG_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$")
+WORKSPACE_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$")
-def validate_org_domain_name(value):
+def validate_workspace_domain_name(value):
from handles.models import COMMON_EMAIL_DOMAINS
- if not ORG_DOMAIN_NAME_RE.fullmatch(value):
+ if not WORKSPACE_DOMAIN_NAME_RE.fullmatch(value):
raise ValidationError("Invalid domain name")
if value in COMMON_EMAIL_DOMAINS:
raise ValidationError("This domain name is reserved")
-class OrgRole(models.IntegerChoices):
+class WorkspaceRole(models.IntegerChoices):
OWNER = 1
ADMIN = 2
MEMBER = 3
-class OrgManager(SafeDeleteManager):
- def create_org(
+class WorkspaceManager(SafeDeleteManager):
+ def create_workspace(
self,
*,
created_by: "AppUser",
- org_id: str | None = None,
balance: int | None = None,
**kwargs,
- ) -> Org:
- org = self.model(
- org_id=org_id or get_random_doc_id(),
+ ) -> Workspace:
+ workspace = self.model(
created_by=created_by,
balance=balance,
**kwargs,
)
if (
balance is None
- and Org.all_objects.filter(created_by=created_by).count() <= 1
+ and Workspace.all_objects.filter(created_by=created_by).count() <= 1
):
# set some balance for first team created by user
- # Org.all_objects is important to include deleted orgs
- org.balance = settings.FIRST_ORG_FREE_CREDITS
+ # Workspace.all_objects is important to include deleted workspaces
+ workspace.balance = settings.FIRST_WORKSPACE_FREE_CREDITS
- org.full_clean()
- org.save()
- org.add_member(
+ workspace.full_clean()
+ workspace.save()
+ workspace.add_member(
created_by,
- role=OrgRole.OWNER,
+ role=WorkspaceRole.OWNER,
)
- return org
+ return workspace
- def get_or_create_from_org_id(self, org_id: str) -> tuple[Org, bool]:
- from app_users.models import AppUser
+ def get_or_create_from_uid(self, uid: str) -> tuple[Workspace, bool]:
+ workspace = Workspace.objects.filter(
+ is_personal=True, created_by__uid=uid
+ ).first()
+ if workspace:
+ return workspace, False
- try:
- return self.get(org_id=org_id), False
- except self.model.DoesNotExist:
- user = AppUser.objects.get_or_create_from_uid(org_id)[0]
- return self.migrate_from_appuser(user), True
+ user, _ = AppUser.objects.get_or_create_from_uid(uid)
+ workspace = self.migrate_from_appuser(user)
+ return workspace, True
- def migrate_from_appuser(self, user: "AppUser") -> Org:
- return self.create_org(
+ def migrate_from_appuser(self, user: "AppUser") -> Workspace:
+ return self.create_workspace(
name=f"{user.first_name()}'s Personal Workspace",
- org_id=user.uid or get_random_doc_id(),
created_by=user,
is_personal=True,
balance=user.balance,
@@ -108,24 +108,22 @@ def get_dollars_spent_this_month(self) -> float:
return (cents_spent or 0) / 100
-class Org(SafeDeleteModel):
+class Workspace(SafeDeleteModel):
_safedelete_policy = SOFT_DELETE_CASCADE
- org_id = models.CharField(max_length=100, null=True, blank=True, unique=True)
-
name = models.CharField(max_length=100)
created_by = models.ForeignKey(
"app_users.appuser",
on_delete=models.CASCADE,
)
- logo = models.URLField(null=True, blank=True)
+ logo = CustomURLField(null=True, blank=True)
domain_name = models.CharField(
max_length=30,
blank=True,
null=True,
validators=[
- validate_org_domain_name,
+ validate_workspace_domain_name,
],
)
@@ -136,7 +134,7 @@ class Org(SafeDeleteModel):
subscription = models.OneToOneField(
"payments.Subscription",
on_delete=models.SET_NULL,
- related_name="org",
+ related_name="workspace",
null=True,
blank=True,
)
@@ -147,7 +145,7 @@ class Org(SafeDeleteModel):
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
- objects = OrgManager()
+ objects = WorkspaceManager()
class Meta:
constraints = [
@@ -160,7 +158,7 @@ class Meta:
models.UniqueConstraint(
"created_by",
condition=Q(deleted__isnull=True, is_personal=True),
- name="unique_personal_org_per_user",
+ name="unique_personal_workspace_per_user",
),
]
@@ -174,10 +172,13 @@ def get_slug(self):
return slugify(self.name)
def add_member(
- self, user: "AppUser", role: OrgRole, invitation: "OrgInvitation | None" = None
+ self,
+ user: "AppUser",
+ role: WorkspaceRole,
+ invitation: "WorkspaceInvitation | None" = None,
):
- OrgMembership(
- org=self,
+ WorkspaceMembership(
+ workspace=self,
user=user,
role=role,
invitation=invitation,
@@ -188,9 +189,9 @@ def invite_user(
*,
invitee_email: str,
inviter: "AppUser",
- role: OrgRole,
+ role: WorkspaceRole,
auto_accept: bool = False,
- ) -> "OrgInvitation":
+ ) -> "WorkspaceInvitation":
"""
auto_accept: If True, the user will be automatically added if they have an account
"""
@@ -198,15 +199,17 @@ def invite_user(
if member.user.email == invitee_email:
raise ValidationError(f"{member.user} is already a member of this team")
- for invitation in self.invitations.filter(status=OrgInvitation.Status.PENDING):
+ for invitation in self.invitations.filter(
+ status=WorkspaceInvitation.Status.PENDING
+ ):
if invitation.invitee_email == invitee_email:
raise ValidationError(
f"{invitee_email} was already invited to this team"
)
- invitation = OrgInvitation(
+ invitation = WorkspaceInvitation(
invite_id=get_random_doc_id(),
- org=self,
+ workspace=self,
invitee_email=invitee_email,
inviter=inviter,
role=role,
@@ -225,8 +228,8 @@ def invite_user(
return invitation
- def get_owners(self) -> list[OrgMembership]:
- return self.memberships.filter(role=OrgRole.OWNER)
+ def get_owners(self) -> models.QuerySet[WorkspaceMembership]:
+ return self.memberships.filter(role=WorkspaceRole.OWNER)
@db_middleware
@transaction.atomic
@@ -255,51 +258,58 @@ def add_balance(
# It won't lock this row for reads, and multiple threads can update the same row leading incorrect balance
#
# Also we're not using .update() here because it won't give back the updated end balance
- org: Org = Org.objects.select_for_update().get(pk=self.pk)
- org.balance += amount
- org.save(update_fields=["balance"])
- kwargs.setdefault("plan", org.subscription and org.subscription.plan)
+ workspace: Workspace = Workspace.objects.select_for_update().get(pk=self.pk)
+ workspace.balance += amount
+ workspace.save(update_fields=["balance"])
+ kwargs.setdefault(
+ "plan", workspace.subscription and workspace.subscription.plan
+ )
return AppUserTransaction.objects.create(
- org=org,
- user=org.created_by if org.is_personal else None,
+ workspace=workspace,
+ user=workspace.created_by if workspace.is_personal else None,
invoice_id=invoice_id,
amount=amount,
- end_balance=org.balance,
+ end_balance=workspace.balance,
**kwargs,
)
def get_or_create_stripe_customer(self) -> stripe.Customer:
customer = self.search_stripe_customer()
if not customer:
+ metadata = {"workspace_id": self.id}
+ if self.is_personal:
+ metadata["uid"] = self.created_by.uid
+
customer = stripe.Customer.create(
name=self.created_by.display_name,
email=self.created_by.email,
phone=self.created_by.phone_number,
- metadata={"uid": self.org_id, "org_id": self.org_id, "id": self.pk},
+ metadata=metadata,
)
self.stripe_customer_id = customer.id
self.save()
return customer
def search_stripe_customer(self) -> stripe.Customer | None:
- if not self.org_id:
- return None
if self.stripe_customer_id:
try:
return stripe.Customer.retrieve(self.stripe_customer_id)
- except stripe.error.InvalidRequestError as e:
+ except stripe.InvalidRequestError as e:
if e.http_status != 404:
raise
+
try:
customer = stripe.Customer.search(
- query=f'metadata["uid"]:"{self.org_id}"'
+ query=f'metadata["workspace_id"]:"{self.id}"'
).data[0]
except IndexError:
- return None
- else:
- self.stripe_customer_id = customer.id
- self.save()
- return customer
+ customer = self.is_personal and self.created_by.search_stripe_customer()
+ if not customer:
+ return None
+
+ self.stripe_customer_id = customer.id
+ self.save()
+ return customer
def get_dollars_spent_this_month(self) -> float:
today = timezone.now()
@@ -311,13 +321,17 @@ def get_dollars_spent_this_month(self) -> float:
return (cents_spent or 0) / 100
-class OrgMembership(SafeDeleteModel):
- org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="memberships")
+class WorkspaceMembership(SafeDeleteModel):
+ workspace = models.ForeignKey(
+ Workspace, on_delete=models.CASCADE, related_name="memberships"
+ )
user = models.ForeignKey(
- "app_users.AppUser", on_delete=models.CASCADE, related_name="org_memberships"
+ "app_users.AppUser",
+ on_delete=models.CASCADE,
+ related_name="workspace_memberships",
)
invitation = models.OneToOneField(
- "OrgInvitation",
+ "WorkspaceInvitation",
on_delete=models.SET_NULL,
blank=True,
null=True,
@@ -325,7 +339,9 @@ class OrgMembership(SafeDeleteModel):
related_name="membership",
)
- role = models.IntegerField(choices=OrgRole.choices, default=OrgRole.MEMBER)
+ role = models.IntegerField(
+ choices=WorkspaceRole.choices, default=WorkspaceRole.MEMBER
+ )
created_at = models.DateTimeField(auto_now_add=True) # same as joining date
updated_at = models.DateTimeField(auto_now=True)
@@ -335,45 +351,45 @@ class OrgMembership(SafeDeleteModel):
class Meta:
constraints = [
models.UniqueConstraint(
- fields=["org", "user"],
+ fields=["workspace", "user"],
condition=Q(deleted__isnull=True),
- name="unique_org_user",
+ name="unique_workspace_user",
)
]
def __str__(self):
- return f"{self.get_role_display()} - {self.user} ({self.org})"
+ return f"{self.get_role_display()} - {self.user} ({self.workspace})"
- def can_edit_org_metadata(self):
- return self.role in (OrgRole.OWNER, OrgRole.ADMIN)
+ def can_edit_workspace_metadata(self):
+ return self.role in (WorkspaceRole.OWNER, WorkspaceRole.ADMIN)
- def can_delete_org(self):
- return self.role == OrgRole.OWNER
+ def can_delete_workspace(self):
+ return self.role == WorkspaceRole.OWNER
- def has_higher_role_than(self, other: "OrgMembership"):
+ def has_higher_role_than(self, other: "WorkspaceMembership"):
# creator > owner > admin > member
match other.role:
- case OrgRole.OWNER:
- return self.org.created_by == OrgRole.OWNER
- case OrgRole.ADMIN:
- return self.role == OrgRole.OWNER
- case OrgRole.MEMBER:
- return self.role in (OrgRole.OWNER, OrgRole.ADMIN)
-
- def can_change_role(self, other: "OrgMembership"):
+ case WorkspaceRole.OWNER:
+ return self.workspace.created_by == WorkspaceRole.OWNER
+ case WorkspaceRole.ADMIN:
+ return self.role == WorkspaceRole.OWNER
+ case WorkspaceRole.MEMBER:
+ return self.role in (WorkspaceRole.OWNER, WorkspaceRole.ADMIN)
+
+ def can_change_role(self, other: "WorkspaceMembership"):
return self.has_higher_role_than(other)
- def can_kick(self, other: "OrgMembership"):
+ def can_kick(self, other: "WorkspaceMembership"):
return self.has_higher_role_than(other)
def can_transfer_ownership(self):
- return self.role == OrgRole.OWNER
+ return self.role == WorkspaceRole.OWNER
def can_invite(self):
- return self.role in (OrgRole.OWNER, OrgRole.ADMIN)
+ return self.role in (WorkspaceRole.OWNER, WorkspaceRole.ADMIN)
-class OrgInvitation(SafeDeleteModel):
+class WorkspaceInvitation(SafeDeleteModel):
class Status(models.IntegerChoices):
PENDING = 1
ACCEPTED = 2
@@ -384,14 +400,18 @@ class Status(models.IntegerChoices):
invite_id = models.CharField(max_length=100, unique=True)
invitee_email = models.EmailField()
- org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="invitations")
+ workspace = models.ForeignKey(
+ Workspace, on_delete=models.CASCADE, related_name="invitations"
+ )
inviter = models.ForeignKey(
"app_users.AppUser", on_delete=models.CASCADE, related_name="sent_invitations"
)
status = models.IntegerField(choices=Status.choices, default=Status.PENDING)
auto_accepted = models.BooleanField(default=False)
- role = models.IntegerField(choices=OrgRole.choices, default=OrgRole.MEMBER)
+ role = models.IntegerField(
+ choices=WorkspaceRole.choices, default=WorkspaceRole.MEMBER
+ )
last_email_sent_at = models.DateTimeField(null=True, blank=True, default=None)
status_changed_at = models.DateTimeField(null=True, blank=True, default=None)
@@ -407,12 +427,12 @@ class Status(models.IntegerChoices):
updated_at = models.DateTimeField(auto_now=True)
def __str__(self):
- return f"{self.invitee_email} - {self.org} ({self.get_status_display()})"
+ return f"{self.invitee_email} - {self.workspace} ({self.get_status_display()})"
def has_expired(self):
return self.status == self.Status.EXPIRED or (
timezone.now() - (self.last_email_sent_at or self.created_at)
- > timedelta(days=settings.ORG_INVITATION_EXPIRY_DAYS)
+ > timedelta(days=settings.WORKSPACE_INVITATION_EXPIRY_DAYS)
)
def auto_accept(self):
@@ -431,7 +451,9 @@ def auto_accept(self):
self.accept(invitee, auto_accepted=True)
if self.auto_accepted:
- logger.info(f"User {invitee} auto-accepted invitation to org {self.org}")
+ logger.info(
+ f"User {invitee} auto-accepted invitation to workspace {self.workspace}"
+ )
send_auto_accepted_email.delay(self.pk)
def get_url(self):
@@ -439,7 +461,10 @@ def get_url(self):
return get_app_route_url(
invitation_route,
- path_params={"invite_id": self.invite_id, "org_slug": self.org.get_slug()},
+ path_params={
+ "invite_id": self.invite_id,
+ "workspace_slug": self.workspace.get_slug(),
+ },
)
def send_email(self):
@@ -469,7 +494,7 @@ def accept(self, user: "AppUser", *, auto_accepted: bool = False):
"This invitation has expired. Please ask your team admin to send a new one."
)
- if self.org.memberships.filter(user_id=user.pk).exists():
+ if self.workspace.memberships.filter(user_id=user.pk).exists():
raise ValidationError(f"User is already a member of this team.")
self.status = self.Status.ACCEPTED
@@ -480,8 +505,8 @@ def accept(self, user: "AppUser", *, auto_accepted: bool = False):
self.full_clean()
with transaction.atomic():
- user.org_memberships.all().delete() # delete current memberships
- self.org.add_member(
+ user.workspace_memberships.all().delete() # delete current memberships
+ self.workspace.add_member(
user,
role=self.role,
invitation=self,
@@ -505,5 +530,5 @@ def can_resend_email(self):
return True
return timezone.now() - self.last_email_sent_at > timedelta(
- seconds=settings.ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL
+ seconds=settings.WORKSPACE_INVITATION_EMAIL_COOLDOWN_INTERVAL
)
diff --git a/workspaces/signals.py b/workspaces/signals.py
new file mode 100644
index 000000000..962ffe794
--- /dev/null
+++ b/workspaces/signals.py
@@ -0,0 +1,50 @@
+from django.db.models.signals import post_save
+from django.dispatch import receiver
+from loguru import logger
+from safedelete.signals import post_softdelete
+
+from app_users.models import AppUser
+from .models import Workspace, WorkspaceMembership, WorkspaceRole
+
+
+@receiver(post_save, sender=AppUser)
+def add_user_existing_workspace(instance: AppUser, **kwargs):
+ """
+ if the domain name matches
+ """
+ if not instance.email:
+ return
+
+ email_domain = instance.email.split("@")[1]
+ workspace = Workspace.objects.filter(domain_name=email_domain).first()
+ if not workspace:
+ return
+
+ if instance.received_invitations.exists():
+ # user has some existing invitations
+ return
+
+ workspace_owner = workspace.memberships.filter(role=WorkspaceRole.OWNER).first()
+ if not workspace_owner:
+ logger.warning(
+ f"Workspace {workspace} has no owner. Skipping auto-accept for user {instance}"
+ )
+ return
+
+ workspace.invite_user(
+ invitee_email=instance.email,
+ inviter=workspace_owner.user,
+ role=WorkspaceRole.MEMBER,
+ auto_accept=not instance.workspace_memberships.exists(), # auto-accept only if user has no existing memberships
+ )
+
+
+@receiver(post_softdelete, sender=WorkspaceMembership)
+def delete_workspace_if_no_members_left(instance: WorkspaceMembership, **kwargs):
+ if instance.workspace.memberships.exists():
+ return
+
+ logger.info(
+ f"Deleting workspace {instance.workspace} because it has no members left"
+ )
+ instance.workspace.delete()
diff --git a/orgs/tasks.py b/workspaces/tasks.py
similarity index 67%
rename from orgs/tasks.py
rename to workspaces/tasks.py
index 09258c9ec..bdfd416ed 100644
--- a/orgs/tasks.py
+++ b/workspaces/tasks.py
@@ -10,20 +10,20 @@
@app.task
def send_invitation_email(invitation_pk: int):
- from orgs.models import OrgInvitation
+ from workspaces.models import WorkspaceInvitation
- invitation = OrgInvitation.objects.get(pk=invitation_pk)
+ invitation = WorkspaceInvitation.objects.get(pk=invitation_pk)
assert invitation.status == invitation.Status.PENDING
logger.info(
- f"Sending inviation email to {invitation.invitee_email} for org {invitation.org}..."
+ f"Sending inviation email to {invitation.invitee_email} for workspace {invitation.workspace}..."
)
send_email_via_postmark(
to_address=invitation.invitee_email,
from_address=settings.SUPPORT_EMAIL,
- subject=f"[Gooey.AI] Invitation to join {invitation.org.name}",
- html_body=templates.get_template("org_invitation_email.html").render(
+ subject=f"[Gooey.AI] Invitation to join {invitation.workspace.name}",
+ html_body=templates.get_template("workspace_invitation_email.html").render(
settings=settings,
invitation=invitation,
),
@@ -37,10 +37,10 @@ def send_invitation_email(invitation_pk: int):
@app.task
def send_auto_accepted_email(invitation_pk: int):
- from orgs.models import OrgInvitation
- from routers.account import orgs_route
+ from workspaces.models import WorkspaceInvitation
+ from routers.account import workspaces_route
- invitation = OrgInvitation.objects.get(pk=invitation_pk)
+ invitation = WorkspaceInvitation.objects.get(pk=invitation_pk)
assert invitation.auto_accepted and invitation.status == invitation.Status.ACCEPTED
assert invitation.status_changed_by
@@ -50,19 +50,19 @@ def send_auto_accepted_email(invitation_pk: int):
return
logger.info(
- f"Sending auto-accepted email to {user.email} for org {invitation.org}..."
+ f"Sending auto-accepted email to {user.email} for workspace {invitation.workspace}..."
)
send_email_via_postmark(
to_address=user.email,
from_address=settings.SUPPORT_EMAIL,
subject=f"[Gooey.AI] You've been added to a new team!",
html_body=templates.get_template(
- "org_invitation_auto_accepted_email.html"
+ "workspace_invitation_auto_accepted_email.html"
).render(
settings=settings,
user=user,
- org=invitation.org,
- orgs_url=get_app_route_url(orgs_route),
+ workspace=invitation.workspace,
+ workspaces_url=get_app_route_url(workspaces_route),
),
message_stream="outbound",
)
diff --git a/orgs/tests.py b/workspaces/tests.py
similarity index 100%
rename from orgs/tests.py
rename to workspaces/tests.py
diff --git a/orgs/views.py b/workspaces/views.py
similarity index 64%
rename from orgs/views.py
rename to workspaces/views.py
index 494bac72a..75121d1cd 100644
--- a/orgs/views.py
+++ b/workspaces/views.py
@@ -5,43 +5,44 @@
import gooey_gui as gui
from django.core.exceptions import ValidationError
+from .models import Workspace, WorkspaceInvitation, WorkspaceMembership, WorkspaceRole
from app_users.models import AppUser
-from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole
from daras_ai_v2 import icons
from daras_ai_v2.fastapi_tricks import get_route_path
-DEFAULT_ORG_LOGO = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/74a37c52-8260-11ee-a297-02420a0001ee/gooey.ai%20-%20A%20pop%20art%20illustration%20of%20robots%20taki...y%20Liechtenstein%20mint%20colour%20is%20main%20city%20Seattle.png"
+DEFAULT_WORKSPACE_LOGO = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/74a37c52-8260-11ee-a297-02420a0001ee/gooey.ai%20-%20A%20pop%20art%20illustration%20of%20robots%20taki...y%20Liechtenstein%20mint%20colour%20is%20main%20city%20Seattle.png"
rounded_border = "w-100 border shadow-sm rounded py-4 px-3"
-def invitation_page(user: AppUser, invitation: OrgInvitation):
- from routers.account import orgs_route
+def invitation_page(user: AppUser, invitation: WorkspaceInvitation):
+ from routers.account import workspaces_route
- orgs_page_path = get_route_path(orgs_route)
+ workspaces_page_path = get_route_path(workspaces_route)
with gui.div(className="text-center my-5"):
gui.write(
- f"# Invitation to join {invitation.org.name}", className="d-block mb-5"
+ f"# Invitation to join {invitation.workspace.name}",
+ className="d-block mb-5",
)
- if invitation.org.memberships.filter(user=user).exists():
- # redirect to org page
- raise gui.RedirectException(orgs_page_path)
+ if invitation.workspace.memberships.filter(user=user).exists():
+ # redirect to workspace page
+ raise gui.RedirectException(workspaces_page_path)
- if invitation.status != OrgInvitation.Status.PENDING:
+ if invitation.status != WorkspaceInvitation.Status.PENDING:
gui.write(f"This invitation has been {invitation.get_status_display()}.")
return
gui.write(
- f"**{format_user_name(invitation.inviter)}** has invited you to join **{invitation.org.name}**."
+ f"**{format_user_name(invitation.inviter)}** has invited you to join **{invitation.workspace.name}**."
)
- if other_m := user.org_memberships.first():
+ if other_m := user.workspace_memberships.first():
gui.caption(
- f"You are currently a member of [{other_m.org.name}]({orgs_page_path}). You will be removed from that team if you accept this invitation."
+ f"You are currently a member of [{other_m.workspace.name}]({workspaces_page_path}). You will be removed from that team if you accept this invitation."
)
accept_label = "Leave and Accept"
else:
@@ -56,58 +57,61 @@ def invitation_page(user: AppUser, invitation: OrgInvitation):
if accept_button:
invitation.accept(user=user)
- raise gui.RedirectException(orgs_page_path)
+ raise gui.RedirectException(workspaces_page_path)
if reject_button:
invitation.reject(user=user)
-def orgs_page(user: AppUser):
- memberships = user.org_memberships.filter()
+def workspaces_page(user: AppUser):
+ memberships = user.workspace_memberships.filter()
if not memberships:
- gui.write("*You're not part of an organization yet... Create one?*")
+ gui.write("*You're not part of an workspaceanization yet... Create one?*")
- render_org_creation_view(user)
+ render_workspace_creation_view(user)
else:
- # only support one org for now
- render_org_by_membership(memberships.first())
+ # only support one workspace for now
+ render_workspace_by_membership(memberships.first())
-def render_org_by_membership(membership: OrgMembership):
+def render_workspace_by_membership(membership: WorkspaceMembership):
"""
membership object has all the information we need:
- - org
+ - workspace
- current user
- - current user's role in the org (and other metadata)
+ - current user's role in the workspace (and other metadata)
"""
- org = membership.org
+ workspace = membership.workspace
current_user = membership.user
with gui.div(
className="d-xs-block d-sm-flex flex-row-reverse justify-content-between"
):
with gui.div(className="d-flex justify-content-center align-items-center"):
- if membership.can_edit_org_metadata():
- org_edit_modal = gui.Modal("Edit Org", key="edit-org-modal")
- if org_edit_modal.is_open():
- with org_edit_modal.container():
- render_org_edit_view_by_membership(
- membership, modal=org_edit_modal
+ if membership.can_edit_workspace_metadata():
+ workspace_edit_modal = gui.Modal(
+ "Edit Workspace", key="edit-workspace-modal"
+ )
+ if workspace_edit_modal.is_open():
+ with workspace_edit_modal.container():
+ render_workspace_edit_view_by_membership(
+ membership, modal=workspace_edit_modal
)
if gui.button(f"{icons.edit} Edit", type="secondary"):
- org_edit_modal.open()
+ workspace_edit_modal.open()
with gui.div(className="d-flex align-items-center"):
gui.image(
- org.logo or DEFAULT_ORG_LOGO,
+ workspace.logo or DEFAULT_WORKSPACE_LOGO,
className="my-0 me-4 rounded",
style={"width": "128px", "height": "128px", "object-fit": "contain"},
)
with gui.div(className="d-flex flex-column justify-content-center"):
- gui.write(f"# {org.name}")
- if org.domain_name:
+ gui.write(f"# {workspace.name}")
+ if workspace.domain_name:
gui.write(
- f"Org Domain: `@{org.domain_name}`", className="text-muted"
+ f"Workspace Domain: `@{workspace.domain_name}`",
+ className="text-muted",
)
with gui.div(className="mt-4"):
@@ -122,38 +126,44 @@ def render_org_by_membership(membership: OrgMembership):
if invite_modal.is_open():
with invite_modal.container():
render_invite_creation_view(
- org=org, inviter=current_user, modal=invite_modal
+ workspace=workspace,
+ inviter=current_user,
+ modal=invite_modal,
)
- render_members_list(org=org, current_member=membership)
+ render_members_list(workspace=workspace, current_member=membership)
with gui.div(className="mt-4"):
- render_pending_invitations_list(org=org, current_member=membership)
+ render_pending_invitations_list(workspace=workspace, current_member=membership)
with gui.div(className="mt-4"):
- org_leave_modal = gui.Modal("Leave Org", key="leave-org-modal")
- if org_leave_modal.is_open():
- with org_leave_modal.container():
- render_org_leave_view_by_membership(membership, modal=org_leave_modal)
+ workspace_leave_modal = gui.Modal(
+ "Leave Workspace", key="leave-workspace-modal"
+ )
+ if workspace_leave_modal.is_open():
+ with workspace_leave_modal.container():
+ render_workspace_leave_view_by_membership(
+ membership, modal=workspace_leave_modal
+ )
with gui.div(className="text-end"):
- leave_org = gui.button(
+ leave_workspace = gui.button(
"Leave",
className="btn btn-theme bg-danger border-danger text-white",
)
- if leave_org:
- org_leave_modal.open()
+ if leave_workspace:
+ workspace_leave_modal.open()
-def render_org_creation_view(user: AppUser):
- gui.write(f"# {icons.company} Create an Org", unsafe_allow_html=True)
- org_fields = render_org_create_or_edit_form()
+def render_workspace_creation_view(user: AppUser):
+ gui.write(f"# {icons.company} Create an Workspace", unsafe_allow_html=True)
+ workspace_fields = render_workspace_create_or_edit_form()
if gui.button("Create"):
try:
- Org.objects.create_org(
+ Workspace.objects.create_workspace(
created_by=user,
- **org_fields,
+ **workspace_fields,
)
except ValidationError as e:
gui.write(", ".join(e.messages), className="text-danger")
@@ -161,50 +171,54 @@ def render_org_creation_view(user: AppUser):
gui.rerun()
-def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: gui.Modal):
- org = membership.org
- render_org_create_or_edit_form(org=org)
+def render_workspace_edit_view_by_membership(
+ membership: WorkspaceMembership, *, modal: gui.Modal
+):
+ workspace = membership.workspace
+ render_workspace_create_or_edit_form(workspace=workspace)
if gui.button("Save", className="w-100", type="primary"):
try:
- org.full_clean()
+ workspace.full_clean()
except ValidationError as e:
# newlines in markdown
gui.write(" \n".join(e.messages), className="text-danger")
else:
- org.save()
+ workspace.save()
modal.close()
- if membership.can_delete_org() or membership.can_transfer_ownership():
+ if membership.can_delete_workspace() or membership.can_transfer_ownership():
gui.write("---")
render_danger_zone_by_membership(membership)
-def render_danger_zone_by_membership(membership: OrgMembership):
+def render_danger_zone_by_membership(membership: WorkspaceMembership):
gui.write("### Danger Zone", className="d-block my-2")
- if membership.can_delete_org():
- org_deletion_modal = gui.Modal("Delete Organization", key="delete-org-modal")
- if org_deletion_modal.is_open():
- with org_deletion_modal.container():
- render_org_deletion_view_by_membership(
- membership, modal=org_deletion_modal
+ if membership.can_delete_workspace():
+ workspace_deletion_modal = gui.Modal(
+ "Delete Workspaceanization", key="delete-workspace-modal"
+ )
+ if workspace_deletion_modal.is_open():
+ with workspace_deletion_modal.container():
+ render_workspace_deletion_view_by_membership(
+ membership, modal=workspace_deletion_modal
)
with gui.div(className="d-flex justify-content-between align-items-center"):
- gui.write("Delete Organization")
+ gui.write("Delete Workspaceanization")
if gui.button(
f"{icons.delete} Delete",
className="btn btn-theme py-2 bg-danger border-danger text-white",
):
- org_deletion_modal.open()
+ workspace_deletion_modal.open()
-def render_org_deletion_view_by_membership(
- membership: OrgMembership, *, modal: gui.Modal
+def render_workspace_deletion_view_by_membership(
+ membership: WorkspaceMembership, *, modal: gui.Modal
):
gui.write(
- f"Are you sure you want to delete **{membership.org.name}**? This action is irreversible."
+ f"Are you sure you want to delete **{membership.workspace.name}**? This action is irreversible."
)
with gui.div(className="d-flex"):
@@ -216,34 +230,37 @@ def render_org_deletion_view_by_membership(
if gui.button(
"Delete", className="btn btn-theme bg-danger border-danger text-light w-50"
):
- membership.org.delete()
+ membership.workspace.delete()
modal.close()
-def render_org_leave_view_by_membership(
- current_member: OrgMembership, *, modal: gui.Modal
+def render_workspace_leave_view_by_membership(
+ current_member: WorkspaceMembership, *, modal: gui.Modal
):
- org = current_member.org
+ workspace = current_member.workspace
- gui.write("Are you sure you want to leave this organization?")
+ gui.write("Are you sure you want to leave this workspaceanization?")
new_owner = None
- if current_member.role == OrgRole.OWNER and org.memberships.count() == 1:
+ if (
+ current_member.role == WorkspaceRole.OWNER
+ and workspace.memberships.count() == 1
+ ):
gui.caption(
"You are the only member. You will lose access to this team if you leave."
)
elif (
- current_member.role == OrgRole.OWNER
- and org.memberships.filter(role=OrgRole.OWNER).count() == 1
+ current_member.role == WorkspaceRole.OWNER
+ and workspace.memberships.filter(role=WorkspaceRole.OWNER).count() == 1
):
members_by_uid = {
m.user.uid: m
- for m in org.memberships.all().select_related("user")
+ for m in workspace.memberships.all().select_related("user")
if m != current_member
}
gui.caption(
- "You are the only owner of this organization. Please choose another member to promote to owner."
+ "You are the only owner of this workspaceanization. Please choose another member to promote to owner."
)
new_owner_uid = gui.selectbox(
"New Owner",
@@ -262,13 +279,13 @@ def render_org_leave_view_by_membership(
"Leave", className="btn btn-theme bg-danger border-danger text-light w-50"
):
if new_owner:
- new_owner.role = OrgRole.OWNER
+ new_owner.role = WorkspaceRole.OWNER
new_owner.save()
current_member.delete()
modal.close()
-def render_members_list(org: Org, current_member: OrgMembership):
+def render_members_list(workspace: Workspace, current_member: WorkspaceMembership):
with gui.tag("table", className="table table-responsive"):
with gui.tag("thead"), gui.tag("tr"):
with gui.tag("th", scope="col"):
@@ -281,7 +298,7 @@ def render_members_list(org: Org, current_member: OrgMembership):
gui.html("")
with gui.tag("tbody"):
- for m in org.memberships.all().order_by("created_at"):
+ for m in workspace.memberships.all().order_by("created_at"):
with gui.tag("tr"):
with gui.tag("td"):
name = format_user_name(
@@ -300,9 +317,11 @@ def render_members_list(org: Org, current_member: OrgMembership):
render_membership_actions(m, current_member=current_member)
-def render_membership_actions(m: OrgMembership, current_member: OrgMembership):
+def render_membership_actions(
+ m: WorkspaceMembership, current_member: WorkspaceMembership
+):
if current_member.can_change_role(m):
- if m.role == OrgRole.MEMBER:
+ if m.role == WorkspaceRole.MEMBER:
modal, confirmed = button_with_confirmation_modal(
f"{icons.admin} Make Admin",
key=f"promote-member-{m.pk}",
@@ -312,10 +331,10 @@ def render_membership_actions(m: OrgMembership, current_member: OrgMembership):
modal_key=f"promote-member-{m.pk}-modal",
)
if confirmed:
- m.role = OrgRole.ADMIN
+ m.role = WorkspaceRole.ADMIN
m.save()
modal.close()
- elif m.role == OrgRole.ADMIN:
+ elif m.role == WorkspaceRole.ADMIN:
modal, confirmed = button_with_confirmation_modal(
f"{icons.remove_user} Revoke Admin",
key=f"demote-member-{m.pk}",
@@ -325,7 +344,7 @@ def render_membership_actions(m: OrgMembership, current_member: OrgMembership):
modal_key=f"demote-member-{m.pk}-modal",
)
if confirmed:
- m.role = OrgRole.MEMBER
+ m.role = WorkspaceRole.MEMBER
m.save()
modal.close()
@@ -334,7 +353,7 @@ def render_membership_actions(m: OrgMembership, current_member: OrgMembership):
f"{icons.remove_user} Remove",
key=f"remove-member-{m.pk}",
unsafe_allow_html=True,
- confirmation_text=f"Are you sure you want to remove **{format_user_name(m.user)}** from **{m.org.name}**?",
+ confirmation_text=f"Are you sure you want to remove **{format_user_name(m.user)}** from **{m.workspace.name}**?",
modal_title="Remove Member",
modal_key=f"remove-member-{m.pk}-modal",
className="bg-danger border-danger text-light",
@@ -382,8 +401,12 @@ def button_with_confirmation_modal(
return modal, False
-def render_pending_invitations_list(org: Org, *, current_member: OrgMembership):
- pending_invitations = org.invitations.filter(status=OrgInvitation.Status.PENDING)
+def render_pending_invitations_list(
+ workspace: Workspace, *, current_member: WorkspaceMembership
+):
+ pending_invitations = workspace.invitations.filter(
+ status=WorkspaceInvitation.Status.PENDING
+ )
if not pending_invitations:
return
@@ -419,7 +442,9 @@ def render_pending_invitations_list(org: Org, *, current_member: OrgMembership):
render_invitation_actions(invite, current_member=current_member)
-def render_invitation_actions(invitation: OrgInvitation, current_member: OrgMembership):
+def render_invitation_actions(
+ invitation: WorkspaceInvitation, current_member: WorkspaceMembership
+):
if current_member.can_invite() and invitation.can_resend_email():
modal, confirmed = button_with_confirmation_modal(
f"{icons.email} Resend",
@@ -453,20 +478,23 @@ def render_invitation_actions(invitation: OrgInvitation, current_member: OrgMemb
modal.close()
-def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal):
+def render_invite_creation_view(
+ workspace: Workspace, inviter: AppUser, modal: gui.Modal
+):
email = gui.text_input("Email")
- if org.domain_name:
+ if workspace.domain_name:
gui.caption(
- f"Users with `@{org.domain_name}` email will be added automatically."
+ f"Users with `@{workspace.domain_name}` email will be added automatically."
)
if gui.button(f"{icons.add_user} Invite", type="primary", unsafe_allow_html=True):
try:
- org.invite_user(
+ workspace.invite_user(
invitee_email=email,
inviter=inviter,
- role=OrgRole.MEMBER,
- auto_accept=org.domain_name.lower() == email.split("@")[1].lower(),
+ role=WorkspaceRole.MEMBER,
+ auto_accept=workspace.domain_name.lower()
+ == email.split("@")[1].lower(),
)
except ValidationError as e:
gui.write(", ".join(e.messages), className="text-danger")
@@ -474,24 +502,28 @@ def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal):
modal.close()
-def render_org_create_or_edit_form(org: Org | None = None) -> AttrDict | Org:
- org_proxy = org or AttrDict()
+def render_workspace_create_or_edit_form(
+ workspace: Workspace | None = None,
+) -> AttrDict | Workspace:
+ workspace_proxy = workspace or AttrDict()
- org_proxy.name = gui.text_input("Team Name", value=org and org.name or "")
- org_proxy.logo = gui.file_uploader(
- "Logo", accept=["image/*"], value=org and org.logo or ""
+ workspace_proxy.name = gui.text_input(
+ "Team Name", value=workspace and workspace.name or ""
+ )
+ workspace_proxy.logo = gui.file_uploader(
+ "Logo", accept=["image/*"], value=workspace and workspace.logo or ""
)
- org_proxy.domain_name = gui.text_input(
+ workspace_proxy.domain_name = gui.text_input(
"Domain Name (Optional)",
placeholder="e.g. gooey.ai",
- value=org and org.domain_name or "",
+ value=workspace and workspace.domain_name or "",
)
- if org_proxy.domain_name:
+ if workspace_proxy.domain_name:
gui.caption(
- f"Invite any user with `@{org_proxy.domain_name}` email to this organization."
+ f"Invite any user with `@{workspace_proxy.domain_name}` email to this workspaceanization."
)
- return org_proxy
+ return workspace_proxy
def format_user_name(user: AppUser, current_user: AppUser | None = None):
From 555cfafc29cf2e06378b695a0f07b9345cb3f7cc Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 4 Sep 2024 18:43:37 +0530
Subject: [PATCH 070/110] fix: /v1/balance API should return balance from
personal workspace
---
routers/api.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/routers/api.py b/routers/api.py
index 5d2b4e42d..59d3e95b7 100644
--- a/routers/api.py
+++ b/routers/api.py
@@ -434,7 +434,8 @@ class BalanceResponse(BaseModel):
@app.get("/v1/balance/", response_model=BalanceResponse, tags=["Misc"])
def get_balance(user: AppUser = Depends(api_auth_header)):
- return BalanceResponse(balance=user.balance)
+ workspace, _ = user.get_or_create_personal_workspace()
+ return BalanceResponse(balance=workspace.balance)
@app.get("/status")
From 2edea041e84993328c95bb055aba5a7db367cd94 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 4 Sep 2024 19:01:41 +0530
Subject: [PATCH 071/110] remove useless debug logging
---
payments/webhooks.py | 8 --------
1 file changed, 8 deletions(-)
diff --git a/payments/webhooks.py b/payments/webhooks.py
index 2c1820065..bff97390f 100644
--- a/payments/webhooks.py
+++ b/payments/webhooks.py
@@ -194,15 +194,12 @@ def handle_subscription_cancelled(cls, uid: str):
@classmethod
def handle_invoice_failed(cls, uid: str, data: dict):
- logger.info(f"Invoice failed: {data}")
-
if stripe.Charge.list(payment_intent=data["payment_intent"], limit=1).has_more:
# we must have already sent an invoice for this to the user. so we should just ignore this event
logger.info("Charge already exists for this payment intent")
return
if data.get("metadata", {}).get("auto_recharge"):
- logger.info("auto recharge failed... sending invoice email")
send_payment_failed_email_with_invoice.delay(
uid=uid,
invoice_url=data["hosted_invoice_url"],
@@ -210,17 +207,12 @@ def handle_invoice_failed(cls, uid: str, data: dict):
subject="Payment failure on your Gooey.AI auto-recharge",
)
elif data.get("subscription_details", {}):
- print("subscription failed")
send_payment_failed_email_with_invoice.delay(
uid=uid,
invoice_url=data["hosted_invoice_url"],
dollar_amt=data["amount_due"] / 100,
subject="Payment failure on your Gooey.AI subscription",
)
- else:
- print("not auto recharge or subscription")
- print(f"{data.get('metadata')=}")
- return
def add_balance_for_payment(
From 8d29766dd969df4b0b837b1de4d910667f6651f3 Mon Sep 17 00:00:00 2001
From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com>
Date: Wed, 4 Sep 2024 19:05:37 +0530
Subject: [PATCH 072/110] cleanup: remove base_email.html template
---
templates/base_email.html | 22 --------
.../off_session_payment_failed_email.html | 51 ++++++++++++-------
2 files changed, 32 insertions(+), 41 deletions(-)
delete mode 100644 templates/base_email.html
diff --git a/templates/base_email.html b/templates/base_email.html
deleted file mode 100644
index 63ab8b012..000000000
--- a/templates/base_email.html
+++ /dev/null
@@ -1,22 +0,0 @@
-
-
-
-
-
-
-
- {% block title %}{% endblock title %}
-
- {% block head %}{% endblock head %}
-
-
-
-
-
-