From 4ea8759140025be285dab4f4daaedbeb3d3d3e8c Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Tue, 28 May 2024 16:38:51 -0700 Subject: [PATCH 001/110] per second pricing and 10 second free tier limit --- daras_ai_v2/lipsync_api.py | 4 +-- recipes/Lipsync.py | 50 ++++++++++++++++++++++++-------------- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/daras_ai_v2/lipsync_api.py b/daras_ai_v2/lipsync_api.py index 8269f092f..a502fcb12 100644 --- a/daras_ai_v2/lipsync_api.py +++ b/daras_ai_v2/lipsync_api.py @@ -10,8 +10,8 @@ class LipsyncModel(Enum): - Wav2Lip = "Rudrabha/Wav2Lip" - SadTalker = "OpenTalker/SadTalker" + Wav2Lip = "SD (Rudrabha/Wav2Lip)" + SadTalker = "HD (OpenTalker/SadTalker)" class SadTalkerSettings(BaseModel): diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 1efd41fe6..763342bae 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -3,6 +3,7 @@ import requests from pydantic import BaseModel +from daras_ai_v2 import settings import gooey_ui as st from bots.models import Workflow from daras_ai_v2.base import BasePage @@ -11,12 +12,26 @@ from daras_ai_v2.lipsync_settings_widgets import lipsync_settings, LipsyncModel from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.pydantic_validation import FieldHttpUrl +from daras_ai_v2.redis_cache import redis_cache_decorator -CREDITS_PER_MB = 2 +CREDITS_PER_MINUTE = 36 DEFAULT_LIPSYNC_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7fc4d302-9402-11ee-98dc-02420a0001ca/Lip%20Sync.jpg.png" +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def get_audio_duration(audio_url: str) -> float: + import soundfile as sf + import tempfile + + with tempfile.NamedTemporaryFile(suffix=audio_url.split(".")[-1]) as tfile: + tfile.write(requests.get(audio_url).content) + tfile.flush() + f = sf.SoundFile(tfile.name) + seconds = len(f) / f.samplerate + return seconds + + class LipsyncPage(BasePage): title = "Lip Syncing" explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f33e6332-88d8-11ee-89f9-02420a000169/Lipsync%20TTS.png.png" @@ -62,9 +77,16 @@ def render_form_v2(self): ) def validate_form_v2(self): - assert st.session_state.get("input_audio"), "Please provide an Audio file" + input_audio = st.session_state.get("input_audio") + assert input_audio, "Please provide an Audio file" assert st.session_state.get("input_face"), "Please provide an Input Face" + # free users can only use <10 seconds of audio + if not self.is_current_user_paying() and not self.is_current_user_admin(): + assert ( + get_audio_duration(input_audio) < 10 + ), "Free users can only use audio files less than 10 seconds long" + def render_settings(self): lipsync_settings(st.session_state.get("selected_model")) @@ -119,27 +141,19 @@ def preview_description(self, state: dict) -> str: def get_cost_note(self) -> str | None: multiplier = ( - 3 - if st.session_state.get("lipsync_model") == LipsyncModel.SadTalker.name + 2 + if st.session_state.get("selected_model") == LipsyncModel.SadTalker.name else 1 ) - return f"{CREDITS_PER_MB * multiplier} credits per MB" + return f"{CREDITS_PER_MINUTE * multiplier}/minute" def get_raw_price(self, state: dict) -> float: - total_bytes = 0 + from math import ceil input_audio = state.get("input_audio") - if input_audio: - r = requests.head(input_audio) - total_bytes += float(r.headers.get("Content-length") or "1") - - input_face = state.get("input_face") - if input_face: - r = requests.head(input_face) - total_bytes += float(r.headers.get("Content-length") or "1") - - total_mb = total_bytes / 1024 / 1024 + seconds = get_audio_duration(input_audio) if input_audio else 0 + seconds = ceil(seconds / 5) * 5 # round up to nearest 5 seconds multiplier = ( - 3 if state.get("lipsync_model") == LipsyncModel.SadTalker.name else 1 + 2 if state.get("selected_model") == LipsyncModel.SadTalker.name else 1 ) - return total_mb * CREDITS_PER_MB * multiplier + return seconds * CREDITS_PER_MINUTE * multiplier / 60 From ab0735701d03f7b826a6d91953b020c63602ec27 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Tue, 28 May 2024 16:44:56 -0700 Subject: [PATCH 002/110] rename --- daras_ai_v2/lipsync_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/daras_ai_v2/lipsync_api.py b/daras_ai_v2/lipsync_api.py index a502fcb12..ab3bacbdb 100644 --- a/daras_ai_v2/lipsync_api.py +++ b/daras_ai_v2/lipsync_api.py @@ -10,8 +10,8 @@ class LipsyncModel(Enum): - Wav2Lip = "SD (Rudrabha/Wav2Lip)" - SadTalker = "HD (OpenTalker/SadTalker)" + Wav2Lip = "SD" + SadTalker = "HD: SadTalker" class SadTalkerSettings(BaseModel): From b8832849d3dcf0b52e4753c3ee255ecb4e98986b Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Tue, 28 May 2024 17:07:54 -0700 Subject: [PATCH 003/110] clip audio instead of disallowing it --- recipes/Lipsync.py | 50 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 763342bae..2d28f57b5 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -29,6 +29,7 @@ def get_audio_duration(audio_url: str) -> float: tfile.flush() f = sf.SoundFile(tfile.name) seconds = len(f) / f.samplerate + f.close() return seconds @@ -68,6 +69,18 @@ def render_form_v2(self): """, key="input_audio", ) + input_audio = st.session_state.get("input_audio") + if ( + input_audio + and not self.is_current_user_paying() + and not self.is_current_user_admin() + and get_audio_duration(input_audio) > 10 + ): + st.error( + "Audio duration is greater than 10 seconds and will be clipped. Please upgrade to process longer audio files.", + icon="⚠️", + color="orange", + ) enum_selector( LipsyncModel, @@ -81,11 +94,38 @@ def validate_form_v2(self): assert input_audio, "Please provide an Audio file" assert st.session_state.get("input_face"), "Please provide an Input Face" - # free users can only use <10 seconds of audio - if not self.is_current_user_paying() and not self.is_current_user_admin(): - assert ( - get_audio_duration(input_audio) < 10 - ), "Free users can only use audio files less than 10 seconds long" + # cut the audio to <=10 seconds if user is not paying + if ( + not self.is_current_user_paying() + and not self.is_current_user_admin() + and get_audio_duration(input_audio) > 10 + ): + import soundfile as sf + import tempfile + from daras_ai.image_input import upload_file_from_bytes + + with tempfile.NamedTemporaryFile( + suffix="." + input_audio.split(".")[-1] + ) as src: + src.write(requests.get(input_audio).content) + src.flush() + f = sf.SoundFile(src.name) + with tempfile.NamedTemporaryFile( + suffix="." + input_audio.split(".")[-1] + ) as dst: + clip = sf.SoundFile( + dst.name, mode="w", samplerate=f.samplerate, channels=f.channels + ) + clip.write(f.read(10 * f.samplerate)) + clip.flush() + clip.close() + input_audio = upload_file_from_bytes( + filename="clipped_audio.wav", + data=dst.read(), + content_type="audio/wav", + ) + st.session_state["input_audio"] = input_audio + f.close() def render_settings(self): lipsync_settings(st.session_state.get("selected_model")) From 5faaf7eabe63f0e469bf58ce8e87827bb6d27b72 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Tue, 28 May 2024 17:08:53 -0700 Subject: [PATCH 004/110] poetry lock --- poetry.lock | 46 +++++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 1 + 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 9382b7274..3e1d8ac0c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2917,6 +2917,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -4464,6 +4474,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4471,8 +4482,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4489,6 +4508,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4496,6 +4516,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -5302,6 +5323,29 @@ files = [ {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, ] +[[package]] +name = "soundfile" +version = "0.12.1" +description = "An audio library based on libsndfile, CFFI and NumPy" +optional = false +python-versions = "*" +files = [ + {file = "soundfile-0.12.1-py2.py3-none-any.whl", hash = "sha256:828a79c2e75abab5359f780c81dccd4953c45a2c4cd4f05ba3e233ddf984b882"}, + {file = "soundfile-0.12.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d922be1563ce17a69582a352a86f28ed8c9f6a8bc951df63476ffc310c064bfa"}, + {file = "soundfile-0.12.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:bceaab5c4febb11ea0554566784bcf4bc2e3977b53946dda2b12804b4fe524a8"}, + {file = "soundfile-0.12.1-py2.py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:2dc3685bed7187c072a46ab4ffddd38cef7de9ae5eb05c03df2ad569cf4dacbc"}, + {file = "soundfile-0.12.1-py2.py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:074247b771a181859d2bc1f98b5ebf6d5153d2c397b86ee9e29ba602a8dfe2a6"}, + {file = "soundfile-0.12.1-py2.py3-none-win32.whl", hash = "sha256:59dfd88c79b48f441bbf6994142a19ab1de3b9bb7c12863402c2bc621e49091a"}, + {file = "soundfile-0.12.1-py2.py3-none-win_amd64.whl", hash = "sha256:0d86924c00b62552b650ddd28af426e3ff2d4dc2e9047dae5b3d8452e0a49a77"}, + {file = "soundfile-0.12.1.tar.gz", hash = "sha256:e8e1017b2cf1dda767aef19d2fd9ee5ebe07e050d430f77a0a7c66ba08b8cdae"}, +] + +[package.dependencies] +cffi = ">=1.0" + +[package.extras] +numpy = ["numpy"] + [[package]] name = "soupsieve" version = "2.5" @@ -6384,4 +6428,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "a4efb36ab8d78f27caa79189d06a8b977f20990654f9d5b4a096fe654465e3c5" +content-hash = "9efe4bf5265c71ecdaf899750132aee05acdc0f48f06a0d38545723749f17742" diff --git a/pyproject.toml b/pyproject.toml index 0e08956d2..bc4bdd455 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ emoji = "^2.10.1" pyvespa = "^0.39.0" anthropic = "^0.25.5" azure-cognitiveservices-speech = "^1.37.0" +soundfile = "^0.12.1" [tool.poetry.group.dev.dependencies] watchdog = "^2.1.9" From a4447592899148a24e684483d520c5d5d015f2da Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Tue, 2 Jul 2024 10:05:51 -0700 Subject: [PATCH 005/110] rename lipsync models --- daras_ai_v2/lipsync_api.py | 4 ++-- pyproject.toml | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/daras_ai_v2/lipsync_api.py b/daras_ai_v2/lipsync_api.py index ab3bacbdb..72443bcca 100644 --- a/daras_ai_v2/lipsync_api.py +++ b/daras_ai_v2/lipsync_api.py @@ -10,8 +10,8 @@ class LipsyncModel(Enum): - Wav2Lip = "SD" - SadTalker = "HD: SadTalker" + Wav2Lip = "SD: Fast but low-res" + SadTalker = "HD (SadTalker): Hi-res but slow" class SadTalkerSettings(BaseModel): diff --git a/pyproject.toml b/pyproject.toml index 9762105b6..d0b85af5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,6 @@ emoji = "^2.10.1" pyvespa = "^0.39.0" anthropic = "^0.25.5" azure-cognitiveservices-speech = "^1.37.0" -soundfile = "^0.12.1" [tool.poetry.group.dev.dependencies] watchdog = "^2.1.9" From 50066ac5dd62f6ecb38221eef0991e42bcb7bd66 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Tue, 2 Jul 2024 10:08:56 -0700 Subject: [PATCH 006/110] revert poetry lock --- poetry.lock | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/poetry.lock b/poetry.lock index 69dc42aa8..a9a6f950a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2928,16 +2928,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -4489,7 +4479,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4497,16 +4486,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4523,7 +4504,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4531,7 +4511,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, From 008a2a56421513129d50f81093140b4c7c4c62d1 Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Tue, 2 Jul 2024 11:18:50 -0700 Subject: [PATCH 007/110] truncate audio on gpu --- daras_ai_v2/gpu_server.py | 24 ++++++++++- daras_ai_v2/lipsync_api.py | 36 ++++++++++++---- recipes/Lipsync.py | 88 ++++++++++---------------------------- 3 files changed, 73 insertions(+), 75 deletions(-) diff --git a/daras_ai_v2/gpu_server.py b/daras_ai_v2/gpu_server.py index 1a59b12a5..d99bab574 100644 --- a/daras_ai_v2/gpu_server.py +++ b/daras_ai_v2/gpu_server.py @@ -42,6 +42,26 @@ def call_celery_task_outfile( content_type: str | None, filename: str, num_outputs: int = 1, +): + links, _ = call_celery_task_outfile_with_ret( + task_name, + pipeline=pipeline, + inputs=inputs, + content_type=content_type, + filename=filename, + num_outputs=num_outputs, + ) + return links + + +def call_celery_task_outfile_with_ret( + task_name: str, + *, + pipeline: dict, + inputs: dict, + content_type: str | None, + filename: str, + num_outputs: int = 1, ): blobs = [storage_blob_for(filename) for i in range(num_outputs)] pipeline["upload_urls"] = [ @@ -55,8 +75,8 @@ def call_celery_task_outfile( ) for blob in blobs ] - call_celery_task(task_name, pipeline=pipeline, inputs=inputs) - return [blob.public_url for blob in blobs] + ret = call_celery_task(task_name, pipeline=pipeline, inputs=inputs) + return [blob.public_url for blob in blobs], ret _app = None diff --git a/daras_ai_v2/lipsync_api.py b/daras_ai_v2/lipsync_api.py index 72443bcca..0971c9b52 100644 --- a/daras_ai_v2/lipsync_api.py +++ b/daras_ai_v2/lipsync_api.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from daras_ai_v2.exceptions import UserError, GPUError -from daras_ai_v2.gpu_server import call_celery_task_outfile +from daras_ai_v2.gpu_server import call_celery_task_outfile_with_ret from daras_ai_v2.pydantic_validation import FieldHttpUrl @@ -44,22 +44,40 @@ class LipsyncSettings(BaseModel): sadtalker_settings: SadTalkerSettings = None -def run_sadtalker(settings: SadTalkerSettings, face: str, audio: str): - return call_celery_task_outfile( +def run_sadtalker( + settings: SadTalkerSettings, + face: str, + audio: str, + truncate_to_seconds: float | None = None, +) -> tuple[str, float]: + links, metadata = call_celery_task_outfile_with_ret( "lipsync.sadtalker", pipeline=dict( model_id="SadTalker_V0.0.2_512.safetensors", preprocess=settings.preprocess, ), - inputs=settings.dict() | dict(source_image=face, driven_audio=audio), + inputs=settings.dict() + | dict( + source_image=face, + driven_audio=audio, + truncate_to_seconds=truncate_to_seconds, + ), content_type="video/mp4", filename=f"gooey.ai lipsync.mp4", - )[0] + ) + + return links[0], metadata["output"]["duration_sec"] -def run_wav2lip(*, face: str, audio: str, pads: tuple[int, int, int, int]) -> bytes: +def run_wav2lip( + *, + face: str, + audio: str, + pads: tuple[int, int, int, int], + truncate_to_seconds: float | None = None, +) -> tuple[str, float]: try: - return call_celery_task_outfile( + links, metadata = call_celery_task_outfile_with_ret( "wav2lip", pipeline=dict( model_id="wav2lip_gan.pth", @@ -72,10 +90,12 @@ def run_wav2lip(*, face: str, audio: str, pads: tuple[int, int, int, int]) -> by # "out_height": 480, # "smooth": True, # "fps": 25, + truncate_to_seconds=truncate_to_seconds, ), content_type="video/mp4", filename=f"gooey.ai lipsync.mp4", - )[0] + ) + return links[0], metadata["output"]["duration_sec"] except ValueError as e: msg = "\n\n".join(e.args).lower() if "unsupported" in msg: diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 2d28f57b5..31f5c6b49 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -3,7 +3,6 @@ import requests from pydantic import BaseModel -from daras_ai_v2 import settings import gooey_ui as st from bots.models import Workflow from daras_ai_v2.base import BasePage @@ -12,27 +11,12 @@ from daras_ai_v2.lipsync_settings_widgets import lipsync_settings, LipsyncModel from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.pydantic_validation import FieldHttpUrl -from daras_ai_v2.redis_cache import redis_cache_decorator CREDITS_PER_MINUTE = 36 DEFAULT_LIPSYNC_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7fc4d302-9402-11ee-98dc-02420a0001ca/Lip%20Sync.jpg.png" -@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) -def get_audio_duration(audio_url: str) -> float: - import soundfile as sf - import tempfile - - with tempfile.NamedTemporaryFile(suffix=audio_url.split(".")[-1]) as tfile: - tfile.write(requests.get(audio_url).content) - tfile.flush() - f = sf.SoundFile(tfile.name) - seconds = len(f) / f.samplerate - f.close() - return seconds - - class LipsyncPage(BasePage): title = "Lip Syncing" explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f33e6332-88d8-11ee-89f9-02420a000169/Lipsync%20TTS.png.png" @@ -47,6 +31,8 @@ class RequestModel(LipsyncSettings, BaseModel): class ResponseModel(BaseModel): output_video: FieldHttpUrl + seconds: float + truncated: bool = False def preview_image(self, state: dict) -> str | None: return DEFAULT_LIPSYNC_META_IMG @@ -69,18 +55,6 @@ def render_form_v2(self): """, key="input_audio", ) - input_audio = st.session_state.get("input_audio") - if ( - input_audio - and not self.is_current_user_paying() - and not self.is_current_user_admin() - and get_audio_duration(input_audio) > 10 - ): - st.error( - "Audio duration is greater than 10 seconds and will be clipped. Please upgrade to process longer audio files.", - icon="⚠️", - color="orange", - ) enum_selector( LipsyncModel, @@ -94,50 +68,24 @@ def validate_form_v2(self): assert input_audio, "Please provide an Audio file" assert st.session_state.get("input_face"), "Please provide an Input Face" - # cut the audio to <=10 seconds if user is not paying - if ( - not self.is_current_user_paying() - and not self.is_current_user_admin() - and get_audio_duration(input_audio) > 10 - ): - import soundfile as sf - import tempfile - from daras_ai.image_input import upload_file_from_bytes - - with tempfile.NamedTemporaryFile( - suffix="." + input_audio.split(".")[-1] - ) as src: - src.write(requests.get(input_audio).content) - src.flush() - f = sf.SoundFile(src.name) - with tempfile.NamedTemporaryFile( - suffix="." + input_audio.split(".")[-1] - ) as dst: - clip = sf.SoundFile( - dst.name, mode="w", samplerate=f.samplerate, channels=f.channels - ) - clip.write(f.read(10 * f.samplerate)) - clip.flush() - clip.close() - input_audio = upload_file_from_bytes( - filename="clipped_audio.wav", - data=dst.read(), - content_type="audio/wav", - ) - st.session_state["input_audio"] = input_audio - f.close() - def render_settings(self): lipsync_settings(st.session_state.get("selected_model")) def run(self, state: dict) -> typing.Iterator[str | None]: request = self.RequestModel.parse_obj(state) + if self.is_current_user_paying() or self.is_current_user_admin(): + truncate_to_seconds = None + state["truncated"] = False + else: + truncate_to_seconds = 10 + state["truncated"] = True + model = LipsyncModel[request.selected_model] yield f"Running {model.value}..." match model: case LipsyncModel.Wav2Lip: - state["output_video"] = run_wav2lip( + state["output_video"], state["seconds"] = run_wav2lip( face=request.input_face, audio=request.input_audio, pads=( @@ -146,12 +94,14 @@ def run(self, state: dict) -> typing.Iterator[str | None]: request.face_padding_left or 0, request.face_padding_right or 0, ), + truncate_to_seconds=truncate_to_seconds, ) case LipsyncModel.SadTalker: - state["output_video"] = run_sadtalker( + state["output_video"], state["seconds"] = run_sadtalker( request.sadtalker_settings, face=request.input_face, audio=request.input_audio, + truncate_to_seconds=truncate_to_seconds, ) def render_example(self, state: dict): @@ -161,6 +111,12 @@ def render_example(self, state: dict): st.video(output_video, autoplay=True, show_download_button=True) else: st.div() + if state.get("truncated"): + st.error( + "Audio durations greater than 10 seconds will be truncated for free users. Please upgrade to process longer audio files.", + icon="⚠️", + color="orange", + ) def render_output(self): self.render_example(st.session_state) @@ -190,10 +146,12 @@ def get_cost_note(self) -> str | None: def get_raw_price(self, state: dict) -> float: from math import ceil - input_audio = state.get("input_audio") - seconds = get_audio_duration(input_audio) if input_audio else 0 + seconds = self.get_duration(state) seconds = ceil(seconds / 5) * 5 # round up to nearest 5 seconds multiplier = ( 2 if state.get("selected_model") == LipsyncModel.SadTalker.name else 1 ) return seconds * CREDITS_PER_MINUTE * multiplier / 60 + + def get_duration(self, state: dict) -> float: + return state.get("seconds", 0.0) From eaca3a3a816767b2b8997838fab68d85daf7967e Mon Sep 17 00:00:00 2001 From: Alexander Metzger Date: Tue, 2 Jul 2024 13:09:28 -0700 Subject: [PATCH 008/110] default value for seconds --- recipes/Lipsync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 31f5c6b49..471dbfe9f 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -31,7 +31,7 @@ class RequestModel(LipsyncSettings, BaseModel): class ResponseModel(BaseModel): output_video: FieldHttpUrl - seconds: float + seconds: float = 0 truncated: bool = False def preview_image(self, state: dict) -> str | None: From 0a0eec8bf6c40570eda28366608a6e9734269974 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 27 Aug 2024 19:59:35 +0530 Subject: [PATCH 009/110] fix subscription upgrade on indian cards: cancel & create new subscription on failure --- daras_ai_v2/billing.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 0eb6fa37a..0254a89fc 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -1,4 +1,5 @@ import gooey_gui as gui +import sentry_sdk import stripe from django.core.exceptions import ValidationError @@ -227,12 +228,22 @@ def _render_plan_action_button( ): modal.open() if confirmed: - change_subscription( - user, - plan, - # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time - billing_cycle_anchor="now", - ) + try: + change_subscription( + user, + plan, + # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time + billing_cycle_anchor="now", + payment_behavior="error_if_incomplete", + ) + except (stripe.CardError, stripe.InvalidRequestError) as e: + if isinstance(e, stripe.InvalidRequestError): + sentry_sdk.capture_exception(e) + + # only handle error if it's related to mandates + # cancel current subscription & redirect user to new subscription page + user.subscription.cancel() + stripe_subscription_create(user=user, plan=plan) else: modal, confirmed = confirm_modal( title="Downgrade Plan", From bd6e730e8eff49ffcd7b1abc54da38ee8d1cfc39 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 28 Aug 2024 09:56:39 +0530 Subject: [PATCH 010/110] feat: send notification when payment of auto-recharge/susbcription invoice fails --- daras_ai_v2/send_email.py | 2 +- daras_ai_v2/settings.py | 2 +- payments/tasks.py | 33 +++++++++++++++++ payments/webhooks.py | 35 ++++++++++++++++++- routers/stripe.py | 2 ++ templates/base_email.html | 22 ++++++++++++ .../off_session_payment_failed_email.html | 25 +++++++++++++ 7 files changed, 118 insertions(+), 3 deletions(-) create mode 100644 templates/base_email.html create mode 100644 templates/off_session_payment_failed_email.html diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py index a9ff1934d..84741c2dc 100644 --- a/daras_ai_v2/send_email.py +++ b/daras_ai_v2/send_email.py @@ -76,7 +76,7 @@ def send_email_via_postmark( html_body: str = "", text_body: str = "", message_stream: typing.Literal[ - "outbound", "gooey-ai-workflows", "announcements" + "outbound", "gooey-ai-workflows", "announcements", "billing" ] = "outbound", ): if is_running_pytest: diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 88815ea5f..501a132bd 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -11,7 +11,6 @@ """ import os -import json from pathlib import Path import sentry_sdk @@ -265,6 +264,7 @@ ADMIN_EMAILS = config("ADMIN_EMAILS", cast=Csv(), default="") SUPPORT_EMAIL = "Gooey.AI Support " SALES_EMAIL = "Gooey.AI Sales " +PAYMENT_EMAIL = "Gooey.AI Payments " SEND_RUN_EMAIL_AFTER_SEC = config("SEND_RUN_EMAIL_AFTER_SEC", 5) DISALLOWED_TITLE_SLUGS = config("DISALLOWED_TITLE_SLUGS", cast=Csv(), default="") + [ diff --git a/payments/tasks.py b/payments/tasks.py index 252064541..7627acfef 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -1,3 +1,6 @@ +from typing import Literal + +import stripe from django.utils import timezone from loguru import logger @@ -39,6 +42,36 @@ def send_monthly_spending_notification_email(user_id: int): user.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) +@app.task +def send_payment_failed_email_with_invoice( + uid: str, + invoice_url: str, + dollar_amt: float, + kind: Literal["subscription", "auto recharge"], +): + from routers.account import account_route + + user = AppUser.objects.get(uid=uid) + if not user.email: + logger.error(f"User doesn't have an email: {user=}") + return + + send_email_via_postmark( + from_address=settings.PAYMENT_EMAIL, + to_address=user.email, + subject=f"Payment failure on your Gooey.AI {kind}", + html_body=templates.get_template( + "off_session_payment_failed_email.html" + ).render( + user=user, + dollar_amt=f"{dollar_amt:.2f}", + invoice_url=invoice_url, + account_url=get_app_route_url(account_route), + ), + message_stream="billing", + ) + + def send_monthly_budget_reached_email(user: AppUser): from routers.account import account_route diff --git a/payments/webhooks.py b/payments/webhooks.py index 0b822cfe7..79c788f19 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -12,7 +12,10 @@ from daras_ai_v2 import paypal from .models import Subscription from .plans import PricingPlan -from .tasks import send_monthly_spending_notification_email +from .tasks import ( + send_monthly_spending_notification_email, + send_payment_failed_email_with_invoice, +) class PaypalWebhookHandler: @@ -189,6 +192,36 @@ def handle_subscription_cancelled(cls, uid: str): external_id=None, ) + @classmethod + def handle_invoice_failed(cls, uid: str, data: dict): + logger.info(f"Invoice failed: {data}") + + if stripe.Charge.list(payment_intent=data["payment_intent"], limit=1).has_more: + # we must have already sent an invoice for this to the user. so we should just ignore this event + logger.info("Charge already exists for this payment intent") + return + + if data.get("metadata", {}).get("auto_recharge"): + logger.info("auto recharge failed... sending invoice email") + send_payment_failed_email_with_invoice.delay( + uid=uid, + invoice_url=data["hosted_invoice_url"], + dollar_amt=data["amount_due"] / 100, + kind="auto recharge", + ) + elif data.get("subscription_details", {}): + print("subscription failed") + send_payment_failed_email_with_invoice.delay( + uid=uid, + invoice_url=data["hosted_invoice_url"], + dollar_amt=data["amount_due"] / 100, + kind="subscription", + ) + else: + print("not auto recharge or subscription") + print(f"{data.get('metadata')=}") + return + def add_balance_for_payment( *, diff --git a/routers/stripe.py b/routers/stripe.py index aa948ba98..7b1534ecc 100644 --- a/routers/stripe.py +++ b/routers/stripe.py @@ -42,6 +42,8 @@ def webhook_received(request: Request, payload: bytes = fastapi_request_body): match event["type"]: case "invoice.paid": StripeWebhookHandler.handle_invoice_paid(uid, data) + case "invoice.payment_failed": + StripeWebhookHandler.handle_invoice_failed(uid, data) case "checkout.session.completed": StripeWebhookHandler.handle_checkout_session_completed(uid, data) case "customer.subscription.created" | "customer.subscription.updated": diff --git a/templates/base_email.html b/templates/base_email.html new file mode 100644 index 000000000..63ab8b012 --- /dev/null +++ b/templates/base_email.html @@ -0,0 +1,22 @@ + + + + + + + + {% block title %}{% endblock title %} + + {% block head %}{% endblock head %} + + + + + +
+ {% block content %}{% endblock content %} +
+ + + + diff --git a/templates/off_session_payment_failed_email.html b/templates/off_session_payment_failed_email.html new file mode 100644 index 000000000..c8624fec5 --- /dev/null +++ b/templates/off_session_payment_failed_email.html @@ -0,0 +1,25 @@ +{% extends 'base_email.html' %} + +{% block title %}Payment failed{% endblock title %} + +{% block content %} +

Hi {{ user.first_name() }},

+ +

We attempted to process your payment for ${{ dollar_amt }} but your payment method was declined.

+ +

+ Please make a payment on Gooey.AI for continued service or update + your payment method on your account. +

+ +

+ + + +

+ +

+ Cheers,
+ The Gooey.AI team +

+{% endblock content %} From c539af542adda4fe82d0f5688b8598131f82716d Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 28 Aug 2024 10:00:35 +0530 Subject: [PATCH 011/110] make subject logic for payment failure notification email simpler --- payments/tasks.py | 7 ++----- payments/webhooks.py | 4 ++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/payments/tasks.py b/payments/tasks.py index 7627acfef..f0ee87fce 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -1,6 +1,3 @@ -from typing import Literal - -import stripe from django.utils import timezone from loguru import logger @@ -47,7 +44,7 @@ def send_payment_failed_email_with_invoice( uid: str, invoice_url: str, dollar_amt: float, - kind: Literal["subscription", "auto recharge"], + subject: str, ): from routers.account import account_route @@ -59,7 +56,7 @@ def send_payment_failed_email_with_invoice( send_email_via_postmark( from_address=settings.PAYMENT_EMAIL, to_address=user.email, - subject=f"Payment failure on your Gooey.AI {kind}", + subject=subject, html_body=templates.get_template( "off_session_payment_failed_email.html" ).render( diff --git a/payments/webhooks.py b/payments/webhooks.py index 79c788f19..2c1820065 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -207,7 +207,7 @@ def handle_invoice_failed(cls, uid: str, data: dict): uid=uid, invoice_url=data["hosted_invoice_url"], dollar_amt=data["amount_due"] / 100, - kind="auto recharge", + subject="Payment failure on your Gooey.AI auto-recharge", ) elif data.get("subscription_details", {}): print("subscription failed") @@ -215,7 +215,7 @@ def handle_invoice_failed(cls, uid: str, data: dict): uid=uid, invoice_url=data["hosted_invoice_url"], dollar_amt=data["amount_due"] / 100, - kind="subscription", + subject="Payment failure on your Gooey.AI subscription", ) else: print("not auto recharge or subscription") From d456f0cc99dc61db523fca6d697055b877a43223 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 30 Aug 2024 11:29:33 +0530 Subject: [PATCH 012/110] fix google translate test --- tests/test_translation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_translation.py b/tests/test_translation.py index dcf7a2568..3581f69eb 100644 --- a/tests/test_translation.py +++ b/tests/test_translation.py @@ -7,7 +7,7 @@ ( "hi", "Hi Sir Mera khet me mircha ke ped me fal gal Kar gir hai to iske liye ham kon sa dawa de please help me", - "hi sir in my field the fruits of chilli tree are rotting and falling so which medicine should i give for this please help", + "hi sir the fruits on the chilli tree in my field have fallen down what medicine should i give you? please help me", ), ( "hi", From aa6c3dc98fbf50e2171c50702b5541053af2268d Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 30 Aug 2024 16:56:09 +0530 Subject: [PATCH 013/110] include all example urls in sitemap --- routers/root.py | 47 +++++++++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/routers/root.py b/routers/root.py index 93d889b36..b3a636207 100644 --- a/routers/root.py +++ b/routers/root.py @@ -22,7 +22,7 @@ ) from app_users.models import AppUser -from bots.models import Workflow, BotIntegration +from bots.models import Workflow, BotIntegration, PublishedRun from daras_ai.image_input import upload_file_from_bytes, safe_filename from daras_ai_v2 import settings, icons from daras_ai_v2.api_examples_widget import api_example_generator @@ -52,24 +52,39 @@ @app.get("/sitemap.xml/") -async def get_sitemap(): - from daras_ai_v2.all_pages import all_api_pages - +def get_sitemap(): my_sitemap = """ - """ - - all_paths = ["/", "/faq", "/pricing", "/privacy", "/terms", "/team/"] + [ - page.slug_versions[-1] for page in all_api_pages + """ + + all_urls = [ + furl(settings.APP_BASE_URL) / path + for path in [ + "/", + "/faq", + "/pricing", + "/privacy", + "/terms", + "/team", + "/jobs", + "/farmerchat", + "/contact", + "/impact", + "/explore", + "/api", + ] + ] + [ + pr.get_app_url() + for pr in ( + PublishedRun.objects.filter(is_approved_example=True).order_by("workflow") + ) ] - - for path in all_paths: - url = furl(settings.APP_BASE_URL) / path + for url in all_urls: my_sitemap += f""" - {url} - 2022-12-26 - daily - 1.0 - """ + {url} + {datetime.datetime.today().strftime("%Y-%m-%d")} + daily + 1.0 + """ my_sitemap += """""" From 7dfbe6c52110310470b90d78299a202244ac234f Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Sat, 31 Aug 2024 14:00:11 +0530 Subject: [PATCH 014/110] fix: AttributeError: 'NoneType' object has no attribute 'saved_run' -- use root run if saved run has no parent version --- daras_ai_v2/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 787f4ce8b..2233a0803 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1100,7 +1100,7 @@ def get_pr_from_query_params( ) -> PublishedRun | None: if run_id and uid: sr = cls.get_sr_from_query_params(example_id, run_id, uid) - return sr.parent_published_run() + return sr.parent_published_run() or cls.get_root_published_run() elif example_id: return cls.get_published_run(published_run_id=example_id) else: From 8274b3d4513ce4fcc63a125f386fd582d24b029c Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 2 Sep 2024 16:44:27 +0530 Subject: [PATCH 015/110] Add ChatGPT-4o, latest GPT-4o and Gemini 1.5 Flash models, support JSON mode on Gemini and Claude --- daras_ai_v2/language_model.py | 86 +++++++++++++++-- poetry.lock | 94 ++++++++++++++++--- pyproject.toml | 4 +- scripts/init_llm_pricing.py | 37 +++++++- .../0018_alter_modelpricing_model_name.py | 18 ++++ 5 files changed, 211 insertions(+), 28 deletions(-) create mode 100644 usage_costs/migrations/0018_alter_modelpricing_model_name.py diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 533be6410..4e1088bba 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -1,4 +1,5 @@ import base64 +import json import mimetypes import re import typing @@ -75,7 +76,7 @@ class LargeLanguageModels(Enum): # https://platform.openai.com/docs/models/gpt-4o gpt_4_o = LLMSpec( label="GPT-4o (openai)", - model_id=("openai-gpt-4o-prod-eastus2-1", "gpt-4o"), + model_id="gpt-4o-2024-08-06", llm_api=LLMApis.openai, context_window=128_000, price=10, @@ -92,6 +93,14 @@ class LargeLanguageModels(Enum): is_vision_model=True, supports_json=True, ) + chatgpt_4_o = LLMSpec( + label="ChatGPT-4o (openai) 🧪", + model_id="chatgpt-4o-latest", + llm_api=LLMApis.openai, + context_window=128_000, + price=10, + is_vision_model=True, + ) # https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4 gpt_4_turbo_vision = LLMSpec( label="GPT-4 Turbo with Vision (openai)", @@ -232,13 +241,23 @@ class LargeLanguageModels(Enum): ) # https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models + gemini_1_5_flash = LLMSpec( + label="Gemini 1.5 Flash (Google)", + model_id="gemini-1.5-flash", + llm_api=LLMApis.gemini, + context_window=1_048_576, + price=15, + is_vision_model=True, + supports_json=True, + ) gemini_1_5_pro = LLMSpec( label="Gemini 1.5 Pro (Google)", - model_id="gemini-1.5-pro-preview-0409", + model_id="gemini-1.5-pro", llm_api=LLMApis.gemini, - context_window=1_000_000, + context_window=2_097_152, price=15, is_vision_model=True, + supports_json=True, ) gemini_1_pro_vision = LLMSpec( label="Gemini 1.0 Pro Vision (Google)", @@ -280,6 +299,7 @@ class LargeLanguageModels(Enum): context_window=200_000, price=15, is_vision_model=True, + supports_json=True, ) claude_3_opus = LLMSpec( label="Claude 3 Opus [L] (Anthropic)", @@ -288,6 +308,7 @@ class LargeLanguageModels(Enum): context_window=200_000, price=75, is_vision_model=True, + supports_json=True, ) claude_3_sonnet = LLMSpec( label="Claude 3 Sonnet [M] (Anthropic)", @@ -296,6 +317,7 @@ class LargeLanguageModels(Enum): context_window=200_000, price=15, is_vision_model=True, + supports_json=True, ) claude_3_haiku = LLMSpec( label="Claude 3 Haiku [S] (Anthropic)", @@ -304,6 +326,7 @@ class LargeLanguageModels(Enum): context_window=200_000, price=2, is_vision_model=True, + supports_json=True, ) sea_lion_7b_instruct = LLMSpec( @@ -666,6 +689,7 @@ def _run_chat_model( messages=messages, max_output_tokens=min(max_tokens, 1024), # because of Vertex AI limits temperature=temperature, + response_format_type=response_format_type, ) case LLMApis.palm2: if tools: @@ -696,6 +720,7 @@ def _run_chat_model( max_tokens=max_tokens, temperature=temperature, stop=stop, + response_format_type=response_format_type, ) case LLMApis.self_hosted: return [ @@ -785,6 +810,7 @@ def _run_anthropic_chat( max_tokens: int, temperature: float, stop: list[str] | None, + response_format_type: ResponseFormatType | None, ): import anthropic from usage_costs.cost_utils import record_cost_auto @@ -818,6 +844,27 @@ def _run_anthropic_chat( content = get_entry_text(msg) anthropic_msgs.append({"role": role, "content": content}) + if response_format_type == "json_object": + kwargs = dict( + tools=[ + { + "name": "json_output", + "input_schema": { + "type": "object", + "properties": { + "response": { + "type": "object", + "description": "The response to the user's prompt as a JSON object.", + }, + }, + }, + } + ], + tool_choice={"type": "tool", "name": "json_output"}, + ) + else: + kwargs = {} + client = anthropic.Anthropic() response = client.messages.create( model=model, @@ -826,6 +873,7 @@ def _run_anthropic_chat( messages=anthropic_msgs, stop_sequences=stop, temperature=temperature, + **kwargs, ) record_cost_auto( @@ -839,9 +887,21 @@ def _run_anthropic_chat( quantity=response.usage.output_tokens, ) + if response_format_type == "json_object": + for entry in response.content: + if entry.type == "tool_use": + response = entry.input + if isinstance(response, dict): + response = response.get("response", {}) + return [ + { + "role": CHATML_ROLE_ASSISTANT, + "content": json.dumps(response), + } + ] return [ { - "role": CHATML_ROLE_USER, + "role": CHATML_ROLE_ASSISTANT, "content": "".join(entry.text for entry in response.content), } ] @@ -1212,6 +1272,7 @@ def _run_gemini_pro( messages: list[ConversationEntry], max_output_tokens: int, temperature: float, + response_format_type: ResponseFormatType | None, ): contents = [] for entry in messages: @@ -1244,6 +1305,7 @@ def _run_gemini_pro( contents=contents, max_output_tokens=max_output_tokens, temperature=temperature, + response_format_type=response_format_type, ) return [{"role": CHATML_ROLE_ASSISTANT, "content": msg}] @@ -1292,18 +1354,22 @@ def _call_gemini_api( contents: list[dict], max_output_tokens: int, temperature: float, - stop: list[str] = None, + stop: list[str] | None = None, + response_format_type: ResponseFormatType | None = None, ) -> str: session, project = get_google_auth_session() + generation_config = { + "temperature": temperature, + "maxOutputTokens": max_output_tokens, + "stopSequences": stop or [], + } + if response_format_type == "json_object": + generation_config["response_mime_type"] = "application/json" r = session.post( f"https://{settings.GCP_REGION}-aiplatform.googleapis.com/v1/projects/{project}/locations/{settings.GCP_REGION}/publishers/google/models/{model_id}:generateContent", json={ "contents": contents, - "generation_config": { - "temperature": temperature, - "maxOutputTokens": max_output_tokens, - "stopSequences": stop or [], - }, + "generation_config": generation_config, }, ) raise_for_status(r) diff --git a/poetry.lock b/poetry.lock index 73fc4dd20..9950129f5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -205,19 +205,20 @@ vine = ">=5.0.0,<6.0.0" [[package]] name = "anthropic" -version = "0.25.5" +version = "0.34.1" description = "The official Python library for the anthropic API" optional = false python-versions = ">=3.7" files = [ - {file = "anthropic-0.25.5-py3-none-any.whl", hash = "sha256:8665a8aee45be6a1f0664b2f8fd740f5b60d5a88fab62f0e647105d769a5c9dd"}, - {file = "anthropic-0.25.5.tar.gz", hash = "sha256:bc64a17f18a967fae4254bd7464f5c4d39951dacceff22e823434216c4731e38"}, + {file = "anthropic-0.34.1-py3-none-any.whl", hash = "sha256:2fa26710809d0960d970f26cd0be3686437250a481edb95c33d837aa5fa24158"}, + {file = "anthropic-0.34.1.tar.gz", hash = "sha256:69e822bd7a31ec11c2edb85f2147e8f0ee0cfd3288fea70b0ca8808b2f9bf91d"}, ] [package.dependencies] anyio = ">=3.5.0,<5" distro = ">=1.7.0,<2" httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" pydantic = ">=1.9.0,<3" sniffio = "*" tokenizers = ">=0.13.0" @@ -2497,6 +2498,76 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "jiter" +version = "0.5.0" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jiter-0.5.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b599f4e89b3def9a94091e6ee52e1d7ad7bc33e238ebb9c4c63f211d74822c3f"}, + {file = "jiter-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a063f71c4b06225543dddadbe09d203dc0c95ba352d8b85f1221173480a71d5"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:acc0d5b8b3dd12e91dd184b87273f864b363dfabc90ef29a1092d269f18c7e28"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c22541f0b672f4d741382a97c65609332a783501551445ab2df137ada01e019e"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63314832e302cc10d8dfbda0333a384bf4bcfce80d65fe99b0f3c0da8945a91a"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a25fbd8a5a58061e433d6fae6d5298777c0814a8bcefa1e5ecfff20c594bd749"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:503b2c27d87dfff5ab717a8200fbbcf4714516c9d85558048b1fc14d2de7d8dc"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6d1f3d27cce923713933a844872d213d244e09b53ec99b7a7fdf73d543529d6d"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c95980207b3998f2c3b3098f357994d3fd7661121f30669ca7cb945f09510a87"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:afa66939d834b0ce063f57d9895e8036ffc41c4bd90e4a99631e5f261d9b518e"}, + {file = "jiter-0.5.0-cp310-none-win32.whl", hash = "sha256:f16ca8f10e62f25fd81d5310e852df6649af17824146ca74647a018424ddeccf"}, + {file = "jiter-0.5.0-cp310-none-win_amd64.whl", hash = "sha256:b2950e4798e82dd9176935ef6a55cf6a448b5c71515a556da3f6b811a7844f1e"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d4c8e1ed0ef31ad29cae5ea16b9e41529eb50a7fba70600008e9f8de6376d553"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c6f16e21276074a12d8421692515b3fd6d2ea9c94fd0734c39a12960a20e85f3"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5280e68e7740c8c128d3ae5ab63335ce6d1fb6603d3b809637b11713487af9e6"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:583c57fc30cc1fec360e66323aadd7fc3edeec01289bfafc35d3b9dcb29495e4"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26351cc14507bdf466b5f99aba3df3143a59da75799bf64a53a3ad3155ecded9"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4829df14d656b3fb87e50ae8b48253a8851c707da9f30d45aacab2aa2ba2d614"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a42a4bdcf7307b86cb863b2fb9bb55029b422d8f86276a50487982d99eed7c6e"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04d461ad0aebf696f8da13c99bc1b3e06f66ecf6cfd56254cc402f6385231c06"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e6375923c5f19888c9226582a124b77b622f8fd0018b843c45eeb19d9701c403"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2cec323a853c24fd0472517113768c92ae0be8f8c384ef4441d3632da8baa646"}, + {file = "jiter-0.5.0-cp311-none-win32.whl", hash = "sha256:aa1db0967130b5cab63dfe4d6ff547c88b2a394c3410db64744d491df7f069bb"}, + {file = "jiter-0.5.0-cp311-none-win_amd64.whl", hash = "sha256:aa9d2b85b2ed7dc7697597dcfaac66e63c1b3028652f751c81c65a9f220899ae"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9f664e7351604f91dcdd557603c57fc0d551bc65cc0a732fdacbf73ad335049a"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:044f2f1148b5248ad2c8c3afb43430dccf676c5a5834d2f5089a4e6c5bbd64df"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:702e3520384c88b6e270c55c772d4bd6d7b150608dcc94dea87ceba1b6391248"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:528d742dcde73fad9d63e8242c036ab4a84389a56e04efd854062b660f559544"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8cf80e5fe6ab582c82f0c3331df27a7e1565e2dcf06265afd5173d809cdbf9ba"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:44dfc9ddfb9b51a5626568ef4e55ada462b7328996294fe4d36de02fce42721f"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c451f7922992751a936b96c5f5b9bb9312243d9b754c34b33d0cb72c84669f4e"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:308fce789a2f093dca1ff91ac391f11a9f99c35369117ad5a5c6c4903e1b3e3a"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7f5ad4a7c6b0d90776fdefa294f662e8a86871e601309643de30bf94bb93a64e"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ea189db75f8eca08807d02ae27929e890c7d47599ce3d0a6a5d41f2419ecf338"}, + {file = "jiter-0.5.0-cp312-none-win32.whl", hash = "sha256:e3bbe3910c724b877846186c25fe3c802e105a2c1fc2b57d6688b9f8772026e4"}, + {file = "jiter-0.5.0-cp312-none-win_amd64.whl", hash = "sha256:a586832f70c3f1481732919215f36d41c59ca080fa27a65cf23d9490e75b2ef5"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f04bc2fc50dc77be9d10f73fcc4e39346402ffe21726ff41028f36e179b587e6"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6f433a4169ad22fcb550b11179bb2b4fd405de9b982601914ef448390b2954f3"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad4a6398c85d3a20067e6c69890ca01f68659da94d74c800298581724e426c7e"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6baa88334e7af3f4d7a5c66c3a63808e5efbc3698a1c57626541ddd22f8e4fbf"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ece0a115c05efca597c6d938f88c9357c843f8c245dbbb53361a1c01afd7148"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:335942557162ad372cc367ffaf93217117401bf930483b4b3ebdb1223dbddfa7"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649b0ee97a6e6da174bffcb3c8c051a5935d7d4f2f52ea1583b5b3e7822fbf14"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4be354c5de82157886ca7f5925dbda369b77344b4b4adf2723079715f823989"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5206144578831a6de278a38896864ded4ed96af66e1e63ec5dd7f4a1fce38a3a"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8120c60f8121ac3d6f072b97ef0e71770cc72b3c23084c72c4189428b1b1d3b6"}, + {file = "jiter-0.5.0-cp38-none-win32.whl", hash = "sha256:6f1223f88b6d76b519cb033a4d3687ca157c272ec5d6015c322fc5b3074d8a5e"}, + {file = "jiter-0.5.0-cp38-none-win_amd64.whl", hash = "sha256:c59614b225d9f434ea8fc0d0bec51ef5fa8c83679afedc0433905994fb36d631"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0af3838cfb7e6afee3f00dc66fa24695199e20ba87df26e942820345b0afc566"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:550b11d669600dbc342364fd4adbe987f14d0bbedaf06feb1b983383dcc4b961"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:489875bf1a0ffb3cb38a727b01e6673f0f2e395b2aad3c9387f94187cb214bbf"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b250ca2594f5599ca82ba7e68785a669b352156260c5362ea1b4e04a0f3e2389"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ea18e01f785c6667ca15407cd6dabbe029d77474d53595a189bdc813347218e"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:462a52be85b53cd9bffd94e2d788a09984274fe6cebb893d6287e1c296d50653"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92cc68b48d50fa472c79c93965e19bd48f40f207cb557a8346daa020d6ba973b"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1c834133e59a8521bc87ebcad773608c6fa6ab5c7a022df24a45030826cf10bc"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab3a71ff31cf2d45cb216dc37af522d335211f3a972d2fe14ea99073de6cb104"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cccd3af9c48ac500c95e1bcbc498020c87e1781ff0345dd371462d67b76643eb"}, + {file = "jiter-0.5.0-cp39-none-win32.whl", hash = "sha256:368084d8d5c4fc40ff7c3cc513c4f73e02c85f6009217922d0823a48ee7adf61"}, + {file = "jiter-0.5.0-cp39-none-win_amd64.whl", hash = "sha256:ce03f7b4129eb72f1687fa11300fbf677b02990618428934662406d2a76742a1"}, + {file = "jiter-0.5.0.tar.gz", hash = "sha256:1d916ba875bcab5c5f7d927df998c4cb694d27dceddf3392e58beaf10563368a"}, +] + [[package]] name = "jsonschema" version = "4.19.2" @@ -3377,23 +3448,24 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] [[package]] name = "openai" -version = "1.35.7" +version = "1.43.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.35.7-py3-none-any.whl", hash = "sha256:3d1e0b0aac9b0db69a972d36dc7efa7563f8e8d65550b27a48f2a0c2ec207e80"}, - {file = "openai-1.35.7.tar.gz", hash = "sha256:009bfa1504c9c7ef64d87be55936d142325656bbc6d98c68b669d6472e4beb09"}, + {file = "openai-1.43.0-py3-none-any.whl", hash = "sha256:1a748c2728edd3a738a72a0212ba866f4fdbe39c9ae03813508b267d45104abe"}, + {file = "openai-1.43.0.tar.gz", hash = "sha256:e607aff9fc3e28eade107e5edd8ca95a910a4b12589336d3cbb6bfe2ac306b3c"}, ] [package.dependencies] anyio = ">=3.5.0,<5" distro = ">=1.7.0,<2" httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" pydantic = ">=1.9.0,<3" sniffio = "*" tqdm = ">4" -typing-extensions = ">=4.7,<5" +typing-extensions = ">=4.11,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] @@ -5845,13 +5917,13 @@ requests = ">=2.0.0" [[package]] name = "typing-extensions" -version = "4.8.0" +version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, - {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -6466,4 +6538,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "3955eb5901ce23cc6e25cf4d45c9a742d830ea2a63a60000e1cfc1d93c6299a6" +content-hash = "5834cb5e676e83b492e8aec5d9efa15bed653848f6d356c139917e1a1b01e872" diff --git a/pyproject.toml b/pyproject.toml index d44685d2f..90859e035 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ package-mode = false [tool.poetry.dependencies] python = ">=3.10,<3.13" streamlit = "^1.15.2" -openai = "^1.35.7" +openai = "^1.43.0" python-decouple = "^3.6" requests = "^2.28.1" glom = "^22.1.0" @@ -82,7 +82,7 @@ aifail = "^0.3.0" pytest-playwright = "^0.4.3" emoji = "^2.10.1" pyvespa = "^0.39.0" -anthropic = "^0.25.5" +anthropic = "^0.34.1" azure-cognitiveservices-speech = "^1.37.0" twilio = "^9.2.3" sentry-sdk = {version = "1.45.0", extras = ["loguru"]} diff --git a/scripts/init_llm_pricing.py b/scripts/init_llm_pricing.py index 7c22b2c65..a8be03a63 100644 --- a/scripts/init_llm_pricing.py +++ b/scripts/init_llm_pricing.py @@ -19,6 +19,24 @@ def run(): # GPT-4o + llm_pricing_create( + model_id="chatgpt-4o-latest", + model_name=LargeLanguageModels.chatgpt_4_o.name, + unit_cost_input=5, + unit_cost_output=15, + unit_quantity=10**6, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) + llm_pricing_create( + model_id="gpt-4o-2024-08-06", + model_name=LargeLanguageModels.gpt_4_o.name, + unit_cost_input=2.5, + unit_cost_output=10, + unit_quantity=10**6, + provider=ModelProvider.openai, + pricing_url="https://openai.com/pricing", + ) llm_pricing_create( model_id="gpt-4o", model_name=LargeLanguageModels.gpt_4_o.name, @@ -410,13 +428,22 @@ def run(): # Gemini llm_pricing_create( - model_id="gemini-1.5-pro-preview-0409", + model_id="gemini-1.5-flash", + model_name=LargeLanguageModels.gemini_1_5_flash.name, + unit_cost_input=0.075, + unit_cost_output=0.30, + unit_quantity=10**6, + provider=ModelProvider.google, + pricing_url="https://ai.google.dev/pricing", + ) + llm_pricing_create( + model_id="gemini-1.5-pro", model_name=LargeLanguageModels.gemini_1_5_pro.name, - unit_cost_input=0.0025, - unit_cost_output=0.0075, - unit_quantity=1000, + unit_cost_input=3.50, + unit_cost_output=10.50, + unit_quantity=10**6, provider=ModelProvider.google, - pricing_url="https://cloud.google.com/vertex-ai/docs/generative-ai/pricing#text_generation", + pricing_url="https://ai.google.dev/pricing", ) ModelPricing.objects.get_or_create( diff --git a/usage_costs/migrations/0018_alter_modelpricing_model_name.py b/usage_costs/migrations/0018_alter_modelpricing_model_name.py new file mode 100644 index 000000000..58aa73ce4 --- /dev/null +++ b/usage_costs/migrations/0018_alter_modelpricing_model_name.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-09-02 11:14 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('usage_costs', '0017_alter_modelpricing_model_name'), + ] + + operations = [ + migrations.AlterField( + model_name='modelpricing', + name='model_name', + field=models.CharField(choices=[('gpt_4_o', 'GPT-4o (openai)'), ('gpt_4_o_mini', 'GPT-4o-mini (openai)'), ('chatgpt_4_o', 'ChatGPT-4o (openai) 🧪'), ('gpt_4_turbo_vision', 'GPT-4 Turbo with Vision (openai)'), ('gpt_4_vision', 'GPT-4 Vision (openai) 🔻'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai) 🔻'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('gpt_3_5_turbo_instruct', 'GPT-3.5 Instruct (openai) 🔻'), ('llama3_70b', 'Llama 3 70b (Meta AI)'), ('llama_3_groq_70b_tool_use', 'Llama 3 Groq 70b Tool Use'), ('llama3_8b', 'Llama 3 8b (Meta AI)'), ('llama_3_groq_8b_tool_use', 'Llama 3 Groq 8b Tool Use'), ('llama2_70b_chat', 'Llama 2 70b Chat [Deprecated] (Meta AI)'), ('mixtral_8x7b_instruct_0_1', 'Mixtral 8x7b Instruct v0.1 (Mistral)'), ('gemma_2_9b_it', 'Gemma 2 9B (Google)'), ('gemma_7b_it', 'Gemma 7B (Google)'), ('gemini_1_5_flash', 'Gemini 1.5 Flash (Google)'), ('gemini_1_5_pro', 'Gemini 1.5 Pro (Google)'), ('gemini_1_pro_vision', 'Gemini 1.0 Pro Vision (Google)'), ('gemini_1_pro', 'Gemini 1.0 Pro (Google)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('claude_3_5_sonnet', 'Claude 3.5 Sonnet (Anthropic)'), ('claude_3_opus', 'Claude 3 Opus [L] (Anthropic)'), ('claude_3_sonnet', 'Claude 3 Sonnet [M] (Anthropic)'), ('claude_3_haiku', 'Claude 3 Haiku [S] (Anthropic)'), ('sea_lion_7b_instruct', 'SEA-LION-7B-Instruct [Deprecated] (aisingapore)'), ('llama3_8b_cpt_sea_lion_v2_instruct', 'Llama3 8B CPT SEA-LIONv2 Instruct (aisingapore)'), ('sarvam_2b', 'Sarvam 2B (sarvamai)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 [Deprecated] (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 [Deprecated] (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('text_curie_001', 'Curie [Deprecated] (openai)'), ('text_babbage_001', 'Babbage [Deprecated] (openai)'), ('text_ada_001', 'Ada [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'DALL·E 2 (OpenAI)'), ('dall_e_3', 'DALL·E 3 (OpenAI)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero)'), ('openjourney', 'Open Journey (PromptHero)'), ('analog_diffusion', 'Analog Diffusion (wavymulder)'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('deepfloyd_if', 'DeepFloyd IF [Deprecated] (stability.ai)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('instruct_pix2pix', '✨ InstructPix2Pix (Tim Brooks)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐢'), ('openjourney', 'Open Journey (PromptHero) 🐢'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐢'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐢'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('runway_ml', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('wav2lip', 'LipSync (wav2lip)')], help_text='The name of the model. Only used for Display purposes.', max_length=255), + ), + ] From 813af40c18c75993d8c433cbe8f66c827d5b6fd5 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 2 Sep 2024 17:04:48 +0530 Subject: [PATCH 016/110] better error handling of claude's tool use json output mode --- daras_ai_v2/language_model.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 4e1088bba..d6eb526d2 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -888,17 +888,31 @@ def _run_anthropic_chat( ) if response_format_type == "json_object": + if response.stop_reason == "max_tokens": + raise UserError( + "Claude’s response got cut off due to hitting the max_tokens limit, and the truncated response contains an incomplete tool use block. " + "Please retry the request with a higher max_tokens value to get the full tool use. " + ) from anthropic.AnthropicError( + f"Hit {response.stop_reason=} when generating JSON: {response.content=}" + ) + if response.stop_reason != "tool_use": + raise UserError( + f"Claude was unable to generate a JSON response. Please retry the request with a different prompt, or try a different model." + ) from anthropic.AnthropicError( + f"Failed to generate JSON response: {response.stop_reason=} {response.content}" + ) for entry in response.content: - if entry.type == "tool_use": - response = entry.input - if isinstance(response, dict): - response = response.get("response", {}) - return [ - { - "role": CHATML_ROLE_ASSISTANT, - "content": json.dumps(response), - } - ] + if entry.type != "tool_use": + continue + response = entry.input + if isinstance(response, dict): + response = response.get("response", {}) + return [ + { + "role": CHATML_ROLE_ASSISTANT, + "content": json.dumps(response), + } + ] return [ { "role": CHATML_ROLE_ASSISTANT, From abddbff347c470b226b9ad182538cdcff61d352b Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 2 Sep 2024 17:39:03 +0530 Subject: [PATCH 017/110] Add scraping proxy support for scraping and youtube - Implement `scraping_proxy.py` to configure and manage scraping proxies and fake user agents. - Update `settings.py` to add scraping proxy configuration variables. - Modify `vector_search.py`, `DocExtract.py`, and `SEOSummary.py` to use the new `requests_scraping_kwargs()` function. - Update yt-dlp config with proxy --- celeryapp/tasks.py | 2 +- daras_ai_v2/asr.py | 11 ++- daras_ai_v2/scraping_proxy.py | 44 ++++++++++ daras_ai_v2/settings.py | 5 ++ daras_ai_v2/vector_search.py | 42 +++++---- recipes/DocExtract.py | 159 ++++++++++++++++++++++------------ recipes/DocSearch.py | 17 ++-- recipes/SEOSummary.py | 7 +- 8 files changed, 205 insertions(+), 82 deletions(-) create mode 100644 daras_ai_v2/scraping_proxy.py diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index 221754ddb..c3ae75b52 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -97,7 +97,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False except Exception as e: if isinstance(e, UserError): sentry_level = e.sentry_level - logger.warning(e) + logger.warning("\n".join(map(str, [e, e.__cause__]))) else: sentry_level = "error" traceback.print_exc() diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 96b5f2e70..699e9fe2a 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -5,12 +5,12 @@ import typing from enum import Enum +import gooey_gui as gui import requests import typing_extensions from django.db.models import F from furl import furl -import gooey_gui as gui from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings from daras_ai_v2.azure_asr import azure_asr @@ -31,6 +31,7 @@ from daras_ai_v2.google_asr import gcp_asr_v1 from daras_ai_v2.gpu_server import call_celery_task from daras_ai_v2.redis_cache import redis_cache_decorator +from daras_ai_v2.scraping_proxy import SCRAPING_PROXIES, get_scraping_proxy_cert_path from daras_ai_v2.text_splitter import text_splitter TRANSLATE_BATCH_SIZE = 8 @@ -988,13 +989,19 @@ def download_youtube_to_wav(youtube_url: str) -> bytes: with _yt_dlp_lock, tempfile.TemporaryDirectory() as tmpdir: infile = os.path.join(tmpdir, "infile") outfile = os.path.join(tmpdir, "outfile.wav") + proxy_args = [] + if proxy := SCRAPING_PROXIES.get("https"): + proxy_args += ["--proxy", proxy] + if cert := get_scraping_proxy_cert_path(): + proxy_args += ["--client-certificate-key", cert] # run yt-dlp to download audio call_cmd( - "yt-dlp", + "yt-dlp", "-v", "--no-playlist", "--max-downloads", "1", "--format", "bestaudio", "--output", infile, + *proxy_args, youtube_url, # ignore MaxDownloadsReached - https://github.com/ytdl-org/youtube-dl/blob/a452f9437c8a3048f75fc12f75bcfd3eed78430f/youtube_dl/__init__.py#L468 ok_returncodes={101}, diff --git a/daras_ai_v2/scraping_proxy.py b/daras_ai_v2/scraping_proxy.py new file mode 100644 index 000000000..c6b847b1a --- /dev/null +++ b/daras_ai_v2/scraping_proxy.py @@ -0,0 +1,44 @@ +import random + +import requests +from furl import furl + +from daras_ai_v2 import settings +from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS + +if settings.SCRAPING_PROXY_HOST: + SCRAPING_PROXIES = { + scheme: str( + furl( + scheme="http", + origin=settings.SCRAPING_PROXY_HOST, + username=settings.SCRAPING_PROXY_USERNAME, + password=settings.SCRAPING_PROXY_PASSWORD, + ), + ) + for scheme in ["http", "https"] + } +else: + SCRAPING_PROXIES = {} + + +def get_scraping_proxy_cert_path() -> str | None: + if not settings.SCRAPING_PROXY_CERT_URL: + return None + + path = settings.BASE_DIR / "proxy_ca_crt.pem" + if not path.exists(): + settings.logger.info(f"Downloading proxy cert to {path}") + path.write_bytes(requests.get(settings.SCRAPING_PROXY_CERT_URL).content) + + return str(path) + + +def requests_scraping_kwargs() -> dict: + """Return kwargs for requests library to use scraping proxy and fake user agent.""" + return dict( + headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, + proxies=SCRAPING_PROXIES, + verify=get_scraping_proxy_cert_path(), + # verify=False, + ) diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 97baa9019..e1aca9769 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -397,3 +397,8 @@ TWILIO_ACCOUNT_SID = config("TWILIO_ACCOUNT_SID", "") TWILIO_API_KEY_SID = config("TWILIO_API_KEY_SID", "") TWILIO_API_KEY_SECRET = config("TWILIO_API_KEY_SECRET", "") + +SCRAPING_PROXY_HOST = config("SCRAPING_PROXY_HOST", "") +SCRAPING_PROXY_USERNAME = config("SCRAPING_PROXY_USERNAME", "") +SCRAPING_PROXY_PASSWORD = config("SCRAPING_PROXY_PASSWORD", "") +SCRAPING_PROXY_CERT_URL = config("SCRAPING_PROXY_CERT_URL", "") diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 8c97d7d1d..9fc137413 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -5,13 +5,13 @@ import io import mimetypes import multiprocessing -import random import re import tempfile import typing from functools import partial from time import time +import gooey_gui as gui import numpy as np import requests from django.db import transaction @@ -21,7 +21,6 @@ from loguru import logger from pydantic import BaseModel, Field -import gooey_gui as gui from app_users.models import AppUser from daras_ai.image_input import ( upload_file_from_bytes, @@ -45,7 +44,6 @@ ) from daras_ai_v2.embedding_model import create_embeddings_cached, EmbeddingModels from daras_ai_v2.exceptions import raise_for_status, call_cmd, UserError -from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS from daras_ai_v2.functional import ( flatmap_parallel, map_parallel, @@ -58,6 +56,11 @@ gdrive_metadata, ) from daras_ai_v2.redis_cache import redis_lock +from daras_ai_v2.scraping_proxy import ( + get_scraping_proxy_cert_path, + requests_scraping_kwargs, + SCRAPING_PROXIES, +) from daras_ai_v2.search_ref import ( SearchReference, remove_quotes, @@ -312,11 +315,14 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: total_bytes = int(meta.get("size") or 0) else: try: - r = requests.head( - f_url, - headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, - timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, - ) + if is_user_uploaded_url(f_url): + r = requests.head(f_url) + else: + r = requests.head( + f_url, + timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, + **requests_scraping_kwargs(), + ) raise_for_status(r) except requests.RequestException as e: logger.warning(f"ignore error while downloading {f_url}: {e}") @@ -337,7 +343,7 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: total_bytes = int(r.headers.get("content-length") or 0) # extract filename from url as a fallback if not name: - if is_user_uploaded_url(str(f)): + if is_user_uploaded_url(f_url): name = f.path.segments[-1] else: name = f"{f.host}{f.path}" @@ -359,7 +365,12 @@ def yt_dlp_extract_info(url: str) -> dict: import yt_dlp # https://github.com/yt-dlp/yt-dlp/blob/master/yt_dlp/options.py - params = dict(ignoreerrors=True, check_formats=False) + params = dict( + ignoreerrors=True, + check_formats=False, + proxy=SCRAPING_PROXIES.get("https"), + client_certificate=get_scraping_proxy_cert_path(), + ) with yt_dlp.YoutubeDL(params) as ydl: data = ydl.extract_info(url, download=False) if not data: @@ -666,10 +677,10 @@ def download_content_bytes( return gdrive_download(f, mime_type) try: # download from url - r = requests.get( - f_url, - headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, - ) + if is_user_uploaded_url(f_url): + r = requests.get(f_url) + else: + r = requests.get(f_url, **requests_scraping_kwargs()) raise_for_status(r, is_user_url=is_user_url) except requests.RequestException as e: logger.warning(f"ignore error while downloading {f_url}: {e}") @@ -730,7 +741,8 @@ def any_bytes_to_text_pages_or_df( def is_yt_url(url: str) -> bool: - return "youtube.com" in url or "youtu.be" in url + origin = furl(url).origin + return "youtube.com" in origin or "youtu.be" in origin def pdf_or_tabular_bytes_to_text_pages_or_df( diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 010ed8ad0..80e942c1a 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -1,16 +1,15 @@ -import random +import json import threading import typing -from daras_ai_v2.field_render import field_title_desc -from daras_ai_v2.pydantic_validation import FieldHttpUrl +import gooey_gui as gui +import pandas as pd import requests from aifail import retry_if from django.db.models import IntegerChoices from furl import furl from pydantic import BaseModel, Field -import gooey_gui as gui from bots.models import Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings @@ -29,10 +28,11 @@ from daras_ai_v2.doc_search_settings_widgets import ( bulk_documents_uploader, SUPPORTED_SPREADSHEET_TYPES, + is_user_uploaded_url, ) from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.exceptions import raise_for_status -from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS +from daras_ai_v2.field_render import field_title_desc from daras_ai_v2.functional import ( apply_parallel, flatapply_parallel, @@ -48,13 +48,18 @@ LanguageModelSettings, ) from daras_ai_v2.loom_video_widget import youtube_video +from daras_ai_v2.pydantic_validation import FieldHttpUrl +from daras_ai_v2.scraping_proxy import requests_scraping_kwargs from daras_ai_v2.settings import service_account_key_path from daras_ai_v2.vector_search import ( add_page_number_to_pdf, yt_dlp_get_video_entries, doc_url_to_file_metadata, get_pdf_num_pages, + doc_url_to_text_pages, + doc_or_yt_url_to_metadatas, ) +from files.models import FileMetadata from recipes.DocSearch import render_documents DEFAULT_YOUTUBE_BOT_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ddc8ffac-93fb-11ee-89fb-02420a0001cb/Youtube%20transcripts.jpg.png" @@ -107,32 +112,36 @@ class RequestModel(LanguageModelSettings, RequestModelBase): pass class ResponseModel(BaseModel): - pass + output_documents: list[FieldHttpUrl] | None def preview_image(self, state: dict) -> str | None: return DEFAULT_YOUTUBE_BOT_META_IMG def render_form_v2(self): bulk_documents_uploader( - "#### 🤖 Youtube/PDF/Drive URLs", + "#### 🤖 Youtube/PDF/Drive/Web URLs", accept=("audio/*", "application/pdf", "video/*"), ) gui.text_input( - "#### 📊 Google Sheets URL", + "📊 Google Sheets URL _(optional)_", key="sheet_url", ) def validate_form_v2(self): - assert gui.session_state.get("documents"), "Please enter Youtube/PDF/Drive URLs" - assert gui.session_state.get("sheet_url"), "Please enter a Google Sheet URL" + assert gui.session_state.get("documents"), "Please provide input documents" def preview_description(self, state: dict) -> str: return "Transcribe YouTube videos in any language with Whisper, Google Chirp & more, run your own GPT4 prompt on each transcript and save it all to a Google Sheet. Perfect for making a YouTube-based dataset to create your own chatbot or enterprise copilot (ie. just add the finished Google sheet url to the doc section in https://gooey.ai/copilot)." def render_example(self, state: dict): - render_documents(state) - gui.write("**Google Sheets URL**") - gui.write(state.get("sheet_url")) + if sheet_url := state.get("sheet_url"): + render_documents(state, label="**Input Documents**") + gui.write("**Google Sheets URL**") + gui.write(sheet_url) + else: + render_documents( + state, label="**Output Documents**", key="output_documents" + ) def render_usage_guide(self): youtube_video("p7ZLb-loR_4") @@ -146,7 +155,12 @@ def render_settings(self): selected_model = language_model_selector() language_model_settings(selected_model) - enum_selector(AsrModels, label="##### ASR Model", key="selected_asr_model") + enum_selector( + AsrModels, + label="##### ASR Model", + key="selected_asr_model", + use_selectbox=True, + ) gui.write("---") google_translate_language_selector() @@ -165,34 +179,80 @@ def related_workflows(self) -> list: return [VideoBotsPage, AsrPage, CompareLLMPage, DocSearchPage] - def run(self, state: dict) -> typing.Iterator[str | None]: + def run_v2( + self, + request: "DocExtractPage.RequestModel", + response: "DocExtractPage.ResponseModel", + ): import gspread.utils - request: DocExtractPage.RequestModel = self.RequestModel.parse_obj(state) + if request.sheet_url: + entries = yield from flatapply_parallel( + extract_info, + request.documents, + message="Extracting metadata...", + max_workers=50, + ) - entries = yield from flatapply_parallel( - extract_info, - request.documents, - message="Extracting metadata...", - max_workers=50, - ) + yield "Preparing sheet..." + spreadsheet_id = gspread.utils.extract_id_from_url(request.sheet_url) + ensure_header(spreadsheet_id) + row_numbers = init_sheet(spreadsheet_id, entries) + + yield from apply_parallel( + lambda entry, row: process_entry( + spreadsheet_id=spreadsheet_id, + entry=entry, + row=row, + request=request, + ), + entries, + row_numbers, + max_workers=4, + message="Updating sheet...", + ) + else: + file_url_metas = yield from flatapply_parallel( + doc_or_yt_url_to_metadatas, + request.documents, + message="Extracting metadata...", + ) + file_urls, file_metas = zip(*file_url_metas) + output_documents = yield from apply_parallel( + lambda *args: _doc_extract_and_upload(request, *args), + file_urls, + file_metas, + max_workers=4, + message="Processing documents...", + ) + response.output_documents = list(filter(None, output_documents)) - yield "Preparing sheet..." - spreadsheet_id = gspread.utils.extract_id_from_url(request.sheet_url) - ensure_header(spreadsheet_id) - row_numbers = init_sheet(spreadsheet_id, entries) - - yield from apply_parallel( - lambda entry, row: process_entry( - spreadsheet_id=spreadsheet_id, - entry=entry, - row=row, - request=request, - ), - entries, - row_numbers, - max_workers=4, - message="Updating sheet...", + +def _doc_extract_and_upload( + request: DocExtractPage.RequestModel, f_url: str, file_meta: FileMetadata +) -> str | None: + pages = doc_url_to_text_pages( + f_url=f_url, + file_meta=file_meta, + selected_asr_model=request.selected_asr_model, + ) + if isinstance(pages, pd.DataFrame): + return upload_file_from_bytes( + file_meta.name + ".csv", + pages.to_csv(index=False).encode(), + content_type="text/csv", + ) + elif len(pages) <= 1: + return upload_file_from_bytes( + file_meta.name + ".txt", + "".join(pages).encode(), + content_type="text/plain", + ) + else: + return upload_file_from_bytes( + file_meta.name + ".json", + json.dumps(pages).encode(), + content_type="application/json", ) @@ -294,16 +354,6 @@ def extract_info(url: str) -> list[dict | None]: if is_yt_url(url): return yt_dlp_get_video_entries(url) - # https://github.com/yt-dlp/yt-dlp/blob/master/yt_dlp/options.py - params = dict(ignoreerrors=True, check_formats=False) - with yt_dlp.YoutubeDL(params) as ydl: - data = ydl.extract_info(url, download=False) - if data: - entries = data.get("entries", [data]) - return [e for e in entries if e] - else: - return [{"webpage_url": url, "title": "Youtube Video"}] - # assume it's a direct link file_meta = doc_url_to_file_metadata(url) assert file_meta.mime_type, f"Could not determine mime type for {url}" @@ -316,11 +366,14 @@ def extract_info(url: str) -> list[dict | None]: file_meta.name, f_bytes, content_type=file_meta.mime_type ) else: - r = requests.get( - url, - headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, - timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, - ) + if is_user_uploaded_url(url): + r = requests.get(url) + else: + r = requests.get( + url, + timeout=settings.EXTERNAL_REQUEST_TIMEOUT_SEC, + **requests_scraping_kwargs(), + ) raise_for_status(r, is_user_url=True) f_bytes = r.content content_url = url diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 3bbccef25..04e034dde 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -235,13 +235,16 @@ def render_documents(state, label="**Documents**", *, key="documents"): if not documents: return gui.write(label) - for doc in documents: - if is_user_uploaded_url(doc): - f = furl(doc) - filename = f.path.segments[-1] - else: - filename = doc - gui.write(f"🔗[*{filename}*]({doc})") + with gui.div( + className="overflow-auto bg-light p-2 mb-3", style=dict(maxHeight="200px") + ): + for doc in documents: + if is_user_uploaded_url(doc): + f = furl(doc) + filename = f.path.segments[-1] + else: + filename = doc + gui.write(f"🔗[*{filename}*]({doc})") def render_doc_search_step(state: dict): diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index 62af754a5..1b4e65365 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -1,7 +1,7 @@ -import random import re import typing +import gooey_gui as gui import readability import requests from furl import furl @@ -9,11 +9,9 @@ from loguru import logger from pydantic import BaseModel -import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.exceptions import raise_for_status -from daras_ai_v2.fake_user_agents import FAKE_USER_AGENTS from daras_ai_v2.functional import map_parallel from daras_ai_v2.language_model import ( run_language_model, @@ -26,6 +24,7 @@ LanguageModelSettings, ) from daras_ai_v2.loom_video_widget import youtube_video +from daras_ai_v2.scraping_proxy import requests_scraping_kwargs from daras_ai_v2.scrollable_html_widget import scrollable_html from daras_ai_v2.serp_search import get_links_from_serp_api from daras_ai_v2.serp_search_locations import ( @@ -452,8 +451,8 @@ def html_to_text(text): def _call_summarize_url(url: str) -> tuple[str | None, str | None]: r = requests.get( url, - headers={"User-Agent": random.choice(FAKE_USER_AGENTS)}, timeout=EXTERNAL_REQUEST_TIMEOUT_SEC, + **requests_scraping_kwargs(), ) raise_for_status(r) # we only support HTML for now From f5dc5361db420703ad12299db69a6ef4339f88f8 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Mon, 2 Sep 2024 17:56:20 +0530 Subject: [PATCH 018/110] Add --wrap none to pandoc command to prevent text wrapping Remove accept parameter from bulk_documents_uploader function in DocExtract to allow uploading everything --- daras_ai_v2/vector_search.py | 1 + recipes/DocExtract.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 9fc137413..36de65f29 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -893,6 +893,7 @@ def pandoc_to_text(f_name: str, f_bytes: bytes, to="plain") -> str: "+RTS", f"-M{MAX_PANDOC_MEM_MB}M", "-RTS", "--sandbox", "--standalone", infile.name, + "--wrap", "none", "--to", to, "--output", outfile.name, diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 80e942c1a..323e3eab9 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -120,7 +120,6 @@ def preview_image(self, state: dict) -> str | None: def render_form_v2(self): bulk_documents_uploader( "#### 🤖 Youtube/PDF/Drive/Web URLs", - accept=("audio/*", "application/pdf", "video/*"), ) gui.text_input( "📊 Google Sheets URL _(optional)_", From 6984d1a4704b5d00075be62e02d8bb4bc137864e Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 15 Jul 2024 21:13:01 +0530 Subject: [PATCH 019/110] Add org support with role and UI view --- daras_ai_v2/icons.py | 5 +++++ daras_ai_v2/settings.py | 1 + routers/account.py | 24 ++++++++++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/daras_ai_v2/icons.py b/daras_ai_v2/icons.py index 6ce628f16..2196c2542 100644 --- a/daras_ai_v2/icons.py +++ b/daras_ai_v2/icons.py @@ -15,9 +15,14 @@ copy = '' preview = '' add = '' +time = '' code = '' chat = '' +admin = '' +remove_user = '' +add_user = '' +transfer_ownership = '' # brands github = '' diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 97baa9019..1d1cd28ce 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -63,6 +63,7 @@ "handles", "payments", "functions", + "orgs", ] MIDDLEWARE = [ diff --git a/routers/account.py b/routers/account.py index 89fc1aff2..36dde5d78 100644 --- a/routers/account.py +++ b/routers/account.py @@ -19,6 +19,7 @@ from daras_ai_v2.profiles import edit_user_profile_page from payments.webhooks import PaypalWebhookHandler from routers.root import page_wrapper, get_og_url_path +from orgs.views import orgs_page from routers.custom_api_router import CustomAPIRouter @@ -139,6 +140,24 @@ def api_keys_route(request: Request): ) +@app.post("/orgs/") +@st.route +def orgs_route(request: Request): + with account_page_wrapper(request, AccountTabs.orgs): + orgs_tab(request) + + url = get_og_url_path(request) + return dict( + meta=raw_build_meta_tags( + url=url, + canonical_url=url, + title="Teams • Gooey.AI", + description="Your teams.", + robots="noindex,nofollow", + ) + ) + + class TabData(typing.NamedTuple): title: str route: typing.Callable @@ -149,6 +168,7 @@ class AccountTabs(TabData, Enum): profile = TabData(title=f"{icons.profile} Profile", route=profile_route) saved = TabData(title=f"{icons.save} Saved", route=saved_route) api_keys = TabData(title=f"{icons.api} API Keys", route=api_keys_route) + orgs = TabData(title=f"{icons.company} Teams", route=orgs_route) @property def url_path(self) -> str: @@ -208,6 +228,10 @@ def api_keys_tab(request: Request): manage_api_keys(request.user) +def orgs_tab(request: Request): + orgs_page(request.user) + + @contextmanager def account_page_wrapper(request: Request, current_tab: TabData): if not request.user or request.user.is_anonymous: From 9bb3c9ac38b37b6f80b6b2fb366084b078d5b9ce Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 15 Jul 2024 21:15:32 +0530 Subject: [PATCH 020/110] Add uncommitted orgs files --- orgs/__init__.py | 0 orgs/admin.py | 19 ++ orgs/apps.py | 6 + orgs/migrations/0001_initial.py | 60 +++++ .../0002_org_logo_alter_org_created_by.py | 25 ++ orgs/migrations/0003_orginvitation_role.py | 18 ++ orgs/migrations/0004_org_deleted_at.py | 18 ++ orgs/migrations/__init__.py | 0 orgs/models.py | 211 +++++++++++++++++ orgs/tests.py | 3 + orgs/views.py | 218 ++++++++++++++++++ 11 files changed, 578 insertions(+) create mode 100644 orgs/__init__.py create mode 100644 orgs/admin.py create mode 100644 orgs/apps.py create mode 100644 orgs/migrations/0001_initial.py create mode 100644 orgs/migrations/0002_org_logo_alter_org_created_by.py create mode 100644 orgs/migrations/0003_orginvitation_role.py create mode 100644 orgs/migrations/0004_org_deleted_at.py create mode 100644 orgs/migrations/__init__.py create mode 100644 orgs/models.py create mode 100644 orgs/tests.py create mode 100644 orgs/views.py diff --git a/orgs/__init__.py b/orgs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/orgs/admin.py b/orgs/admin.py new file mode 100644 index 000000000..fb445abd3 --- /dev/null +++ b/orgs/admin.py @@ -0,0 +1,19 @@ +from django.contrib import admin + +from .models import Org, OrgQuerySet, OrgInvitation, OrgMembership + + +@admin.register(Org) +class OrgAdmin(admin.ModelAdmin): + def get_queryset(self, request): + return OrgQuerySet(self.model).all() + + +@admin.register(OrgMembership) +class OrgMembershipAdmin(admin.ModelAdmin): + pass + + +@admin.register(OrgInvitation) +class OrgInvitationAdmin(admin.ModelAdmin): + pass diff --git a/orgs/apps.py b/orgs/apps.py new file mode 100644 index 000000000..70c7fa169 --- /dev/null +++ b/orgs/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class OrgsConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "orgs" diff --git a/orgs/migrations/0001_initial.py b/orgs/migrations/0001_initial.py new file mode 100644 index 000000000..1ed2139eb --- /dev/null +++ b/orgs/migrations/0001_initial.py @@ -0,0 +1,60 @@ +# Generated by Django 4.2.7 on 2024-07-11 17:41 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('app_users', '0017_alter_appuser_subscription'), + ] + + operations = [ + migrations.CreateModel( + name='Org', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('org_id', models.CharField(blank=True, max_length=100, null=True, unique=True)), + ('name', models.CharField(max_length=100)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to='app_users.appuser')), + ], + ), + migrations.CreateModel( + name='OrgInvitation', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('invitee_email', models.EmailField(max_length=254)), + ('status', models.IntegerField(choices=[(1, 'Pending'), (2, 'Accepted'), (3, 'Rejected'), (4, 'Canceled')], default=1)), + ('auto_accepted', models.BooleanField(default=False)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('inviter', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='app_users.appuser')), + ('org', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='invitations', to='orgs.org')), + ], + ), + migrations.CreateModel( + name='OrgMembership', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('role', models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(auto_now=True)), + ('invitation', models.OneToOneField(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='membership', to='orgs.orginvitation')), + ('org', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='memberships', to='orgs.org')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='org_memberships', to='app_users.appuser')), + ], + options={ + 'unique_together': {('org', 'user')}, + }, + ), + migrations.AddField( + model_name='org', + name='members', + field=models.ManyToManyField(related_name='orgs', through='orgs.OrgMembership', to='app_users.appuser'), + ), + ] diff --git a/orgs/migrations/0002_org_logo_alter_org_created_by.py b/orgs/migrations/0002_org_logo_alter_org_created_by.py new file mode 100644 index 000000000..78084ae6a --- /dev/null +++ b/orgs/migrations/0002_org_logo_alter_org_created_by.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.7 on 2024-07-11 17:44 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0017_alter_appuser_subscription'), + ('orgs', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='org', + name='logo', + field=models.URLField(blank=True, null=True), + ), + migrations.AlterField( + model_name='org', + name='created_by', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='app_users.appuser'), + ), + ] diff --git a/orgs/migrations/0003_orginvitation_role.py b/orgs/migrations/0003_orginvitation_role.py new file mode 100644 index 000000000..03ccc44a4 --- /dev/null +++ b/orgs/migrations/0003_orginvitation_role.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-07-11 18:36 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('orgs', '0002_org_logo_alter_org_created_by'), + ] + + operations = [ + migrations.AddField( + model_name='orginvitation', + name='role', + field=models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3), + ), + ] diff --git a/orgs/migrations/0004_org_deleted_at.py b/orgs/migrations/0004_org_deleted_at.py new file mode 100644 index 000000000..a7127a6f7 --- /dev/null +++ b/orgs/migrations/0004_org_deleted_at.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-07-15 13:57 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('orgs', '0003_orginvitation_role'), + ] + + operations = [ + migrations.AddField( + model_name='org', + name='deleted_at', + field=models.DateTimeField(blank=True, default=None, null=True), + ), + ] diff --git a/orgs/migrations/__init__.py b/orgs/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/orgs/models.py b/orgs/models.py new file mode 100644 index 000000000..68105a48f --- /dev/null +++ b/orgs/models.py @@ -0,0 +1,211 @@ +from django.db import models, transaction +from django.core.exceptions import ValidationError +from django.utils import timezone +from django.utils.text import slugify + +from app_users.models import AppUser +from daras_ai_v2.crypto import get_random_doc_id + + +class OrgRole(models.IntegerChoices): + OWNER = 1 + ADMIN = 2 + MEMBER = 3 + + +class OrgQuerySet(models.QuerySet): + def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwargs): + org = self.create( + org_id=org_id or get_random_doc_id(), created_by=created_by, **kwargs + ) + org.members.add( + created_by, + through_defaults={ + "role": OrgRole.OWNER, + }, + ) + return org + + +class OrgManager(models.Manager): + def get_queryset(self): + return OrgQuerySet(self.model, using=self._db).filter(deleted_at__isnull=True) + + +class Org(models.Model): + org_id = models.CharField(max_length=100, null=True, blank=True, unique=True) + + name = models.CharField(max_length=100) + members = models.ManyToManyField( + "app_users.AppUser", + through="OrgMembership", + related_name="orgs", + ) + created_by = models.ForeignKey( + "app_users.appuser", + on_delete=models.CASCADE, + ) + + logo = models.URLField(null=True, blank=True) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + deleted_at = models.DateTimeField(null=True, blank=True, default=None) + + objects = OrgManager() + + def __str__(self): + return self.name + + def get_slug(self): + return slugify(self.name) + + def is_deleted(self): + return self.deleted_at is not None + + def soft_delete(self): + with transaction.atomic(): + for m in self.memberships.all(): + m.delete() + self.deleted_at = timezone.now() + self.save() + + def invite_user( + self, + *, + invitee_email: str, + inviter: AppUser, + role: OrgRole, + auto_accept: bool = True, + ): + """ + auto_accept: If True, the user will be automatically added if they have an account + """ + for member in self.members.all(): + if member.email == invitee_email: + raise ValidationError(f"{member} is already a member of this org") + + for invitation in self.invitations.filter(status=OrgInvitation.Status.PENDING): + if invitation.invitee_email == invitee_email: + raise ValidationError(f"{invitee_email} was already invited") + + invitation = OrgInvitation( + org=self, + invitee_email=invitee_email, + inviter=inviter, + role=role, + ) + invitation.full_clean() + invitation.save() + + if auto_accept: + try: + invitation.accept(auto_accepted=True) + except AppUser.DoesNotExist: + pass + + +class OrgMembership(models.Model): + org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="memberships") + user = models.ForeignKey( + "app_users.AppUser", on_delete=models.CASCADE, related_name="org_memberships" + ) + invitation = models.OneToOneField( + "OrgInvitation", + on_delete=models.SET_NULL, + blank=True, + null=True, + default=None, + related_name="membership", + ) + + role = models.IntegerField(choices=OrgRole.choices, default=OrgRole.MEMBER) + + created_at = models.DateTimeField(auto_now_add=True) # same as joining date + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + unique_together = ("org", "user") + + def __str__(self): + return f"{self.get_role_display()} - {self.user} ({self.org})" + + def can_edit_org_metadata(self): + return self.role in (OrgRole.OWNER, OrgRole.ADMIN) + + def can_delete_org(self): + return self.role == OrgRole.OWNER + + def can_invite(self): + return self.role in (OrgRole.OWNER, OrgRole.ADMIN) + + def has_higher_role_than(self, other: "OrgMembership"): + # creator > owner > admin > member + match other.role: + case OrgRole.OWNER: + return self.org.created_by == OrgRole.OWNER + case OrgRole.ADMIN: + return self.role == OrgRole.OWNER + case OrgRole.MEMBER: + return self.role in (OrgRole.OWNER, OrgRole.ADMIN) + + def can_change_role(self, other: "OrgMembership"): + return self.has_higher_role_than(other) + + def can_kick(self, other: "OrgMembership"): + return self.has_higher_role_than(other) + + def can_transfer_ownership(self, other: "OrgMembership"): + return self.role == OrgRole.OWNER and other.role == OrgRole.ADMIN + + +class OrgInvitation(models.Model): + class Status(models.IntegerChoices): + PENDING = 1 + ACCEPTED = 2 + REJECTED = 3 + CANCELED = 4 + + org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="invitations") + invitee_email = models.EmailField() + inviter = models.ForeignKey("app_users.AppUser", on_delete=models.CASCADE) + + status = models.IntegerField(choices=Status.choices, default=Status.PENDING) + auto_accepted = models.BooleanField(default=False) + role = models.IntegerField(choices=OrgRole.choices, default=OrgRole.MEMBER) + + # TODO: don't spam invitees! + # invitation_email_count = models.IntegerField(default=0) + # last_invitation_sent_at = models.DateTimeField(null=True, blank=True) + + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(auto_now=True) + + def __str__(self): + return f"{self.invitee_email} - {self.org} ({self.get_status_display()})" + + def accept(self, *, auto_accepted: bool = False): + assert self.status == self.Status.PENDING + + invitee = AppUser.objects.get(email=self.invitee_email) + + self.status = self.Status.ACCEPTED + self.auto_accepted = auto_accepted + + with transaction.atomic(): + self.org.members.add( + invitee, + through_defaults={ + "role": self.role, + "invitation": self, + }, + ) + self.save() + + def reject(self): + self.status = self.Status.REJECTED + self.save() + + def cancel(self): + self.status = self.Status.CANCELED + self.save() diff --git a/orgs/tests.py b/orgs/tests.py new file mode 100644 index 000000000..7ce503c2d --- /dev/null +++ b/orgs/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/orgs/views.py b/orgs/views.py new file mode 100644 index 000000000..5164b5828 --- /dev/null +++ b/orgs/views.py @@ -0,0 +1,218 @@ +import html as html_lib +from django.core.exceptions import ValidationError +from django.db import transaction + +import gooey_ui as st +from app_users.models import AppUser +from gooey_ui.components.modal import Modal +from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole +from daras_ai_v2 import icons + + +DEFAULT_ORG_LOGO = "https://seccdn.libravatar.org/avatar/40f8d096a3777232204cb3f796c577b7?s=80&forcedefault=y&default=monsterid" + + +def orgs_page(user: AppUser): + memberships = user.org_memberships.all() + if not memberships: + st.write("*You're not part of an organization yet... Create one?*") + + render_org_creation_view(user) + else: + # only support one org for now + render_org_by_membership(memberships.first()) + + +def render_org_by_membership(membership: OrgMembership): + """ + membership object has all the information we need: + - org + - current user + - current user's role in the org (and other metadata) + """ + org = membership.org + current_user = membership.user + + with st.div(className="d-flex justify-content-between"): + with st.div(className="d-flex align-items-center"): + st.image( + org.logo or DEFAULT_ORG_LOGO, + className="my-0 me-2", + style={"width": "128px", "height": "128px", "object-fit": "contain"}, + ) + st.write(f"# {org.name}") + + if membership.can_edit_org_metadata(): + org_edit_modal = Modal("Edit Org", key="edit-org-modal") + if org_edit_modal.is_open(): + with org_edit_modal.container(): + render_org_edit_view(membership.org, modal=org_edit_modal) + + with st.div(className="d-flex align-items-center"): + if st.button(f"{icons.edit} Edit", className="btn btn-theme"): + org_edit_modal.open() + + with st.div(className="mt-4"): + with st.div(className="d-flex justify-content-between align-items-center"): + st.write("## Members") + + if membership.can_invite(): + invite_modal = Modal("Invite Member", key="invite-member-modal") + if st.button(f"{icons.add_user} Invite Member", type="primary"): + invite_modal.open() + + if invite_modal.is_open(): + with invite_modal.container(): + render_invite_creation_view( + org=org, inviter=current_user, modal=invite_modal + ) + + render_members_list(org=org, current_membership=membership) + + with st.div(className="mt-4"): + render_pending_invitations_list(org=org, current_user=current_user) + + with st.div(className="mt-4"): + st.write("# Danger Zone") + + if membership.role == OrgRole.OWNER: + # Owner can't leave! Only delete + if st.button( + "Delete Org", + className="btn btn-theme bg-danger border-danger text-white", + ): + org.soft_delete() + st.experimental_rerun() + else: + if st.button( + "Leave Org", + className="btn btn-theme bg-danger border-danger text-white", + ): + org.members.remove(current_user) + st.experimental_rerun() + + +def render_org_creation_view(user: AppUser): + st.write(f"# {icons.company} Create an Org", unsafe_allow_html=True) + name = st.text_input("Team Name") + logo = st.file_uploader("Org Logo", accept=["image/*"]) + + if st.button("Create"): + Org.objects.create_org(created_by=user, name=name, logo=logo) + st.experimental_rerun() + + +def render_org_edit_view(org: Org, *, modal: Modal): + org.name = st.text_input("Org Name", value=org.name) + org.logo = st.file_uploader("Org Logo", accept=["image/*"], value=org.logo) + + if st.button("Save"): + try: + org.full_clean() + except ValidationError as e: + # newlines in markdown + st.write(" \n".join(e.messages), className="text-danger") + else: + org.save() + modal.close() + + +def render_members_list(org: Org, current_membership: OrgMembership): + with st.raw_table(["Name", "Role", f"{icons.time} Since", ""]): + for m in org.memberships.all().order_by("created_at"): + name_val = m.user.display_name or m.user.first_name() + if current_membership == m: + name_val += " (You)" + with st.tag("tr"): + with st.tag("td"): + if m.user.handle_id: + with st.link(to=m.user.handle.get_app_url()): + st.html(html_lib.escape(name_val)) + else: + st.html(html_lib.escape(name_val)) + with st.tag("td"): + st.html(html_lib.escape(m.get_role_display())) + with st.tag("td"): + st.html(m.created_at.strftime("%b %d, %Y")) + with st.tag("td", className="text-end"): + render_role_change_buttons(m, current_membership=current_membership) + render_membership_buttons(m, current_membership=current_membership) + + +def render_role_change_buttons(m: OrgMembership, current_membership: OrgMembership): + if current_membership.can_change_role(m): + if m.role == OrgRole.MEMBER: + if st.button( + f"{icons.admin} Make Admin", + className="btn btn-theme btn-sm my-0 py-0", + unsafe_allow_html=True, + ): + m.role = OrgRole.ADMIN + m.save() + st.experimental_rerun() + if m.role == OrgRole.ADMIN: + if st.button( + f"{icons.admin} Remove Admin", + className="btn btn-theme btn-sm my-0 py-0", + unsafe_allow_html=True, + ): + m.role = OrgRole.MEMBER + m.save() + st.experimental_rerun() + + if current_membership.can_transfer_ownership(m): + if st.button( + f"{icons.transfer_ownership} Transfer Ownership", + className="btn btn-theme btn-sm my-0 py-0 bg-danger border-danger text-light", + unsafe_allow_html=True, + ): + m.role = OrgRole.OWNER + current_membership.role = OrgRole.ADMIN + with transaction.atomic(): + m.save() + current_membership.save() + st.experimental_rerun() + + +def render_membership_buttons(m: OrgMembership, current_membership: OrgMembership): + if current_membership.can_kick(m): + if st.button( + f"{icons.remove_user} Remove", + className="btn btn-theme btn-sm my-0 py-0 bg-danger border-danger text-light", + unsafe_allow_html=True, + ): + # TODO: soft-delete + m.delete() + st.experimental_rerun() + + +def render_pending_invitations_list(org: Org, current_user: AppUser): + pending_invitations = org.invitations.filter(status=OrgInvitation.Status.PENDING) + if not pending_invitations: + return + + st.write("## Pending") + table = st.raw_table(["Email", "Invited By", f"{icons.time} Invited on"]) + with table: + for invite in pending_invitations: + inviter_name = invite.inviter.display_name or invite.inviter.first_name() + if current_user == invite.inviter: + inviter_name += " (You)" + st.table_row( + [ + invite.invitee_email, + inviter_name, + invite.created_at.strftime("%b %d, %Y"), + ] + ) + + +def render_invite_creation_view(org: Org, inviter: AppUser, modal: Modal): + email = st.text_input("Email") + + if st.button(f"{icons.add_user} Invite", type="primary", unsafe_allow_html=True): + try: + org.invite_user(invitee_email=email, inviter=inviter, role=OrgRole.MEMBER) + modal.close() + except ValidationError as e: + st.write(", ".join(e.messages), className="text-danger") From 674d92309e087f28c4bda3627b08d369d9990e28 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 17 Jul 2024 18:14:28 +0530 Subject: [PATCH 021/110] Add support for accepting org domain name --- ...ions_alter_org_managers_org_domain_name.py | 30 +++++++++++ orgs/models.py | 52 +++++++++++++++---- 2 files changed, 73 insertions(+), 9 deletions(-) create mode 100644 orgs/migrations/0005_alter_org_options_alter_org_managers_org_domain_name.py diff --git a/orgs/migrations/0005_alter_org_options_alter_org_managers_org_domain_name.py b/orgs/migrations/0005_alter_org_options_alter_org_managers_org_domain_name.py new file mode 100644 index 000000000..0af9f3056 --- /dev/null +++ b/orgs/migrations/0005_alter_org_options_alter_org_managers_org_domain_name.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.7 on 2024-07-16 15:24 + +from django.db import migrations, models +import django.db.models.manager +import orgs.models + + +class Migration(migrations.Migration): + + dependencies = [ + ('orgs', '0004_org_deleted_at'), + ] + + operations = [ + migrations.AlterModelOptions( + name='org', + options={'default_manager_name': 'all_objects'}, + ), + migrations.AlterModelManagers( + name='org', + managers=[ + ('all_objects', django.db.models.manager.Manager()), + ], + ), + migrations.AddField( + model_name='org', + name='domain_name', + field=models.CharField(blank=True, max_length=30, null=True, unique=True, validators=[orgs.models.validate_org_domain_name]), + ), + ] diff --git a/orgs/models.py b/orgs/models.py index 68105a48f..f0a1fee67 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -1,3 +1,5 @@ +import re + from django.db import models, transaction from django.core.exceptions import ValidationError from django.utils import timezone @@ -7,6 +9,19 @@ from daras_ai_v2.crypto import get_random_doc_id +ORG_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$") + + +def validate_org_domain_name(value): + from handles.models import COMMON_EMAIL_DOMAINS + + if not ORG_DOMAIN_NAME_RE.fullmatch(value): + raise ValidationError("Invalid domain name") + + if value in COMMON_EMAIL_DOMAINS: + raise ValidationError("This domain name is reserved") + + class OrgRole(models.IntegerChoices): OWNER = 1 ADMIN = 2 @@ -14,10 +29,19 @@ class OrgRole(models.IntegerChoices): class OrgQuerySet(models.QuerySet): + pass + + +class OrgManager(models.Manager): + def get_queryset(self): + return OrgQuerySet(self.model, using=self._db).filter(deleted_at__isnull=True) + def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwargs): - org = self.create( + org = self.model( org_id=org_id or get_random_doc_id(), created_by=created_by, **kwargs ) + org.full_clean() + org.save() org.members.add( created_by, through_defaults={ @@ -27,11 +51,6 @@ def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwar return org -class OrgManager(models.Manager): - def get_queryset(self): - return OrgQuerySet(self.model, using=self._db).filter(deleted_at__isnull=True) - - class Org(models.Model): org_id = models.CharField(max_length=100, null=True, blank=True, unique=True) @@ -47,14 +66,29 @@ class Org(models.Model): ) logo = models.URLField(null=True, blank=True) + domain_name = models.CharField( + max_length=30, + blank=True, + null=True, + unique=True, + validators=[ + validate_org_domain_name, + ], + ) created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) deleted_at = models.DateTimeField(null=True, blank=True, default=None) - objects = OrgManager() + objects = OrgManager() # only active orgs + all_objects = OrgQuerySet.as_manager() # for internal & admin use + + class Meta: + default_manager_name = "all_objects" def __str__(self): + if self.is_deleted(): + return f"[Deleted] {self.name}" return self.name def get_slug(self): @@ -155,8 +189,8 @@ def can_change_role(self, other: "OrgMembership"): def can_kick(self, other: "OrgMembership"): return self.has_higher_role_than(other) - def can_transfer_ownership(self, other: "OrgMembership"): - return self.role == OrgRole.OWNER and other.role == OrgRole.ADMIN + def can_transfer_ownership(self): + return self.role == OrgRole.OWNER class OrgInvitation(models.Model): From bc0cb85aeebe22f90fbb923d35c1cc973b4dc6df Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 17 Jul 2024 18:45:05 +0530 Subject: [PATCH 022/110] Add delete icon --- daras_ai_v2/icons.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/daras_ai_v2/icons.py b/daras_ai_v2/icons.py index 2196c2542..c5f84193b 100644 --- a/daras_ai_v2/icons.py +++ b/daras_ai_v2/icons.py @@ -10,6 +10,7 @@ camera = '' cancel = '' edit = '' +delete = '' link = '' company = '' copy = '' @@ -22,7 +23,7 @@ admin = '' remove_user = '' add_user = '' -transfer_ownership = '' +transfer = '' # brands github = '' From c07a2f5f332c5b3365d089861cbd901158b6dc95 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 17 Jul 2024 18:46:15 +0530 Subject: [PATCH 023/110] Add confirmation modals for removing members, move danger zone to edit --- orgs/views.py | 299 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 204 insertions(+), 95 deletions(-) diff --git a/orgs/views.py b/orgs/views.py index 5164b5828..5b0d86475 100644 --- a/orgs/views.py +++ b/orgs/views.py @@ -33,24 +33,31 @@ def render_org_by_membership(membership: OrgMembership): org = membership.org current_user = membership.user - with st.div(className="d-flex justify-content-between"): + with st.div( + className="d-xs-block d-sm-flex flex-row-reverse justify-content-between" + ): + with st.div(className="d-flex justify-content-center align-items-center"): + if membership.can_edit_org_metadata(): + org_edit_modal = Modal("Edit Org", key="edit-org-modal") + if org_edit_modal.is_open(): + with org_edit_modal.container(): + render_org_edit_view_by_membership( + membership, modal=org_edit_modal + ) + + if st.button(f"{icons.edit} Edit", type="secondary"): + org_edit_modal.open() + with st.div(className="d-flex align-items-center"): st.image( org.logo or DEFAULT_ORG_LOGO, className="my-0 me-2", style={"width": "128px", "height": "128px", "object-fit": "contain"}, ) - st.write(f"# {org.name}") - - if membership.can_edit_org_metadata(): - org_edit_modal = Modal("Edit Org", key="edit-org-modal") - if org_edit_modal.is_open(): - with org_edit_modal.container(): - render_org_edit_view(membership.org, modal=org_edit_modal) - - with st.div(className="d-flex align-items-center"): - if st.button(f"{icons.edit} Edit", className="btn btn-theme"): - org_edit_modal.open() + with st.div(className="d-flex flex-column justify-content-center"): + st.write(f"# {org.name}") + if org.domain_name: + st.write(f"Domain: `@{org.domain_name}`", className="text-muted") with st.div(className="mt-4"): with st.div(className="d-flex justify-content-between align-items-center"): @@ -73,40 +80,55 @@ def render_org_by_membership(membership: OrgMembership): render_pending_invitations_list(org=org, current_user=current_user) with st.div(className="mt-4"): - st.write("# Danger Zone") + if membership.role != OrgRole.OWNER: + # Owners can't leave! They can only delete + org_leave_modal = Modal("Leave Org", key="leave-org-modal") + if org_leave_modal.is_open(): + with org_leave_modal.container(): + render_org_leave_view_by_membership( + membership, modal=org_leave_modal + ) - if membership.role == OrgRole.OWNER: - # Owner can't leave! Only delete - if st.button( - "Delete Org", - className="btn btn-theme bg-danger border-danger text-white", - ): - org.soft_delete() - st.experimental_rerun() - else: if st.button( "Leave Org", className="btn btn-theme bg-danger border-danger text-white", ): - org.members.remove(current_user) - st.experimental_rerun() + org_leave_modal.open() def render_org_creation_view(user: AppUser): st.write(f"# {icons.company} Create an Org", unsafe_allow_html=True) name = st.text_input("Team Name") - logo = st.file_uploader("Org Logo", accept=["image/*"]) + logo = st.file_uploader("Logo", accept=["image/*"]) + domain_name = st.text_input("Domain Name (Optional)", placeholder="e.g. gooey.ai") + if domain_name: + st.caption(f"Add any user with `@{domain_name}` email to this organization.") if st.button("Create"): - Org.objects.create_org(created_by=user, name=name, logo=logo) - st.experimental_rerun() + try: + Org.objects.create_org( + created_by=user, name=name, logo=logo, domain_name=domain_name + ) + except ValidationError as e: + st.write(", ".join(e.messages), className="text-danger") + else: + st.experimental_rerun() -def render_org_edit_view(org: Org, *, modal: Modal): - org.name = st.text_input("Org Name", value=org.name) - org.logo = st.file_uploader("Org Logo", accept=["image/*"], value=org.logo) +def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: Modal): + org = membership.org - if st.button("Save"): + org.name = st.text_input("Team Name", value=org.name) + org.logo = st.file_uploader("Logo", accept=["image/*"], value=org.logo) + org.domain_name = st.text_input( + "Domain Name (Optional)", placeholder="e.g. gooey.ai", value=org.domain_name + ) + if org.domain_name: + st.caption( + f"Add any user with `@{org.domain_name}` email to this organization." + ) + + if st.button("Save", className="w-100", type="primary"): try: org.full_clean() except ValidationError as e: @@ -116,74 +138,142 @@ def render_org_edit_view(org: Org, *, modal: Modal): org.save() modal.close() + if membership.can_delete_org() or membership.can_transfer_ownership(): + st.write("---") + render_danger_zone_by_membership(membership) -def render_members_list(org: Org, current_membership: OrgMembership): - with st.raw_table(["Name", "Role", f"{icons.time} Since", ""]): - for m in org.memberships.all().order_by("created_at"): - name_val = m.user.display_name or m.user.first_name() - if current_membership == m: - name_val += " (You)" - with st.tag("tr"): - with st.tag("td"): - if m.user.handle_id: - with st.link(to=m.user.handle.get_app_url()): - st.html(html_lib.escape(name_val)) - else: - st.html(html_lib.escape(name_val)) - with st.tag("td"): - st.html(html_lib.escape(m.get_role_display())) - with st.tag("td"): - st.html(m.created_at.strftime("%b %d, %Y")) - with st.tag("td", className="text-end"): - render_role_change_buttons(m, current_membership=current_membership) - render_membership_buttons(m, current_membership=current_membership) - - -def render_role_change_buttons(m: OrgMembership, current_membership: OrgMembership): - if current_membership.can_change_role(m): - if m.role == OrgRole.MEMBER: + +def render_danger_zone_by_membership(membership: OrgMembership): + st.write("### Danger Zone", className="d-block my-2") + + if membership.can_delete_org(): + org_deletion_modal = Modal("Delete Organization", key="delete-org-modal") + if org_deletion_modal.is_open(): + with org_deletion_modal.container(): + render_org_deletion_view_by_membership( + membership, modal=org_deletion_modal + ) + + with st.div(className="d-flex justify-content-between align-items-center"): + st.write("Delete Organization") if st.button( - f"{icons.admin} Make Admin", - className="btn btn-theme btn-sm my-0 py-0", - unsafe_allow_html=True, + f"{icons.delete} Delete", + className="btn btn-theme py-2 bg-danger border-danger text-white", ): - m.role = OrgRole.ADMIN - m.save() - st.experimental_rerun() - if m.role == OrgRole.ADMIN: + org_deletion_modal.open() + + if membership.can_transfer_ownership(): + with st.div(className="d-flex justify-content-between align-items-center"): + st.write("Transfer Ownership") if st.button( - f"{icons.admin} Remove Admin", - className="btn btn-theme btn-sm my-0 py-0", + f"{icons.transfer} Transfer", + className="btn btn-theme py-2 bg-danger border-danger text-light", unsafe_allow_html=True, ): - m.role = OrgRole.MEMBER - m.save() + m.role = OrgRole.OWNER + membership.role = OrgRole.ADMIN + with transaction.atomic(): + m.save() + membership.save() st.experimental_rerun() - if current_membership.can_transfer_ownership(m): + +def render_org_deletion_view_by_membership(membership: OrgMembership, *, modal: Modal): + st.write( + f"Are you sure you want to delete **{membership.org.name}**? This action is irreversible." + ) + + with st.div(className="d-flex"): if st.button( - f"{icons.transfer_ownership} Transfer Ownership", - className="btn btn-theme btn-sm my-0 py-0 bg-danger border-danger text-light", - unsafe_allow_html=True, + "Cancel", type="secondary", className="border-danger text-danger w-50" ): - m.role = OrgRole.OWNER - current_membership.role = OrgRole.ADMIN - with transaction.atomic(): - m.save() - current_membership.save() - st.experimental_rerun() + modal.close() + + if st.button( + "Delete", className="btn btn-theme bg-danger border-danger text-light w-50" + ): + membership.org.soft_delete() + modal.close() + + +def render_org_leave_view_by_membership(membership: OrgMembership, *, modal: Modal): + st.write("Are you sure you want to leave this organization?") + + if st.button("Cancel", type="secondary", className="border-danger text-danger"): + modal.close() + + if st.button("Leave", className="btn btn-theme bg-danger border-danger text-light"): + membership.org.members.remove(membership.user) + modal.close() -def render_membership_buttons(m: OrgMembership, current_membership: OrgMembership): +def render_members_list(org: Org, current_membership: OrgMembership): + with st.tag("table", className="table table-responsive"): + with st.tag("thead"), st.tag("tr"): + with st.tag("th", scope="col"): + st.html("Name") + with st.tag("th", scope="col"): + st.html("Role") + with st.tag("th", scope="col"): + st.html(f"{icons.time} Since") + with st.tag("th", scope="col"): + st.html("") + + with st.tag("tbody"): + for m in org.memberships.all().order_by("created_at"): + with st.tag("tr"): + with st.tag("td"): + name = format_user_name( + m.user, current_user=current_membership.user + ) + if m.user.handle_id: + with st.link(to=m.user.handle.get_app_url()): + st.html(html_lib.escape(name)) + else: + st.html(html_lib.escape(name)) + with st.tag("td"): + st.html(m.get_role_display()) + with st.tag("td"): + st.html(m.created_at.strftime("%b %d, %Y")) + with st.tag("td", className="text-end"): + render_membership_actions( + m, current_membership=current_membership + ) + + +def render_membership_actions(m: OrgMembership, current_membership: OrgMembership): if current_membership.can_kick(m): + member_deletion_modal = Modal( + "Remove Member", key=f"remove-member-{m.pk}-modal" + ) + if member_deletion_modal.is_open(): + with member_deletion_modal.container(): + render_member_deletion_view(m, modal=member_deletion_modal) + if st.button( f"{icons.remove_user} Remove", className="btn btn-theme btn-sm my-0 py-0 bg-danger border-danger text-light", unsafe_allow_html=True, ): - # TODO: soft-delete - m.delete() - st.experimental_rerun() + member_deletion_modal.open() + + +def render_member_deletion_view(membership: OrgMembership, modal: Modal): + st.write( + f"Are you sure you want to remove **{format_user_name(membership.user)}** from **{membership.org.name}**?" + ) + + with st.div(className="d-flex"): + if st.button( + "Cancel", type="secondary", className="border-danger text-danger w-50" + ): + modal.close() + + if st.button( + "Remove", className="btn btn-theme bg-danger border-danger text-light w-50" + ): + membership.delete() + modal.close() def render_pending_invitations_list(org: Org, current_user: AppUser): @@ -192,19 +282,30 @@ def render_pending_invitations_list(org: Org, current_user: AppUser): return st.write("## Pending") - table = st.raw_table(["Email", "Invited By", f"{icons.time} Invited on"]) - with table: - for invite in pending_invitations: - inviter_name = invite.inviter.display_name or invite.inviter.first_name() - if current_user == invite.inviter: - inviter_name += " (You)" - st.table_row( - [ - invite.invitee_email, - inviter_name, - invite.created_at.strftime("%b %d, %Y"), - ] - ) + with st.tag("table", className="table table-responsive"): + with st.tag("thead"), st.tag("tr"): + with st.tag("th", scope="col"): + st.html("Email") + with st.tag("th", scope="col"): + st.html("Invited By") + with st.tag("th", scope="col"): + st.html(f"{icons.time} Invited on") + + with st.tag("tbody"): + for invite in pending_invitations: + with st.tag("tr", className="text-break"): + with st.tag("td"): + st.html(html_lib.escape(invite.invitee_email)) + with st.tag("td"): + st.html( + html_lib.escape( + format_user_name( + invite.inviter, current_user=current_user + ) + ) + ) + with st.tag("td"): + st.html(invite.created_at.strftime("%b %d, %Y")) def render_invite_creation_view(org: Org, inviter: AppUser, modal: Modal): @@ -213,6 +314,14 @@ def render_invite_creation_view(org: Org, inviter: AppUser, modal: Modal): if st.button(f"{icons.add_user} Invite", type="primary", unsafe_allow_html=True): try: org.invite_user(invitee_email=email, inviter=inviter, role=OrgRole.MEMBER) - modal.close() except ValidationError as e: st.write(", ".join(e.messages), className="text-danger") + else: + modal.close() + + +def format_user_name(user: AppUser, current_user: AppUser | None = None): + name = user.display_name or user.first_name() + if current_user and user == current_user: + name += " (You)" + return name From b7751932cff10edd51802fe4d11c4e916df21ef8 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 18 Jul 2024 16:36:08 +0530 Subject: [PATCH 024/110] make use of django-safedelete library --- daras_ai_v2/settings.py | 1 + orgs/admin.py | 102 ++++++++++++++++-- orgs/migrations/0001_initial.py | 25 ++++- .../0002_org_logo_alter_org_created_by.py | 25 ----- orgs/migrations/0003_orginvitation_role.py | 18 ---- orgs/migrations/0004_org_deleted_at.py | 18 ---- ...ions_alter_org_managers_org_domain_name.py | 30 ------ orgs/models.py | 54 ++++------ 8 files changed, 134 insertions(+), 139 deletions(-) delete mode 100644 orgs/migrations/0002_org_logo_alter_org_created_by.py delete mode 100644 orgs/migrations/0003_orginvitation_role.py delete mode 100644 orgs/migrations/0004_org_deleted_at.py delete mode 100644 orgs/migrations/0005_alter_org_options_alter_org_managers_org_domain_name.py diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 1d1cd28ce..693462c26 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -54,6 +54,7 @@ # the order matters, since we want to override the admin templates "django.forms", # needed to override admin forms "django.contrib.admin", + "safedelete", "app_users", "files", "url_shortener", diff --git a/orgs/admin.py b/orgs/admin.py index fb445abd3..976e8d38e 100644 --- a/orgs/admin.py +++ b/orgs/admin.py @@ -1,19 +1,105 @@ from django.contrib import admin +from safedelete.admin import SafeDeleteAdmin, SafeDeleteAdminFilter -from .models import Org, OrgQuerySet, OrgInvitation, OrgMembership +from bots.admin_links import change_obj_url +from orgs.models import Org, OrgMembership, OrgInvitation + + +class OrgMembershipInline(admin.TabularInline): + model = OrgMembership + extra = 0 + show_change_link = True + fields = ["user", "role", "created_at", "updated_at"] + readonly_fields = ["created_at", "updated_at"] + ordering = ["-created_at"] + can_delete = False + show_change_link = True + + +class OrgInvitationInline(admin.TabularInline): + model = OrgInvitation + extra = 0 + show_change_link = True + fields = [ + "invitee_email", + "inviter", + "role", + "status", + "auto_accepted", + "created_at", + "updated_at", + ] + readonly_fields = ["auto_accepted", "created_at", "updated_at"] + ordering = ["status", "-created_at"] + can_delete = False + show_change_link = True @admin.register(Org) -class OrgAdmin(admin.ModelAdmin): - def get_queryset(self, request): - return OrgQuerySet(self.model).all() +class OrgAdmin(SafeDeleteAdmin): + list_display = [ + "name", + "domain_name", + "created_at", + "updated_at", + ] + list(SafeDeleteAdmin.list_display) + list_filter = [SafeDeleteAdminFilter] + list(SafeDeleteAdmin.list_filter) + fields = ["name", "domain_name", "created_by", "created_at", "updated_at"] + search_fields = ["name", "domain_name"] + readonly_fields = ["created_at", "updated_at"] + inlines = [OrgMembershipInline, OrgInvitationInline] + ordering = ["-created_at"] @admin.register(OrgMembership) -class OrgMembershipAdmin(admin.ModelAdmin): - pass +class OrgMembershipAdmin(SafeDeleteAdmin): + list_display = [ + "user", + "org", + "role", + "created_at", + "updated_at", + ] + list(SafeDeleteAdmin.list_display) + list_filter = ["org", "role", SafeDeleteAdminFilter] + list( + SafeDeleteAdmin.list_filter + ) + + def get_readonly_fields( + self, request: "HttpRequest", obj: OrgMembership | None = None + ) -> list[str]: + readonly_fields = list(super().get_readonly_fields(request, obj)) + if obj and obj.org and obj.org.deleted: + return readonly_fields + ["deleted_org"] + else: + return readonly_fields + + @admin.display + def deleted_org(self, obj): + org = Org.deleted_objects.get(pk=obj.org_id) + return change_obj_url(org) @admin.register(OrgInvitation) -class OrgInvitationAdmin(admin.ModelAdmin): - pass +class OrgInvitationAdmin(SafeDeleteAdmin): + fields = [ + "org", + "invitee_email", + "inviter", + "role", + "status", + "auto_accepted", + "created_at", + "updated_at", + ] + list_display = [ + "org", + "invitee_email", + "inviter", + "status", + "created_at", + "updated_at", + ] + list(SafeDeleteAdmin.list_display) + list_filter = ["org", "inviter", "role", SafeDeleteAdminFilter] + list( + SafeDeleteAdmin.list_filter + ) + readonly_fields = ["auto_accepted"] diff --git a/orgs/migrations/0001_initial.py b/orgs/migrations/0001_initial.py index 1ed2139eb..b8d7747a2 100644 --- a/orgs/migrations/0001_initial.py +++ b/orgs/migrations/0001_initial.py @@ -1,7 +1,8 @@ -# Generated by Django 4.2.7 on 2024-07-11 17:41 +# Generated by Django 4.2.7 on 2024-07-18 10:18 from django.db import migrations, models import django.db.models.deletion +import orgs.models class Migration(migrations.Migration): @@ -9,7 +10,7 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ('app_users', '0017_alter_appuser_subscription'), + ('app_users', '0019_alter_appusertransaction_reason'), ] operations = [ @@ -17,30 +18,42 @@ class Migration(migrations.Migration): name='Org', fields=[ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('deleted', models.DateTimeField(db_index=True, editable=False, null=True)), + ('deleted_by_cascade', models.BooleanField(default=False, editable=False)), ('org_id', models.CharField(blank=True, max_length=100, null=True, unique=True)), ('name', models.CharField(max_length=100)), + ('logo', models.URLField(blank=True, null=True)), + ('domain_name', models.CharField(blank=True, max_length=30, null=True, validators=[orgs.models.validate_org_domain_name])), ('created_at', models.DateTimeField(auto_now_add=True)), ('updated_at', models.DateTimeField(auto_now=True)), - ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, to='app_users.appuser')), + ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='app_users.appuser')), ], ), migrations.CreateModel( name='OrgInvitation', fields=[ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('deleted', models.DateTimeField(db_index=True, editable=False, null=True)), + ('deleted_by_cascade', models.BooleanField(default=False, editable=False)), ('invitee_email', models.EmailField(max_length=254)), ('status', models.IntegerField(choices=[(1, 'Pending'), (2, 'Accepted'), (3, 'Rejected'), (4, 'Canceled')], default=1)), ('auto_accepted', models.BooleanField(default=False)), + ('role', models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3)), ('created_at', models.DateTimeField(auto_now_add=True)), ('updated_at', models.DateTimeField(auto_now=True)), ('inviter', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='app_users.appuser')), ('org', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='invitations', to='orgs.org')), ], + options={ + 'abstract': False, + }, ), migrations.CreateModel( name='OrgMembership', fields=[ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('deleted', models.DateTimeField(db_index=True, editable=False, null=True)), + ('deleted_by_cascade', models.BooleanField(default=False, editable=False)), ('role', models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3)), ('created_at', models.DateTimeField(auto_now_add=True)), ('updated_at', models.DateTimeField(auto_now=True)), @@ -49,7 +62,7 @@ class Migration(migrations.Migration): ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='org_memberships', to='app_users.appuser')), ], options={ - 'unique_together': {('org', 'user')}, + 'unique_together': {('org', 'user', 'deleted')}, }, ), migrations.AddField( @@ -57,4 +70,8 @@ class Migration(migrations.Migration): name='members', field=models.ManyToManyField(related_name='orgs', through='orgs.OrgMembership', to='app_users.appuser'), ), + migrations.AlterUniqueTogether( + name='org', + unique_together={('domain_name', 'deleted')}, + ), ] diff --git a/orgs/migrations/0002_org_logo_alter_org_created_by.py b/orgs/migrations/0002_org_logo_alter_org_created_by.py deleted file mode 100644 index 78084ae6a..000000000 --- a/orgs/migrations/0002_org_logo_alter_org_created_by.py +++ /dev/null @@ -1,25 +0,0 @@ -# Generated by Django 4.2.7 on 2024-07-11 17:44 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - dependencies = [ - ('app_users', '0017_alter_appuser_subscription'), - ('orgs', '0001_initial'), - ] - - operations = [ - migrations.AddField( - model_name='org', - name='logo', - field=models.URLField(blank=True, null=True), - ), - migrations.AlterField( - model_name='org', - name='created_by', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='app_users.appuser'), - ), - ] diff --git a/orgs/migrations/0003_orginvitation_role.py b/orgs/migrations/0003_orginvitation_role.py deleted file mode 100644 index 03ccc44a4..000000000 --- a/orgs/migrations/0003_orginvitation_role.py +++ /dev/null @@ -1,18 +0,0 @@ -# Generated by Django 4.2.7 on 2024-07-11 18:36 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('orgs', '0002_org_logo_alter_org_created_by'), - ] - - operations = [ - migrations.AddField( - model_name='orginvitation', - name='role', - field=models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3), - ), - ] diff --git a/orgs/migrations/0004_org_deleted_at.py b/orgs/migrations/0004_org_deleted_at.py deleted file mode 100644 index a7127a6f7..000000000 --- a/orgs/migrations/0004_org_deleted_at.py +++ /dev/null @@ -1,18 +0,0 @@ -# Generated by Django 4.2.7 on 2024-07-15 13:57 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('orgs', '0003_orginvitation_role'), - ] - - operations = [ - migrations.AddField( - model_name='org', - name='deleted_at', - field=models.DateTimeField(blank=True, default=None, null=True), - ), - ] diff --git a/orgs/migrations/0005_alter_org_options_alter_org_managers_org_domain_name.py b/orgs/migrations/0005_alter_org_options_alter_org_managers_org_domain_name.py deleted file mode 100644 index 0af9f3056..000000000 --- a/orgs/migrations/0005_alter_org_options_alter_org_managers_org_domain_name.py +++ /dev/null @@ -1,30 +0,0 @@ -# Generated by Django 4.2.7 on 2024-07-16 15:24 - -from django.db import migrations, models -import django.db.models.manager -import orgs.models - - -class Migration(migrations.Migration): - - dependencies = [ - ('orgs', '0004_org_deleted_at'), - ] - - operations = [ - migrations.AlterModelOptions( - name='org', - options={'default_manager_name': 'all_objects'}, - ), - migrations.AlterModelManagers( - name='org', - managers=[ - ('all_objects', django.db.models.manager.Manager()), - ], - ), - migrations.AddField( - model_name='org', - name='domain_name', - field=models.CharField(blank=True, max_length=30, null=True, unique=True, validators=[orgs.models.validate_org_domain_name]), - ), - ] diff --git a/orgs/models.py b/orgs/models.py index f0a1fee67..314ad85d0 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -2,8 +2,9 @@ from django.db import models, transaction from django.core.exceptions import ValidationError -from django.utils import timezone from django.utils.text import slugify +from safedelete.managers import SafeDeleteManager +from safedelete.models import SafeDeleteModel, SOFT_DELETE_CASCADE from app_users.models import AppUser from daras_ai_v2.crypto import get_random_doc_id @@ -28,14 +29,7 @@ class OrgRole(models.IntegerChoices): MEMBER = 3 -class OrgQuerySet(models.QuerySet): - pass - - -class OrgManager(models.Manager): - def get_queryset(self): - return OrgQuerySet(self.model, using=self._db).filter(deleted_at__isnull=True) - +class OrgManager(SafeDeleteManager): def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwargs): org = self.model( org_id=org_id or get_random_doc_id(), created_by=created_by, **kwargs @@ -51,7 +45,9 @@ def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwar return org -class Org(models.Model): +class Org(SafeDeleteModel): + _safedelete_policy = SOFT_DELETE_CASCADE + org_id = models.CharField(max_length=100, null=True, blank=True, unique=True) name = models.CharField(max_length=100) @@ -70,7 +66,6 @@ class Org(models.Model): max_length=30, blank=True, null=True, - unique=True, validators=[ validate_org_domain_name, ], @@ -78,32 +73,21 @@ class Org(models.Model): created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) - deleted_at = models.DateTimeField(null=True, blank=True, default=None) - objects = OrgManager() # only active orgs - all_objects = OrgQuerySet.as_manager() # for internal & admin use + objects = OrgManager() class Meta: - default_manager_name = "all_objects" + unique_together = ("domain_name", "deleted") def __str__(self): - if self.is_deleted(): + if self.deleted: return f"[Deleted] {self.name}" - return self.name + else: + return self.name def get_slug(self): return slugify(self.name) - def is_deleted(self): - return self.deleted_at is not None - - def soft_delete(self): - with transaction.atomic(): - for m in self.memberships.all(): - m.delete() - self.deleted_at = timezone.now() - self.save() - def invite_user( self, *, @@ -117,11 +101,13 @@ def invite_user( """ for member in self.members.all(): if member.email == invitee_email: - raise ValidationError(f"{member} is already a member of this org") + raise ValidationError(f"{member} is already a member of this team") for invitation in self.invitations.filter(status=OrgInvitation.Status.PENDING): if invitation.invitee_email == invitee_email: - raise ValidationError(f"{invitee_email} was already invited") + raise ValidationError( + f"{invitee_email} was already invited to this team" + ) invitation = OrgInvitation( org=self, @@ -139,7 +125,7 @@ def invite_user( pass -class OrgMembership(models.Model): +class OrgMembership(SafeDeleteModel): org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="memberships") user = models.ForeignKey( "app_users.AppUser", on_delete=models.CASCADE, related_name="org_memberships" @@ -159,7 +145,7 @@ class OrgMembership(models.Model): updated_at = models.DateTimeField(auto_now=True) class Meta: - unique_together = ("org", "user") + unique_together = ("org", "user", "deleted") def __str__(self): return f"{self.get_role_display()} - {self.user} ({self.org})" @@ -193,7 +179,7 @@ def can_transfer_ownership(self): return self.role == OrgRole.OWNER -class OrgInvitation(models.Model): +class OrgInvitation(SafeDeleteModel): class Status(models.IntegerChoices): PENDING = 1 ACCEPTED = 2 @@ -208,10 +194,6 @@ class Status(models.IntegerChoices): auto_accepted = models.BooleanField(default=False) role = models.IntegerField(choices=OrgRole.choices, default=OrgRole.MEMBER) - # TODO: don't spam invitees! - # invitation_email_count = models.IntegerField(default=0) - # last_invitation_sent_at = models.DateTimeField(null=True, blank=True) - created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) From eea33ef1496515a3efbeb117f1af47da22201a78 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 18 Jul 2024 16:36:37 +0530 Subject: [PATCH 025/110] Add django-safedelete to poetry --- poetry.lock | 39 +++++++++++++++++++++++++++++++++++++-- pyproject.toml | 1 + 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 73fc4dd20..4483d43e0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "absl-py" @@ -1181,6 +1181,21 @@ phonenumberslite = {version = ">=7.0.2", optional = true, markers = "extra == \" phonenumbers = ["phonenumbers (>=7.0.2)"] phonenumberslite = ["phonenumberslite (>=7.0.2)"] +[[package]] +name = "django-safedelete" +version = "1.4.0" +description = "Mask your objects instead of deleting them from your database." +optional = false +python-versions = "*" +files = [ + {file = "django_safedelete-1.4.0-py3-none-any.whl", hash = "sha256:f722845088c00398711fad8961f044cf18badfecaf541bcc616102f46339adda"}, + {file = "django_safedelete-1.4.0.tar.gz", hash = "sha256:ce63f2dd101fec303837ef624592628e022691c3ade2a0893c9fc4c7796e8288"}, +] + +[package.dependencies] +Django = "*" +packaging = "*" + [[package]] name = "docker" version = "7.0.0" @@ -2959,6 +2974,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -4486,6 +4511,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4493,8 +4519,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4511,6 +4544,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4518,6 +4552,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -6466,4 +6501,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "3955eb5901ce23cc6e25cf4d45c9a742d830ea2a63a60000e1cfc1d93c6299a6" +content-hash = "3db7b5843c1e50294e913e5a137bc7984c5c28cbbbf82e5124aa636a270b6d2b" diff --git a/pyproject.toml b/pyproject.toml index d44685d2f..b6a97055a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ azure-cognitiveservices-speech = "^1.37.0" twilio = "^9.2.3" sentry-sdk = {version = "1.45.0", extras = ["loguru"]} gooey-gui = "^0.1.0" +django-safedelete = "^1.4.0" [tool.poetry.group.dev.dependencies] watchdog = "^2.1.9" From 9777d7ab8070086528616d72dcb4456a1273ec3d Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 18 Jul 2024 16:37:37 +0530 Subject: [PATCH 026/110] Move transfer ownership functionality to leave modal --- orgs/views.py | 108 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 73 insertions(+), 35 deletions(-) diff --git a/orgs/views.py b/orgs/views.py index 5b0d86475..3dcacb1df 100644 --- a/orgs/views.py +++ b/orgs/views.py @@ -1,6 +1,5 @@ import html as html_lib from django.core.exceptions import ValidationError -from django.db import transaction import gooey_ui as st from app_users.models import AppUser @@ -80,20 +79,37 @@ def render_org_by_membership(membership: OrgMembership): render_pending_invitations_list(org=org, current_user=current_user) with st.div(className="mt-4"): - if membership.role != OrgRole.OWNER: - # Owners can't leave! They can only delete - org_leave_modal = Modal("Leave Org", key="leave-org-modal") - if org_leave_modal.is_open(): - with org_leave_modal.container(): - render_org_leave_view_by_membership( - membership, modal=org_leave_modal - ) - - if st.button( - "Leave Org", + org_leave_modal = Modal("Leave Org", key="leave-org-modal") + if org_leave_modal.is_open(): + with org_leave_modal.container(): + render_org_leave_view_by_membership(membership, modal=org_leave_modal) + + with st.div(className="text-end"): + leave_org = st.button( + "Leave", className="btn btn-theme bg-danger border-danger text-white", - ): - org_leave_modal.open() + ) + if leave_org: + org_leave_modal.open() + + +# if membership.can_transfer_ownership(): +# transfer_ownership_modal = Modal("Transfer Ownership", key="transfer-ownership") +# if transfer_ownership_modal.is_open(): +# with transfer_ownership_modal.container(): +# render_transfer_ownership_view_by_membership( +# membership, modal=transfer_ownership_modal +# ) +# +# with st.div(className="d-flex justify-content-between align-items-center"): +# st.write("Transfer Ownership") +# if st.button( +# f"{icons.transfer} Transfer", +# className="btn btn-theme py-2 bg-danger border-danger text-light", +# unsafe_allow_html=True, +# ): +# transfer_ownership_modal.open() +# def render_org_creation_view(user: AppUser): @@ -162,21 +178,6 @@ def render_danger_zone_by_membership(membership: OrgMembership): ): org_deletion_modal.open() - if membership.can_transfer_ownership(): - with st.div(className="d-flex justify-content-between align-items-center"): - st.write("Transfer Ownership") - if st.button( - f"{icons.transfer} Transfer", - className="btn btn-theme py-2 bg-danger border-danger text-light", - unsafe_allow_html=True, - ): - m.role = OrgRole.OWNER - membership.role = OrgRole.ADMIN - with transaction.atomic(): - m.save() - membership.save() - st.experimental_rerun() - def render_org_deletion_view_by_membership(membership: OrgMembership, *, modal: Modal): st.write( @@ -194,17 +195,54 @@ def render_org_deletion_view_by_membership(membership: OrgMembership, *, modal: ): membership.org.soft_delete() modal.close() + st.experimental_rerun() + +def render_org_leave_view_by_membership( + current_membership: OrgMembership, *, modal: Modal +): + org = current_membership.org -def render_org_leave_view_by_membership(membership: OrgMembership, *, modal: Modal): st.write("Are you sure you want to leave this organization?") - if st.button("Cancel", type="secondary", className="border-danger text-danger"): - modal.close() + new_owner = None + if current_membership.role == OrgRole.OWNER and org.memberships.count() == 1: + st.caption( + "You are the only member. You will lose access to this team if you leave." + ) + elif ( + current_membership.role == OrgRole.OWNER + and org.memberships.filter(role=OrgRole.OWNER).count() == 1 + ): + members_by_uid = { + m.user.uid: m + for m in org.memberships.all().select_related("user") + if m != current_membership + } + new_owner_uid = st.selectbox( + "New Owner", + options=list(members_by_uid), + format_func=lambda uid: format_user_name(members_by_uid[uid].user), + ) + new_owner = members_by_uid[new_owner_uid] + st.caption( + "You are the only owner of this organization. Please choose another member to promote to owner." + ) + + with st.div(className="d-flex"): + if st.button( + "Cancel", type="secondary", className="border-danger text-danger w-50" + ): + modal.close() - if st.button("Leave", className="btn btn-theme bg-danger border-danger text-light"): - membership.org.members.remove(membership.user) - modal.close() + if st.button( + "Leave", className="btn btn-theme bg-danger border-danger text-light w-50" + ): + if new_owner: + new_owner.role = OrgRole.OWNER + new_owner.save() + org.members.remove(current_membership.user) + modal.close() def render_members_list(org: Org, current_membership: OrgMembership): From 195e089aa25cb8d6a8651c80c009b177949baf9b Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 18 Jul 2024 17:11:28 +0530 Subject: [PATCH 027/110] soft_delete -> delete --- orgs/views.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/orgs/views.py b/orgs/views.py index 3dcacb1df..8e7504c93 100644 --- a/orgs/views.py +++ b/orgs/views.py @@ -193,9 +193,8 @@ def render_org_deletion_view_by_membership(membership: OrgMembership, *, modal: if st.button( "Delete", className="btn btn-theme bg-danger border-danger text-light w-50" ): - membership.org.soft_delete() + membership.org.delete() modal.close() - st.experimental_rerun() def render_org_leave_view_by_membership( From ac1ac0386961891552502b2a1ce199a39615aa29 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 18 Jul 2024 18:29:29 +0530 Subject: [PATCH 028/110] don't show invitation role in inline table --- orgs/admin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/orgs/admin.py b/orgs/admin.py index 976e8d38e..969866f41 100644 --- a/orgs/admin.py +++ b/orgs/admin.py @@ -23,7 +23,6 @@ class OrgInvitationInline(admin.TabularInline): fields = [ "invitee_email", "inviter", - "role", "status", "auto_accepted", "created_at", From fac8b8ac2cd9edf1e79c9ed6a5d105b18b054c01 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 18 Jul 2024 18:30:05 +0530 Subject: [PATCH 029/110] python magic to reuse same form for create/edit org --- orgs/views.py | 75 +++++++++++++++++++++++++-------------------------- 1 file changed, 37 insertions(+), 38 deletions(-) diff --git a/orgs/views.py b/orgs/views.py index 8e7504c93..e2d4d0cef 100644 --- a/orgs/views.py +++ b/orgs/views.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import html as html_lib + from django.core.exceptions import ValidationError import gooey_ui as st @@ -93,37 +96,15 @@ def render_org_by_membership(membership: OrgMembership): org_leave_modal.open() -# if membership.can_transfer_ownership(): -# transfer_ownership_modal = Modal("Transfer Ownership", key="transfer-ownership") -# if transfer_ownership_modal.is_open(): -# with transfer_ownership_modal.container(): -# render_transfer_ownership_view_by_membership( -# membership, modal=transfer_ownership_modal -# ) -# -# with st.div(className="d-flex justify-content-between align-items-center"): -# st.write("Transfer Ownership") -# if st.button( -# f"{icons.transfer} Transfer", -# className="btn btn-theme py-2 bg-danger border-danger text-light", -# unsafe_allow_html=True, -# ): -# transfer_ownership_modal.open() -# - - def render_org_creation_view(user: AppUser): st.write(f"# {icons.company} Create an Org", unsafe_allow_html=True) - name = st.text_input("Team Name") - logo = st.file_uploader("Logo", accept=["image/*"]) - domain_name = st.text_input("Domain Name (Optional)", placeholder="e.g. gooey.ai") - if domain_name: - st.caption(f"Add any user with `@{domain_name}` email to this organization.") + org_fields = render_org_create_or_edit_form() if st.button("Create"): try: Org.objects.create_org( - created_by=user, name=name, logo=logo, domain_name=domain_name + created_by=user, + **org_fields, ) except ValidationError as e: st.write(", ".join(e.messages), className="text-danger") @@ -133,16 +114,7 @@ def render_org_creation_view(user: AppUser): def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: Modal): org = membership.org - - org.name = st.text_input("Team Name", value=org.name) - org.logo = st.file_uploader("Logo", accept=["image/*"], value=org.logo) - org.domain_name = st.text_input( - "Domain Name (Optional)", placeholder="e.g. gooey.ai", value=org.domain_name - ) - if org.domain_name: - st.caption( - f"Add any user with `@{org.domain_name}` email to this organization." - ) + render_org_create_or_edit_form(org=org) if st.button("Save", className="w-100", type="primary"): try: @@ -218,15 +190,16 @@ def render_org_leave_view_by_membership( for m in org.memberships.all().select_related("user") if m != current_membership } + + st.caption( + "You are the only owner of this organization. Please choose another member to promote to owner." + ) new_owner_uid = st.selectbox( "New Owner", options=list(members_by_uid), format_func=lambda uid: format_user_name(members_by_uid[uid].user), ) new_owner = members_by_uid[new_owner_uid] - st.caption( - "You are the only owner of this organization. Please choose another member to promote to owner." - ) with st.div(className="d-flex"): if st.button( @@ -357,8 +330,34 @@ def render_invite_creation_view(org: Org, inviter: AppUser, modal: Modal): modal.close() +def render_org_create_or_edit_form(org: Org | None = None) -> AttrDict | Org: + org_proxy = org or AttrDict() + + org_proxy.name = st.text_input("Team Name", value=org and org.name or "") + org_proxy.logo = st.file_uploader( + "Logo", accept=["image/*"], value=org and org.logo or "" + ) + org_proxy.domain_name = st.text_input( + "Domain Name (Optional)", + placeholder="e.g. gooey.ai", + value=org and org.domain_name or "", + ) + if org_proxy.domain_name: + st.caption( + f"Invite any user with `@{org_proxy.domain_name}` email to this organization." + ) + + return org_proxy + + def format_user_name(user: AppUser, current_user: AppUser | None = None): name = user.display_name or user.first_name() if current_user and user == current_user: name += " (You)" return name + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self From 7447e2b526b2a09f18b0dfcc2ef0b7ef8a10fb45 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:07:41 +0530 Subject: [PATCH 030/110] Add invite_id and other Org model+procedure changes --- orgs/migrations/0001_initial.py | 10 +- ...0002_alter_org_unique_together_and_more.py | 35 ++++ orgs/models.py | 171 +++++++++++++++--- 3 files changed, 183 insertions(+), 33 deletions(-) create mode 100644 orgs/migrations/0002_alter_org_unique_together_and_more.py diff --git a/orgs/migrations/0001_initial.py b/orgs/migrations/0001_initial.py index b8d7747a2..7de84737d 100644 --- a/orgs/migrations/0001_initial.py +++ b/orgs/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.7 on 2024-07-18 10:18 +# Generated by Django 4.2.7 on 2024-07-18 15:41 from django.db import migrations, models import django.db.models.deletion @@ -35,14 +35,18 @@ class Migration(migrations.Migration): ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('deleted', models.DateTimeField(db_index=True, editable=False, null=True)), ('deleted_by_cascade', models.BooleanField(default=False, editable=False)), + ('invite_id', models.CharField(max_length=100, unique=True)), ('invitee_email', models.EmailField(max_length=254)), - ('status', models.IntegerField(choices=[(1, 'Pending'), (2, 'Accepted'), (3, 'Rejected'), (4, 'Canceled')], default=1)), + ('status', models.IntegerField(choices=[(1, 'Pending'), (2, 'Accepted'), (3, 'Rejected'), (4, 'Canceled'), (5, 'Expired')], default=1)), ('auto_accepted', models.BooleanField(default=False)), ('role', models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3)), + ('last_email_sent_at', models.DateTimeField(blank=True, default=False, null=True)), + ('status_changed_at', models.DateTimeField(blank=True, default=False, null=True)), ('created_at', models.DateTimeField(auto_now_add=True)), ('updated_at', models.DateTimeField(auto_now=True)), - ('inviter', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='app_users.appuser')), + ('inviter', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='sent_invitations', to='app_users.appuser')), ('org', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='invitations', to='orgs.org')), + ('status_changed_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='app_users.appuser')), ], options={ 'abstract': False, diff --git a/orgs/migrations/0002_alter_org_unique_together_and_more.py b/orgs/migrations/0002_alter_org_unique_together_and_more.py new file mode 100644 index 000000000..2c5384d67 --- /dev/null +++ b/orgs/migrations/0002_alter_org_unique_together_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 4.2.7 on 2024-07-22 14:45 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('orgs', '0001_initial'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='org', + unique_together=set(), + ), + migrations.AlterField( + model_name='orginvitation', + name='last_email_sent_at', + field=models.DateTimeField(blank=True, default=None, null=True), + ), + migrations.AlterField( + model_name='orginvitation', + name='status_changed_at', + field=models.DateTimeField(blank=True, default=None, null=True), + ), + migrations.AddConstraint( + model_name='org', + constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('domain_name',), name='unique_domain_name_when_not_deleted'), + ), + migrations.RemoveField( + model_name='org', + name='members', + ), + ] diff --git a/orgs/models.py b/orgs/models.py index 314ad85d0..df98064ae 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -1,13 +1,20 @@ import re +from datetime import timedelta from django.db import models, transaction from django.core.exceptions import ValidationError +from django.db.backends.base.schema import logger +from django.db.models.query_utils import Q +from django.utils import timezone from django.utils.text import slugify from safedelete.managers import SafeDeleteManager from safedelete.models import SafeDeleteModel, SOFT_DELETE_CASCADE from app_users.models import AppUser +from daras_ai_v2 import settings +from daras_ai_v2.fastapi_tricks import get_route_url from daras_ai_v2.crypto import get_random_doc_id +from orgs.tasks import send_auto_accepted_email, send_invitation_email ORG_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$") @@ -36,11 +43,9 @@ def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwar ) org.full_clean() org.save() - org.members.add( + org.add_member( created_by, - through_defaults={ - "role": OrgRole.OWNER, - }, + role=OrgRole.OWNER, ) return org @@ -51,11 +56,6 @@ class Org(SafeDeleteModel): org_id = models.CharField(max_length=100, null=True, blank=True, unique=True) name = models.CharField(max_length=100) - members = models.ManyToManyField( - "app_users.AppUser", - through="OrgMembership", - related_name="orgs", - ) created_by = models.ForeignKey( "app_users.appuser", on_delete=models.CASCADE, @@ -77,7 +77,14 @@ class Org(SafeDeleteModel): objects = OrgManager() class Meta: - unique_together = ("domain_name", "deleted") + constraints = [ + models.UniqueConstraint( + fields=["domain_name"], + condition=Q(deleted__isnull=True), + name="unique_domain_name_when_not_deleted", + violation_error_message=f"This domain name is already in use by another team. Contact {settings.SUPPORT_EMAIL} if you think this is a mistake.", + ) + ] def __str__(self): if self.deleted: @@ -88,20 +95,30 @@ def __str__(self): def get_slug(self): return slugify(self.name) + def add_member( + self, user: AppUser, role: OrgRole, invitation: "OrgInvitation | None" = None + ): + OrgMembership( + org=self, + user=user, + role=role, + invitation=invitation, + ).save() + def invite_user( self, *, invitee_email: str, inviter: AppUser, role: OrgRole, - auto_accept: bool = True, - ): + auto_accept: bool = False, + ) -> "OrgInvitation": """ auto_accept: If True, the user will be automatically added if they have an account """ - for member in self.members.all(): - if member.email == invitee_email: - raise ValidationError(f"{member} is already a member of this team") + for member in self.memberships.all().select_related("user"): + if member.user.email == invitee_email: + raise ValidationError(f"{member.user} is already a member of this team") for invitation in self.invitations.filter(status=OrgInvitation.Status.PENDING): if invitation.invitee_email == invitee_email: @@ -110,6 +127,7 @@ def invite_user( ) invitation = OrgInvitation( + invite_id=get_random_doc_id(), org=self, invitee_email=invitee_email, inviter=inviter, @@ -120,10 +138,15 @@ def invite_user( if auto_accept: try: - invitation.accept(auto_accepted=True) + invitation.auto_accept() except AppUser.DoesNotExist: pass + if not invitation.auto_accepted: + invitation.send_email() + + return invitation + class OrgMembership(SafeDeleteModel): org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="memberships") @@ -144,6 +167,8 @@ class OrgMembership(SafeDeleteModel): created_at = models.DateTimeField(auto_now_add=True) # same as joining date updated_at = models.DateTimeField(auto_now=True) + objects = SafeDeleteManager() + class Meta: unique_together = ("org", "user", "deleted") @@ -156,9 +181,6 @@ def can_edit_org_metadata(self): def can_delete_org(self): return self.role == OrgRole.OWNER - def can_invite(self): - return self.role in (OrgRole.OWNER, OrgRole.ADMIN) - def has_higher_role_than(self, other: "OrgMembership"): # creator > owner > admin > member match other.role: @@ -178,6 +200,9 @@ def can_kick(self, other: "OrgMembership"): def can_transfer_ownership(self): return self.role == OrgRole.OWNER + def can_invite(self): + return self.role in (OrgRole.OWNER, OrgRole.ADMIN) + class OrgInvitation(SafeDeleteModel): class Status(models.IntegerChoices): @@ -185,43 +210,129 @@ class Status(models.IntegerChoices): ACCEPTED = 2 REJECTED = 3 CANCELED = 4 + EXPIRED = 5 - org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="invitations") + invite_id = models.CharField(max_length=100, unique=True) invitee_email = models.EmailField() - inviter = models.ForeignKey("app_users.AppUser", on_delete=models.CASCADE) + + org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="invitations") + inviter = models.ForeignKey( + "app_users.AppUser", on_delete=models.CASCADE, related_name="sent_invitations" + ) status = models.IntegerField(choices=Status.choices, default=Status.PENDING) auto_accepted = models.BooleanField(default=False) role = models.IntegerField(choices=OrgRole.choices, default=OrgRole.MEMBER) + last_email_sent_at = models.DateTimeField(null=True, blank=True, default=None) + status_changed_at = models.DateTimeField(null=True, blank=True, default=None) + status_changed_by = models.ForeignKey( + "app_users.AppUser", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="received_invitations", + ) + created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) def __str__(self): return f"{self.invitee_email} - {self.org} ({self.get_status_display()})" - def accept(self, *, auto_accepted: bool = False): + def has_expired(self): + return self.status == self.Status.EXPIRED or ( + timezone.now() - (self.last_email_sent_at or self.created_at) + > timedelta(days=settings.ORG_INVITATION_EXPIRY_DAYS) + ) + + def auto_accept(self): + """ + Automatically accept the invitation if user has an account. + + If user is already part of the team, then the invitation will be canceled. + + Raises: ValidationError + """ assert self.status == self.Status.PENDING invitee = AppUser.objects.get(email=self.invitee_email) + self.accept(invitee, auto_accepted=True) + + if self.auto_accepted: + logger.info(f"User {invitee} auto-accepted invitation to org {self.org}") + send_auto_accepted_email.delay(self.pk) + + def get_url(self): + from routers.account import invitation_route + + return get_route_url( + invitation_route, + params={"invite_id": self.invite_id, "org_slug": self.org.get_slug()}, + ) + + def send_email(self): + # pre-emptively set last_email_sent_at to avoid sending multiple emails concurrently + if not self.can_resend_email(): + raise ValidationError("This user has already been invited recently.") + + self.last_email_sent_at = timezone.now() + self.save(update_fields=["last_email_sent_at"]) + + send_invitation_email.delay(invitation_pk=self.pk) + + def accept(self, user: AppUser, *, auto_accepted: bool = False): + """ + Raises: ValidationError + """ + # can't accept an invitation that is already accepted / rejected / canceled + if self.status != self.Status.PENDING: + raise ValidationError( + f"This invitation has been {self.get_status_display().lower()}." + ) + + if self.has_expired(): + self.status = self.Status.EXPIRED + self.save() + raise ValidationError( + "This invitation has expired. Please ask your team admin to send a new one." + ) + + if self.org.memberships.filter(user_id=user.pk).exists(): + raise ValidationError(f"User is already a member of this team.") self.status = self.Status.ACCEPTED self.auto_accepted = auto_accepted + self.status_changed_at = timezone.now() + self.status_changed_by = user + + self.full_clean() with transaction.atomic(): - self.org.members.add( - invitee, - through_defaults={ - "role": self.role, - "invitation": self, - }, + user.org_memberships.all().delete() # delete current memberships + self.org.add_member( + user, + role=self.role, + invitation=self, ) self.save() - def reject(self): + def reject(self, user: AppUser): self.status = self.Status.REJECTED + self.status_changed_at = timezone.now() + self.status_changed_by = user self.save() - def cancel(self): + def cancel(self, user: AppUser): self.status = self.Status.CANCELED + self.status_changed_at = timezone.now() + self.status_changed_by = user self.save() + + def can_resend_email(self): + if not self.last_email_sent_at: + return True + + return timezone.now() - self.last_email_sent_at > timedelta( + seconds=settings.ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL + ) From 547506ff5c717a6bc3e37ffc86e163b82f9205e5 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:08:16 +0530 Subject: [PATCH 031/110] Add invitation page --- routers/account.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/routers/account.py b/routers/account.py index 36dde5d78..e79cc75dc 100644 --- a/routers/account.py +++ b/routers/account.py @@ -8,6 +8,7 @@ from furl import furl from loguru import logger from requests.models import HTTPError +from starlette.responses import Response from bots.models import PublishedRun, PublishedRunVisibility, Workflow from daras_ai_v2 import icons, paypal @@ -17,9 +18,10 @@ from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_content import raw_build_meta_tags from daras_ai_v2.profiles import edit_user_profile_page +from orgs.models import OrgInvitation from payments.webhooks import PaypalWebhookHandler from routers.root import page_wrapper, get_og_url_path -from orgs.views import orgs_page +from orgs.views import invitation_page, orgs_page from routers.custom_api_router import CustomAPIRouter @@ -158,6 +160,33 @@ def orgs_route(request: Request): ) +@app.post("/invitation/{org_slug}/{invite_id}/") +@st.route +def invitation_route(request: Request, org_slug: str, invite_id: str): + from routers.root import login + + if not request.user or request.user.is_anonymous: + next_url = request.url.path + redirect_url = str(furl(get_route_path(login), query_params={"next": next_url})) + raise RedirectException(redirect_url) + + try: + invitation = OrgInvitation.objects.get(invite_id=invite_id) + except OrgInvitation.DoesNotExist: + return Response(status_code=404) + + with page_wrapper(request): + invitation_page(user=request.user, invitation=invitation) + return dict( + meta=raw_build_meta_tags( + url=str(request.url), + title=f"Join {invitation.org.name} • Gooey.AI", + description=f"Invitation to join {invitation.org.name}", + robots="noindex,nofollow", + ) + ) + + class TabData(typing.NamedTuple): title: str route: typing.Callable From 3beddaa57bbe6d176643e096044e4c9532c63072 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:09:56 +0530 Subject: [PATCH 032/110] Add signals to auto-delete orgs and auto-add members --- daras_ai_v2/settings.py | 5 +++++ orgs/signals.py | 49 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 orgs/signals.py diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 693462c26..3cdd88dc8 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -399,3 +399,8 @@ TWILIO_ACCOUNT_SID = config("TWILIO_ACCOUNT_SID", "") TWILIO_API_KEY_SID = config("TWILIO_API_KEY_SID", "") TWILIO_API_KEY_SECRET = config("TWILIO_API_KEY_SECRET", "") + +ORG_INVITATION_EXPIRY_DAYS = config("ORG_INVITATIONS_EXPIRY_IN_DAYS", 10, cast=int) +ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL = config( + "ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL", 60 * 60 * 24, cast=int # 24 hours +) diff --git a/orgs/signals.py b/orgs/signals.py new file mode 100644 index 000000000..bb23b7e06 --- /dev/null +++ b/orgs/signals.py @@ -0,0 +1,49 @@ +from django.db.models.signals import post_save +from django.dispatch import receiver +from loguru import logger +from safedelete.signals import post_softdelete + +from app_users.models import AppUser +from orgs.models import Org, OrgMembership, OrgRole +from orgs.tasks import send_auto_accepted_email + + +@receiver(post_save, sender=AppUser) +def add_user_existing_org(instance: AppUser, **kwargs): + """ + if the domain name matches + """ + if not instance.email: + return + + email_domain = instance.email.split("@")[1] + org = Org.objects.filter(domain_name=email_domain).first() + if not org: + return + + if instance.received_invitations.exists(): + # user has some existing invitations + return + + org_owner = org.memberships.filter(role=OrgRole.OWNER).first() + if not org_owner: + logger.warning( + f"Org {org} has no owner. Skipping auto-accept for user {instance}" + ) + return + + invitation = org.invite_user( + invitee_email=instance.email, + inviter=org_owner.user, + role=OrgRole.MEMBER, + auto_accept=not instance.org_memberships.exists(), # auto-accept only if user has no existing memberships + ) + + +@receiver(post_softdelete, sender=OrgMembership) +def delete_org_if_no_members_left(instance: OrgMembership, **kwargs): + if instance.org.memberships.exists(): + return + + logger.info(f"Deleting org {instance.org} because it has no members left") + instance.org.delete() From de6ed2f156025f0e5c87e1d586a1162f61cf43be Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:10:44 +0530 Subject: [PATCH 033/110] add invitation page --- daras_ai_v2/icons.py | 3 +- orgs/apps.py | 5 + orgs/views.py | 242 +++++++++++++++++++++++++++++++++---------- 3 files changed, 197 insertions(+), 53 deletions(-) diff --git a/daras_ai_v2/icons.py b/daras_ai_v2/icons.py index c5f84193b..30dcc1e01 100644 --- a/daras_ai_v2/icons.py +++ b/daras_ai_v2/icons.py @@ -10,13 +10,14 @@ camera = '' cancel = '' edit = '' -delete = '' +delete = '' link = '' company = '' copy = '' preview = '' add = '' time = '' +email = '' code = '' chat = '' diff --git a/orgs/apps.py b/orgs/apps.py index 70c7fa169..a75310666 100644 --- a/orgs/apps.py +++ b/orgs/apps.py @@ -4,3 +4,8 @@ class OrgsConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" name = "orgs" + + def ready(self): + from . import signals + + assert signals diff --git a/orgs/views.py b/orgs/views.py index e2d4d0cef..78a38bfc4 100644 --- a/orgs/views.py +++ b/orgs/views.py @@ -9,9 +9,54 @@ from gooey_ui.components.modal import Modal from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole from daras_ai_v2 import icons +from daras_ai_v2.fastapi_tricks import get_route_path -DEFAULT_ORG_LOGO = "https://seccdn.libravatar.org/avatar/40f8d096a3777232204cb3f796c577b7?s=80&forcedefault=y&default=monsterid" +DEFAULT_ORG_LOGO = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/74a37c52-8260-11ee-a297-02420a0001ee/gooey.ai%20-%20A%20pop%20art%20illustration%20of%20robots%20taki...y%20Liechtenstein%20mint%20colour%20is%20main%20city%20Seattle.png" + + +def invitation_page(user: AppUser, invitation: OrgInvitation): + from routers.account import orgs_route + + orgs_page_path = get_route_path(orgs_route) + + with st.div(className="text-center my-5"): + st.write( + f"# Invitation to join {invitation.org.name}", className="d-block mb-5" + ) + + if invitation.org.memberships.filter(user=user).exists(): + # redirect to org page + raise st.RedirectException(orgs_page_path) + + if invitation.status != OrgInvitation.Status.PENDING: + st.write(f"This invitation has been {invitation.get_status_display()}.") + return + + st.write( + f"**{format_user_name(invitation.inviter)}** has invited you to join **{invitation.org.name}**." + ) + + if other_m := user.org_memberships.first(): + st.caption( + f"You are currently a member of [{other_m.org.name}]({orgs_page_path}). You will be removed from that team if you accept this invitation." + ) + accept_label = "Leave and Accept" + else: + accept_label = "Accept" + + with st.div( + className="d-flex justify-content-center align-items-center mx-auto", + style={"max-width": "600px"}, + ): + accept_button = st.button(accept_label, type="primary", className="w-50") + reject_button = st.button("Decline", type="secondary", className="w-50") + + if accept_button: + invitation.accept(user=user) + raise st.RedirectException(orgs_page_path) + if reject_button: + invitation.reject(user=user) def orgs_page(user: AppUser): @@ -53,13 +98,15 @@ def render_org_by_membership(membership: OrgMembership): with st.div(className="d-flex align-items-center"): st.image( org.logo or DEFAULT_ORG_LOGO, - className="my-0 me-2", + className="my-0 me-4 rounded", style={"width": "128px", "height": "128px", "object-fit": "contain"}, ) with st.div(className="d-flex flex-column justify-content-center"): st.write(f"# {org.name}") if org.domain_name: - st.write(f"Domain: `@{org.domain_name}`", className="text-muted") + st.write( + f"Org Domain: `@{org.domain_name}`", className="text-muted" + ) with st.div(className="mt-4"): with st.div(className="d-flex justify-content-between align-items-center"): @@ -67,7 +114,7 @@ def render_org_by_membership(membership: OrgMembership): if membership.can_invite(): invite_modal = Modal("Invite Member", key="invite-member-modal") - if st.button(f"{icons.add_user} Invite Member", type="primary"): + if st.button(f"{icons.add_user} Invite"): invite_modal.open() if invite_modal.is_open(): @@ -76,10 +123,10 @@ def render_org_by_membership(membership: OrgMembership): org=org, inviter=current_user, modal=invite_modal ) - render_members_list(org=org, current_membership=membership) + render_members_list(org=org, current_member=membership) with st.div(className="mt-4"): - render_pending_invitations_list(org=org, current_user=current_user) + render_pending_invitations_list(org=org, current_member=membership) with st.div(className="mt-4"): org_leave_modal = Modal("Leave Org", key="leave-org-modal") @@ -169,26 +216,24 @@ def render_org_deletion_view_by_membership(membership: OrgMembership, *, modal: modal.close() -def render_org_leave_view_by_membership( - current_membership: OrgMembership, *, modal: Modal -): - org = current_membership.org +def render_org_leave_view_by_membership(current_member: OrgMembership, *, modal: Modal): + org = current_member.org st.write("Are you sure you want to leave this organization?") new_owner = None - if current_membership.role == OrgRole.OWNER and org.memberships.count() == 1: + if current_member.role == OrgRole.OWNER and org.memberships.count() == 1: st.caption( "You are the only member. You will lose access to this team if you leave." ) elif ( - current_membership.role == OrgRole.OWNER + current_member.role == OrgRole.OWNER and org.memberships.filter(role=OrgRole.OWNER).count() == 1 ): members_by_uid = { m.user.uid: m for m in org.memberships.all().select_related("user") - if m != current_membership + if m != current_member } st.caption( @@ -213,11 +258,11 @@ def render_org_leave_view_by_membership( if new_owner: new_owner.role = OrgRole.OWNER new_owner.save() - org.members.remove(current_membership.user) + current_member.delete() modal.close() -def render_members_list(org: Org, current_membership: OrgMembership): +def render_members_list(org: Org, current_member: OrgMembership): with st.tag("table", className="table table-responsive"): with st.tag("thead"), st.tag("tr"): with st.tag("th", scope="col"): @@ -234,7 +279,7 @@ def render_members_list(org: Org, current_membership: OrgMembership): with st.tag("tr"): with st.tag("td"): name = format_user_name( - m.user, current_user=current_membership.user + m.user, current_user=current_member.user ) if m.user.handle_id: with st.link(to=m.user.handle.get_app_url()): @@ -246,47 +291,92 @@ def render_members_list(org: Org, current_membership: OrgMembership): with st.tag("td"): st.html(m.created_at.strftime("%b %d, %Y")) with st.tag("td", className="text-end"): - render_membership_actions( - m, current_membership=current_membership - ) - - -def render_membership_actions(m: OrgMembership, current_membership: OrgMembership): - if current_membership.can_kick(m): - member_deletion_modal = Modal( - "Remove Member", key=f"remove-member-{m.pk}-modal" - ) - if member_deletion_modal.is_open(): - with member_deletion_modal.container(): - render_member_deletion_view(m, modal=member_deletion_modal) + render_membership_actions(m, current_member=current_member) + + +def render_membership_actions(m: OrgMembership, current_member: OrgMembership): + if current_member.can_change_role(m): + if m.role == OrgRole.MEMBER: + modal, confirmed = button_with_confirmation_modal( + f"{icons.admin} Make Admin", + key=f"promote-member-{m.pk}", + unsafe_allow_html=True, + confirmation_text=f"Are you sure you want to promote **{format_user_name(m.user)}** to an admin?", + modal_title="Make Admin", + modal_key=f"promote-member-{m.pk}-modal", + ) + if confirmed: + m.role = OrgRole.ADMIN + m.save() + modal.close() + elif m.role == OrgRole.ADMIN: + modal, confirmed = button_with_confirmation_modal( + f"{icons.remove_user} Revoke Admin", + key=f"demote-member-{m.pk}", + unsafe_allow_html=True, + confirmation_text=f"Are you sure you want to revoke admin privileges from **{format_user_name(m.user)}**?", + modal_title="Revoke Admin", + modal_key=f"demote-member-{m.pk}-modal", + ) + if confirmed: + m.role = OrgRole.MEMBER + m.save() + modal.close() - if st.button( + if current_member.can_kick(m): + modal, confirmed = button_with_confirmation_modal( f"{icons.remove_user} Remove", - className="btn btn-theme btn-sm my-0 py-0 bg-danger border-danger text-light", + key=f"remove-member-{m.pk}", unsafe_allow_html=True, - ): - member_deletion_modal.open() + confirmation_text=f"Are you sure you want to remove **{format_user_name(m.user)}** from **{m.org.name}**?", + modal_title="Remove Member", + modal_key=f"remove-member-{m.pk}-modal", + className="bg-danger border-danger text-light", + ) + if confirmed: + m.delete() + modal.close() -def render_member_deletion_view(membership: OrgMembership, modal: Modal): - st.write( - f"Are you sure you want to remove **{format_user_name(membership.user)}** from **{membership.org.name}**?" - ) +def button_with_confirmation_modal( + btn_label: str, + confirmation_text: str, + modal_title: str | None = None, + modal_key: str | None = None, + modal_className: str = "", + **btn_props, +) -> tuple[Modal, bool]: + """ + Returns boolean for whether user confirmed the action or not. + """ - with st.div(className="d-flex"): - if st.button( - "Cancel", type="secondary", className="border-danger text-danger w-50" - ): - modal.close() + modal = Modal(modal_title or btn_label, key=modal_key) + + btn_classes = "btn btn-theme btn-sm my-0 py-0 " + btn_props.pop("className", "") + if st.button(btn_label, className=btn_classes, **btn_props): + modal.open() + + if modal.is_open(): + with modal.container(className=modal_className): + st.write(confirmation_text) + with st.div(className="d-flex"): + if st.button( + "Cancel", + type="secondary", + className="border-danger text-danger w-50", + ): + modal.close() + + confirmed = st.button( + "Confirm", + className="btn btn-theme bg-danger border-danger text-light w-50", + ) + return modal, confirmed - if st.button( - "Remove", className="btn btn-theme bg-danger border-danger text-light w-50" - ): - membership.delete() - modal.close() + return modal, False -def render_pending_invitations_list(org: Org, current_user: AppUser): +def render_pending_invitations_list(org: Org, *, current_member: OrgMembership): pending_invitations = org.invitations.filter(status=OrgInvitation.Status.PENDING) if not pending_invitations: return @@ -299,7 +389,9 @@ def render_pending_invitations_list(org: Org, current_user: AppUser): with st.tag("th", scope="col"): st.html("Invited By") with st.tag("th", scope="col"): - st.html(f"{icons.time} Invited on") + st.html(f"{icons.time} Last invited on") + with st.tag("th", scope="col"): + pass with st.tag("tbody"): for invite in pending_invitations: @@ -310,20 +402,66 @@ def render_pending_invitations_list(org: Org, current_user: AppUser): st.html( html_lib.escape( format_user_name( - invite.inviter, current_user=current_user + invite.inviter, current_user=current_member.user ) ) ) with st.tag("td"): - st.html(invite.created_at.strftime("%b %d, %Y")) + last_invited_at = invite.last_email_sent_at or invite.created_at + st.html(last_invited_at.strftime("%b %d, %Y")) + with st.tag("td", className="text-end"): + render_invitation_actions(invite, current_member=current_member) + + +def render_invitation_actions(invitation: OrgInvitation, current_member: OrgMembership): + if current_member.can_invite() and invitation.can_resend_email(): + modal, confirmed = button_with_confirmation_modal( + f"{icons.email} Resend", + className="btn btn-theme btn-sm my-0 py-0", + key=f"resend-invitation-{invitation.pk}", + unsafe_allow_html=True, + confirmation_text=f"Resend invitation to **{invitation.invitee_email}**?", + modal_title="Resend Invitation", + modal_key=f"resend-invitation-{invitation.pk}-modal", + ) + if confirmed: + try: + invitation.send_email() + except ValidationError as e: + pass + finally: + modal.close() + + if current_member.can_invite(): + modal, confirmed = button_with_confirmation_modal( + f"{icons.delete} Cancel", + key=f"cancel-invitation-{invitation.pk}", + unsafe_allow_html=True, + confirmation_text=f"Are you sure you want to cancel the invitation to **{invitation.invitee_email}**?", + modal_title="Cancel Invitation", + modal_key=f"cancel-invitation-{invitation.pk}-modal", + className="bg-danger border-danger text-light", + ) + if confirmed: + invitation.cancel(user=current_member.user) + modal.close() def render_invite_creation_view(org: Org, inviter: AppUser, modal: Modal): email = st.text_input("Email") + if org.domain_name: + st.caption( + f"Users with `@{org.domain_name}` email will be added automatically." + ) if st.button(f"{icons.add_user} Invite", type="primary", unsafe_allow_html=True): try: - org.invite_user(invitee_email=email, inviter=inviter, role=OrgRole.MEMBER) + org.invite_user( + invitee_email=email, + inviter=inviter, + role=OrgRole.MEMBER, + auto_accept=org.domain_name.lower() == email.split("@")[1].lower(), + ) except ValidationError as e: st.write(", ".join(e.messages), className="text-danger") else: From 28925b0941da5f5aaf0c74afa31009e6e1feccb0 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:10:56 +0530 Subject: [PATCH 034/110] Add email templates --- .../org_invitation_auto_accepted_email.html | 19 ++++++++++++++ templates/org_invitation_email.html | 25 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 templates/org_invitation_auto_accepted_email.html create mode 100644 templates/org_invitation_email.html diff --git a/templates/org_invitation_auto_accepted_email.html b/templates/org_invitation_auto_accepted_email.html new file mode 100644 index 000000000..843fb7426 --- /dev/null +++ b/templates/org_invitation_auto_accepted_email.html @@ -0,0 +1,19 @@ +

+ Hi {{ user.first_name() }}, +

+ +

+ You have been added to the team {{ org.name }} on Gooey.AI. + Visit the teams page to see your team. +

+ +

+ Your invite was automatically accepted because your email domain matches the organization's configured email domain. + If you think this shouldn't have happened, you can leave this organization from the + teams page. +

+ +

+ Cheers,
+ Gooey.AI team +

diff --git a/templates/org_invitation_email.html b/templates/org_invitation_email.html new file mode 100644 index 000000000..c8e12dc87 --- /dev/null +++ b/templates/org_invitation_email.html @@ -0,0 +1,25 @@ +

+ Hi! +

+ +

+ {{ invitation.inviter.display_name or invitation.inviter.first_name() }} has invited + you to join their team {{ invitation.org.name }} on Gooey.AI. +

+ +

+ {% set invitation_url = invitation.get_url() %} + Visit this link to view the invitation: + {{ invitation_url }}. +

+ +

+ The link will expire in {{ settings.ORG_INVITATION_EXPIRY_DAYS }} days. +

+ +

+ Cheers,
+ The Gooey.AI team +

+ +{{ "{{{ pm:unsubscribe }}}" }} From 0821d22b821b1fead201ac1feaf70aad9bcfc3f9 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:15:57 +0530 Subject: [PATCH 035/110] Use UniqueConstraint instead of unique_together for membership --- ...e_domain_name_when_not_deleted_and_more.py | 36 +++++++++++++++++++ orgs/models.py | 8 ++++- 2 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py diff --git a/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py b/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py new file mode 100644 index 000000000..6047919f1 --- /dev/null +++ b/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py @@ -0,0 +1,36 @@ +# Generated by Django 4.2.7 on 2024-07-23 11:45 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('app_users', '0019_alter_appusertransaction_reason'), + ('orgs', '0002_alter_org_unique_together_and_more'), + ] + + operations = [ + migrations.RemoveConstraint( + model_name='org', + name='unique_domain_name_when_not_deleted', + ), + migrations.AlterUniqueTogether( + name='orgmembership', + unique_together=set(), + ), + migrations.AlterField( + model_name='orginvitation', + name='status_changed_by', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='received_invitations', to='app_users.appuser'), + ), + migrations.AddConstraint( + model_name='org', + constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('domain_name',), name='unique_domain_name_when_not_deleted', violation_error_message='This domain name is already in use by another team. Contact Gooey.AI Support if you think this is a mistake.'), + ), + migrations.AddConstraint( + model_name='orgmembership', + constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('org', 'user'), name='unique_org_user'), + ), + ] diff --git a/orgs/models.py b/orgs/models.py index df98064ae..33219ead2 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -170,7 +170,13 @@ class OrgMembership(SafeDeleteModel): objects = SafeDeleteManager() class Meta: - unique_together = ("org", "user", "deleted") + constraints = [ + models.UniqueConstraint( + fields=["org", "user"], + condition=Q(deleted__isnull=True), + name="unique_org_user", + ) + ] def __str__(self): return f"{self.get_role_display()} - {self.user} ({self.org})" From bf4b4ffc3ddf203c47cdcb5e9400f15603ffa5d6 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:30:34 +0530 Subject: [PATCH 036/110] rename get_route_url -> get_app_route_url in orgs/ --- orgs/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/orgs/models.py b/orgs/models.py index 33219ead2..5a19dad78 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -12,7 +12,7 @@ from app_users.models import AppUser from daras_ai_v2 import settings -from daras_ai_v2.fastapi_tricks import get_route_url +from daras_ai_v2.fastapi_tricks import get_app_route_url from daras_ai_v2.crypto import get_random_doc_id from orgs.tasks import send_auto_accepted_email, send_invitation_email @@ -272,9 +272,9 @@ def auto_accept(self): def get_url(self): from routers.account import invitation_route - return get_route_url( + return get_app_route_url( invitation_route, - params={"invite_id": self.invite_id, "org_slug": self.org.get_slug()}, + path_params={"invite_id": self.invite_id, "org_slug": self.org.get_slug()}, ) def send_email(self): From 49dff365492975654cf58d7bbdbc76c65bcbb2f4 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:40:21 +0530 Subject: [PATCH 037/110] Add orgs/tasks.py --- orgs/tasks.py | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 orgs/tasks.py diff --git a/orgs/tasks.py b/orgs/tasks.py new file mode 100644 index 000000000..09258c9ec --- /dev/null +++ b/orgs/tasks.py @@ -0,0 +1,68 @@ +from django.utils import timezone +from loguru import logger + +from celeryapp.celeryconfig import app +from daras_ai_v2 import settings +from daras_ai_v2.fastapi_tricks import get_app_route_url +from daras_ai_v2.send_email import send_email_via_postmark +from daras_ai_v2.settings import templates + + +@app.task +def send_invitation_email(invitation_pk: int): + from orgs.models import OrgInvitation + + invitation = OrgInvitation.objects.get(pk=invitation_pk) + + assert invitation.status == invitation.Status.PENDING + + logger.info( + f"Sending inviation email to {invitation.invitee_email} for org {invitation.org}..." + ) + send_email_via_postmark( + to_address=invitation.invitee_email, + from_address=settings.SUPPORT_EMAIL, + subject=f"[Gooey.AI] Invitation to join {invitation.org.name}", + html_body=templates.get_template("org_invitation_email.html").render( + settings=settings, + invitation=invitation, + ), + message_stream="outbound", + ) + + invitation.last_email_sent_at = timezone.now() + invitation.save() + logger.info("Invitation sent. Saved to DB") + + +@app.task +def send_auto_accepted_email(invitation_pk: int): + from orgs.models import OrgInvitation + from routers.account import orgs_route + + invitation = OrgInvitation.objects.get(pk=invitation_pk) + assert invitation.auto_accepted and invitation.status == invitation.Status.ACCEPTED + assert invitation.status_changed_by + + user = invitation.status_changed_by + if not user.email: + logger.warning(f"User {user} has no email. Skipping auto-accepted email.") + return + + logger.info( + f"Sending auto-accepted email to {user.email} for org {invitation.org}..." + ) + send_email_via_postmark( + to_address=user.email, + from_address=settings.SUPPORT_EMAIL, + subject=f"[Gooey.AI] You've been added to a new team!", + html_body=templates.get_template( + "org_invitation_auto_accepted_email.html" + ).render( + settings=settings, + user=user, + org=invitation.org, + orgs_url=get_app_route_url(orgs_route), + ), + message_stream="outbound", + ) From 10f28ec94cf6c1c7011bf99bc315d9fcd4d6f6bb Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 7 Aug 2024 20:36:45 +0530 Subject: [PATCH 038/110] gooey gui renaming --- orgs/views.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/orgs/views.py b/orgs/views.py index 78a38bfc4..f324bb140 100644 --- a/orgs/views.py +++ b/orgs/views.py @@ -2,11 +2,10 @@ import html as html_lib +import gooey_gui as gui from django.core.exceptions import ValidationError -import gooey_ui as st from app_users.models import AppUser -from gooey_ui.components.modal import Modal from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole from daras_ai_v2 import icons from daras_ai_v2.fastapi_tricks import get_route_path @@ -85,7 +84,7 @@ def render_org_by_membership(membership: OrgMembership): ): with st.div(className="d-flex justify-content-center align-items-center"): if membership.can_edit_org_metadata(): - org_edit_modal = Modal("Edit Org", key="edit-org-modal") + org_edit_modal = gui.Modal("Edit Org", key="edit-org-modal") if org_edit_modal.is_open(): with org_edit_modal.container(): render_org_edit_view_by_membership( @@ -113,7 +112,7 @@ def render_org_by_membership(membership: OrgMembership): st.write("## Members") if membership.can_invite(): - invite_modal = Modal("Invite Member", key="invite-member-modal") + invite_modal = gui.Modal("Invite Member", key="invite-member-modal") if st.button(f"{icons.add_user} Invite"): invite_modal.open() @@ -129,7 +128,7 @@ def render_org_by_membership(membership: OrgMembership): render_pending_invitations_list(org=org, current_member=membership) with st.div(className="mt-4"): - org_leave_modal = Modal("Leave Org", key="leave-org-modal") + org_leave_modal = gui.Modal("Leave Org", key="leave-org-modal") if org_leave_modal.is_open(): with org_leave_modal.container(): render_org_leave_view_by_membership(membership, modal=org_leave_modal) @@ -159,7 +158,7 @@ def render_org_creation_view(user: AppUser): st.experimental_rerun() -def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: Modal): +def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: gui.Modal): org = membership.org render_org_create_or_edit_form(org=org) @@ -182,7 +181,7 @@ def render_danger_zone_by_membership(membership: OrgMembership): st.write("### Danger Zone", className="d-block my-2") if membership.can_delete_org(): - org_deletion_modal = Modal("Delete Organization", key="delete-org-modal") + org_deletion_modal = gui.Modal("Delete Organization", key="delete-org-modal") if org_deletion_modal.is_open(): with org_deletion_modal.container(): render_org_deletion_view_by_membership( @@ -198,7 +197,9 @@ def render_danger_zone_by_membership(membership: OrgMembership): org_deletion_modal.open() -def render_org_deletion_view_by_membership(membership: OrgMembership, *, modal: Modal): +def render_org_deletion_view_by_membership( + membership: OrgMembership, *, modal: gui.Modal +): st.write( f"Are you sure you want to delete **{membership.org.name}**? This action is irreversible." ) @@ -216,7 +217,9 @@ def render_org_deletion_view_by_membership(membership: OrgMembership, *, modal: modal.close() -def render_org_leave_view_by_membership(current_member: OrgMembership, *, modal: Modal): +def render_org_leave_view_by_membership( + current_member: OrgMembership, *, modal: gui.Modal +): org = current_member.org st.write("Are you sure you want to leave this organization?") @@ -345,12 +348,12 @@ def button_with_confirmation_modal( modal_key: str | None = None, modal_className: str = "", **btn_props, -) -> tuple[Modal, bool]: +) -> tuple[gui.Modal, bool]: """ Returns boolean for whether user confirmed the action or not. """ - modal = Modal(modal_title or btn_label, key=modal_key) + modal = gui.Modal(modal_title or btn_label, key=modal_key) btn_classes = "btn btn-theme btn-sm my-0 py-0 " + btn_props.pop("className", "") if st.button(btn_label, className=btn_classes, **btn_props): @@ -447,7 +450,7 @@ def render_invitation_actions(invitation: OrgInvitation, current_member: OrgMemb modal.close() -def render_invite_creation_view(org: Org, inviter: AppUser, modal: Modal): +def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal): email = st.text_input("Email") if org.domain_name: st.caption( From 0eab8eef13b6bf462f329a29ae94c489ff1ea8b5 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:04:16 +0530 Subject: [PATCH 039/110] procfile: use && instead of ; between cd and npm run --- Procfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Procfile b/Procfile index 8711211c2..984315504 100644 --- a/Procfile +++ b/Procfile @@ -19,4 +19,4 @@ dashboard: poetry run streamlit run Home.py --server.port 8501 --server.headless celery: poetry run celery -A celeryapp worker -P threads -c 16 -l DEBUG -ui: cd ../gooey-gui/; PORT=3000 npm run dev +ui: cd ../gooey-gui/ && env PORT=3000 npm run dev From a349eeaad55b9f6d73d5d5b080ce6c1074288f8e Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:05:41 +0530 Subject: [PATCH 040/110] rename st->gui in account.py --- routers/account.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/routers/account.py b/routers/account.py index e79cc75dc..3c3e9c881 100644 --- a/routers/account.py +++ b/routers/account.py @@ -143,7 +143,7 @@ def api_keys_route(request: Request): @app.post("/orgs/") -@st.route +@gui.route def orgs_route(request: Request): with account_page_wrapper(request, AccountTabs.orgs): orgs_tab(request) @@ -161,7 +161,7 @@ def orgs_route(request: Request): @app.post("/invitation/{org_slug}/{invite_id}/") -@st.route +@gui.route def invitation_route(request: Request, org_slug: str, invite_id: str): from routers.root import login From 437025503b80a30b48cb1a28e083099f847e5e3c Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 12 Aug 2024 18:49:57 +0530 Subject: [PATCH 041/110] rename st -> gui --- orgs/views.py | 206 ++++++++++++++++++++++----------------------- routers/account.py | 8 +- 2 files changed, 106 insertions(+), 108 deletions(-) diff --git a/orgs/views.py b/orgs/views.py index f324bb140..ed864cb94 100644 --- a/orgs/views.py +++ b/orgs/views.py @@ -19,41 +19,41 @@ def invitation_page(user: AppUser, invitation: OrgInvitation): orgs_page_path = get_route_path(orgs_route) - with st.div(className="text-center my-5"): - st.write( + with gui.div(className="text-center my-5"): + gui.write( f"# Invitation to join {invitation.org.name}", className="d-block mb-5" ) if invitation.org.memberships.filter(user=user).exists(): # redirect to org page - raise st.RedirectException(orgs_page_path) + raise gui.RedirectException(orgs_page_path) if invitation.status != OrgInvitation.Status.PENDING: - st.write(f"This invitation has been {invitation.get_status_display()}.") + gui.write(f"This invitation has been {invitation.get_status_display()}.") return - st.write( + gui.write( f"**{format_user_name(invitation.inviter)}** has invited you to join **{invitation.org.name}**." ) if other_m := user.org_memberships.first(): - st.caption( + gui.caption( f"You are currently a member of [{other_m.org.name}]({orgs_page_path}). You will be removed from that team if you accept this invitation." ) accept_label = "Leave and Accept" else: accept_label = "Accept" - with st.div( + with gui.div( className="d-flex justify-content-center align-items-center mx-auto", style={"max-width": "600px"}, ): - accept_button = st.button(accept_label, type="primary", className="w-50") - reject_button = st.button("Decline", type="secondary", className="w-50") + accept_button = gui.button(accept_label, type="primary", className="w-50") + reject_button = gui.button("Decline", type="secondary", className="w-50") if accept_button: invitation.accept(user=user) - raise st.RedirectException(orgs_page_path) + raise gui.RedirectException(orgs_page_path) if reject_button: invitation.reject(user=user) @@ -61,7 +61,7 @@ def invitation_page(user: AppUser, invitation: OrgInvitation): def orgs_page(user: AppUser): memberships = user.org_memberships.all() if not memberships: - st.write("*You're not part of an organization yet... Create one?*") + gui.write("*You're not part of an organization yet... Create one?*") render_org_creation_view(user) else: @@ -79,10 +79,10 @@ def render_org_by_membership(membership: OrgMembership): org = membership.org current_user = membership.user - with st.div( + with gui.div( className="d-xs-block d-sm-flex flex-row-reverse justify-content-between" ): - with st.div(className="d-flex justify-content-center align-items-center"): + with gui.div(className="d-flex justify-content-center align-items-center"): if membership.can_edit_org_metadata(): org_edit_modal = gui.Modal("Edit Org", key="edit-org-modal") if org_edit_modal.is_open(): @@ -91,29 +91,29 @@ def render_org_by_membership(membership: OrgMembership): membership, modal=org_edit_modal ) - if st.button(f"{icons.edit} Edit", type="secondary"): + if gui.button(f"{icons.edit} Edit", type="secondary"): org_edit_modal.open() - with st.div(className="d-flex align-items-center"): - st.image( + with gui.div(className="d-flex align-items-center"): + gui.image( org.logo or DEFAULT_ORG_LOGO, className="my-0 me-4 rounded", style={"width": "128px", "height": "128px", "object-fit": "contain"}, ) - with st.div(className="d-flex flex-column justify-content-center"): - st.write(f"# {org.name}") + with gui.div(className="d-flex flex-column justify-content-center"): + gui.write(f"# {org.name}") if org.domain_name: - st.write( + gui.write( f"Org Domain: `@{org.domain_name}`", className="text-muted" ) - with st.div(className="mt-4"): - with st.div(className="d-flex justify-content-between align-items-center"): - st.write("## Members") + with gui.div(className="mt-4"): + with gui.div(className="d-flex justify-content-between align-items-center"): + gui.write("## Members") if membership.can_invite(): invite_modal = gui.Modal("Invite Member", key="invite-member-modal") - if st.button(f"{icons.add_user} Invite"): + if gui.button(f"{icons.add_user} Invite"): invite_modal.open() if invite_modal.is_open(): @@ -124,17 +124,17 @@ def render_org_by_membership(membership: OrgMembership): render_members_list(org=org, current_member=membership) - with st.div(className="mt-4"): + with gui.div(className="mt-4"): render_pending_invitations_list(org=org, current_member=membership) - with st.div(className="mt-4"): + with gui.div(className="mt-4"): org_leave_modal = gui.Modal("Leave Org", key="leave-org-modal") if org_leave_modal.is_open(): with org_leave_modal.container(): render_org_leave_view_by_membership(membership, modal=org_leave_modal) - with st.div(className="text-end"): - leave_org = st.button( + with gui.div(className="text-end"): + leave_org = gui.button( "Leave", className="btn btn-theme bg-danger border-danger text-white", ) @@ -143,42 +143,42 @@ def render_org_by_membership(membership: OrgMembership): def render_org_creation_view(user: AppUser): - st.write(f"# {icons.company} Create an Org", unsafe_allow_html=True) + gui.write(f"# {icons.company} Create an Org", unsafe_allow_html=True) org_fields = render_org_create_or_edit_form() - if st.button("Create"): + if gui.button("Create"): try: Org.objects.create_org( created_by=user, **org_fields, ) except ValidationError as e: - st.write(", ".join(e.messages), className="text-danger") + gui.write(", ".join(e.messages), className="text-danger") else: - st.experimental_rerun() + gui.experimental_rerun() def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: gui.Modal): org = membership.org render_org_create_or_edit_form(org=org) - if st.button("Save", className="w-100", type="primary"): + if gui.button("Save", className="w-100", type="primary"): try: org.full_clean() except ValidationError as e: # newlines in markdown - st.write(" \n".join(e.messages), className="text-danger") + gui.write(" \n".join(e.messages), className="text-danger") else: org.save() modal.close() if membership.can_delete_org() or membership.can_transfer_ownership(): - st.write("---") + gui.write("---") render_danger_zone_by_membership(membership) def render_danger_zone_by_membership(membership: OrgMembership): - st.write("### Danger Zone", className="d-block my-2") + gui.write("### Danger Zone", className="d-block my-2") if membership.can_delete_org(): org_deletion_modal = gui.Modal("Delete Organization", key="delete-org-modal") @@ -188,9 +188,9 @@ def render_danger_zone_by_membership(membership: OrgMembership): membership, modal=org_deletion_modal ) - with st.div(className="d-flex justify-content-between align-items-center"): - st.write("Delete Organization") - if st.button( + with gui.div(className="d-flex justify-content-between align-items-center"): + gui.write("Delete Organization") + if gui.button( f"{icons.delete} Delete", className="btn btn-theme py-2 bg-danger border-danger text-white", ): @@ -200,17 +200,17 @@ def render_danger_zone_by_membership(membership: OrgMembership): def render_org_deletion_view_by_membership( membership: OrgMembership, *, modal: gui.Modal ): - st.write( + gui.write( f"Are you sure you want to delete **{membership.org.name}**? This action is irreversible." ) - with st.div(className="d-flex"): - if st.button( + with gui.div(className="d-flex"): + if gui.button( "Cancel", type="secondary", className="border-danger text-danger w-50" ): modal.close() - if st.button( + if gui.button( "Delete", className="btn btn-theme bg-danger border-danger text-light w-50" ): membership.org.delete() @@ -222,11 +222,11 @@ def render_org_leave_view_by_membership( ): org = current_member.org - st.write("Are you sure you want to leave this organization?") + gui.write("Are you sure you want to leave this organization?") new_owner = None if current_member.role == OrgRole.OWNER and org.memberships.count() == 1: - st.caption( + gui.caption( "You are the only member. You will lose access to this team if you leave." ) elif ( @@ -239,23 +239,23 @@ def render_org_leave_view_by_membership( if m != current_member } - st.caption( + gui.caption( "You are the only owner of this organization. Please choose another member to promote to owner." ) - new_owner_uid = st.selectbox( + new_owner_uid = gui.selectbox( "New Owner", options=list(members_by_uid), format_func=lambda uid: format_user_name(members_by_uid[uid].user), ) new_owner = members_by_uid[new_owner_uid] - with st.div(className="d-flex"): - if st.button( + with gui.div(className="d-flex"): + if gui.button( "Cancel", type="secondary", className="border-danger text-danger w-50" ): modal.close() - if st.button( + if gui.button( "Leave", className="btn btn-theme bg-danger border-danger text-light w-50" ): if new_owner: @@ -266,34 +266,34 @@ def render_org_leave_view_by_membership( def render_members_list(org: Org, current_member: OrgMembership): - with st.tag("table", className="table table-responsive"): - with st.tag("thead"), st.tag("tr"): - with st.tag("th", scope="col"): - st.html("Name") - with st.tag("th", scope="col"): - st.html("Role") - with st.tag("th", scope="col"): - st.html(f"{icons.time} Since") - with st.tag("th", scope="col"): - st.html("") - - with st.tag("tbody"): + with gui.tag("table", className="table table-responsive"): + with gui.tag("thead"), gui.tag("tr"): + with gui.tag("th", scope="col"): + gui.html("Name") + with gui.tag("th", scope="col"): + gui.html("Role") + with gui.tag("th", scope="col"): + gui.html(f"{icons.time} Since") + with gui.tag("th", scope="col"): + gui.html("") + + with gui.tag("tbody"): for m in org.memberships.all().order_by("created_at"): - with st.tag("tr"): - with st.tag("td"): + with gui.tag("tr"): + with gui.tag("td"): name = format_user_name( m.user, current_user=current_member.user ) if m.user.handle_id: - with st.link(to=m.user.handle.get_app_url()): - st.html(html_lib.escape(name)) + with gui.link(to=m.user.handle.get_app_url()): + gui.html(html_lib.escape(name)) else: - st.html(html_lib.escape(name)) - with st.tag("td"): - st.html(m.get_role_display()) - with st.tag("td"): - st.html(m.created_at.strftime("%b %d, %Y")) - with st.tag("td", className="text-end"): + gui.html(html_lib.escape(name)) + with gui.tag("td"): + gui.html(m.get_role_display()) + with gui.tag("td"): + gui.html(m.created_at.strftime("%b %d, %Y")) + with gui.tag("td", className="text-end"): render_membership_actions(m, current_member=current_member) @@ -356,21 +356,21 @@ def button_with_confirmation_modal( modal = gui.Modal(modal_title or btn_label, key=modal_key) btn_classes = "btn btn-theme btn-sm my-0 py-0 " + btn_props.pop("className", "") - if st.button(btn_label, className=btn_classes, **btn_props): + if gui.button(btn_label, className=btn_classes, **btn_props): modal.open() if modal.is_open(): with modal.container(className=modal_className): - st.write(confirmation_text) - with st.div(className="d-flex"): - if st.button( + gui.write(confirmation_text) + with gui.div(className="d-flex"): + if gui.button( "Cancel", type="secondary", className="border-danger text-danger w-50", ): modal.close() - confirmed = st.button( + confirmed = gui.button( "Confirm", className="btn btn-theme bg-danger border-danger text-light w-50", ) @@ -384,35 +384,35 @@ def render_pending_invitations_list(org: Org, *, current_member: OrgMembership): if not pending_invitations: return - st.write("## Pending") - with st.tag("table", className="table table-responsive"): - with st.tag("thead"), st.tag("tr"): - with st.tag("th", scope="col"): - st.html("Email") - with st.tag("th", scope="col"): - st.html("Invited By") - with st.tag("th", scope="col"): - st.html(f"{icons.time} Last invited on") - with st.tag("th", scope="col"): + gui.write("## Pending") + with gui.tag("table", className="table table-responsive"): + with gui.tag("thead"), gui.tag("tr"): + with gui.tag("th", scope="col"): + gui.html("Email") + with gui.tag("th", scope="col"): + gui.html("Invited By") + with gui.tag("th", scope="col"): + gui.html(f"{icons.time} Last invited on") + with gui.tag("th", scope="col"): pass - with st.tag("tbody"): + with gui.tag("tbody"): for invite in pending_invitations: - with st.tag("tr", className="text-break"): - with st.tag("td"): - st.html(html_lib.escape(invite.invitee_email)) - with st.tag("td"): - st.html( + with gui.tag("tr", className="text-break"): + with gui.tag("td"): + gui.html(html_lib.escape(invite.invitee_email)) + with gui.tag("td"): + gui.html( html_lib.escape( format_user_name( invite.inviter, current_user=current_member.user ) ) ) - with st.tag("td"): + with gui.tag("td"): last_invited_at = invite.last_email_sent_at or invite.created_at - st.html(last_invited_at.strftime("%b %d, %Y")) - with st.tag("td", className="text-end"): + gui.html(last_invited_at.strftime("%b %d, %Y")) + with gui.tag("td", className="text-end"): render_invitation_actions(invite, current_member=current_member) @@ -451,13 +451,13 @@ def render_invitation_actions(invitation: OrgInvitation, current_member: OrgMemb def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal): - email = st.text_input("Email") + email = gui.text_input("Email") if org.domain_name: - st.caption( + gui.caption( f"Users with `@{org.domain_name}` email will be added automatically." ) - if st.button(f"{icons.add_user} Invite", type="primary", unsafe_allow_html=True): + if gui.button(f"{icons.add_user} Invite", type="primary", unsafe_allow_html=True): try: org.invite_user( invitee_email=email, @@ -466,7 +466,7 @@ def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal): auto_accept=org.domain_name.lower() == email.split("@")[1].lower(), ) except ValidationError as e: - st.write(", ".join(e.messages), className="text-danger") + gui.write(", ".join(e.messages), className="text-danger") else: modal.close() @@ -474,17 +474,17 @@ def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal): def render_org_create_or_edit_form(org: Org | None = None) -> AttrDict | Org: org_proxy = org or AttrDict() - org_proxy.name = st.text_input("Team Name", value=org and org.name or "") - org_proxy.logo = st.file_uploader( + org_proxy.name = gui.text_input("Team Name", value=org and org.name or "") + org_proxy.logo = gui.file_uploader( "Logo", accept=["image/*"], value=org and org.logo or "" ) - org_proxy.domain_name = st.text_input( + org_proxy.domain_name = gui.text_input( "Domain Name (Optional)", placeholder="e.g. gooey.ai", value=org and org.domain_name or "", ) if org_proxy.domain_name: - st.caption( + gui.caption( f"Invite any user with `@{org_proxy.domain_name}` email to this organization." ) diff --git a/routers/account.py b/routers/account.py index 3c3e9c881..a898501db 100644 --- a/routers/account.py +++ b/routers/account.py @@ -3,9 +3,9 @@ from enum import Enum import gooey_gui as gui -from fastapi import APIRouter from fastapi.requests import Request from furl import furl +from gooey_gui.core import RedirectException from loguru import logger from requests.models import HTTPError from starlette.responses import Response @@ -142,8 +142,7 @@ def api_keys_route(request: Request): ) -@app.post("/orgs/") -@gui.route +@gui.route(app, "/orgs/") def orgs_route(request: Request): with account_page_wrapper(request, AccountTabs.orgs): orgs_tab(request) @@ -160,8 +159,7 @@ def orgs_route(request: Request): ) -@app.post("/invitation/{org_slug}/{invite_id}/") -@gui.route +@gui.route(app, "/invitation/{org_slug}/{invite_id}/") def invitation_route(request: Request, org_slug: str, invite_id: str): from routers.root import login From 3c06688ff79f925e09f3214e7fb4eab222027d64 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:30:03 +0530 Subject: [PATCH 042/110] make org page only accessible to admins --- routers/account.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/routers/account.py b/routers/account.py index a898501db..b52239b2b 100644 --- a/routers/account.py +++ b/routers/account.py @@ -256,9 +256,30 @@ def api_keys_tab(request: Request): def orgs_tab(request: Request): + """only accessible to admins""" + from daras_ai_v2.base import BasePage + + if not BasePage.is_user_admin(request.user): + raise RedirectException(get_route_path(account_route)) + orgs_page(request.user) +def get_tabs(request: Request) -> list[AccountTabs]: + from daras_ai_v2.base import BasePage + + tab_list = [ + AccountTabs.billing, + AccountTabs.profile, + AccountTabs.saved, + AccountTabs.api_keys, + ] + if BasePage.is_user_admin(request.user): + tab_list.append(AccountTabs.orgs) + + return tab_list + + @contextmanager def account_page_wrapper(request: Request, current_tab: TabData): if not request.user or request.user.is_anonymous: @@ -269,7 +290,7 @@ def account_page_wrapper(request: Request, current_tab: TabData): with page_wrapper(request): gui.div(className="mt-5") with gui.nav_tabs(): - for tab in AccountTabs: + for tab in get_tabs(request): with gui.nav_item(tab.url_path, active=tab == current_tab): gui.html(tab.title) From e5ddad713cdbde3fc3a83348e6b41cf0de0e988d Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:35:26 +0530 Subject: [PATCH 043/110] Add org support with role and UI view --- daras_ai_v2/icons.py | 1 - gooey_ui/components/__init__.py | 1009 +++++++++++++++++++++++++++++++ 2 files changed, 1009 insertions(+), 1 deletion(-) create mode 100644 gooey_ui/components/__init__.py diff --git a/daras_ai_v2/icons.py b/daras_ai_v2/icons.py index 30dcc1e01..90334bbc3 100644 --- a/daras_ai_v2/icons.py +++ b/daras_ai_v2/icons.py @@ -24,7 +24,6 @@ admin = '' remove_user = '' add_user = '' -transfer = '' # brands github = '' diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py new file mode 100644 index 000000000..2c27edd1d --- /dev/null +++ b/gooey_ui/components/__init__.py @@ -0,0 +1,1009 @@ +import base64 +import html as html_lib +import math +import textwrap +import typing +from datetime import datetime, timezone + +import numpy as np +from furl import furl + +from daras_ai.image_input import resize_img_scale +from gooey_ui import state +from gooey_ui.pubsub import md5_values + +T = typing.TypeVar("T") +LabelVisibility = typing.Literal["visible", "collapsed"] + +BLANK_OPTION = "———" + + +def _default_format(value: typing.Any) -> str: + if value is None: + return BLANK_OPTION + return str(value) + + +def dummy(*args, **kwargs): + return state.NestingCtx() + + +spinner = dummy +set_page_config = dummy +form = dummy +dataframe = dummy + + +def countdown_timer( + end_time: datetime, + delay_text: str, +) -> state.NestingCtx: + return _node( + "countdown-timer", + endTime=end_time.astimezone(timezone.utc).isoformat(), + delayText=delay_text, + ) + + +def nav_tabs(): + return _node("nav-tabs") + + +def nav_item(href: str, *, active: bool): + return _node("nav-item", to=href, active="true" if active else None) + + +def nav_tab_content(): + return _node("nav-tab-content") + + +def div(**props) -> state.NestingCtx: + return tag("div", **props) + + +def link(*, to: str, **props) -> state.NestingCtx: + return _node("Link", to=to, **props) + + +def tag(tag_name: str, **props) -> state.NestingCtx: + props["__reactjsxelement"] = tag_name + return _node("tag", **props) + + +def html(body: str, **props): + props["className"] = props.get("className", "") + " gui-html-container" + return _node("html", body=body, **props) + + +def write(*objs: typing.Any, line_clamp: int = None, unsafe_allow_html=False, **props): + for obj in objs: + markdown( + obj if isinstance(obj, str) else repr(obj), + line_clamp=line_clamp, + unsafe_allow_html=unsafe_allow_html, + **props, + ) + + +def center(direction="flex-column", className="") -> state.NestingCtx: + return div( + className=f"d-flex justify-content-center align-items-center text-center {direction} {className}" + ) + + +def newline(): + html("
") + + +def markdown( + body: str | None, *, line_clamp: int = None, unsafe_allow_html=False, **props +): + if body is None: + return _node("markdown", body="", **props) + if not unsafe_allow_html: + body = html_lib.escape(body) + props["className"] = ( + props.get("className", "") + " gui-html-container gui-md-container" + ) + return _node("markdown", body=dedent(body).strip(), lineClamp=line_clamp, **props) + + +def _node(name: str, **props): + node = state.RenderTreeNode(name=name, props=props) + node.mount() + return state.NestingCtx(node) + + +def text(body: str, **props): + state.RenderTreeNode( + name="pre", + props=dict(body=dedent(body), **props), + ).mount() + + +def error( + body: str, + icon: str = "🔥", + *, + unsafe_allow_html=False, + color="rgba(255, 108, 108, 0.2)", + **props, +): + if not isinstance(body, str): + body = repr(body) + with div( + style=dict( + backgroundColor=color, + padding="1rem", + paddingBottom="0", + marginBottom="0.5rem", + borderRadius="0.25rem", + display="flex", + gap="0.5rem", + ) + ): + markdown(icon) + with div(): + markdown(dedent(body), unsafe_allow_html=unsafe_allow_html, **props) + + +def success(body: str, icon: str = "✅", *, unsafe_allow_html=False): + if not isinstance(body, str): + body = repr(body) + with div( + style=dict( + backgroundColor="rgba(108, 255, 108, 0.2)", + padding="1rem", + paddingBottom="0", + marginBottom="0.5rem", + borderRadius="0.25rem", + display="flex", + gap="0.5rem", + ) + ): + markdown(icon) + markdown(dedent(body), unsafe_allow_html=unsafe_allow_html) + + +def caption(body: str, className: str = None, **props): + className = className or "text-muted" + markdown(body, className=className, **props) + + +def tabs(labels: list[str]) -> list[state.NestingCtx]: + parent = state.RenderTreeNode( + name="tabs", + children=[ + state.RenderTreeNode( + name="tab", + props=dict(label=dedent(label)), + ) + for label in labels + ], + ).mount() + return [state.NestingCtx(tab) for tab in parent.children] + + +def controllable_tabs( + labels: list[str], key: str +) -> tuple[list[state.NestingCtx], int]: + index = state.session_state.get(key, 0) + for i, label in enumerate(labels): + if button( + label, + key=f"tab-{i}", + type="primary", + className="replicate-nav", + style={ + "background": "black" if i == index else "white", + "color": "white" if i == index else "black", + }, + ): + state.session_state[key] = index = i + state.experimental_rerun() + ctxs = [] + for i, label in enumerate(labels): + if i == index: + ctxs += [div(className="tab-content")] + else: + ctxs += [div(className="tab-content", style={"display": "none"})] + return ctxs, index + + +def columns( + spec, + *, + gap: str = None, + responsive: bool = True, + column_props: dict = {}, + **props, +) -> tuple[state.NestingCtx, ...]: + if isinstance(spec, int): + spec = [1] * spec + total_weight = sum(spec) + props.setdefault("className", "row") + with div(**props): + return tuple( + div( + className=f"col-lg-{p} {'col-12' if responsive else f'col-{p}'}", + **column_props, + ) + for w in spec + if (p := f"{round(w / total_weight * 12)}") + ) + + +def image( + src: str | np.ndarray, + caption: str = None, + alt: str = None, + href: str = None, + show_download_button: bool = False, + **props, +): + if isinstance(src, np.ndarray): + from daras_ai.image_input import cv2_img_to_bytes + + if not src.shape: + return + # ensure image is not too large + data = resize_img_scale(cv2_img_to_bytes(src), (128, 128)) + # convert to base64 + b64 = base64.b64encode(data).decode("utf-8") + src = "data:image/png;base64," + b64 + if not src: + return + state.RenderTreeNode( + name="img", + props=dict( + src=src, + caption=dedent(caption), + alt=alt or caption, + href=href, + **props, + ), + ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) + + +def video( + src: str, + caption: str = None, + autoplay: bool = False, + show_download_button: bool = False, +): + autoplay_props = {} + if autoplay: + autoplay_props = { + "preload": "auto", + "controls": True, + "autoPlay": True, + "loop": True, + "muted": True, + "playsInline": True, + } + + if not src: + return + if isinstance(src, str): + # https://muffinman.io/blog/hack-for-ios-safari-to-display-html-video-thumbnail/ + f = furl(src) + f.fragment.args["t"] = "0.001" + src = f.url + state.RenderTreeNode( + name="video", + props=dict(src=src, caption=dedent(caption), **autoplay_props), + ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) + + +def audio(src: str, caption: str = None, show_download_button: bool = False): + if not src: + return + state.RenderTreeNode( + name="audio", + props=dict(src=src, caption=dedent(caption)), + ).mount() + if show_download_button: + download_button( + label=' Download', url=src + ) + + +def text_area( + label: str, + value: str = "", + height: int = 500, + key: str = None, + help: str = None, + placeholder: str = None, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + **props, +) -> str: + style = props.setdefault("style", {}) + # if key: + # assert not value, "only one of value or key can be provided" + # else: + if not key: + key = md5_values( + "textarea", + label, + height, + help, + placeholder, + label_visibility, + not disabled or value, + ) + value = str(state.session_state.setdefault(key, value) or "") + if label_visibility != "visible": + label = None + if disabled: + max_height = f"{height}px" + rows = nrows_for_text(value, height) + else: + max_height = "50vh" + rows = nrows_for_text(value, height) + style.setdefault("maxHeight", max_height) + props.setdefault("rows", rows) + state.RenderTreeNode( + name="textarea", + props=dict( + name=key, + label=dedent(label), + defaultValue=value, + help=help, + placeholder=placeholder, + disabled=disabled, + **props, + ), + ).mount() + return value or "" + + +def nrows_for_text( + text: str, + max_height_px: int, + min_rows: int = 1, + row_height_px: int = 30, + row_width_px: int = 70, +) -> int: + max_rows = max_height_px // row_height_px + nrows = math.ceil( + sum( + math.ceil(len(line) / row_width_px) + for line in (text or "").splitlines(keepends=True) + ) + ) + nrows = min(max(nrows, min_rows), max_rows) + return nrows + + +def multiselect( + label: str, + options: typing.Sequence[T], + format_func: typing.Callable[[T], typing.Any] = _default_format, + key: str = None, + help: str = None, + allow_none: bool = False, + *, + disabled: bool = False, +) -> list[T]: + if not options: + return [] + options = list(options) + if not key: + key = md5_values("multiselect", label, options, help) + value = state.session_state.get(key) or [] + if not isinstance(value, list): + value = [value] + value = [o for o in value if o in options] + if not allow_none and not value: + value = [options[0]] + state.session_state[key] = value + state.RenderTreeNode( + name="select", + props=dict( + name=key, + label=dedent(label), + help=help, + isDisabled=disabled, + isMulti=True, + defaultValue=value, + allow_none=allow_none, + options=[ + {"value": option, "label": str(format_func(option))} + for option in options + ], + ), + ).mount() + return value + + +def selectbox( + label: str, + options: typing.Iterable[T], + format_func: typing.Callable[[T], typing.Any] = _default_format, + key: str = None, + help: str = None, + *, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + value: T = None, + allow_none: bool = False, + **props, +) -> T | None: + if not options: + return None + if label_visibility != "visible": + label = None + options = list(options) + if allow_none: + options.insert(0, None) + if not key: + key = md5_values("select", label, options, help, label_visibility) + value = state.session_state.setdefault(key, value) + if value not in options: + value = state.session_state[key] = options[0] + state.RenderTreeNode( + name="select", + props=dict( + name=key, + label=dedent(label), + help=help, + isDisabled=disabled, + defaultValue=value, + options=[ + {"value": option, "label": str(format_func(option))} + for option in options + ], + **props, + ), + ).mount() + return value + + +def download_button( + label: str, + url: str, + key: str = None, + help: str = None, + *, + type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", + disabled: bool = False, + **props, +) -> bool: + url = furl(url).remove(fragment=True).url + return button( + component="download-button", + url=url, + label=label, + key=key, + help=help, + type=type, + disabled=disabled, + **props, + ) + + +def button( + label: str, + key: str = None, + help: str = None, + *, + type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", + disabled: bool = False, + component: typing.Literal["download-button", "gui-button"] = "gui-button", + **props, +) -> bool: + """ + Example: + st.button("Primary", key="test0", type="primary") + st.button("Secondary", key="test1") + st.button("Tertiary", key="test3", type="tertiary") + st.button("Link Button", key="test3", type="link") + """ + if not key: + key = md5_values("button", label, help, type, props) + className = f"btn-{type} " + props.pop("className", "") + state.RenderTreeNode( + name=component, + props=dict( + type="submit", + value="yes", + name=key, + label=dedent(label), + help=help, + disabled=disabled, + className=className, + **props, + ), + ).mount() + return bool(state.session_state.pop(key, False)) + + +def anchor( + label: str, + href: str, + *, + type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", + disabled: bool = False, + unsafe_allow_html: bool = False, + new_tab: bool = False, + **props, +): + className = f"btn btn-theme btn-{type} " + props.pop("className", "") + style = props.pop("style", {}) + if disabled: + style["pointerEvents"] = "none" + if new_tab: + props["target"] = "_blank" + with tag("a", href=href, className=className, style=style, **props): + markdown(dedent(label), unsafe_allow_html=unsafe_allow_html) + + +form_submit_button = button + + +def expander(label: str, *, expanded: bool = False, key: str = None, **props): + node = state.RenderTreeNode( + name="expander", + props=dict( + label=dedent(label), + open=expanded, + name=key or md5_values(label, expanded, props), + **props, + ), + ) + node.mount() + return state.NestingCtx(node) + + +def file_uploader( + label: str, + accept: list[str] = None, + accept_multiple_files=False, + key: str = None, + value: str | list[str] = None, + upload_key: str = None, + help: str = None, + *, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + upload_meta: dict = None, + optional: bool = False, +) -> str | list[str] | None: + if label_visibility != "visible": + label = None + key = upload_key or key + if not key: + key = md5_values( + "file_uploader", + label, + accept, + accept_multiple_files, + help, + label_visibility, + ) + if optional: + if not checkbox( + label, value=bool(state.session_state.get(key, value)), disabled=disabled + ): + state.session_state.pop(key, None) + return None + label = None + value = state.session_state.setdefault(key, value) + if not value: + if accept_multiple_files: + value = [] + else: + value = None + state.session_state[key] = value + state.RenderTreeNode( + name="input", + props=dict( + type="file", + name=key, + label=dedent(label), + help=help, + disabled=disabled, + accept=accept, + multiple=accept_multiple_files, + defaultValue=value, + uploadMeta=upload_meta, + ), + ).mount() + return value + + +def json(value: typing.Any, expanded: bool = False, depth: int = 1): + state.RenderTreeNode( + name="json", + props=dict( + value=value, + expanded=expanded, + defaultInspectDepth=3 if expanded else depth, + ), + ).mount() + + +def data_table(file_url_or_cells: str | list): + if isinstance(file_url_or_cells, str): + file_url = file_url_or_cells + return _node("data-table", fileUrl=file_url) + else: + cells = file_url_or_cells + return _node("data-table-raw", cells=cells) + + +def table(df: "pd.DataFrame"): + with tag("table", className="table table-striped table-sm"): + with tag("thead"): + with tag("tr"): + for col in df.columns: + with tag("th", scope="col"): + html(dedent(col)) + with tag("tbody"): + for row in df.itertuples(index=False): + with tag("tr"): + for value in row: + with tag("td"): + html(dedent(str(value))) + + +def raw_table(header: list[str], className: str = "", **props) -> state.NestingCtx: + className = "table " + className + with tag("table", className=className, **props): + if header: + with tag("thead"), tag("tr"): + for col in header: + with tag("th", scope="col"): + html(dedent(col)) + + return tag("tbody") + + +def table_row(values: list[str], **props): + row = tag("tr", **props) + with row: + for v in values: + with tag("td"): + html(html_lib.escape(v)) + return row + + +def horizontal_radio( + label: str, + options: typing.Sequence[T], + format_func: typing.Callable[[T], typing.Any] = _default_format, + *, + key: str = None, + help: str = None, + value: T = None, + disabled: bool = False, + checked_by_default: bool = True, + label_visibility: LabelVisibility = "visible", + **button_props, +) -> T | None: + if not options: + return None + options = list(options) + if not key: + key = md5_values("horizontal_radio", label, options, help, label_visibility) + value = state.session_state.setdefault(key, value) + if value not in options and checked_by_default: + value = state.session_state[key] = options[0] + if label_visibility != "visible": + label = None + markdown(label) + for option in options: + if button( + format_func(option), + key=f"tab-{key}-{option}", + type="primary", + className="replicate-nav " + ("active" if value == option else ""), + disabled=disabled, + **button_props, + ): + state.session_state[key] = value = option + state.experimental_rerun() + return value + + +def radio( + label: str, + options: typing.Sequence[T], + format_func: typing.Callable[[T], typing.Any] = _default_format, + key: str = None, + value: T = None, + help: str = None, + *, + disabled: bool = False, + checked_by_default: bool = True, + label_visibility: LabelVisibility = "visible", +) -> T | None: + if not options: + return None + options = list(options) + if not key: + key = md5_values("radio", label, options, help, label_visibility) + value = state.session_state.setdefault(key, value) + if value not in options and checked_by_default: + value = state.session_state[key] = options[0] + if label_visibility != "visible": + label = None + markdown(label) + for option in options: + state.RenderTreeNode( + name="input", + props=dict( + type="radio", + name=key, + label=dedent(str(format_func(option))), + value=option, + defaultChecked=bool(value == option), + help=help, + disabled=disabled, + ), + ).mount() + return value + + +def text_input( + label: str, + value: str = "", + max_chars: str = None, + key: str = None, + help: str = None, + *, + placeholder: str = None, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + **props, +) -> str: + value = _input_widget( + input_type="text", + label=label, + value=value, + key=key, + help=help, + disabled=disabled, + label_visibility=label_visibility, + maxLength=max_chars, + placeholder=placeholder, + **props, + ) + return value or "" + + +def date_input( + label: str, + value: str | None = None, + key: str = None, + help: str = None, + *, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + **props, +) -> datetime | None: + value = _input_widget( + input_type="date", + label=label, + value=value, + key=key, + help=help, + disabled=disabled, + label_visibility=label_visibility, + style=dict( + border="1px solid hsl(0, 0%, 80%)", + padding="0.375rem 0.75rem", + borderRadius="0.25rem", + margin="0 0.5rem 0 0.5rem", + ), + **props, + ) + try: + return datetime.strptime(value, "%Y-%m-%d") if value else None + except ValueError: + return None + + +def password_input( + label: str, + value: str = "", + max_chars: str = None, + key: str = None, + help: str = None, + *, + placeholder: str = None, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + **props, +) -> str: + value = _input_widget( + input_type="password", + label=label, + value=value, + key=key, + help=help, + disabled=disabled, + label_visibility=label_visibility, + maxLength=max_chars, + placeholder=placeholder, + **props, + ) + return value or "" + + +def slider( + label: str, + min_value: float = None, + max_value: float = None, + value: float = None, + step: float = None, + key: str = None, + help: str = None, + *, + disabled: bool = False, +) -> float: + value = _input_widget( + input_type="range", + label=label, + value=value, + key=key, + help=help, + disabled=disabled, + min=min_value, + max=max_value, + step=_step_value(min_value, max_value, step), + ) + return value or 0 + + +def number_input( + label: str, + min_value: float = None, + max_value: float = None, + value: float = None, + step: float = None, + key: str = None, + help: str = None, + *, + disabled: bool = False, +) -> float: + value = _input_widget( + input_type="number", + inputMode="decimal", + label=label, + value=value, + key=key, + help=help, + disabled=disabled, + min=min_value, + max=max_value, + step=_step_value(min_value, max_value, step), + ) + return value or 0 + + +def _step_value( + min_value: float | None, max_value: float | None, step: float | None +) -> float: + if step: + return step + elif isinstance(min_value, float) or isinstance(max_value, float): + return 0.1 + else: + return 1 + + +def checkbox( + label: str, + value: bool = False, + key: str = None, + help: str = None, + *, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + **props, +) -> bool: + value = _input_widget( + input_type="checkbox", + label=label, + value=value, + key=key, + help=help, + disabled=disabled, + label_visibility=label_visibility, + default_value_attr="defaultChecked", + **props, + ) + return bool(value) + + +def _input_widget( + *, + input_type: str, + label: str, + value: typing.Any = None, + key: str = None, + help: str = None, + disabled: bool = False, + label_visibility: LabelVisibility = "visible", + default_value_attr: str = "defaultValue", + **kwargs, +) -> typing.Any: + # if key: + # assert not value, "only one of value or key can be provided" + # else: + if not key: + key = md5_values("input", input_type, label, help, label_visibility) + value = state.session_state.setdefault(key, value) + if label_visibility != "visible": + label = None + state.RenderTreeNode( + name="input", + props={ + "type": input_type, + "name": key, + "label": dedent(label), + default_value_attr: value, + "help": help, + "disabled": disabled, + **kwargs, + }, + ).mount() + return value + + +def breadcrumbs(divider: str = "/", **props) -> state.NestingCtx: + style = props.pop("style", {}) | {"--bs-breadcrumb-divider": f"'{divider}'"} + with tag("nav", style=style, **props): + return tag("ol", className="breadcrumb mb-0") + + +def breadcrumb_item(inner_html: str, link_to: str | None = None, **props): + className = "breadcrumb-item " + props.pop("className", "") + with tag("li", className=className, **props): + if link_to: + with tag("a", href=link_to): + html(inner_html) + else: + html(inner_html) + + +def plotly_chart(figure_or_data, **kwargs): + data = ( + figure_or_data.to_plotly_json() + if hasattr(figure_or_data, "to_plotly_json") + else figure_or_data + ) + state.RenderTreeNode( + name="plotly-chart", + props=dict( + chart=data, + args=kwargs, + ), + ).mount() + + +def dedent(text: str | None) -> str | None: + if not text: + return text + return textwrap.dedent(text) + + +def js(src: str, **kwargs): + state.RenderTreeNode( + name="script", + props=dict( + src=src, + args=kwargs, + ), + ).mount() From 64fb8bda877980899ee114eed0939b3b76ea18bd Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 17 Jul 2024 18:45:30 +0530 Subject: [PATCH 044/110] Make modals rounded --- gooey_ui/components/modal.py | 97 ++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 gooey_ui/components/modal.py diff --git a/gooey_ui/components/modal.py b/gooey_ui/components/modal.py new file mode 100644 index 000000000..72e951fc8 --- /dev/null +++ b/gooey_ui/components/modal.py @@ -0,0 +1,97 @@ +from contextlib import contextmanager + +import gooey_ui as st +from gooey_ui import experimental_rerun as rerun + + +class Modal: + def __init__(self, title, key, padding=20, max_width=744): + """ + :param title: title of the Modal shown in the h1 + :param key: unique key identifying this modal instance + :param padding: padding of the content within the modal + :param max_width: maximum width this modal should use + """ + self.title = title + self.padding = padding + self.max_width = str(max_width) + "px" + self.key = key + + self._container = None + + def is_open(self): + return st.session_state.get(f"{self.key}-opened", False) + + def open(self): + st.session_state[f"{self.key}-opened"] = True + rerun() + + def close(self, rerun_condition=True): + st.session_state[f"{self.key}-opened"] = False + if rerun_condition: + rerun() + + def empty(self): + if self._container: + self._container.empty() + + @contextmanager + def container(self, **props): + st.html( + f""" + + """ + ) + + with st.div(className="blur-background"): + with st.div(className="modal-parent"): + container_class = "modal-container " + props.pop("className", "") + self._container = st.div(className=container_class, **props) + + with self._container: + with st.div(className="d-flex justify-content-between align-items-center"): + if self.title: + st.markdown(f"### {self.title}") + else: + st.div() + + close_ = st.button( + "✖", + type="tertiary", + key=f"{self.key}-close", + style={"padding": "0.375rem 0.75rem"}, + ) + if close_: + self.close() + yield self._container From 4c72b4690aef1784a683bbf5789d1da126523767 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:08:16 +0530 Subject: [PATCH 045/110] Add invitation page --- gooey_ui/components/__init__.py | 1009 ------------------------------- gooey_ui/components/modal.py | 97 --- 2 files changed, 1106 deletions(-) delete mode 100644 gooey_ui/components/__init__.py delete mode 100644 gooey_ui/components/modal.py diff --git a/gooey_ui/components/__init__.py b/gooey_ui/components/__init__.py deleted file mode 100644 index 2c27edd1d..000000000 --- a/gooey_ui/components/__init__.py +++ /dev/null @@ -1,1009 +0,0 @@ -import base64 -import html as html_lib -import math -import textwrap -import typing -from datetime import datetime, timezone - -import numpy as np -from furl import furl - -from daras_ai.image_input import resize_img_scale -from gooey_ui import state -from gooey_ui.pubsub import md5_values - -T = typing.TypeVar("T") -LabelVisibility = typing.Literal["visible", "collapsed"] - -BLANK_OPTION = "———" - - -def _default_format(value: typing.Any) -> str: - if value is None: - return BLANK_OPTION - return str(value) - - -def dummy(*args, **kwargs): - return state.NestingCtx() - - -spinner = dummy -set_page_config = dummy -form = dummy -dataframe = dummy - - -def countdown_timer( - end_time: datetime, - delay_text: str, -) -> state.NestingCtx: - return _node( - "countdown-timer", - endTime=end_time.astimezone(timezone.utc).isoformat(), - delayText=delay_text, - ) - - -def nav_tabs(): - return _node("nav-tabs") - - -def nav_item(href: str, *, active: bool): - return _node("nav-item", to=href, active="true" if active else None) - - -def nav_tab_content(): - return _node("nav-tab-content") - - -def div(**props) -> state.NestingCtx: - return tag("div", **props) - - -def link(*, to: str, **props) -> state.NestingCtx: - return _node("Link", to=to, **props) - - -def tag(tag_name: str, **props) -> state.NestingCtx: - props["__reactjsxelement"] = tag_name - return _node("tag", **props) - - -def html(body: str, **props): - props["className"] = props.get("className", "") + " gui-html-container" - return _node("html", body=body, **props) - - -def write(*objs: typing.Any, line_clamp: int = None, unsafe_allow_html=False, **props): - for obj in objs: - markdown( - obj if isinstance(obj, str) else repr(obj), - line_clamp=line_clamp, - unsafe_allow_html=unsafe_allow_html, - **props, - ) - - -def center(direction="flex-column", className="") -> state.NestingCtx: - return div( - className=f"d-flex justify-content-center align-items-center text-center {direction} {className}" - ) - - -def newline(): - html("
") - - -def markdown( - body: str | None, *, line_clamp: int = None, unsafe_allow_html=False, **props -): - if body is None: - return _node("markdown", body="", **props) - if not unsafe_allow_html: - body = html_lib.escape(body) - props["className"] = ( - props.get("className", "") + " gui-html-container gui-md-container" - ) - return _node("markdown", body=dedent(body).strip(), lineClamp=line_clamp, **props) - - -def _node(name: str, **props): - node = state.RenderTreeNode(name=name, props=props) - node.mount() - return state.NestingCtx(node) - - -def text(body: str, **props): - state.RenderTreeNode( - name="pre", - props=dict(body=dedent(body), **props), - ).mount() - - -def error( - body: str, - icon: str = "🔥", - *, - unsafe_allow_html=False, - color="rgba(255, 108, 108, 0.2)", - **props, -): - if not isinstance(body, str): - body = repr(body) - with div( - style=dict( - backgroundColor=color, - padding="1rem", - paddingBottom="0", - marginBottom="0.5rem", - borderRadius="0.25rem", - display="flex", - gap="0.5rem", - ) - ): - markdown(icon) - with div(): - markdown(dedent(body), unsafe_allow_html=unsafe_allow_html, **props) - - -def success(body: str, icon: str = "✅", *, unsafe_allow_html=False): - if not isinstance(body, str): - body = repr(body) - with div( - style=dict( - backgroundColor="rgba(108, 255, 108, 0.2)", - padding="1rem", - paddingBottom="0", - marginBottom="0.5rem", - borderRadius="0.25rem", - display="flex", - gap="0.5rem", - ) - ): - markdown(icon) - markdown(dedent(body), unsafe_allow_html=unsafe_allow_html) - - -def caption(body: str, className: str = None, **props): - className = className or "text-muted" - markdown(body, className=className, **props) - - -def tabs(labels: list[str]) -> list[state.NestingCtx]: - parent = state.RenderTreeNode( - name="tabs", - children=[ - state.RenderTreeNode( - name="tab", - props=dict(label=dedent(label)), - ) - for label in labels - ], - ).mount() - return [state.NestingCtx(tab) for tab in parent.children] - - -def controllable_tabs( - labels: list[str], key: str -) -> tuple[list[state.NestingCtx], int]: - index = state.session_state.get(key, 0) - for i, label in enumerate(labels): - if button( - label, - key=f"tab-{i}", - type="primary", - className="replicate-nav", - style={ - "background": "black" if i == index else "white", - "color": "white" if i == index else "black", - }, - ): - state.session_state[key] = index = i - state.experimental_rerun() - ctxs = [] - for i, label in enumerate(labels): - if i == index: - ctxs += [div(className="tab-content")] - else: - ctxs += [div(className="tab-content", style={"display": "none"})] - return ctxs, index - - -def columns( - spec, - *, - gap: str = None, - responsive: bool = True, - column_props: dict = {}, - **props, -) -> tuple[state.NestingCtx, ...]: - if isinstance(spec, int): - spec = [1] * spec - total_weight = sum(spec) - props.setdefault("className", "row") - with div(**props): - return tuple( - div( - className=f"col-lg-{p} {'col-12' if responsive else f'col-{p}'}", - **column_props, - ) - for w in spec - if (p := f"{round(w / total_weight * 12)}") - ) - - -def image( - src: str | np.ndarray, - caption: str = None, - alt: str = None, - href: str = None, - show_download_button: bool = False, - **props, -): - if isinstance(src, np.ndarray): - from daras_ai.image_input import cv2_img_to_bytes - - if not src.shape: - return - # ensure image is not too large - data = resize_img_scale(cv2_img_to_bytes(src), (128, 128)) - # convert to base64 - b64 = base64.b64encode(data).decode("utf-8") - src = "data:image/png;base64," + b64 - if not src: - return - state.RenderTreeNode( - name="img", - props=dict( - src=src, - caption=dedent(caption), - alt=alt or caption, - href=href, - **props, - ), - ).mount() - if show_download_button: - download_button( - label=' Download', url=src - ) - - -def video( - src: str, - caption: str = None, - autoplay: bool = False, - show_download_button: bool = False, -): - autoplay_props = {} - if autoplay: - autoplay_props = { - "preload": "auto", - "controls": True, - "autoPlay": True, - "loop": True, - "muted": True, - "playsInline": True, - } - - if not src: - return - if isinstance(src, str): - # https://muffinman.io/blog/hack-for-ios-safari-to-display-html-video-thumbnail/ - f = furl(src) - f.fragment.args["t"] = "0.001" - src = f.url - state.RenderTreeNode( - name="video", - props=dict(src=src, caption=dedent(caption), **autoplay_props), - ).mount() - if show_download_button: - download_button( - label=' Download', url=src - ) - - -def audio(src: str, caption: str = None, show_download_button: bool = False): - if not src: - return - state.RenderTreeNode( - name="audio", - props=dict(src=src, caption=dedent(caption)), - ).mount() - if show_download_button: - download_button( - label=' Download', url=src - ) - - -def text_area( - label: str, - value: str = "", - height: int = 500, - key: str = None, - help: str = None, - placeholder: str = None, - disabled: bool = False, - label_visibility: LabelVisibility = "visible", - **props, -) -> str: - style = props.setdefault("style", {}) - # if key: - # assert not value, "only one of value or key can be provided" - # else: - if not key: - key = md5_values( - "textarea", - label, - height, - help, - placeholder, - label_visibility, - not disabled or value, - ) - value = str(state.session_state.setdefault(key, value) or "") - if label_visibility != "visible": - label = None - if disabled: - max_height = f"{height}px" - rows = nrows_for_text(value, height) - else: - max_height = "50vh" - rows = nrows_for_text(value, height) - style.setdefault("maxHeight", max_height) - props.setdefault("rows", rows) - state.RenderTreeNode( - name="textarea", - props=dict( - name=key, - label=dedent(label), - defaultValue=value, - help=help, - placeholder=placeholder, - disabled=disabled, - **props, - ), - ).mount() - return value or "" - - -def nrows_for_text( - text: str, - max_height_px: int, - min_rows: int = 1, - row_height_px: int = 30, - row_width_px: int = 70, -) -> int: - max_rows = max_height_px // row_height_px - nrows = math.ceil( - sum( - math.ceil(len(line) / row_width_px) - for line in (text or "").splitlines(keepends=True) - ) - ) - nrows = min(max(nrows, min_rows), max_rows) - return nrows - - -def multiselect( - label: str, - options: typing.Sequence[T], - format_func: typing.Callable[[T], typing.Any] = _default_format, - key: str = None, - help: str = None, - allow_none: bool = False, - *, - disabled: bool = False, -) -> list[T]: - if not options: - return [] - options = list(options) - if not key: - key = md5_values("multiselect", label, options, help) - value = state.session_state.get(key) or [] - if not isinstance(value, list): - value = [value] - value = [o for o in value if o in options] - if not allow_none and not value: - value = [options[0]] - state.session_state[key] = value - state.RenderTreeNode( - name="select", - props=dict( - name=key, - label=dedent(label), - help=help, - isDisabled=disabled, - isMulti=True, - defaultValue=value, - allow_none=allow_none, - options=[ - {"value": option, "label": str(format_func(option))} - for option in options - ], - ), - ).mount() - return value - - -def selectbox( - label: str, - options: typing.Iterable[T], - format_func: typing.Callable[[T], typing.Any] = _default_format, - key: str = None, - help: str = None, - *, - disabled: bool = False, - label_visibility: LabelVisibility = "visible", - value: T = None, - allow_none: bool = False, - **props, -) -> T | None: - if not options: - return None - if label_visibility != "visible": - label = None - options = list(options) - if allow_none: - options.insert(0, None) - if not key: - key = md5_values("select", label, options, help, label_visibility) - value = state.session_state.setdefault(key, value) - if value not in options: - value = state.session_state[key] = options[0] - state.RenderTreeNode( - name="select", - props=dict( - name=key, - label=dedent(label), - help=help, - isDisabled=disabled, - defaultValue=value, - options=[ - {"value": option, "label": str(format_func(option))} - for option in options - ], - **props, - ), - ).mount() - return value - - -def download_button( - label: str, - url: str, - key: str = None, - help: str = None, - *, - type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", - disabled: bool = False, - **props, -) -> bool: - url = furl(url).remove(fragment=True).url - return button( - component="download-button", - url=url, - label=label, - key=key, - help=help, - type=type, - disabled=disabled, - **props, - ) - - -def button( - label: str, - key: str = None, - help: str = None, - *, - type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", - disabled: bool = False, - component: typing.Literal["download-button", "gui-button"] = "gui-button", - **props, -) -> bool: - """ - Example: - st.button("Primary", key="test0", type="primary") - st.button("Secondary", key="test1") - st.button("Tertiary", key="test3", type="tertiary") - st.button("Link Button", key="test3", type="link") - """ - if not key: - key = md5_values("button", label, help, type, props) - className = f"btn-{type} " + props.pop("className", "") - state.RenderTreeNode( - name=component, - props=dict( - type="submit", - value="yes", - name=key, - label=dedent(label), - help=help, - disabled=disabled, - className=className, - **props, - ), - ).mount() - return bool(state.session_state.pop(key, False)) - - -def anchor( - label: str, - href: str, - *, - type: typing.Literal["primary", "secondary", "tertiary", "link"] = "secondary", - disabled: bool = False, - unsafe_allow_html: bool = False, - new_tab: bool = False, - **props, -): - className = f"btn btn-theme btn-{type} " + props.pop("className", "") - style = props.pop("style", {}) - if disabled: - style["pointerEvents"] = "none" - if new_tab: - props["target"] = "_blank" - with tag("a", href=href, className=className, style=style, **props): - markdown(dedent(label), unsafe_allow_html=unsafe_allow_html) - - -form_submit_button = button - - -def expander(label: str, *, expanded: bool = False, key: str = None, **props): - node = state.RenderTreeNode( - name="expander", - props=dict( - label=dedent(label), - open=expanded, - name=key or md5_values(label, expanded, props), - **props, - ), - ) - node.mount() - return state.NestingCtx(node) - - -def file_uploader( - label: str, - accept: list[str] = None, - accept_multiple_files=False, - key: str = None, - value: str | list[str] = None, - upload_key: str = None, - help: str = None, - *, - disabled: bool = False, - label_visibility: LabelVisibility = "visible", - upload_meta: dict = None, - optional: bool = False, -) -> str | list[str] | None: - if label_visibility != "visible": - label = None - key = upload_key or key - if not key: - key = md5_values( - "file_uploader", - label, - accept, - accept_multiple_files, - help, - label_visibility, - ) - if optional: - if not checkbox( - label, value=bool(state.session_state.get(key, value)), disabled=disabled - ): - state.session_state.pop(key, None) - return None - label = None - value = state.session_state.setdefault(key, value) - if not value: - if accept_multiple_files: - value = [] - else: - value = None - state.session_state[key] = value - state.RenderTreeNode( - name="input", - props=dict( - type="file", - name=key, - label=dedent(label), - help=help, - disabled=disabled, - accept=accept, - multiple=accept_multiple_files, - defaultValue=value, - uploadMeta=upload_meta, - ), - ).mount() - return value - - -def json(value: typing.Any, expanded: bool = False, depth: int = 1): - state.RenderTreeNode( - name="json", - props=dict( - value=value, - expanded=expanded, - defaultInspectDepth=3 if expanded else depth, - ), - ).mount() - - -def data_table(file_url_or_cells: str | list): - if isinstance(file_url_or_cells, str): - file_url = file_url_or_cells - return _node("data-table", fileUrl=file_url) - else: - cells = file_url_or_cells - return _node("data-table-raw", cells=cells) - - -def table(df: "pd.DataFrame"): - with tag("table", className="table table-striped table-sm"): - with tag("thead"): - with tag("tr"): - for col in df.columns: - with tag("th", scope="col"): - html(dedent(col)) - with tag("tbody"): - for row in df.itertuples(index=False): - with tag("tr"): - for value in row: - with tag("td"): - html(dedent(str(value))) - - -def raw_table(header: list[str], className: str = "", **props) -> state.NestingCtx: - className = "table " + className - with tag("table", className=className, **props): - if header: - with tag("thead"), tag("tr"): - for col in header: - with tag("th", scope="col"): - html(dedent(col)) - - return tag("tbody") - - -def table_row(values: list[str], **props): - row = tag("tr", **props) - with row: - for v in values: - with tag("td"): - html(html_lib.escape(v)) - return row - - -def horizontal_radio( - label: str, - options: typing.Sequence[T], - format_func: typing.Callable[[T], typing.Any] = _default_format, - *, - key: str = None, - help: str = None, - value: T = None, - disabled: bool = False, - checked_by_default: bool = True, - label_visibility: LabelVisibility = "visible", - **button_props, -) -> T | None: - if not options: - return None - options = list(options) - if not key: - key = md5_values("horizontal_radio", label, options, help, label_visibility) - value = state.session_state.setdefault(key, value) - if value not in options and checked_by_default: - value = state.session_state[key] = options[0] - if label_visibility != "visible": - label = None - markdown(label) - for option in options: - if button( - format_func(option), - key=f"tab-{key}-{option}", - type="primary", - className="replicate-nav " + ("active" if value == option else ""), - disabled=disabled, - **button_props, - ): - state.session_state[key] = value = option - state.experimental_rerun() - return value - - -def radio( - label: str, - options: typing.Sequence[T], - format_func: typing.Callable[[T], typing.Any] = _default_format, - key: str = None, - value: T = None, - help: str = None, - *, - disabled: bool = False, - checked_by_default: bool = True, - label_visibility: LabelVisibility = "visible", -) -> T | None: - if not options: - return None - options = list(options) - if not key: - key = md5_values("radio", label, options, help, label_visibility) - value = state.session_state.setdefault(key, value) - if value not in options and checked_by_default: - value = state.session_state[key] = options[0] - if label_visibility != "visible": - label = None - markdown(label) - for option in options: - state.RenderTreeNode( - name="input", - props=dict( - type="radio", - name=key, - label=dedent(str(format_func(option))), - value=option, - defaultChecked=bool(value == option), - help=help, - disabled=disabled, - ), - ).mount() - return value - - -def text_input( - label: str, - value: str = "", - max_chars: str = None, - key: str = None, - help: str = None, - *, - placeholder: str = None, - disabled: bool = False, - label_visibility: LabelVisibility = "visible", - **props, -) -> str: - value = _input_widget( - input_type="text", - label=label, - value=value, - key=key, - help=help, - disabled=disabled, - label_visibility=label_visibility, - maxLength=max_chars, - placeholder=placeholder, - **props, - ) - return value or "" - - -def date_input( - label: str, - value: str | None = None, - key: str = None, - help: str = None, - *, - disabled: bool = False, - label_visibility: LabelVisibility = "visible", - **props, -) -> datetime | None: - value = _input_widget( - input_type="date", - label=label, - value=value, - key=key, - help=help, - disabled=disabled, - label_visibility=label_visibility, - style=dict( - border="1px solid hsl(0, 0%, 80%)", - padding="0.375rem 0.75rem", - borderRadius="0.25rem", - margin="0 0.5rem 0 0.5rem", - ), - **props, - ) - try: - return datetime.strptime(value, "%Y-%m-%d") if value else None - except ValueError: - return None - - -def password_input( - label: str, - value: str = "", - max_chars: str = None, - key: str = None, - help: str = None, - *, - placeholder: str = None, - disabled: bool = False, - label_visibility: LabelVisibility = "visible", - **props, -) -> str: - value = _input_widget( - input_type="password", - label=label, - value=value, - key=key, - help=help, - disabled=disabled, - label_visibility=label_visibility, - maxLength=max_chars, - placeholder=placeholder, - **props, - ) - return value or "" - - -def slider( - label: str, - min_value: float = None, - max_value: float = None, - value: float = None, - step: float = None, - key: str = None, - help: str = None, - *, - disabled: bool = False, -) -> float: - value = _input_widget( - input_type="range", - label=label, - value=value, - key=key, - help=help, - disabled=disabled, - min=min_value, - max=max_value, - step=_step_value(min_value, max_value, step), - ) - return value or 0 - - -def number_input( - label: str, - min_value: float = None, - max_value: float = None, - value: float = None, - step: float = None, - key: str = None, - help: str = None, - *, - disabled: bool = False, -) -> float: - value = _input_widget( - input_type="number", - inputMode="decimal", - label=label, - value=value, - key=key, - help=help, - disabled=disabled, - min=min_value, - max=max_value, - step=_step_value(min_value, max_value, step), - ) - return value or 0 - - -def _step_value( - min_value: float | None, max_value: float | None, step: float | None -) -> float: - if step: - return step - elif isinstance(min_value, float) or isinstance(max_value, float): - return 0.1 - else: - return 1 - - -def checkbox( - label: str, - value: bool = False, - key: str = None, - help: str = None, - *, - disabled: bool = False, - label_visibility: LabelVisibility = "visible", - **props, -) -> bool: - value = _input_widget( - input_type="checkbox", - label=label, - value=value, - key=key, - help=help, - disabled=disabled, - label_visibility=label_visibility, - default_value_attr="defaultChecked", - **props, - ) - return bool(value) - - -def _input_widget( - *, - input_type: str, - label: str, - value: typing.Any = None, - key: str = None, - help: str = None, - disabled: bool = False, - label_visibility: LabelVisibility = "visible", - default_value_attr: str = "defaultValue", - **kwargs, -) -> typing.Any: - # if key: - # assert not value, "only one of value or key can be provided" - # else: - if not key: - key = md5_values("input", input_type, label, help, label_visibility) - value = state.session_state.setdefault(key, value) - if label_visibility != "visible": - label = None - state.RenderTreeNode( - name="input", - props={ - "type": input_type, - "name": key, - "label": dedent(label), - default_value_attr: value, - "help": help, - "disabled": disabled, - **kwargs, - }, - ).mount() - return value - - -def breadcrumbs(divider: str = "/", **props) -> state.NestingCtx: - style = props.pop("style", {}) | {"--bs-breadcrumb-divider": f"'{divider}'"} - with tag("nav", style=style, **props): - return tag("ol", className="breadcrumb mb-0") - - -def breadcrumb_item(inner_html: str, link_to: str | None = None, **props): - className = "breadcrumb-item " + props.pop("className", "") - with tag("li", className=className, **props): - if link_to: - with tag("a", href=link_to): - html(inner_html) - else: - html(inner_html) - - -def plotly_chart(figure_or_data, **kwargs): - data = ( - figure_or_data.to_plotly_json() - if hasattr(figure_or_data, "to_plotly_json") - else figure_or_data - ) - state.RenderTreeNode( - name="plotly-chart", - props=dict( - chart=data, - args=kwargs, - ), - ).mount() - - -def dedent(text: str | None) -> str | None: - if not text: - return text - return textwrap.dedent(text) - - -def js(src: str, **kwargs): - state.RenderTreeNode( - name="script", - props=dict( - src=src, - args=kwargs, - ), - ).mount() diff --git a/gooey_ui/components/modal.py b/gooey_ui/components/modal.py deleted file mode 100644 index 72e951fc8..000000000 --- a/gooey_ui/components/modal.py +++ /dev/null @@ -1,97 +0,0 @@ -from contextlib import contextmanager - -import gooey_ui as st -from gooey_ui import experimental_rerun as rerun - - -class Modal: - def __init__(self, title, key, padding=20, max_width=744): - """ - :param title: title of the Modal shown in the h1 - :param key: unique key identifying this modal instance - :param padding: padding of the content within the modal - :param max_width: maximum width this modal should use - """ - self.title = title - self.padding = padding - self.max_width = str(max_width) + "px" - self.key = key - - self._container = None - - def is_open(self): - return st.session_state.get(f"{self.key}-opened", False) - - def open(self): - st.session_state[f"{self.key}-opened"] = True - rerun() - - def close(self, rerun_condition=True): - st.session_state[f"{self.key}-opened"] = False - if rerun_condition: - rerun() - - def empty(self): - if self._container: - self._container.empty() - - @contextmanager - def container(self, **props): - st.html( - f""" - - """ - ) - - with st.div(className="blur-background"): - with st.div(className="modal-parent"): - container_class = "modal-container " + props.pop("className", "") - self._container = st.div(className=container_class, **props) - - with self._container: - with st.div(className="d-flex justify-content-between align-items-center"): - if self.title: - st.markdown(f"### {self.title}") - else: - st.div() - - close_ = st.button( - "✖", - type="tertiary", - key=f"{self.key}-close", - style={"padding": "0.375rem 0.75rem"}, - ) - if close_: - self.close() - yield self._container From c748d70a160f6dc1e631c33fb1be4bfb868cd008 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 28 Aug 2024 14:56:50 +0530 Subject: [PATCH 046/110] feat: add billing support for orgs (db + ux) --- Procfile | 2 +- ...ction_org_alter_appusertransaction_user.py | 25 ++ app_users/models.py | 50 ++- app_users/tasks.py | 10 +- bots/models.py | 7 + daras_ai_v2/base.py | 3 +- daras_ai_v2/billing.py | 7 +- daras_ai_v2/send_email.py | 15 +- orgs/admin.py | 11 +- ..._org_is_paying_org_is_personal_and_more.py | 45 +++ .../0005_org_unique_personal_org_per_user.py | 17 + orgs/models.py | 146 ++++++- orgs/views.py | 382 +++++++++++++++++- payments/models.py | 19 +- payments/tasks.py | 51 +-- payments/webhooks.py | 68 ++-- scripts/migrate_orgs_from_appusers.py | 26 ++ 17 files changed, 780 insertions(+), 104 deletions(-) create mode 100644 app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py create mode 100644 orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py create mode 100644 orgs/migrations/0005_org_unique_personal_org_per_user.py create mode 100644 scripts/migrate_orgs_from_appusers.py diff --git a/Procfile b/Procfile index 984315504..1766991c6 100644 --- a/Procfile +++ b/Procfile @@ -19,4 +19,4 @@ dashboard: poetry run streamlit run Home.py --server.port 8501 --server.headless celery: poetry run celery -A celeryapp worker -P threads -c 16 -l DEBUG -ui: cd ../gooey-gui/ && env PORT=3000 npm run dev +ui: cd ../gooey-gui/ && env PORT=3000 REDIS_URL=redis://localhost:6379 pnpm run dev diff --git a/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py b/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py new file mode 100644 index 000000000..b3e80c708 --- /dev/null +++ b/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.7 on 2024-08-13 14:34 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('orgs', '0005_org_unique_personal_org_per_user'), + ('app_users', '0019_alter_appusertransaction_reason'), + ] + + operations = [ + migrations.AddField( + model_name='appusertransaction', + name='org', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='transactions', to='orgs.org'), + ), + migrations.AlterField( + model_name='appusertransaction', + name='user', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='transactions', to='app_users.appuser'), + ), + ] diff --git a/app_users/models.py b/app_users/models.py index 09832cebc..739ab3bd3 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -90,23 +90,10 @@ class AppUser(models.Model): display_name = models.TextField("name", blank=True) email = models.EmailField(null=True, blank=True) phone_number = PhoneNumberField(null=True, blank=True) - balance = models.IntegerField("bal") is_anonymous = models.BooleanField() is_disabled = models.BooleanField(default=False) photo_url = CustomURLField(default="", blank=True) - stripe_customer_id = models.CharField(max_length=255, default="", blank=True) - is_paying = models.BooleanField("paid", default=False) - - low_balance_email_sent_at = models.DateTimeField(null=True, blank=True) - subscription = models.OneToOneField( - "payments.Subscription", - on_delete=models.SET_NULL, - related_name="user", - null=True, - blank=True, - ) - created_at = models.DateTimeField( "created", editable=False, blank=True, default=timezone.now ) @@ -129,6 +116,18 @@ class AppUser(models.Model): github_username = models.CharField(max_length=255, blank=True, default="") website_url = CustomURLField(blank=True, default="") + balance = models.IntegerField("bal") + is_paying = models.BooleanField("paid", default=False) + stripe_customer_id = models.CharField(max_length=255, default="", blank=True) + subscription = models.OneToOneField( + "payments.Subscription", + on_delete=models.SET_NULL, + related_name="user", + null=True, + blank=True, + ) + low_balance_email_sent_at = models.DateTimeField(null=True, blank=True) + disable_rate_limits = models.BooleanField(default=False) objects = AppUserQuerySet.as_manager() @@ -159,6 +158,9 @@ def first_name_possesive(self) -> str: else: return name + "'s" + def get_personal_org(self) -> "Org | None": + return self.orgs.filter(is_personal=True).first() + @db_middleware @transaction.atomic def add_balance( @@ -246,6 +248,17 @@ def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser": return self + def get_or_create_personal_org(self) -> tuple["Org", bool]: + from orgs.models import Org + + org_membership = self.org_memberships.filter( + org__is_personal=True, org__created_by=self + ).first() + if org_membership: + return org_membership, False + else: + return Org.objects.migrate_from_appuser(self), True + def get_or_create_stripe_customer(self) -> stripe.Customer: customer = self.search_stripe_customer() if not customer: @@ -303,7 +316,16 @@ class TransactionReason(models.IntegerChoices): class AppUserTransaction(models.Model): user = models.ForeignKey( - "AppUser", on_delete=models.CASCADE, related_name="transactions" + "AppUser", + on_delete=models.SET_NULL, + related_name="transactions", + null=True, + ) + org = models.ForeignKey( + "orgs.Org", + on_delete=models.SET_NULL, + related_name="transactions", + null=True, ) invoice_id = models.CharField( max_length=255, diff --git a/app_users/tasks.py b/app_users/tasks.py index 0327ac423..b1d893196 100644 --- a/app_users/tasks.py +++ b/app_users/tasks.py @@ -5,14 +5,14 @@ from celeryapp.celeryconfig import app from payments.models import Subscription from payments.plans import PricingPlan -from payments.webhooks import set_user_subscription +from payments.webhooks import set_org_subscription @app.task def save_stripe_default_payment_method( *, payment_intent_id: str, - uid: str, + org_id: str, amount: int, charged_amount: int, reason: TransactionReason, @@ -41,11 +41,11 @@ def save_stripe_default_payment_method( if ( reason == TransactionReason.ADDON and not Subscription.objects.filter( - user__uid=uid, payment_provider__isnull=False + org__org_id=org_id, payment_provider__isnull=False ).exists() ): - set_user_subscription( - uid=uid, + set_org_subscription( + org_id=org_id, plan=PricingPlan.STARTER, provider=PaymentProvider.STRIPE, external_id=None, diff --git a/bots/models.py b/bots/models.py index e997e8f8a..a6163ee1c 100644 --- a/bots/models.py +++ b/bots/models.py @@ -212,6 +212,13 @@ class SavedRun(models.Model): ) run_id = models.CharField(max_length=128, default=None, null=True, blank=True) uid = models.CharField(max_length=128, default=None, null=True, blank=True) + billed_org = models.ForeignKey( + "orgs.Org", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="billed_runs", + ) state = models.JSONField(default=dict, blank=True, encoder=PostgresJSONEncoder) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 2233a0803..f37a284bb 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -2106,7 +2106,8 @@ def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]: ), "request.user must be set to deduct credits" amount = self.get_price_roundoff(state) - txn = self.request.user.add_balance(-amount, f"gooey_in_{uuid.uuid1()}") + org, _ = self.request.user.get_or_create_personal_org() + txn = org.add_balance(-amount, f"gooey_in_{uuid.uuid1()}") return txn, amount def get_price_roundoff(self, state: dict) -> int: diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 639412464..adc500015 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -9,9 +9,10 @@ from daras_ai_v2.gui_confirm import confirm_modal from daras_ai_v2.settings import templates from daras_ai_v2.user_date_widgets import render_local_date_attrs +from orgs.models import Org from payments.models import PaymentMethodSummary from payments.plans import PricingPlan -from payments.webhooks import StripeWebhookHandler, set_user_subscription +from payments.webhooks import StripeWebhookHandler, set_org_subscription from scripts.migrate_existing_subscriptions import available_subscriptions rounded_border = "w-100 border shadow-sm rounded py-4 px-3" @@ -635,8 +636,8 @@ def render_payment_information(user: AppUser): ): modal.open() if confirmed: - set_user_subscription( - uid=user.uid, + set_org_subscription( + org_id=user.get_personal_org().org_id, plan=PricingPlan.STARTER, provider=None, external_id=None, diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py index a9ff1934d..3c679c6fb 100644 --- a/daras_ai_v2/send_email.py +++ b/daras_ai_v2/send_email.py @@ -3,16 +3,19 @@ import requests -from app_users.models import AppUser from daras_ai_v2 import settings from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.fastapi_tricks import get_app_route_url from daras_ai_v2.settings import templates +if typing.TYPE_CHECKING: + from app_users.models import AppUser + + def send_reported_run_email( *, - user: AppUser, + user: "AppUser", run_uid: str, url: str, recipe_name: str, @@ -41,7 +44,7 @@ def send_reported_run_email( def send_low_balance_email( *, - user: AppUser, + user: "AppUser", total_credits_consumed: int, ): from routers.account import account_route @@ -70,8 +73,8 @@ def send_email_via_postmark( *, from_address: str, to_address: str, - cc: str = None, - bcc: str = None, + cc: str | None = None, + bcc: str | None = None, subject: str = "", html_body: str = "", text_body: str = "", @@ -79,7 +82,7 @@ def send_email_via_postmark( "outbound", "gooey-ai-workflows", "announcements" ] = "outbound", ): - if is_running_pytest: + if is_running_pytest or not settings.POSTMARK_API_TOKEN: pytest_outbox.append( dict( from_address=from_address, diff --git a/orgs/admin.py b/orgs/admin.py index 969866f41..370ca4c4e 100644 --- a/orgs/admin.py +++ b/orgs/admin.py @@ -43,9 +43,16 @@ class OrgAdmin(SafeDeleteAdmin): "updated_at", ] + list(SafeDeleteAdmin.list_display) list_filter = [SafeDeleteAdminFilter] + list(SafeDeleteAdmin.list_filter) - fields = ["name", "domain_name", "created_by", "created_at", "updated_at"] + fields = [ + "name", + "domain_name", + "created_by", + "is_personal", + "created_at", + "updated_at", + ] search_fields = ["name", "domain_name"] - readonly_fields = ["created_at", "updated_at"] + readonly_fields = ["is_personal", "created_at", "updated_at"] inlines = [OrgMembershipInline, OrgInvitationInline] ordering = ["-created_at"] diff --git a/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py b/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py new file mode 100644 index 000000000..9d9fdfc5d --- /dev/null +++ b/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py @@ -0,0 +1,45 @@ +# Generated by Django 4.2.7 on 2024-08-12 14:23 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('payments', '0005_alter_subscription_plan'), + ('orgs', '0003_remove_org_unique_domain_name_when_not_deleted_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='org', + name='balance', + field=models.IntegerField(default=0, verbose_name='bal'), + ), + migrations.AddField( + model_name='org', + name='is_paying', + field=models.BooleanField(default=False, verbose_name='paid'), + ), + migrations.AddField( + model_name='org', + name='is_personal', + field=models.BooleanField(default=False), + ), + migrations.AddField( + model_name='org', + name='low_balance_email_sent_at', + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name='org', + name='stripe_customer_id', + field=models.CharField(blank=True, default='', max_length=255), + ), + migrations.AddField( + model_name='org', + name='subscription', + field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='org', to='payments.subscription'), + ), + ] diff --git a/orgs/migrations/0005_org_unique_personal_org_per_user.py b/orgs/migrations/0005_org_unique_personal_org_per_user.py new file mode 100644 index 000000000..aaaa1cc4d --- /dev/null +++ b/orgs/migrations/0005_org_unique_personal_org_per_user.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.7 on 2024-08-13 14:34 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('orgs', '0004_org_balance_org_is_paying_org_is_personal_and_more'), + ] + + operations = [ + migrations.AddConstraint( + model_name='org', + constraint=models.UniqueConstraint(models.F('created_by'), condition=models.Q(('deleted__isnull', True), ('is_personal', True)), name='unique_personal_org_per_user'), + ), + ] diff --git a/orgs/models.py b/orgs/models.py index 5a19dad78..0c39312c0 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import re from datetime import timedelta +from django.db.models.aggregates import Sum +import stripe from django.db import models, transaction from django.core.exceptions import ValidationError from django.db.backends.base.schema import logger @@ -10,10 +14,10 @@ from safedelete.managers import SafeDeleteManager from safedelete.models import SafeDeleteModel, SOFT_DELETE_CASCADE -from app_users.models import AppUser from daras_ai_v2 import settings from daras_ai_v2.fastapi_tricks import get_app_route_url from daras_ai_v2.crypto import get_random_doc_id +from gooeysite.bg_db_conn import db_middleware from orgs.tasks import send_auto_accepted_email, send_invitation_email @@ -37,7 +41,9 @@ class OrgRole(models.IntegerChoices): class OrgManager(SafeDeleteManager): - def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwargs): + def create_org( + self, *, created_by: "AppUser", org_id: str | None = None, **kwargs + ) -> Org: org = self.model( org_id=org_id or get_random_doc_id(), created_by=created_by, **kwargs ) @@ -49,6 +55,28 @@ def create_org(self, *, created_by: "AppUser", org_id: str | None = None, **kwar ) return org + def get_or_create_from_org_id(self, org_id: str) -> tuple[Org, bool]: + from app_users.models import AppUser + + try: + return self.get(org_id=org_id), False + except self.model.DoesNotExist: + user = AppUser.objects.get_or_create_from_uid(org_id)[0] + return self.migrate_from_appuser(user), True + + def migrate_from_appuser(self, user: "AppUser") -> Org: + return self.create_org( + name=f"{user.first_name()}'s Personal Workspace", + org_id=user.uid or get_random_doc_id(), + created_by=user, + is_personal=True, + balance=user.balance, + stripe_customer_id=user.stripe_customer_id, + subscription=user.subscription, + low_balance_email_sent_at=user.low_balance_email_sent_at, + is_paying=user.is_paying, + ) + class Org(SafeDeleteModel): _safedelete_policy = SOFT_DELETE_CASCADE @@ -71,6 +99,21 @@ class Org(SafeDeleteModel): ], ) + # billing + balance = models.IntegerField("bal", default=0) + is_paying = models.BooleanField("paid", default=False) + stripe_customer_id = models.CharField(max_length=255, default="", blank=True) + subscription = models.OneToOneField( + "payments.Subscription", + on_delete=models.SET_NULL, + related_name="org", + null=True, + blank=True, + ) + low_balance_email_sent_at = models.DateTimeField(null=True, blank=True) + + is_personal = models.BooleanField(default=False) + created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) @@ -83,7 +126,12 @@ class Meta: condition=Q(deleted__isnull=True), name="unique_domain_name_when_not_deleted", violation_error_message=f"This domain name is already in use by another team. Contact {settings.SUPPORT_EMAIL} if you think this is a mistake.", - ) + ), + models.UniqueConstraint( + "created_by", + condition=Q(deleted__isnull=True, is_personal=True), + name="unique_personal_org_per_user", + ), ] def __str__(self): @@ -147,6 +195,90 @@ def invite_user( return invitation + def get_owners(self) -> list[OrgMembership]: + return self.memberships.filter(role=OrgRole.OWNER) + + @db_middleware + @transaction.atomic + def add_balance( + self, amount: int, invoice_id: str, **kwargs + ) -> "AppUserTransaction": + """ + Used to add/deduct credits when they are bought or consumed. + + When credits are bought with stripe -- invoice_id is the stripe + invoice ID. + When credits are deducted due to a run -- invoice_id is of the + form "gooey_in_{uuid}" + """ + from app_users.models import AppUserTransaction + + # if an invoice entry exists + try: + # avoid updating twice for same invoice + return AppUserTransaction.objects.get(invoice_id=invoice_id) + except AppUserTransaction.DoesNotExist: + pass + + # select_for_update() is very important here + # transaction.atomic alone is not enough! + # It won't lock this row for reads, and multiple threads can update the same row leading incorrect balance + # + # Also we're not using .update() here because it won't give back the updated end balance + org: Org = Org.objects.select_for_update().get(pk=self.pk) + org.balance += amount + org.save(update_fields=["balance"]) + kwargs.setdefault("plan", org.subscription and org.subscription.plan) + return AppUserTransaction.objects.create( + org=org, + invoice_id=invoice_id, + amount=amount, + end_balance=org.balance, + **kwargs, + ) + + def get_or_create_stripe_customer(self) -> stripe.Customer: + customer = self.search_stripe_customer() + if not customer: + customer = stripe.Customer.create( + name=self.created_by.display_name, + email=self.created_by.email, + phone=self.created_by.phone, + metadata={"uid": self.org_id, "org_id": self.org_id, "id": self.pk}, + ) + self.stripe_customer_id = customer.id + self.save() + return customer + + def search_stripe_customer(self) -> stripe.Customer | None: + if not self.org_id: + return None + if self.stripe_customer_id: + try: + return stripe.Customer.retrieve(self.stripe_customer_id) + except stripe.error.InvalidRequestError as e: + if e.http_status != 404: + raise + try: + customer = stripe.Customer.search( + query=f'metadata["uid"]:"{self.org_id}"' + ).data[0] + except IndexError: + return None + else: + self.stripe_customer_id = customer.id + self.save() + return customer + + def get_dollars_spent_this_month(self) -> float: + today = timezone.now() + cents_spent = self.transactions.filter( + created_at__month=today.month, + created_at__year=today.year, + amount__gt=0, + ).aggregate(total=Sum("charged_amount"))["total"] + return (cents_spent or 0) / 100 + class OrgMembership(SafeDeleteModel): org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="memberships") @@ -260,6 +392,8 @@ def auto_accept(self): Raises: ValidationError """ + from app_users.models import AppUser + assert self.status == self.Status.PENDING invitee = AppUser.objects.get(email=self.invitee_email) @@ -287,7 +421,7 @@ def send_email(self): send_invitation_email.delay(invitation_pk=self.pk) - def accept(self, user: AppUser, *, auto_accepted: bool = False): + def accept(self, user: "AppUser", *, auto_accepted: bool = False): """ Raises: ValidationError """ @@ -323,13 +457,13 @@ def accept(self, user: AppUser, *, auto_accepted: bool = False): ) self.save() - def reject(self, user: AppUser): + def reject(self, user: "AppUser"): self.status = self.Status.REJECTED self.status_changed_at = timezone.now() self.status_changed_by = user self.save() - def cancel(self, user: AppUser): + def cancel(self, user: "AppUser"): self.status = self.Status.CANCELED self.status_changed_at = timezone.now() self.status_changed_by = user diff --git a/orgs/views.py b/orgs/views.py index ed864cb94..2d6f3c27c 100644 --- a/orgs/views.py +++ b/orgs/views.py @@ -2,18 +2,29 @@ import html as html_lib +import stripe import gooey_gui as gui from django.core.exceptions import ValidationError -from app_users.models import AppUser +from app_users.models import AppUser, PaymentProvider +from daras_ai_v2.billing import format_card_brand, payment_provider_radio +from daras_ai_v2.grid_layout_widget import grid_layout from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole -from daras_ai_v2 import icons -from daras_ai_v2.fastapi_tricks import get_route_path +from daras_ai_v2 import icons, settings +from daras_ai_v2.fastapi_tricks import get_route_path, get_app_route_url +from daras_ai_v2.settings import templates +from daras_ai_v2.user_date_widgets import render_local_date_attrs +from payments.models import PaymentMethodSummary +from payments.plans import PricingPlan +from scripts.migrate_existing_subscriptions import available_subscriptions DEFAULT_ORG_LOGO = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/74a37c52-8260-11ee-a297-02420a0001ee/gooey.ai%20-%20A%20pop%20art%20illustration%20of%20robots%20taki...y%20Liechtenstein%20mint%20colour%20is%20main%20city%20Seattle.png" +rounded_border = "w-100 border shadow-sm rounded py-4 px-3" + + def invitation_page(user: AppUser, invitation: OrgInvitation): from routers.account import orgs_route @@ -107,6 +118,10 @@ def render_org_by_membership(membership: OrgMembership): f"Org Domain: `@{org.domain_name}`", className="text-muted" ) + with gui.div(className="mt-4"): + gui.write("# Billing") + billing_section(org=org, current_member=membership) + with gui.div(className="mt-4"): with gui.div(className="d-flex justify-content-between align-items-center"): gui.write("## Members") @@ -142,6 +157,361 @@ def render_org_by_membership(membership: OrgMembership): org_leave_modal.open() +def billing_section(*, org: Org, current_member: OrgMembership): + render_payments_setup() + + if org.subscription and org.subscription.external_id: + render_current_plan(org) + + with gui.div(className="my-5"): + render_credit_balance(org) + + with gui.div(className="my-5"): + selected_payment_provider = render_all_plans(org) + + with gui.div(className="my-5"): + render_addon_section(org, selected_payment_provider) + + if org.subscription and org.subscription.external_id: + # if org.subscription.payment_provider == PaymentProvider.STRIPE: + # with gui.div(className="my-5"): + # render_auto_recharge_section(user) + with gui.div(className="my-5"): + render_payment_information(org) + + with gui.div(className="my-5"): + render_billing_history(org) + + +def render_payments_setup(): + from routers.account import payment_processing_route + + gui.html( + templates.get_template("payment_setup.html").render( + settings=settings, + payment_processing_url=get_app_route_url(payment_processing_route), + ) + ) + + +def render_current_plan(org: Org): + plan = PricingPlan.from_sub(org.subscription) + provider = ( + PaymentProvider(org.subscription.payment_provider) + if org.subscription.payment_provider + else None + ) + + with gui.div(className=f"{rounded_border} border-dark"): + # ROW 1: Plan title and next invoice date + left, right = left_and_right() + with left: + gui.write(f"#### Gooey.AI {plan.title}") + + if provider: + gui.write( + f"[{icons.edit} Manage Subscription](#payment-information)", + unsafe_allow_html=True, + ) + with right, gui.div(className="d-flex align-items-center gap-1"): + if provider and ( + next_invoice_ts := gui.run_in_thread( + org.subscription.get_next_invoice_timestamp, cache=True + ) + ): + gui.html("Next invoice on ") + gui.pill( + "...", + text_bg="dark", + **render_local_date_attrs( + next_invoice_ts, + date_options={"day": "numeric", "month": "long"}, + ), + ) + + if plan is PricingPlan.ENTERPRISE: + # charge details are not relevant for Enterprise customers + return + + # ROW 2: Plan pricing details + left, right = left_and_right(className="mt-5") + with left: + gui.write(f"# {plan.pricing_title()}", className="no-margin") + if plan.monthly_charge: + provider_text = f" **via {provider.label}**" if provider else "" + gui.caption("per month" + provider_text) + + with right, gui.div(className="text-end"): + gui.write(f"# {plan.credits:,} credits", className="no-margin") + if plan.monthly_charge: + gui.write( + f"**${plan.monthly_charge:,}** monthly renewal for {plan.credits:,} credits" + ) + + +def render_credit_balance(org: Org): + gui.write(f"## Credit Balance: {org.balance:,}") + gui.caption( + "Every time you submit a workflow or make an API call, we deduct credits from your account." + ) + + +def render_all_plans(org: Org) -> PaymentProvider | None: + current_plan = ( + PricingPlan.from_sub(org.subscription) + if org.subscription + else PricingPlan.STARTER + ) + all_plans = [plan for plan in PricingPlan if not plan.deprecated] + + gui.write("## All Plans") + plans_div = gui.div(className="mb-1") + + if org.subscription and org.subscription.payment_provider: + selected_payment_provider = None + else: + with gui.div(): + selected_payment_provider = PaymentProvider[ + payment_provider_radio() or PaymentProvider.STRIPE.name + ] + + def _render_plan(plan: PricingPlan): + if plan == current_plan: + extra_class = "border-dark" + else: + extra_class = "bg-light" + with gui.div(className="d-flex flex-column h-100"): + with gui.div( + className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}" + ): + _render_plan_details(plan) + # _render_plan_action_button( + # user, plan, current_plan, selected_payment_provider + # ) + + with plans_div: + grid_layout(4, all_plans, _render_plan, separator=False) + + with gui.div(className="my-2 d-flex justify-content-center"): + gui.caption( + f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**" + ) + + return selected_payment_provider + + +def _render_plan_details(plan: PricingPlan): + with gui.div(className="flex-grow-1"): + with gui.div(className="mb-4"): + with gui.tag("h4", className="mb-0"): + gui.html(plan.title) + gui.caption( + plan.description, + style={ + "minHeight": "calc(var(--bs-body-line-height) * 2em)", + "display": "block", + }, + ) + with gui.div(className="my-3 w-100"): + with gui.tag("h4", className="my-0 d-inline me-2"): + gui.html(plan.pricing_title()) + with gui.tag("span", className="text-muted my-0"): + gui.html(plan.pricing_caption()) + gui.write(plan.long_description, unsafe_allow_html=True) + + +def render_payment_information(org: Org): + assert org.subscription + + gui.write("## Payment Information", id="payment-information", className="d-block") + col1, col2, col3 = gui.columns(3, responsive=False) + with col1: + gui.write("**Pay via**") + with col2: + provider = PaymentProvider(org.subscription.payment_provider) + gui.write(provider.label) + with col3: + if gui.button(f"{icons.edit} Edit", type="link", key="manage-payment-provider"): + raise gui.RedirectException(org.subscription.get_external_management_url()) + + pm_summary = gui.run_in_thread( + org.subscription.get_payment_method_summary, cache=True + ) + if not pm_summary: + return + pm_summary = PaymentMethodSummary(*pm_summary) + if pm_summary.card_brand and pm_summary.card_last4: + col1, col2, col3 = gui.columns(3, responsive=False) + with col1: + gui.write("**Payment Method**") + with col2: + gui.write( + f"{format_card_brand(pm_summary.card_brand)} ending in {pm_summary.card_last4}", + unsafe_allow_html=True, + ) + with col3: + if gui.button(f"{icons.edit} Edit", type="link", key="edit-payment-method"): + change_payment_method(org) + + if pm_summary.billing_email: + col1, col2, _ = gui.columns(3, responsive=False) + with col1: + gui.write("**Billing Email**") + with col2: + gui.html(pm_summary.billing_email) + + +def change_payment_method(org: Org): + from routers.account import payment_processing_route + from routers.account import account_route + + match org.subscription.payment_provider: + case PaymentProvider.STRIPE: + session = stripe.checkout.Session.create( + mode="setup", + currency="usd", + customer=org.get_or_create_stripe_customer().id, + setup_intent_data={ + "metadata": {"subscription_id": org.subscription.external_id}, + }, + success_url=get_app_route_url(payment_processing_route), + cancel_url=get_app_route_url(account_route), + ) + raise gui.RedirectException(session.url, status_code=303) + case _: + gui.error("Not implemented for this payment provider") + + +def render_billing_history(org: Org, limit: int = 50): + import pandas as pd + + txns = org.transactions.filter(amount__gt=0).order_by("-created_at") + if not txns: + return + + gui.write("## Billing History", className="d-block") + gui.table( + pd.DataFrame.from_records( + [ + { + "Date": txn.created_at.strftime("%m/%d/%Y"), + "Description": txn.reason_note(), + "Amount": f"-${txn.charged_amount / 100:,.2f}", + "Credits": f"+{txn.amount:,}", + "Balance": f"{txn.end_balance:,}", + } + for txn in txns[:limit] + ] + ), + ) + if txns.count() > limit: + gui.caption(f"Showing only the most recent {limit} transactions.") + + +def render_addon_section(org: Org, selected_payment_provider: PaymentProvider): + if org.subscription: + gui.write("# Purchase More Credits") + else: + gui.write("# Purchase Credits") + gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") + + if org.subscription and org.subscription.payment_provider: + provider = PaymentProvider(org.subscription.payment_provider) + else: + provider = selected_payment_provider + match provider: + case PaymentProvider.STRIPE | None: + render_stripe_addon_buttons(org) + case PaymentProvider.PAYPAL: + render_paypal_addon_buttons() + + +def render_paypal_addon_buttons(): + selected_amt = gui.horizontal_radio( + "", + settings.ADDON_AMOUNT_CHOICES, + format_func=lambda amt: f"${amt:,}", + checked_by_default=False, + ) + if selected_amt: + gui.js( + f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})" + ) + gui.div( + id="paypal-addon-buttons", + className="mt-2", + style={"width": "fit-content"}, + ) + gui.div(id="paypal-result-message") + + +def render_stripe_addon_buttons(org: Org): + for dollar_amt in settings.ADDON_AMOUNT_CHOICES: + render_stripe_addon_button(dollar_amt, org) + + +def render_stripe_addon_button(dollar_amt: int, org: Org): + confirm_purchase_modal = gui.Modal( + "Confirm Purchase", key=f"confirm-purchase-{dollar_amt}" + ) + if gui.button(f"${dollar_amt:,}", type="primary"): + if org.subscription and org.subscription.external_id: + confirm_purchase_modal.open() + else: + stripe_addon_checkout_redirect(org, dollar_amt) + + if not confirm_purchase_modal.is_open(): + return + with confirm_purchase_modal.container(): + gui.write( + f""" + Please confirm your purchase: + **{dollar_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollar_amt}**. + """, + className="py-4 d-block text-center", + ) + with gui.div(className="d-flex w-100 justify-content-end"): + if gui.session_state.get("--confirm-purchase"): + success = gui.run_in_thread( + org.subscription.stripe_attempt_addon_purchase, + args=[dollar_amt], + placeholder="Processing payment...", + ) + if success is None: + return + gui.session_state.pop("--confirm-purchase") + if success: + confirm_purchase_modal.close() + else: + gui.error("Payment failed... Please try again.") + return + + if gui.button("Cancel", className="border border-danger text-danger me-2"): + confirm_purchase_modal.close() + gui.button("Buy", type="primary", key="--confirm-purchase") + + +def stripe_addon_checkout_redirect(org: Org, dollar_amt: int): + from routers.account import account_route + from routers.account import payment_processing_route + + line_item = available_subscriptions["addon"]["stripe"].copy() + line_item["quantity"] = dollar_amt * settings.ADDON_CREDITS_PER_DOLLAR + checkout_session = stripe.checkout.Session.create( + line_items=[line_item], + mode="payment", + success_url=get_app_route_url(payment_processing_route), + cancel_url=get_app_route_url(account_route), + customer=org.get_or_create_stripe_customer().id, + invoice_creation={"enabled": True}, + allow_promotion_codes=True, + saved_payment_method_options={ + "payment_method_save": "enabled", + }, + ) + raise gui.RedirectException(checkout_session.url, status_code=303) + + def render_org_creation_view(user: AppUser): gui.write(f"# {icons.company} Create an Org", unsafe_allow_html=True) org_fields = render_org_create_or_edit_form() @@ -502,3 +872,9 @@ class AttrDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__dict__ = self + + +def left_and_right(*, className: str = "", **props): + className += " d-flex flex-row justify-content-between align-items-center" + with gui.div(className=className, **props): + return gui.div(), gui.div() diff --git a/payments/models.py b/payments/models.py index fe280247d..f647bd5a6 100644 --- a/payments/models.py +++ b/payments/models.py @@ -80,8 +80,10 @@ class Meta: def __str__(self): ret = f"{self.get_plan_display()} | {self.get_payment_provider_display()}" - if self.has_user: - ret = f"{ret} | {self.user}" + # if self.has_user: + # ret = f"{ret} | {self.user}" + if self.has_org: + ret = f"{ret} | {self.org}" if self.auto_recharge_enabled: ret = f"Auto | {ret}" return ret @@ -131,6 +133,15 @@ def has_user(self) -> bool: def is_paid(self) -> bool: return PricingPlan.from_sub(self).monthly_charge > 0 and self.external_id + @property + def has_org(self) -> bool: + try: + self.org + except Subscription.org.RelatedObjectDoesNotExist: + return False + else: + return True + def cancel(self): from payments.webhooks import StripeWebhookHandler @@ -361,12 +372,12 @@ def has_sent_monthly_budget_email_this_month(self) -> bool: ) def should_send_monthly_spending_notification(self) -> bool: - assert self.has_user + assert self.has_org return bool( self.monthly_spending_notification_threshold and not self.has_sent_monthly_spending_notification_this_month() - and self.user.get_dollars_spent_this_month() + and self.org.get_dollars_spent_this_month() >= self.monthly_spending_notification_threshold ) diff --git a/payments/tasks.py b/payments/tasks.py index 252064541..2070db714 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -2,6 +2,7 @@ from loguru import logger from app_users.models import AppUser +from orgs.models import Org from celeryapp import app from daras_ai_v2 import settings from daras_ai_v2.fastapi_tricks import get_app_route_url @@ -10,33 +11,33 @@ @app.task -def send_monthly_spending_notification_email(user_id: int): +def send_monthly_spending_notification_email(id: int): from routers.account import account_route - user = AppUser.objects.get(id=user_id) - if not user.email: - logger.error(f"User doesn't have an email: {user=}") - return - - threshold = user.subscription.monthly_spending_notification_threshold - - send_email_via_postmark( - from_address=settings.SUPPORT_EMAIL, - to_address=user.email, - subject=f"[Gooey.AI] Monthly spending has exceeded ${threshold}", - html_body=templates.get_template( - "monthly_spending_notification_threshold_email.html" - ).render( - user=user, - account_url=get_app_route_url(account_route), - ), - ) - - # IMPORTANT: always use update_fields=... / select_for_update when updating - # subscription info. We don't want to overwrite other changes made to - # subscription during the same time - user.subscription.monthly_spending_notification_sent_at = timezone.now() - user.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) + org = Org.objects.get(id=id) + threshold = org.subscription.monthly_spending_notification_threshold + for owner in org.get_owners(): + if not owner.user.email: + logger.error(f"Org Owner doesn't have an email: {owner=}") + return + + send_email_via_postmark( + from_address=settings.SUPPORT_EMAIL, + to_address=owner.user.email, + subject=f"[Gooey.AI] Monthly spending has exceeded ${threshold}", + html_body=templates.get_template( + "monthly_spending_notification_threshold_email.html" + ).render( + user=owner.user, + account_url=get_app_route_url(account_route), + ), + ) + + # IMPORTANT: always use update_fields=... / select_for_update when updating + # subscription info. We don't want to overwrite other changes made to + # subscription during the same time + org.subscription.monthly_spending_notification_sent_at = timezone.now() + org.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) def send_monthly_budget_reached_email(user: AppUser): diff --git a/payments/webhooks.py b/payments/webhooks.py index 0b822cfe7..c280e129f 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -10,6 +10,7 @@ TransactionReason, ) from daras_ai_v2 import paypal +from orgs.models import Org from .models import Subscription from .plans import PricingPlan from .tasks import send_monthly_spending_notification_email @@ -25,7 +26,7 @@ def handle_sale_completed(cls, sale: paypal.Sale): return pp_sub = paypal.Subscription.retrieve(sale.billing_agreement_id) - assert pp_sub.custom_id, "pp_sub is missing uid" + assert pp_sub.custom_id, "pp_sub is missing org_id" assert pp_sub.plan_id, "pp_sub is missing plan ID" plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) @@ -38,9 +39,9 @@ def handle_sale_completed(cls, sale: paypal.Sale): f"paypal: charged amount ${charged_dollars} does not match plan's monthly charge ${plan.monthly_charge}" ) - uid = pp_sub.custom_id + org_id = pp_sub.custom_id add_balance_for_payment( - uid=uid, + org_id=org_id, amount=plan.credits, invoice_id=sale.id, payment_provider=cls.PROVIDER, @@ -53,7 +54,7 @@ def handle_sale_completed(cls, sale: paypal.Sale): def handle_subscription_updated(cls, pp_sub: paypal.Subscription): logger.info(f"Paypal subscription updated {pp_sub.id}") - assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" + assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing org_id" assert pp_sub.plan_id, f"PayPal subscription {pp_sub.id} is missing plan ID" plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) @@ -65,17 +66,17 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): ) return - set_user_subscription( + set_org_subscription( provider=cls.PROVIDER, plan=plan, - uid=pp_sub.custom_id, + org_id=pp_sub.custom_id, external_id=pp_sub.id, ) @classmethod def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription): assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" - set_user_subscription( + set_org_subscription( uid=pp_sub.custom_id, plan=PricingPlan.STARTER, provider=None, @@ -87,11 +88,9 @@ class StripeWebhookHandler: PROVIDER = PaymentProvider.STRIPE @classmethod - def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice): - from app_users.tasks import save_stripe_default_payment_method - + def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice): kwargs = {} - if invoice.subscription: + if invoice.subscription and invoice.subscription_details: kwargs["plan"] = PricingPlan.get_by_key( invoice.subscription_details.metadata.get("subscription_key") ).db_value @@ -112,7 +111,7 @@ def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice): amount = invoice.lines.data[0].quantity charged_amount = invoice.lines.data[0].amount add_balance_for_payment( - uid=uid, + org_id=org_id, amount=amount, invoice_id=invoice.id, payment_provider=cls.PROVIDER, @@ -130,7 +129,7 @@ def handle_invoice_paid(cls, uid: str, invoice: stripe.Invoice): ) @classmethod - def handle_checkout_session_completed(cls, uid: str, session_data): + def handle_checkout_session_completed(cls, org_id: str, session_data): setup_intent_id = session_data.get("setup_intent") if not setup_intent_id: # not a setup mode checkout -- do nothing @@ -152,7 +151,7 @@ def handle_checkout_session_completed(cls, uid: str, session_data): ) @classmethod - def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription): + def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscription): logger.info(f"Stripe subscription updated: {stripe_sub.id}") assert stripe_sub.plan, f"Stripe subscription {stripe_sub.id} is missing plan" @@ -173,17 +172,18 @@ def handle_subscription_updated(cls, uid: str, stripe_sub: stripe.Subscription): ) return - set_user_subscription( + set_org_subscription( provider=cls.PROVIDER, plan=plan, - uid=uid, + org_id=org_id, external_id=stripe_sub.id, ) @classmethod - def handle_subscription_cancelled(cls, uid: str): - set_user_subscription( - uid=uid, + def handle_subscription_cancelled(cls, org_id: str): + logger.info(f"Stripe subscription cancelled: {stripe_sub.id}") + set_org_subscription( + org_id=org_id, plan=PricingPlan.STARTER, provider=PaymentProvider.STRIPE, external_id=None, @@ -192,15 +192,15 @@ def handle_subscription_cancelled(cls, uid: str): def add_balance_for_payment( *, - uid: str, + org_id: str, amount: int, invoice_id: str, payment_provider: PaymentProvider, charged_amount: int, **kwargs, ): - user = AppUser.objects.get_or_create_from_uid(uid)[0] - user.add_balance( + org = Org.objects.get_or_create_from_org_id(org_id)[0] + org.add_balance( amount=amount, invoice_id=invoice_id, charged_amount=charged_amount, @@ -208,20 +208,20 @@ def add_balance_for_payment( **kwargs, ) - if not user.is_paying: - user.is_paying = True - user.save(update_fields=["is_paying"]) + if not org.is_paying: + org.is_paying = True + org.save(update_fields=["is_paying"]) if ( - user.subscription - and user.subscription.should_send_monthly_spending_notification() + org.subscription + and org.subscription.should_send_monthly_spending_notification() ): - send_monthly_spending_notification_email.delay(user.id) + send_monthly_spending_notification_email.delay(org.id) -def set_user_subscription( +def set_org_subscription( *, - uid: str, + org_id: str, plan: PricingPlan, provider: PaymentProvider | None, external_id: str | None, @@ -229,9 +229,9 @@ def set_user_subscription( charged_amount: int = None, ) -> Subscription: with transaction.atomic(): - user = AppUser.objects.get_or_create_from_uid(uid)[0] + org = Org.objects.get_or_create_from_org_id(org_id)[0] - old_sub = user.subscription + old_sub = org.subscription if old_sub: new_sub = copy(old_sub) else: @@ -245,8 +245,8 @@ def set_user_subscription( new_sub.save() if not old_sub: - user.subscription = new_sub - user.save(update_fields=["subscription"]) + org.subscription = new_sub + org.save(update_fields=["subscription"]) # cancel previous subscription if it's not the same as the new one if old_sub and old_sub.external_id != external_id: diff --git a/scripts/migrate_orgs_from_appusers.py b/scripts/migrate_orgs_from_appusers.py new file mode 100644 index 000000000..d4e868e30 --- /dev/null +++ b/scripts/migrate_orgs_from_appusers.py @@ -0,0 +1,26 @@ +from django.db import IntegrityError +from loguru import logger + +from app_users.models import AppUser +from orgs.models import Org + + +def run(): + users_without_personal_org = AppUser.objects.exclude( + id__in=Org.objects.filter(is_personal=True).values_list("created_by", flat=True) + ) + + done_count = 0 + + for appuser in users_without_personal_org: + try: + Org.objects.migrate_from_appuser(appuser) + except IntegrityError as e: + logger.warning(f"IntegrityError: {e}") + else: + done_count += 1 + + if done_count % 100 == 0: + logger.info(f"Running... {done_count} migrated") + + logger.info(f"Done... {done_count} migrated") From 0a3593fa86fb0a7d6bcf4311a61e90d1ea55f0f7 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 28 Aug 2024 15:17:34 +0530 Subject: [PATCH 047/110] feat: set initial credit balance for first org created by user --- daras_ai_v2/billing.py | 1 - daras_ai_v2/settings.py | 3 +-- orgs/models.py | 20 ++++++++++++++++++-- payments/webhooks.py | 1 - 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index adc500015..7722fa5bf 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -9,7 +9,6 @@ from daras_ai_v2.gui_confirm import confirm_modal from daras_ai_v2.settings import templates from daras_ai_v2.user_date_widgets import render_local_date_attrs -from orgs.models import Org from payments.models import PaymentMethodSummary from payments.plans import PricingPlan from payments.webhooks import StripeWebhookHandler, set_org_subscription diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 3cdd88dc8..05a79d4c8 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -11,7 +11,6 @@ """ import os -import json from pathlib import Path import sentry_sdk @@ -289,9 +288,9 @@ EMAIL_USER_FREE_CREDITS = config("EMAIL_USER_FREE_CREDITS", 0, cast=int) ANON_USER_FREE_CREDITS = config("ANON_USER_FREE_CREDITS", 25, cast=int) LOGIN_USER_FREE_CREDITS = config("LOGIN_USER_FREE_CREDITS", 500, cast=int) +FIRST_ORG_FREE_CREDITS = config("ORG_FREE_CREDITS", 500, cast=int) ADDON_CREDITS_PER_DOLLAR = config("ADDON_CREDITS_PER_DOLLAR", 100, cast=int) - ADDON_AMOUNT_CHOICES = [10, 30, 50, 100, 300, 500, 1000] # USD AUTO_RECHARGE_BALANCE_THRESHOLD_CHOICES = [300, 1000, 3000, 10000] # Credit balance AUTO_RECHARGE_COOLDOWN_SECONDS = config("AUTO_RECHARGE_COOLDOWN_SECONDS", 60, cast=int) diff --git a/orgs/models.py b/orgs/models.py index 0c39312c0..fa1b471b9 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -42,11 +42,27 @@ class OrgRole(models.IntegerChoices): class OrgManager(SafeDeleteManager): def create_org( - self, *, created_by: "AppUser", org_id: str | None = None, **kwargs + self, + *, + created_by: "AppUser", + org_id: str | None = None, + balance: int | None = None, + **kwargs, ) -> Org: org = self.model( - org_id=org_id or get_random_doc_id(), created_by=created_by, **kwargs + org_id=org_id or get_random_doc_id(), + created_by=created_by, + balance=balance, + **kwargs, ) + if ( + balance is None + and Org.all_objects.filter(created_by=created_by).count() <= 1 + ): + # set some balance for first team created by user + # Org.all_objects is important to include deleted orgs + org.balance = settings.FIRST_ORG_FREE_CREDITS + org.full_clean() org.save() org.add_member( diff --git a/payments/webhooks.py b/payments/webhooks.py index c280e129f..a00466bbc 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -181,7 +181,6 @@ def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscriptio @classmethod def handle_subscription_cancelled(cls, org_id: str): - logger.info(f"Stripe subscription cancelled: {stripe_sub.id}") set_org_subscription( org_id=org_id, plan=PricingPlan.STARTER, From d6b5bc1de6566019082694d226bad5393206dd72 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 28 Aug 2024 15:33:41 +0530 Subject: [PATCH 048/110] feat: add billed org to saved run & script to migrate org_id for existing saved runs --- bots/migrations/0082_savedrun_billed_org.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 bots/migrations/0082_savedrun_billed_org.py diff --git a/bots/migrations/0082_savedrun_billed_org.py b/bots/migrations/0082_savedrun_billed_org.py new file mode 100644 index 000000000..9dbe6170d --- /dev/null +++ b/bots/migrations/0082_savedrun_billed_org.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.7 on 2024-08-28 09:49 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('orgs', '0005_org_unique_personal_org_per_user'), + ('bots', '0081_alter_botintegration_streaming_enabled'), + ] + + operations = [ + migrations.AddField( + model_name='savedrun', + name='billed_org', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='billed_runs', to='orgs.org'), + ), + ] From e0b94cb34a0d9193e1442183924a9d1e9dbb314d Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 28 Aug 2024 15:34:37 +0530 Subject: [PATCH 049/110] fix: add filter condition in billed_org migration script to only run on historical data --- scripts/migrate_billed_org_for_saved_runs.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 scripts/migrate_billed_org_for_saved_runs.py diff --git a/scripts/migrate_billed_org_for_saved_runs.py b/scripts/migrate_billed_org_for_saved_runs.py new file mode 100644 index 000000000..52b86e932 --- /dev/null +++ b/scripts/migrate_billed_org_for_saved_runs.py @@ -0,0 +1,18 @@ +from django.db.models import F, Subquery, OuterRef +from django.db import transaction + +from bots.models import SavedRun +from orgs.models import Org + + +def run(): + # Start a transaction to ensure atomicity + with transaction.atomic(): + # Perform the update where 'uid' matches a valid 'org_id' in the 'Org' table + SavedRun.objects.filter( + billed_org_id__isnull=True, uid__in=Org.objects.values("org_id") + ).update( + billed_org_id=Subquery( + Org.objects.filter(org_id=OuterRef("uid")).values("id")[:1] + ) + ) From 026bc27ae5baf8a4b6b8020a98384cf917fed200 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 1 Sep 2024 19:01:36 +0530 Subject: [PATCH 050/110] fix: sync migrations in bots app with master --- bots/migrations/0082_savedrun_billed_org.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bots/migrations/0082_savedrun_billed_org.py b/bots/migrations/0082_savedrun_billed_org.py index 9dbe6170d..208f46dcc 100644 --- a/bots/migrations/0082_savedrun_billed_org.py +++ b/bots/migrations/0082_savedrun_billed_org.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.7 on 2024-08-28 09:49 +# Generated by Django 4.2.7 on 2024-08-30 08:10 from django.db import migrations, models import django.db.models.deletion @@ -8,7 +8,7 @@ class Migration(migrations.Migration): dependencies = [ ('orgs', '0005_org_unique_personal_org_per_user'), - ('bots', '0081_alter_botintegration_streaming_enabled'), + ('bots', '0081_remove_conversation_bots_conver_bot_int_73ac7b_idx_and_more'), ] operations = [ From 322dbbd8840c0d022e8d50988d963fa5afacbee5 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 1 Sep 2024 19:02:12 +0530 Subject: [PATCH 051/110] fix: type check for user.get_or_create_personal_org --- app_users/models.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/app_users/models.py b/app_users/models.py index 739ab3bd3..46803c1a8 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -1,5 +1,6 @@ import requests import stripe +import typing from django.db import models, IntegrityError, transaction from django.db.models import Sum from django.utils import timezone @@ -14,6 +15,9 @@ from handles.models import Handle from payments.plans import PricingPlan +if typing.TYPE_CHECKING: + from orgs.models import Org + class AppUserQuerySet(models.QuerySet): def get_or_create_from_uid( @@ -249,13 +253,13 @@ def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser": return self def get_or_create_personal_org(self) -> tuple["Org", bool]: - from orgs.models import Org + from orgs.models import Org, OrgMembership - org_membership = self.org_memberships.filter( + org_membership: OrgMembership | None = self.org_memberships.filter( org__is_personal=True, org__created_by=self ).first() if org_membership: - return org_membership, False + return org_membership.org, False else: return Org.objects.migrate_from_appuser(self), True From 0bf1ee98578985c6e1b489458ca5e74e30d6185e Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 1 Sep 2024 19:05:16 +0530 Subject: [PATCH 052/110] add: make billing tab work with org instead of AppUser --- daras_ai_v2/billing.py | 264 +++++++++++++++++++++-------------------- routers/account.py | 3 +- 2 files changed, 137 insertions(+), 130 deletions(-) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 7722fa5bf..e5a5fa27e 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -1,8 +1,10 @@ +import typing + import gooey_gui as gui import stripe from django.core.exceptions import ValidationError -from app_users.models import AppUser, PaymentProvider +from app_users.models import AppUserTransaction, PaymentProvider from daras_ai_v2 import icons, settings, paypal from daras_ai_v2.fastapi_tricks import get_app_route_url from daras_ai_v2.grid_layout_widget import grid_layout @@ -14,34 +16,38 @@ from payments.webhooks import StripeWebhookHandler, set_org_subscription from scripts.migrate_existing_subscriptions import available_subscriptions +if typing.TYPE_CHECKING: + from orgs.models import Org + + rounded_border = "w-100 border shadow-sm rounded py-4 px-3" -def billing_page(user: AppUser): +def billing_page(org: "Org"): render_payments_setup() - if user.subscription and user.subscription.is_paid(): - render_current_plan(user) + if org.subscription and org.subscription.is_paid(): + render_current_plan(org) with gui.div(className="my-5"): - render_credit_balance(user) + render_credit_balance(org) with gui.div(className="my-5"): - selected_payment_provider = render_all_plans(user) + selected_payment_provider = render_all_plans(org) with gui.div(className="my-5"): - render_addon_section(user, selected_payment_provider) + render_addon_section(org, selected_payment_provider) - if user.subscription: - if user.subscription.payment_provider == PaymentProvider.STRIPE: + if org.subscription: + if org.subscription.payment_provider == PaymentProvider.STRIPE: with gui.div(className="my-5"): - render_auto_recharge_section(user) + render_auto_recharge_section(org) with gui.div(className="my-5"): - render_payment_information(user) + render_payment_information(org) with gui.div(className="my-5"): - render_billing_history(user) + render_billing_history(org) def render_payments_setup(): @@ -55,10 +61,10 @@ def render_payments_setup(): ) -def render_current_plan(user: AppUser): - plan = PricingPlan.from_sub(user.subscription) - if user.subscription.payment_provider: - provider = PaymentProvider(user.subscription.payment_provider) +def render_current_plan(org: "Org"): + plan = PricingPlan.from_sub(org.subscription) + if org.subscription.payment_provider: + provider = PaymentProvider(org.subscription.payment_provider) else: provider = None @@ -76,7 +82,7 @@ def render_current_plan(user: AppUser): with right, gui.div(className="d-flex align-items-center gap-1"): if provider and ( next_invoice_ts := gui.run_in_thread( - user.subscription.get_next_invoice_timestamp, cache=True + org.subscription.get_next_invoice_timestamp, cache=True ) ): gui.html("Next invoice on ") @@ -112,17 +118,17 @@ def render_current_plan(user: AppUser): ) -def render_credit_balance(user: AppUser): - gui.write(f"## Credit Balance: {user.balance:,}") +def render_credit_balance(org: "Org"): + gui.write(f"## Credit Balance: {org.balance:,}") gui.caption( "Every time you submit a workflow or make an API call, we deduct credits from your account." ) -def render_all_plans(user: AppUser) -> PaymentProvider: +def render_all_plans(org: "Org") -> PaymentProvider: current_plan = ( - PricingPlan.from_sub(user.subscription) - if user.subscription + PricingPlan.from_sub(org.subscription) + if org.subscription else PricingPlan.STARTER ) all_plans = [plan for plan in PricingPlan if not plan.deprecated] @@ -130,8 +136,8 @@ def render_all_plans(user: AppUser) -> PaymentProvider: gui.write("## All Plans") plans_div = gui.div(className="mb-1") - if user.subscription and user.subscription.payment_provider: - selected_payment_provider = user.subscription.payment_provider + if org.subscription and org.subscription.payment_provider: + selected_payment_provider = org.subscription.payment_provider else: with gui.div(): selected_payment_provider = PaymentProvider[ @@ -149,7 +155,7 @@ def _render_plan(plan: PricingPlan): ): _render_plan_details(plan) _render_plan_action_button( - user=user, + org=org, plan=plan, current_plan=current_plan, payment_provider=selected_payment_provider, @@ -187,7 +193,7 @@ def _render_plan_details(plan: PricingPlan): def _render_plan_action_button( - user: AppUser, + org: "Org", plan: PricingPlan, current_plan: PricingPlan, payment_provider: PaymentProvider | None, @@ -201,75 +207,72 @@ def _render_plan_action_button( className=btn_classes + " btn btn-theme btn-primary", ): gui.html("Contact Us") - elif ( - user.subscription and user.subscription.plan == PricingPlan.ENTERPRISE.db_value - ): + elif org.subscription and org.subscription.plan == PricingPlan.ENTERPRISE.db_value: # don't show upgrade/downgrade buttons for enterprise customers return - else: - if user.subscription and user.subscription.is_paid(): - # subscription exists, show upgrade/downgrade button - if plan.credits > current_plan.credits: - modal, confirmed = confirm_modal( - title="Upgrade Plan", - key=f"--modal-{plan.key}", - text=f""" + elif org.subscription and org.subscription.is_paid(): + # subscription exists, show upgrade/downgrade button + if plan.credits > current_plan.credits: + modal, confirmed = confirm_modal( + title="Upgrade Plan", + key=f"--modal-{plan.key}", + text=f""" Are you sure you want to upgrade from **{current_plan.title} @ {fmt_price(current_plan)}** to **{plan.title} @ {fmt_price(plan)}**? Your payment method will be charged ${plan.monthly_charge:,} today and again every month until you cancel. **{plan.credits:,} Credits** will be added to your account today and with subsequent payments, your account balance will be refreshed to {plan.credits:,} Credits. - """, - button_label="Upgrade", + """, + button_label="Upgrade", + ) + if gui.button( + "Upgrade", className="primary", key=f"--change-sub-{plan.key}" + ): + modal.open() + if confirmed: + change_subscription( + org, + plan, + # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time + billing_cycle_anchor="now", ) - if gui.button( - "Upgrade", className="primary", key=f"--change-sub-{plan.key}" - ): - modal.open() - if confirmed: - change_subscription( - user, - plan, - # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time - billing_cycle_anchor="now", - ) - else: - modal, confirmed = confirm_modal( - title="Downgrade Plan", - key=f"--modal-{plan.key}", - text=f""" + else: + modal, confirmed = confirm_modal( + title="Downgrade Plan", + key=f"--modal-{plan.key}", + text=f""" Are you sure you want to downgrade from: **{current_plan.title} @ {fmt_price(current_plan)}** to **{plan.title} @ {fmt_price(plan)}**? This will take effect from the next billing cycle. - """, - button_label="Downgrade", - button_class="border-danger bg-danger text-white", - ) - if gui.button( - "Downgrade", className="secondary", key=f"--change-sub-{plan.key}" - ): - modal.open() - if confirmed: - change_subscription(user, plan) - else: - assert payment_provider is not None # for sanity - _render_create_subscription_button( - user=user, - plan=plan, - payment_provider=payment_provider, + """, + button_label="Downgrade", + button_class="border-danger bg-danger text-white", ) + if gui.button( + "Downgrade", className="secondary", key=f"--change-sub-{plan.key}" + ): + modal.open() + if confirmed: + change_subscription(org, plan) + else: + assert payment_provider is not None # for sanity + _render_create_subscription_button( + org=org, + plan=plan, + payment_provider=payment_provider, + ) def _render_create_subscription_button( *, - user: AppUser, + org: "Org", plan: PricingPlan, payment_provider: PaymentProvider, ): match payment_provider: case PaymentProvider.STRIPE: - render_stripe_subscription_button(user=user, plan=plan) + render_stripe_subscription_button(org=org, plan=plan) case PaymentProvider.PAYPAL: render_paypal_subscription_button(plan=plan) @@ -281,27 +284,27 @@ def fmt_price(plan: PricingPlan) -> str: return "Free" -def change_subscription(user: AppUser, new_plan: PricingPlan, **kwargs): +def change_subscription(org: "Org", new_plan: PricingPlan, **kwargs): from routers.account import account_route from routers.account import payment_processing_route - current_plan = PricingPlan.from_sub(user.subscription) + current_plan = PricingPlan.from_sub(org.subscription) if new_plan == current_plan: raise gui.RedirectException(get_app_route_url(account_route), status_code=303) if new_plan == PricingPlan.STARTER: - user.subscription.cancel() + org.subscription.cancel() raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 ) - match user.subscription.payment_provider: + match org.subscription.payment_provider: case PaymentProvider.STRIPE: if not new_plan.supports_stripe(): gui.error(f"Stripe subscription not available for {new_plan}") - subscription = stripe.Subscription.retrieve(user.subscription.external_id) + subscription = stripe.Subscription.retrieve(org.subscription.external_id) stripe.Subscription.modify( subscription.id, items=[ @@ -345,20 +348,20 @@ def payment_provider_radio(**props) -> str | None: ) -def render_addon_section(user: AppUser, selected_payment_provider: PaymentProvider): - if user.subscription: +def render_addon_section(org: "Org", selected_payment_provider: PaymentProvider): + if org.subscription: gui.write("# Purchase More Credits") else: gui.write("# Purchase Credits") gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") - if user.subscription and user.subscription.payment_provider: - provider = PaymentProvider(user.subscription.payment_provider) + if org.subscription and org.subscription.payment_provider: + provider = PaymentProvider(org.subscription.payment_provider) else: provider = selected_payment_provider match provider: case PaymentProvider.STRIPE: - render_stripe_addon_buttons(user) + render_stripe_addon_buttons(org) case PaymentProvider.PAYPAL: render_paypal_addon_buttons() @@ -382,8 +385,8 @@ def render_paypal_addon_buttons(): gui.div(id="paypal-result-message") -def render_stripe_addon_buttons(user: AppUser): - if not (user.subscription and user.subscription.payment_provider): +def render_stripe_addon_buttons(org: "Org"): + if not (org.subscription and org.subscription.payment_provider): save_pm = gui.checkbox( "Save payment method for future purchases & auto-recharge", value=True ) @@ -391,10 +394,10 @@ def render_stripe_addon_buttons(user: AppUser): save_pm = True for dollat_amt in settings.ADDON_AMOUNT_CHOICES: - render_stripe_addon_button(dollat_amt, user, save_pm) + render_stripe_addon_button(dollat_amt, org, save_pm) -def render_stripe_addon_button(dollat_amt: int, user: AppUser, save_pm: bool): +def render_stripe_addon_button(dollat_amt: int, org: "Org", save_pm: bool): modal, confirmed = confirm_modal( title="Purchase Credits", key=f"--addon-modal-{dollat_amt}", @@ -408,14 +411,14 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser, save_pm: bool): ) if gui.button(f"${dollat_amt:,}", type="primary"): - if user.subscription and user.subscription.stripe_get_default_payment_method(): + if org.subscription and org.subscription.stripe_get_default_payment_method(): modal.open() else: - stripe_addon_checkout_redirect(user, dollat_amt, save_pm) + stripe_addon_checkout_redirect(org, dollat_amt, save_pm) if confirmed: success = gui.run_in_thread( - user.subscription.stripe_attempt_addon_purchase, + org.subscription.stripe_attempt_addon_purchase, args=[dollat_amt], placeholder="", ) @@ -426,10 +429,10 @@ def render_stripe_addon_button(dollat_amt: int, user: AppUser, save_pm: bool): modal.close() else: # fallback to stripe checkout flow if the auto payment failed - stripe_addon_checkout_redirect(user, dollat_amt, save_pm) + stripe_addon_checkout_redirect(org, dollat_amt, save_pm) -def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int, save_pm: bool): +def stripe_addon_checkout_redirect(org: "Org", dollat_amt: int, save_pm: bool): from routers.account import account_route from routers.account import payment_processing_route @@ -445,7 +448,7 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int, save_pm: bool mode="payment", success_url=get_app_route_url(payment_processing_route), cancel_url=get_app_route_url(account_route), - customer=user.get_or_create_stripe_customer(), + customer=org.get_or_create_stripe_customer(), invoice_creation={"enabled": True}, allow_promotion_codes=True, **kwargs, @@ -455,7 +458,7 @@ def stripe_addon_checkout_redirect(user: AppUser, dollat_amt: int, save_pm: bool def render_stripe_subscription_button( *, - user: AppUser, + org: "Org", plan: PricingPlan, ): if not plan.supports_stripe(): @@ -483,36 +486,38 @@ def render_stripe_subscription_button( key=f"--change-sub-{plan.key}", type="primary", ): - if user.subscription and user.subscription.stripe_get_default_payment_method(): + if org.subscription and org.subscription.stripe_get_default_payment_method(): modal.open() else: - stripe_subscription_create(user=user, plan=plan) + stripe_subscription_create(org=org, plan=plan) if confirmed: - stripe_subscription_create(user=user, plan=plan) + stripe_subscription_create(org=org, plan=plan) -def stripe_subscription_create(user: AppUser, plan: PricingPlan): +def stripe_subscription_create(org: "Org", plan: PricingPlan): from routers.account import account_route from routers.account import payment_processing_route - if user.subscription and user.subscription.plan == plan.db_value: + if org.subscription and org.subscription.is_paid(): # sanity check: already subscribed to some plan - return + gui.rerun() # check for existing subscriptions on stripe - customer = user.get_or_create_stripe_customer() + customer = org.get_or_create_stripe_customer() for sub in stripe.Subscription.list( customer=customer, status="active", limit=1 ).data: - StripeWebhookHandler.handle_subscription_updated(uid=user.uid, stripe_sub=sub) + StripeWebhookHandler.handle_subscription_updated( + org_id=org.org_id, stripe_sub=sub + ) raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 ) # try to directly create the subscription without checkout - pm = user.subscription and user.subscription.stripe_get_default_payment_method() metadata = {settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: plan.key} + pm = org.subscription and org.subscription.stripe_get_default_payment_method() line_items = [plan.get_stripe_line_item()] if pm: sub = stripe.Subscription.create( @@ -562,12 +567,12 @@ def render_paypal_subscription_button( ) -def render_payment_information(user: AppUser): - if not user.subscription: +def render_payment_information(org: "Org"): + if not org.subscription: return pm_summary = gui.run_in_thread( - user.subscription.get_payment_method_summary, cache=True + org.subscription.get_payment_method_summary, cache=True ) if not pm_summary: return @@ -579,7 +584,7 @@ def render_payment_information(user: AppUser): gui.write("**Pay via**") with col2: provider = PaymentProvider( - user.subscription.payment_provider or PaymentProvider.STRIPE + org.subscription.payment_provider or PaymentProvider.STRIPE ) gui.write(provider.label) with col3: @@ -587,7 +592,7 @@ def render_payment_information(user: AppUser): f"{icons.edit} Edit", type="link", key="manage-payment-provider" ): raise gui.RedirectException( - user.subscription.get_external_management_url() + org.subscription.get_external_management_url() ) pm_summary = PaymentMethodSummary(*pm_summary) @@ -607,7 +612,7 @@ def render_payment_information(user: AppUser): if gui.button( f"{icons.edit} Edit", type="link", key="edit-payment-method" ): - change_payment_method(user) + change_payment_method(org) if pm_summary.billing_email: col1, col2, _ = gui.columns(3, responsive=False) @@ -636,12 +641,12 @@ def render_payment_information(user: AppUser): modal.open() if confirmed: set_org_subscription( - org_id=user.get_personal_org().org_id, + org_id=org.org_id, plan=PricingPlan.STARTER, provider=None, external_id=None, ) - pm = user.subscription and user.subscription.stripe_get_default_payment_method() + pm = org.subscription and org.subscription.stripe_get_default_payment_method() if pm: pm.detach() raise gui.RedirectException( @@ -649,18 +654,18 @@ def render_payment_information(user: AppUser): ) -def change_payment_method(user: AppUser): +def change_payment_method(org: "Org"): from routers.account import payment_processing_route from routers.account import account_route - match user.subscription.payment_provider: + match org.subscription.payment_provider: case PaymentProvider.STRIPE: session = stripe.checkout.Session.create( mode="setup", currency="usd", - customer=user.get_or_create_stripe_customer(), + customer=org.get_or_create_stripe_customer(), setup_intent_data={ - "metadata": {"subscription_id": user.subscription.external_id}, + "metadata": {"subscription_id": org.subscription.external_id}, }, success_url=get_app_route_url(payment_processing_route), cancel_url=get_app_route_url(account_route), @@ -674,10 +679,13 @@ def format_card_brand(brand: str) -> str: return icons.card_icons.get(brand.lower(), brand.capitalize()) -def render_billing_history(user: AppUser, limit: int = 50): +def render_billing_history(org: "Org", limit: int = 50): import pandas as pd - txns = user.transactions.filter(amount__gt=0).order_by("-created_at") + txns = AppUserTransaction.objects.filter( + org=org, + amount__gt=0, + ).order_by("-created_at") if not txns: return @@ -700,9 +708,9 @@ def render_billing_history(user: AppUser, limit: int = 50): gui.caption(f"Showing only the most recent {limit} transactions.") -def render_auto_recharge_section(user: AppUser): - assert user.subscription - subscription = user.subscription +def render_auto_recharge_section(org: "Org"): + assert org.subscription + subscription = org.subscription gui.write("## Auto Recharge & Limits") with gui.div(className="h4"): @@ -746,10 +754,10 @@ def render_auto_recharge_section(user: AppUser): """, ) with gui.div(className="d-flex align-items-center"): - user.subscription.monthly_spending_budget = gui.number_input( + subscription.monthly_spending_budget = gui.number_input( "", min_value=10, - value=user.subscription.monthly_spending_budget, + value=subscription.monthly_spending_budget, key="monthly-spending-budget", ) gui.write("USD", className="d-block ms-2") @@ -762,13 +770,11 @@ def render_auto_recharge_section(user: AppUser): """ ) with gui.div(className="d-flex align-items-center"): - user.subscription.monthly_spending_notification_threshold = ( - gui.number_input( - "", - min_value=10, - value=user.subscription.monthly_spending_notification_threshold, - key="monthly-spending-notification-threshold", - ) + subscription.monthly_spending_notification_threshold = gui.number_input( + "", + min_value=10, + value=subscription.monthly_spending_notification_threshold, + key="monthly-spending-notification-threshold", ) gui.write("USD", className="d-block ms-2") diff --git a/routers/account.py b/routers/account.py index b52239b2b..f9194589b 100644 --- a/routers/account.py +++ b/routers/account.py @@ -203,7 +203,8 @@ def url_path(self) -> str: def billing_tab(request: Request): - return billing_page(request.user) + org, _ = request.user.get_or_create_personal_org() + return billing_page(org) def profile_tab(request: Request): From 33359c53589b39e822ebeba293f148ff5763e68b Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 1 Sep 2024 19:05:50 +0530 Subject: [PATCH 053/110] fix: remove billing from org page --- orgs/views.py | 377 +------------------------------------------------- 1 file changed, 5 insertions(+), 372 deletions(-) diff --git a/orgs/views.py b/orgs/views.py index 2d6f3c27c..494bac72a 100644 --- a/orgs/views.py +++ b/orgs/views.py @@ -2,21 +2,13 @@ import html as html_lib -import stripe import gooey_gui as gui from django.core.exceptions import ValidationError -from app_users.models import AppUser, PaymentProvider -from daras_ai_v2.billing import format_card_brand, payment_provider_radio -from daras_ai_v2.grid_layout_widget import grid_layout +from app_users.models import AppUser from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole -from daras_ai_v2 import icons, settings -from daras_ai_v2.fastapi_tricks import get_route_path, get_app_route_url -from daras_ai_v2.settings import templates -from daras_ai_v2.user_date_widgets import render_local_date_attrs -from payments.models import PaymentMethodSummary -from payments.plans import PricingPlan -from scripts.migrate_existing_subscriptions import available_subscriptions +from daras_ai_v2 import icons +from daras_ai_v2.fastapi_tricks import get_route_path DEFAULT_ORG_LOGO = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/74a37c52-8260-11ee-a297-02420a0001ee/gooey.ai%20-%20A%20pop%20art%20illustration%20of%20robots%20taki...y%20Liechtenstein%20mint%20colour%20is%20main%20city%20Seattle.png" @@ -70,7 +62,7 @@ def invitation_page(user: AppUser, invitation: OrgInvitation): def orgs_page(user: AppUser): - memberships = user.org_memberships.all() + memberships = user.org_memberships.filter() if not memberships: gui.write("*You're not part of an organization yet... Create one?*") @@ -118,10 +110,6 @@ def render_org_by_membership(membership: OrgMembership): f"Org Domain: `@{org.domain_name}`", className="text-muted" ) - with gui.div(className="mt-4"): - gui.write("# Billing") - billing_section(org=org, current_member=membership) - with gui.div(className="mt-4"): with gui.div(className="d-flex justify-content-between align-items-center"): gui.write("## Members") @@ -157,361 +145,6 @@ def render_org_by_membership(membership: OrgMembership): org_leave_modal.open() -def billing_section(*, org: Org, current_member: OrgMembership): - render_payments_setup() - - if org.subscription and org.subscription.external_id: - render_current_plan(org) - - with gui.div(className="my-5"): - render_credit_balance(org) - - with gui.div(className="my-5"): - selected_payment_provider = render_all_plans(org) - - with gui.div(className="my-5"): - render_addon_section(org, selected_payment_provider) - - if org.subscription and org.subscription.external_id: - # if org.subscription.payment_provider == PaymentProvider.STRIPE: - # with gui.div(className="my-5"): - # render_auto_recharge_section(user) - with gui.div(className="my-5"): - render_payment_information(org) - - with gui.div(className="my-5"): - render_billing_history(org) - - -def render_payments_setup(): - from routers.account import payment_processing_route - - gui.html( - templates.get_template("payment_setup.html").render( - settings=settings, - payment_processing_url=get_app_route_url(payment_processing_route), - ) - ) - - -def render_current_plan(org: Org): - plan = PricingPlan.from_sub(org.subscription) - provider = ( - PaymentProvider(org.subscription.payment_provider) - if org.subscription.payment_provider - else None - ) - - with gui.div(className=f"{rounded_border} border-dark"): - # ROW 1: Plan title and next invoice date - left, right = left_and_right() - with left: - gui.write(f"#### Gooey.AI {plan.title}") - - if provider: - gui.write( - f"[{icons.edit} Manage Subscription](#payment-information)", - unsafe_allow_html=True, - ) - with right, gui.div(className="d-flex align-items-center gap-1"): - if provider and ( - next_invoice_ts := gui.run_in_thread( - org.subscription.get_next_invoice_timestamp, cache=True - ) - ): - gui.html("Next invoice on ") - gui.pill( - "...", - text_bg="dark", - **render_local_date_attrs( - next_invoice_ts, - date_options={"day": "numeric", "month": "long"}, - ), - ) - - if plan is PricingPlan.ENTERPRISE: - # charge details are not relevant for Enterprise customers - return - - # ROW 2: Plan pricing details - left, right = left_and_right(className="mt-5") - with left: - gui.write(f"# {plan.pricing_title()}", className="no-margin") - if plan.monthly_charge: - provider_text = f" **via {provider.label}**" if provider else "" - gui.caption("per month" + provider_text) - - with right, gui.div(className="text-end"): - gui.write(f"# {plan.credits:,} credits", className="no-margin") - if plan.monthly_charge: - gui.write( - f"**${plan.monthly_charge:,}** monthly renewal for {plan.credits:,} credits" - ) - - -def render_credit_balance(org: Org): - gui.write(f"## Credit Balance: {org.balance:,}") - gui.caption( - "Every time you submit a workflow or make an API call, we deduct credits from your account." - ) - - -def render_all_plans(org: Org) -> PaymentProvider | None: - current_plan = ( - PricingPlan.from_sub(org.subscription) - if org.subscription - else PricingPlan.STARTER - ) - all_plans = [plan for plan in PricingPlan if not plan.deprecated] - - gui.write("## All Plans") - plans_div = gui.div(className="mb-1") - - if org.subscription and org.subscription.payment_provider: - selected_payment_provider = None - else: - with gui.div(): - selected_payment_provider = PaymentProvider[ - payment_provider_radio() or PaymentProvider.STRIPE.name - ] - - def _render_plan(plan: PricingPlan): - if plan == current_plan: - extra_class = "border-dark" - else: - extra_class = "bg-light" - with gui.div(className="d-flex flex-column h-100"): - with gui.div( - className=f"{rounded_border} flex-grow-1 d-flex flex-column p-3 mb-2 {extra_class}" - ): - _render_plan_details(plan) - # _render_plan_action_button( - # user, plan, current_plan, selected_payment_provider - # ) - - with plans_div: - grid_layout(4, all_plans, _render_plan, separator=False) - - with gui.div(className="my-2 d-flex justify-content-center"): - gui.caption( - f"**[See all features & benefits]({settings.PRICING_DETAILS_URL})**" - ) - - return selected_payment_provider - - -def _render_plan_details(plan: PricingPlan): - with gui.div(className="flex-grow-1"): - with gui.div(className="mb-4"): - with gui.tag("h4", className="mb-0"): - gui.html(plan.title) - gui.caption( - plan.description, - style={ - "minHeight": "calc(var(--bs-body-line-height) * 2em)", - "display": "block", - }, - ) - with gui.div(className="my-3 w-100"): - with gui.tag("h4", className="my-0 d-inline me-2"): - gui.html(plan.pricing_title()) - with gui.tag("span", className="text-muted my-0"): - gui.html(plan.pricing_caption()) - gui.write(plan.long_description, unsafe_allow_html=True) - - -def render_payment_information(org: Org): - assert org.subscription - - gui.write("## Payment Information", id="payment-information", className="d-block") - col1, col2, col3 = gui.columns(3, responsive=False) - with col1: - gui.write("**Pay via**") - with col2: - provider = PaymentProvider(org.subscription.payment_provider) - gui.write(provider.label) - with col3: - if gui.button(f"{icons.edit} Edit", type="link", key="manage-payment-provider"): - raise gui.RedirectException(org.subscription.get_external_management_url()) - - pm_summary = gui.run_in_thread( - org.subscription.get_payment_method_summary, cache=True - ) - if not pm_summary: - return - pm_summary = PaymentMethodSummary(*pm_summary) - if pm_summary.card_brand and pm_summary.card_last4: - col1, col2, col3 = gui.columns(3, responsive=False) - with col1: - gui.write("**Payment Method**") - with col2: - gui.write( - f"{format_card_brand(pm_summary.card_brand)} ending in {pm_summary.card_last4}", - unsafe_allow_html=True, - ) - with col3: - if gui.button(f"{icons.edit} Edit", type="link", key="edit-payment-method"): - change_payment_method(org) - - if pm_summary.billing_email: - col1, col2, _ = gui.columns(3, responsive=False) - with col1: - gui.write("**Billing Email**") - with col2: - gui.html(pm_summary.billing_email) - - -def change_payment_method(org: Org): - from routers.account import payment_processing_route - from routers.account import account_route - - match org.subscription.payment_provider: - case PaymentProvider.STRIPE: - session = stripe.checkout.Session.create( - mode="setup", - currency="usd", - customer=org.get_or_create_stripe_customer().id, - setup_intent_data={ - "metadata": {"subscription_id": org.subscription.external_id}, - }, - success_url=get_app_route_url(payment_processing_route), - cancel_url=get_app_route_url(account_route), - ) - raise gui.RedirectException(session.url, status_code=303) - case _: - gui.error("Not implemented for this payment provider") - - -def render_billing_history(org: Org, limit: int = 50): - import pandas as pd - - txns = org.transactions.filter(amount__gt=0).order_by("-created_at") - if not txns: - return - - gui.write("## Billing History", className="d-block") - gui.table( - pd.DataFrame.from_records( - [ - { - "Date": txn.created_at.strftime("%m/%d/%Y"), - "Description": txn.reason_note(), - "Amount": f"-${txn.charged_amount / 100:,.2f}", - "Credits": f"+{txn.amount:,}", - "Balance": f"{txn.end_balance:,}", - } - for txn in txns[:limit] - ] - ), - ) - if txns.count() > limit: - gui.caption(f"Showing only the most recent {limit} transactions.") - - -def render_addon_section(org: Org, selected_payment_provider: PaymentProvider): - if org.subscription: - gui.write("# Purchase More Credits") - else: - gui.write("# Purchase Credits") - gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") - - if org.subscription and org.subscription.payment_provider: - provider = PaymentProvider(org.subscription.payment_provider) - else: - provider = selected_payment_provider - match provider: - case PaymentProvider.STRIPE | None: - render_stripe_addon_buttons(org) - case PaymentProvider.PAYPAL: - render_paypal_addon_buttons() - - -def render_paypal_addon_buttons(): - selected_amt = gui.horizontal_radio( - "", - settings.ADDON_AMOUNT_CHOICES, - format_func=lambda amt: f"${amt:,}", - checked_by_default=False, - ) - if selected_amt: - gui.js( - f"setPaypalAddonQuantity({int(selected_amt) * settings.ADDON_CREDITS_PER_DOLLAR})" - ) - gui.div( - id="paypal-addon-buttons", - className="mt-2", - style={"width": "fit-content"}, - ) - gui.div(id="paypal-result-message") - - -def render_stripe_addon_buttons(org: Org): - for dollar_amt in settings.ADDON_AMOUNT_CHOICES: - render_stripe_addon_button(dollar_amt, org) - - -def render_stripe_addon_button(dollar_amt: int, org: Org): - confirm_purchase_modal = gui.Modal( - "Confirm Purchase", key=f"confirm-purchase-{dollar_amt}" - ) - if gui.button(f"${dollar_amt:,}", type="primary"): - if org.subscription and org.subscription.external_id: - confirm_purchase_modal.open() - else: - stripe_addon_checkout_redirect(org, dollar_amt) - - if not confirm_purchase_modal.is_open(): - return - with confirm_purchase_modal.container(): - gui.write( - f""" - Please confirm your purchase: - **{dollar_amt * settings.ADDON_CREDITS_PER_DOLLAR:,} credits for ${dollar_amt}**. - """, - className="py-4 d-block text-center", - ) - with gui.div(className="d-flex w-100 justify-content-end"): - if gui.session_state.get("--confirm-purchase"): - success = gui.run_in_thread( - org.subscription.stripe_attempt_addon_purchase, - args=[dollar_amt], - placeholder="Processing payment...", - ) - if success is None: - return - gui.session_state.pop("--confirm-purchase") - if success: - confirm_purchase_modal.close() - else: - gui.error("Payment failed... Please try again.") - return - - if gui.button("Cancel", className="border border-danger text-danger me-2"): - confirm_purchase_modal.close() - gui.button("Buy", type="primary", key="--confirm-purchase") - - -def stripe_addon_checkout_redirect(org: Org, dollar_amt: int): - from routers.account import account_route - from routers.account import payment_processing_route - - line_item = available_subscriptions["addon"]["stripe"].copy() - line_item["quantity"] = dollar_amt * settings.ADDON_CREDITS_PER_DOLLAR - checkout_session = stripe.checkout.Session.create( - line_items=[line_item], - mode="payment", - success_url=get_app_route_url(payment_processing_route), - cancel_url=get_app_route_url(account_route), - customer=org.get_or_create_stripe_customer().id, - invoice_creation={"enabled": True}, - allow_promotion_codes=True, - saved_payment_method_options={ - "payment_method_save": "enabled", - }, - ) - raise gui.RedirectException(checkout_session.url, status_code=303) - - def render_org_creation_view(user: AppUser): gui.write(f"# {icons.company} Create an Org", unsafe_allow_html=True) org_fields = render_org_create_or_edit_form() @@ -525,7 +158,7 @@ def render_org_creation_view(user: AppUser): except ValidationError as e: gui.write(", ".join(e.messages), className="text-danger") else: - gui.experimental_rerun() + gui.rerun() def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: gui.Modal): From 19a24efb2a5c89d0e52252acfa7d409c18828cd5 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 1 Sep 2024 19:06:26 +0530 Subject: [PATCH 054/110] feat: use set_org_subscription instead of set_user_subscription --- payments/webhooks.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/payments/webhooks.py b/payments/webhooks.py index a00466bbc..cedd2b0b3 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -4,11 +4,7 @@ from django.db import transaction from loguru import logger -from app_users.models import ( - AppUser, - PaymentProvider, - TransactionReason, -) +from app_users.models import PaymentProvider, TransactionReason from daras_ai_v2 import paypal from orgs.models import Org from .models import Subscription @@ -67,9 +63,9 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): return set_org_subscription( - provider=cls.PROVIDER, - plan=plan, org_id=pp_sub.custom_id, + plan=plan, + provider=cls.PROVIDER, external_id=pp_sub.id, ) @@ -77,7 +73,7 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription): assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" set_org_subscription( - uid=pp_sub.custom_id, + org_id=pp_sub.custom_id, plan=PricingPlan.STARTER, provider=None, external_id=None, @@ -89,6 +85,8 @@ class StripeWebhookHandler: @classmethod def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice): + from app_users.tasks import save_stripe_default_payment_method + kwargs = {} if invoice.subscription and invoice.subscription_details: kwargs["plan"] = PricingPlan.get_by_key( @@ -122,7 +120,7 @@ def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice): save_stripe_default_payment_method.delay( payment_intent_id=invoice.payment_intent, - uid=uid, + org_id=org_id, amount=amount, charged_amount=charged_amount, reason=reason, @@ -173,9 +171,9 @@ def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscriptio return set_org_subscription( - provider=cls.PROVIDER, - plan=plan, org_id=org_id, + plan=plan, + provider=cls.PROVIDER, external_id=stripe_sub.id, ) From 984602985c4d088df1c6b2eccf7f59224a513ae2 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:43:31 +0530 Subject: [PATCH 055/110] fix: phone number field in org.get_or_create_stripe_customer --- orgs/models.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/orgs/models.py b/orgs/models.py index fa1b471b9..6038d99c9 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +import typing from datetime import timedelta from django.db.models.aggregates import Sum @@ -20,6 +21,9 @@ from gooeysite.bg_db_conn import db_middleware from orgs.tasks import send_auto_accepted_email, send_invitation_email +if typing.TYPE_CHECKING: + from app_users.models import AppUser + ORG_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$") @@ -259,7 +263,7 @@ def get_or_create_stripe_customer(self) -> stripe.Customer: customer = stripe.Customer.create( name=self.created_by.display_name, email=self.created_by.email, - phone=self.created_by.phone, + phone=self.created_by.phone_number, metadata={"uid": self.org_id, "org_id": self.org_id, "id": self.pk}, ) self.stripe_customer_id = customer.id From f99e7231bb32711ddb455324e60b3f771c29751f Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:44:16 +0530 Subject: [PATCH 056/110] add org to list view in transactions admin --- app_users/admin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app_users/admin.py b/app_users/admin.py index 56f325fca..f433f86d2 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -216,6 +216,7 @@ class AppUserTransactionAdmin(admin.ModelAdmin): autocomplete_fields = ["user"] list_display = [ "invoice_id", + "org", "user", "amount", "dollar_amount", From 761d5e618f7031212ef22cff630a31f9320ffc30 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:44:45 +0530 Subject: [PATCH 057/110] fix: types in orgs.models --- orgs/models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/orgs/models.py b/orgs/models.py index 6038d99c9..0b0362503 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -4,8 +4,8 @@ import typing from datetime import timedelta -from django.db.models.aggregates import Sum import stripe +from django.db.models.aggregates import Sum from django.db import models, transaction from django.core.exceptions import ValidationError from django.db.backends.base.schema import logger @@ -21,8 +21,9 @@ from gooeysite.bg_db_conn import db_middleware from orgs.tasks import send_auto_accepted_email, send_invitation_email + if typing.TYPE_CHECKING: - from app_users.models import AppUser + from app_users.models import AppUser, AppUserTransaction ORG_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$") @@ -164,7 +165,7 @@ def get_slug(self): return slugify(self.name) def add_member( - self, user: AppUser, role: OrgRole, invitation: "OrgInvitation | None" = None + self, user: "AppUser", role: OrgRole, invitation: "OrgInvitation | None" = None ): OrgMembership( org=self, @@ -177,7 +178,7 @@ def invite_user( self, *, invitee_email: str, - inviter: AppUser, + inviter: "AppUser", role: OrgRole, auto_accept: bool = False, ) -> "OrgInvitation": From b62a128ac2feffb30ce5bd02e284cccca1de46e2 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:45:16 +0530 Subject: [PATCH 058/110] add transaction migration to org billing migration script --- scripts/migrate_orgs_from_appusers.py | 28 +++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/scripts/migrate_orgs_from_appusers.py b/scripts/migrate_orgs_from_appusers.py index d4e868e30..f4cbc7ec9 100644 --- a/scripts/migrate_orgs_from_appusers.py +++ b/scripts/migrate_orgs_from_appusers.py @@ -1,4 +1,4 @@ -from django.db import IntegrityError +from django.db import IntegrityError, connection from loguru import logger from app_users.models import AppUser @@ -6,12 +6,18 @@ def run(): + migrate_personal_orgs() + migrate_txns() + + +def migrate_personal_orgs(): users_without_personal_org = AppUser.objects.exclude( id__in=Org.objects.filter(is_personal=True).values_list("created_by", flat=True) ) done_count = 0 + logger.info("Creating personal orgs...") for appuser in users_without_personal_org: try: Org.objects.migrate_from_appuser(appuser) @@ -23,4 +29,22 @@ def run(): if done_count % 100 == 0: logger.info(f"Running... {done_count} migrated") - logger.info(f"Done... {done_count} migrated") + logger.info(f"Migrated {done_count} personal orgs...") + + +def migrate_txns(): + with connection.cursor() as cursor: + cursor.execute( + """ + UPDATE app_users_appusertransaction AS txn + SET org_id = orgs_org.id + FROM + app_users_appuser + INNER JOIN orgs_org ON app_users_appuser.id = orgs_org.created_by_id + WHERE + txn.user_id = app_users_appuser.id + AND txn.org_id IS NULL + AND orgs_org.is_personal = true + """ + ) + logger.info(f"Updated {cursor.rowcount} txns with personal orgs") From 60f69380ac418091fbc29a9343bd3e923bd8d66c Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 2 Sep 2024 12:54:59 +0530 Subject: [PATCH 059/110] revert accidental changes to Procfile --- Procfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Procfile b/Procfile index 1766991c6..8711211c2 100644 --- a/Procfile +++ b/Procfile @@ -19,4 +19,4 @@ dashboard: poetry run streamlit run Home.py --server.port 8501 --server.headless celery: poetry run celery -A celeryapp worker -P threads -c 16 -l DEBUG -ui: cd ../gooey-gui/ && env PORT=3000 REDIS_URL=redis://localhost:6379 pnpm run dev +ui: cd ../gooey-gui/; PORT=3000 npm run dev From 782859540e7ccd9cd27319b32a68e8db95553dbd Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 2 Sep 2024 12:59:06 +0530 Subject: [PATCH 060/110] remove unused appuser.get_personal_org --- app_users/models.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/app_users/models.py b/app_users/models.py index 46803c1a8..68193c441 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -162,9 +162,6 @@ def first_name_possesive(self) -> str: else: return name + "'s" - def get_personal_org(self) -> "Org | None": - return self.orgs.filter(is_personal=True).first() - @db_middleware @transaction.atomic def add_balance( From 68df010dda8fb67a57e734a3441bf23a0c58102d Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 2 Sep 2024 13:01:01 +0530 Subject: [PATCH 061/110] set user on txn if org is personal works for now, until we introduce team members for all --- orgs/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/orgs/models.py b/orgs/models.py index 0b0362503..33f0b35de 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -252,6 +252,7 @@ def add_balance( kwargs.setdefault("plan", org.subscription and org.subscription.plan) return AppUserTransaction.objects.create( org=org, + user=org.created_by if org.is_personal else None, invoice_id=invoice_id, amount=amount, end_balance=org.balance, From 811854b9d063f076553cfb220635c09e5505dc9b Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 2 Sep 2024 13:02:34 +0530 Subject: [PATCH 062/110] remove debug change --- daras_ai_v2/send_email.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py index 3c679c6fb..2262624e7 100644 --- a/daras_ai_v2/send_email.py +++ b/daras_ai_v2/send_email.py @@ -82,7 +82,7 @@ def send_email_via_postmark( "outbound", "gooey-ai-workflows", "announcements" ] = "outbound", ): - if is_running_pytest or not settings.POSTMARK_API_TOKEN: + if is_running_pytest: pytest_outbox.append( dict( from_address=from_address, From 87252748f4659db395457cfd06b7c9dd9ae01d89 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 2 Sep 2024 13:33:16 +0530 Subject: [PATCH 063/110] use org.subscription instead of user.subscription --- orgs/models.py | 9 +++ payments/auto_recharge.py | 59 +++++++++---------- payments/tasks.py | 51 ++++++---------- routers/paypal.py | 5 +- templates/auto_recharge_failed_email.html | 15 ----- templates/monthly_budget_reached_email.html | 8 +-- ...spending_notification_threshold_email.html | 6 +- 7 files changed, 66 insertions(+), 87 deletions(-) delete mode 100644 templates/auto_recharge_failed_email.html diff --git a/orgs/models.py b/orgs/models.py index 33f0b35de..4c9b2c8e2 100644 --- a/orgs/models.py +++ b/orgs/models.py @@ -98,6 +98,15 @@ def migrate_from_appuser(self, user: "AppUser") -> Org: is_paying=user.is_paying, ) + def get_dollars_spent_this_month(self) -> float: + today = timezone.now() + cents_spent = self.transactions.filter( + created_at__month=today.month, + created_at__year=today.year, + amount__gt=0, + ).aggregate(total=Sum("charged_amount"))["total"] + return (cents_spent or 0) / 100 + class Org(SafeDeleteModel): _safedelete_policy = SOFT_DELETE_CASCADE diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py index 14d6ba49d..3d07493b5 100644 --- a/payments/auto_recharge.py +++ b/payments/auto_recharge.py @@ -3,12 +3,10 @@ import sentry_sdk from loguru import logger -from app_users.models import AppUser, PaymentProvider +from app_users.models import PaymentProvider from daras_ai_v2.redis_cache import redis_lock -from payments.tasks import ( - send_monthly_budget_reached_email, - send_auto_recharge_failed_email, -) +from orgs.models import Org +from payments.tasks import send_monthly_budget_reached_email class AutoRechargeException(Exception): @@ -30,18 +28,18 @@ class AutoRechargeCooldownException(AutoRechargeException): pass -def should_attempt_auto_recharge(user: AppUser): +def should_attempt_auto_recharge(org: Org): return ( - user.subscription - and user.subscription.auto_recharge_enabled - and user.subscription.payment_provider - and user.balance < user.subscription.auto_recharge_balance_threshold + org.subscription + and org.subscription.auto_recharge_enabled + and org.subscription.payment_provider + and org.balance < org.subscription.auto_recharge_balance_threshold ) -def run_auto_recharge_gracefully(user: AppUser): +def run_auto_recharge_gracefully(org: Org): """ - Wrapper over _auto_recharge_user, that handles exceptions so that it can: + Wrapper over _auto_recharge_org, that handles exceptions so that it can: - log exceptions - send emails when auto-recharge fails - not retry if this is run as a background task @@ -49,50 +47,49 @@ def run_auto_recharge_gracefully(user: AppUser): Meant to be used in conjunction with should_attempt_auto_recharge """ try: - with redis_lock(f"gooey/auto_recharge_user/v1/{user.uid}"): - _auto_recharge_user(user) + with redis_lock(f"gooey/auto_recharge_user/v1/{org.org_id}"): + _auto_recharge_org(org) except AutoRechargeCooldownException as e: logger.info( - f"Rejected auto-recharge because auto-recharge is in cooldown period for user" - f"{user=}, {e=}" + f"Rejected auto-recharge because auto-recharge is in cooldown period for org" + f"{org=}, {e=}" ) except MonthlyBudgetReachedException as e: - send_monthly_budget_reached_email(user) + send_monthly_budget_reached_email(org) logger.info( f"Rejected auto-recharge because user has reached monthly budget" - f"{user=}, spending=${e.spending}, budget=${e.budget}" + f"{org=}, spending=${e.spending}, budget=${e.budget}" ) except Exception as e: traceback.print_exc() sentry_sdk.capture_exception(e) - send_auto_recharge_failed_email(user) -def _auto_recharge_user(user: AppUser): +def _auto_recharge_org(org: Org): """ Returns whether a charge was attempted """ from payments.webhooks import StripeWebhookHandler assert ( - user.subscription.payment_provider == PaymentProvider.STRIPE + org.subscription.payment_provider == PaymentProvider.STRIPE ), "Auto recharge is only supported with Stripe" # check for monthly budget - dollars_spent = user.get_dollars_spent_this_month() + dollars_spent = org.get_dollars_spent_this_month() if ( - dollars_spent + user.subscription.auto_recharge_topup_amount - > user.subscription.monthly_spending_budget + dollars_spent + org.subscription.auto_recharge_topup_amount + > org.subscription.monthly_spending_budget ): raise MonthlyBudgetReachedException( "Performing this top-up would exceed your monthly recharge budget", - budget=user.subscription.monthly_spending_budget, + budget=org.subscription.monthly_spending_budget, spending=dollars_spent, ) try: - invoice = user.subscription.stripe_get_or_create_auto_invoice( - amount_in_dollars=user.subscription.auto_recharge_topup_amount, + invoice = org.subscription.stripe_get_or_create_auto_invoice( + amount_in_dollars=org.subscription.auto_recharge_topup_amount, metadata_key="auto_recharge", ) except Exception as e: @@ -106,9 +103,9 @@ def _auto_recharge_user(user: AppUser): # get default payment method and attempt payment assert invoice.status == "open" # sanity check - pm = user.subscription.stripe_get_default_payment_method() + pm = org.subscription.stripe_get_default_payment_method() if not pm: - logger.warning(f"{user} has no default payment method, cannot auto-recharge") + logger.warning(f"{org} has no default payment method, cannot auto-recharge") return try: @@ -119,4 +116,6 @@ def _auto_recharge_user(user: AppUser): ) from e else: assert invoice_data.paid - StripeWebhookHandler.handle_invoice_paid(uid=user.uid, invoice=invoice_data) + StripeWebhookHandler.handle_invoice_paid( + org_id=org.org_id, invoice=invoice_data + ) diff --git a/payments/tasks.py b/payments/tasks.py index 2070db714..c98b8c12e 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -29,6 +29,7 @@ def send_monthly_spending_notification_email(id: int): "monthly_spending_notification_threshold_email.html" ).render( user=owner.user, + org=org, account_url=get_app_route_url(account_route), ), ) @@ -40,43 +41,27 @@ def send_monthly_spending_notification_email(id: int): org.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) -def send_monthly_budget_reached_email(user: AppUser): +def send_monthly_budget_reached_email(org: Org): from routers.account import account_route - if not user.email: - return + for owner in org.get_owners(): + if not owner.user.email: + continue - email_body = templates.get_template("monthly_budget_reached_email.html").render( - user=user, - account_url=get_app_route_url(account_route), - ) - send_email_via_postmark( - from_address=settings.SUPPORT_EMAIL, - to_address=user.email, - subject="[Gooey.AI] Monthly Budget Reached", - html_body=email_body, - ) + email_body = templates.get_template("monthly_budget_reached_email.html").render( + user=owner.user, + org=org, + account_url=get_app_route_url(account_route), + ) + send_email_via_postmark( + from_address=settings.SUPPORT_EMAIL, + to_address=owner.user.email, + subject="[Gooey.AI] Monthly Budget Reached", + html_body=email_body, + ) # IMPORTANT: always use update_fields=... when updating subscription # info. We don't want to overwrite other changes made to subscription # during the same time - user.subscription.monthly_budget_email_sent_at = timezone.now() - user.subscription.save(update_fields=["monthly_budget_email_sent_at"]) - - -def send_auto_recharge_failed_email(user: AppUser): - from routers.account import account_route - - if not user.email: - return - - email_body = templates.get_template("auto_recharge_failed_email.html").render( - user=user, - account_url=get_app_route_url(account_route), - ) - send_email_via_postmark( - from_address=settings.SUPPORT_EMAIL, - to_address=user.email, - subject="[Gooey.AI] Auto-Recharge failed", - html_body=email_body, - ) + org.subscription.monthly_budget_email_sent_at = timezone.now() + org.subscription.save(update_fields=["monthly_budget_email_sent_at"]) diff --git a/routers/paypal.py b/routers/paypal.py index 48e65a623..86f93ce48 100644 --- a/routers/paypal.py +++ b/routers/paypal.py @@ -126,7 +126,8 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json): if plan.deprecated: return JSONResponse({"error": "Deprecated plan"}, status_code=400) - if request.user.subscription and request.user.subscription.is_paid(): + org, _ = request.user.get_or_create_personal_org() + if org.subscription and org.subscription.is_paid(): return JSONResponse( {"error": "User already has an active subscription"}, status_code=400 ) @@ -134,7 +135,7 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json): paypal_plan_info = plan.get_paypal_plan() pp_subscription = paypal.Subscription.create( plan_id=paypal_plan_info["plan_id"], - custom_id=request.user.uid, + custom_id=org.org_id, plan=paypal_plan_info.get("plan", {}), application_context={ "brand_name": "Gooey.AI", diff --git a/templates/auto_recharge_failed_email.html b/templates/auto_recharge_failed_email.html deleted file mode 100644 index 601fab5d8..000000000 --- a/templates/auto_recharge_failed_email.html +++ /dev/null @@ -1,15 +0,0 @@ -

- Hey, {{ user.first_name() }}! -

- -

- Your Gooey.AI account balance is below your threshold. - An auto-recharge was attempted but failed because {{ reason }}. - Please visit your billing settings. -

- -

- Best Wishes, -
- Gooey.AI Team -

diff --git a/templates/monthly_budget_reached_email.html b/templates/monthly_budget_reached_email.html index 0171e320d..6e467a086 100644 --- a/templates/monthly_budget_reached_email.html +++ b/templates/monthly_budget_reached_email.html @@ -1,6 +1,6 @@ -{% set dollars_spent = user.get_dollars_spent_this_month() %} -{% set monthly_budget = user.subscription.monthly_spending_budget %} -{% set threshold = user.subscription.auto_recharge_balance_threshold %} +{% set dollars_spent = org.get_dollars_spent_this_month() %} +{% set monthly_budget = org.subscription.monthly_spending_budget %} +{% set threshold = org.subscription.auto_recharge_balance_threshold %}

Hey, {{ user.first_name() }}! @@ -18,7 +18,7 @@

    -
  • Credit Balance: {{ user.balance }} credits
  • +
  • Credit Balance: {{ org.balance }} credits
  • Monthly Budget: ${{ monthly_budget }}
  • Spending this month: ${{ dollars_spent }}
diff --git a/templates/monthly_spending_notification_threshold_email.html b/templates/monthly_spending_notification_threshold_email.html index ddf54e223..13be0fae5 100644 --- a/templates/monthly_spending_notification_threshold_email.html +++ b/templates/monthly_spending_notification_threshold_email.html @@ -1,4 +1,4 @@ -{% set dollars_spent = user.get_dollars_spent_this_month() %} +{% set dollars_spent = org.get_dollars_spent_this_month() %}

Hi, {{ user.first_name() }}! @@ -6,11 +6,11 @@

Your spend on Gooey.AI so far this month is ${{ dollars_spent }}, exceeding your notification threshold - of ${{ user.subscription.monthly_spending_notification_threshold }}. + of ${{ org.subscription.monthly_spending_notification_threshold }}.

- Your monthly budget is ${{ user.subscription.monthly_spending_budget }}, after which auto-recharge will be + Your monthly budget is ${{ org.subscription.monthly_spending_budget }}, after which auto-recharge will be paused and all runs / API calls will be rejected.

From 4a598a1cb3028641d509f4bc6e0dc06af18efd00 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 2 Sep 2024 18:37:56 +0530 Subject: [PATCH 064/110] fix: s/user.subscription/org.subscription in billing.py --- daras_ai_v2/billing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index e5a5fa27e..7c99a53bf 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -325,7 +325,7 @@ def change_subscription(org: "Org", new_plan: PricingPlan, **kwargs): if not new_plan.supports_paypal(): gui.error(f"Paypal subscription not available for {new_plan}") - subscription = paypal.Subscription.retrieve(user.subscription.external_id) + subscription = paypal.Subscription.retrieve(org.subscription.external_id) paypal_plan_info = new_plan.get_paypal_plan() approval_url = subscription.update_plan( plan_id=paypal_plan_info["plan_id"], From 5039bea0a1854f75ad78fae815fd05b7ecb20df2 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 2 Sep 2024 19:08:19 +0530 Subject: [PATCH 065/110] fix type for set_org_subscription --- payments/models.py | 6 +++++- payments/webhooks.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/payments/models.py b/payments/models.py index f647bd5a6..ff5be4f69 100644 --- a/payments/models.py +++ b/payments/models.py @@ -89,7 +89,11 @@ def __str__(self): return ret def full_clean( - self, amount: int = None, charged_amount: int = None, *args, **kwargs + self, + amount: int | None = None, + charged_amount: int | None = None, + *args, + **kwargs, ): if self.auto_recharge_enabled: if amount is None: diff --git a/payments/webhooks.py b/payments/webhooks.py index cedd2b0b3..36f0499c7 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -222,8 +222,8 @@ def set_org_subscription( plan: PricingPlan, provider: PaymentProvider | None, external_id: str | None, - amount: int = None, - charged_amount: int = None, + amount: int | None = None, + charged_amount: int | None = None, ) -> Subscription: with transaction.atomic(): org = Org.objects.get_or_create_from_org_id(org_id)[0] From e5385742fd8d5c7f08ee98e6ac0e5605b06395c5 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Mon, 2 Sep 2024 19:14:47 +0530 Subject: [PATCH 066/110] fix paypal handle_invoice_paid: uid -> org_id --- routers/paypal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/routers/paypal.py b/routers/paypal.py index 86f93ce48..3771481cf 100644 --- a/routers/paypal.py +++ b/routers/paypal.py @@ -177,7 +177,7 @@ def _handle_invoice_paid(order_id: str): purchase_unit = order["purchase_units"][0] payment_capture = purchase_unit["payments"]["captures"][0] add_balance_for_payment( - uid=payment_capture["custom_id"], + org_id=payment_capture["custom_id"], amount=int(purchase_unit["items"][0]["quantity"]), invoice_id=payment_capture["id"], payment_provider=PaymentProvider.PAYPAL, From c12f3e8a9187d509307cab1c2fbd4ebf4ee7f5ec Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 3 Sep 2024 01:10:18 +0530 Subject: [PATCH 067/110] update yt-dlp --- poetry.lock | 12 ++++++------ pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9950129f5..a7eb2ca4c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6491,13 +6491,13 @@ multidict = ">=4.0" [[package]] name = "yt-dlp" -version = "2024.7.2" +version = "2024.8.6" description = "A feature-rich command-line audio/video downloader" optional = false python-versions = ">=3.8" files = [ - {file = "yt_dlp-2024.7.2-py3-none-any.whl", hash = "sha256:4f76b48244c783e6ac06e8d7627bcf62cbeb4f6d79ba7e3cfc8249e680d4e691"}, - {file = "yt_dlp-2024.7.2.tar.gz", hash = "sha256:2b0c86b579d4a044eaf3c4b00e3d7b24d82e6e26869fa11c288ea4395b387f41"}, + {file = "yt_dlp-2024.8.6-py3-none-any.whl", hash = "sha256:ab507ff600bd9269ad4d654e309646976778f0e243eaa2f6c3c3214278bb2922"}, + {file = "yt_dlp-2024.8.6.tar.gz", hash = "sha256:e8551f26bc8bf67b99c12373cc87ed2073436c3437e53290878d0f4b4bb1f663"}, ] [package.dependencies] @@ -6511,8 +6511,8 @@ urllib3 = ">=1.26.17,<3" websockets = ">=12.0" [package.extras] -build = ["build", "hatchling", "pip", "setuptools", "wheel"] -curl-cffi = ["curl-cffi (==0.5.10)"] +build = ["build", "hatchling", "pip", "setuptools (>=71.0.2)", "wheel"] +curl-cffi = ["curl-cffi (==0.5.10)", "curl-cffi (>=0.5.10,<0.6.dev0 || ==0.7.*)"] dev = ["autopep8 (>=2.0,<3.0)", "pre-commit", "pytest (>=8.1,<9.0)", "ruff (>=0.5.0,<0.6.0)"] py2exe = ["py2exe (>=0.12)"] pyinstaller = ["pyinstaller (>=6.7.0)"] @@ -6538,4 +6538,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "5834cb5e676e83b492e8aec5d9efa15bed653848f6d356c139917e1a1b01e872" +content-hash = "ac4c7f52c5bb619909f5c1ed8c653aeeb3aa0275e542d716724c5e6ebada2f37" diff --git a/pyproject.toml b/pyproject.toml index 90859e035..1d66681df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ oauth2client = "^4.1.3" tiktoken = "^0.7.0" google-cloud-translate = "^3.12.0" google-cloud-speech = "^2.21.0" -yt-dlp = "^2024.7.2" +yt-dlp = "^2024.8.6" Jinja2 = "^3.1.2" Django = "^4.2" django-phonenumber-field = { extras = ["phonenumberslite"], version = "^7.0.2" } From 34cb49775e8d6684d1b1c5a5a5f5d50f40763733 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 4 Sep 2024 17:55:12 +0530 Subject: [PATCH 068/110] Add logging for InvalidRequestError in billing module --- daras_ai_v2/billing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 0254a89fc..c96bff8c7 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -2,6 +2,7 @@ import sentry_sdk import stripe from django.core.exceptions import ValidationError +from loguru import logger from app_users.models import AppUser, PaymentProvider from daras_ai_v2 import icons, settings, paypal @@ -239,6 +240,7 @@ def _render_plan_action_button( except (stripe.CardError, stripe.InvalidRequestError) as e: if isinstance(e, stripe.InvalidRequestError): sentry_sdk.capture_exception(e) + logger.warning(e) # only handle error if it's related to mandates # cancel current subscription & redirect user to new subscription page From ed05dad3bbe5050474f90fd04ef92e621a16f9b6 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 4 Sep 2024 18:41:18 +0530 Subject: [PATCH 069/110] Rename org -> workspace, remove workspace.id --- app_users/admin.py | 2 +- ..._appusertransaction_workspace_and_more.py} | 8 +- app_users/models.py | 21 +- app_users/tasks.py | 34 ++- ...g.py => 0082_savedrun_billed_workspace.py} | 8 +- bots/models.py | 4 +- celeryapp/tasks.py | 39 ++- daras_ai_v2/base.py | 36 ++- daras_ai_v2/billing.py | 187 ++++++++------ daras_ai_v2/send_email.py | 30 ++- daras_ai_v2/settings.py | 12 +- orgs/admin.py | 111 -------- ...0002_alter_org_unique_together_and_more.py | 35 --- ...e_domain_name_when_not_deleted_and_more.py | 36 --- ..._org_is_paying_org_is_personal_and_more.py | 45 ---- .../0005_org_unique_personal_org_per_user.py | 17 -- orgs/signals.py | 49 ---- payments/auto_recharge.py | 54 ++-- payments/models.py | 14 +- payments/tasks.py | 29 ++- payments/webhooks.py | 82 +++--- routers/account.py | 36 +-- routers/api.py | 1 + routers/paypal.py | 8 +- scripts/migrate_billed_org_for_saved_runs.py | 18 -- ...migrate_billed_workspace_for_saved_runs.py | 23 ++ scripts/migrate_orgs_from_appusers.py | 50 ---- scripts/migrate_workspace_from_appusers.py | 52 ++++ templates/monthly_budget_reached_email.html | 8 +- ...spending_notification_threshold_email.html | 6 +- .../org_invitation_auto_accepted_email.html | 10 +- templates/org_invitation_email.html | 4 +- {orgs => workspaces}/__init__.py | 0 workspaces/admin.py | 155 +++++++++++ {orgs => workspaces}/apps.py | 4 +- .../migrations/0001_initial.py | 55 ++-- .../migrations/0002_alter_workspace_logo.py | 19 ++ {orgs => workspaces}/migrations/__init__.py | 0 {orgs => workspaces}/models.py | 223 +++++++++------- workspaces/signals.py | 50 ++++ {orgs => workspaces}/tasks.py | 24 +- {orgs => workspaces}/tests.py | 0 {orgs => workspaces}/views.py | 244 ++++++++++-------- 43 files changed, 958 insertions(+), 885 deletions(-) rename app_users/migrations/{0020_appusertransaction_org_alter_appusertransaction_user.py => 0020_appusertransaction_workspace_and_more.py} (75%) rename bots/migrations/{0082_savedrun_billed_org.py => 0082_savedrun_billed_workspace.py} (74%) delete mode 100644 orgs/admin.py delete mode 100644 orgs/migrations/0002_alter_org_unique_together_and_more.py delete mode 100644 orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py delete mode 100644 orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py delete mode 100644 orgs/migrations/0005_org_unique_personal_org_per_user.py delete mode 100644 orgs/signals.py delete mode 100644 scripts/migrate_billed_org_for_saved_runs.py create mode 100644 scripts/migrate_billed_workspace_for_saved_runs.py delete mode 100644 scripts/migrate_orgs_from_appusers.py create mode 100644 scripts/migrate_workspace_from_appusers.py rename {orgs => workspaces}/__init__.py (100%) create mode 100644 workspaces/admin.py rename {orgs => workspaces}/apps.py (74%) rename {orgs => workspaces}/migrations/0001_initial.py (56%) create mode 100644 workspaces/migrations/0002_alter_workspace_logo.py rename {orgs => workspaces}/migrations/__init__.py (100%) rename {orgs => workspaces}/models.py (69%) create mode 100644 workspaces/signals.py rename {orgs => workspaces}/tasks.py (67%) rename {orgs => workspaces}/tests.py (100%) rename {orgs => workspaces}/views.py (64%) diff --git a/app_users/admin.py b/app_users/admin.py index f433f86d2..ba05b10e1 100644 --- a/app_users/admin.py +++ b/app_users/admin.py @@ -216,7 +216,7 @@ class AppUserTransactionAdmin(admin.ModelAdmin): autocomplete_fields = ["user"] list_display = [ "invoice_id", - "org", + "workspace", "user", "amount", "dollar_amount", diff --git a/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py b/app_users/migrations/0020_appusertransaction_workspace_and_more.py similarity index 75% rename from app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py rename to app_users/migrations/0020_appusertransaction_workspace_and_more.py index b3e80c708..43b2d32d2 100644 --- a/app_users/migrations/0020_appusertransaction_org_alter_appusertransaction_user.py +++ b/app_users/migrations/0020_appusertransaction_workspace_and_more.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.7 on 2024-08-13 14:34 +# Generated by Django 4.2.7 on 2024-09-02 14:07 from django.db import migrations, models import django.db.models.deletion @@ -7,15 +7,15 @@ class Migration(migrations.Migration): dependencies = [ - ('orgs', '0005_org_unique_personal_org_per_user'), + ('workspaces', '0001_initial'), ('app_users', '0019_alter_appusertransaction_reason'), ] operations = [ migrations.AddField( model_name='appusertransaction', - name='org', - field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='transactions', to='orgs.org'), + name='workspace', + field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='transactions', to='workspaces.workspace'), ), migrations.AlterField( model_name='appusertransaction', diff --git a/app_users/models.py b/app_users/models.py index 68193c441..66ce232f3 100644 --- a/app_users/models.py +++ b/app_users/models.py @@ -16,7 +16,7 @@ from payments.plans import PricingPlan if typing.TYPE_CHECKING: - from orgs.models import Org + from workspaces.models import Workspace class AppUserQuerySet(models.QuerySet): @@ -249,16 +249,13 @@ def copy_from_firebase_user(self, user: auth.UserRecord) -> "AppUser": return self - def get_or_create_personal_org(self) -> tuple["Org", bool]: - from orgs.models import Org, OrgMembership + def get_or_create_personal_workspace(self) -> tuple["Workspace", bool]: + from workspaces.models import Workspace - org_membership: OrgMembership | None = self.org_memberships.filter( - org__is_personal=True, org__created_by=self - ).first() - if org_membership: - return org_membership.org, False - else: - return Org.objects.migrate_from_appuser(self), True + try: + return Workspace.objects.get(is_personal=True, created_by=self), False + except Workspace.DoesNotExist: + return Workspace.objects.migrate_from_appuser(self), True def get_or_create_stripe_customer(self) -> stripe.Customer: customer = self.search_stripe_customer() @@ -322,8 +319,8 @@ class AppUserTransaction(models.Model): related_name="transactions", null=True, ) - org = models.ForeignKey( - "orgs.Org", + workspace = models.ForeignKey( + "workspaces.Workspace", on_delete=models.SET_NULL, related_name="transactions", null=True, diff --git a/app_users/tasks.py b/app_users/tasks.py index b1d893196..9bd85eb7a 100644 --- a/app_users/tasks.py +++ b/app_users/tasks.py @@ -3,16 +3,16 @@ from app_users.models import PaymentProvider, TransactionReason from celeryapp.celeryconfig import app -from payments.models import Subscription from payments.plans import PricingPlan -from payments.webhooks import set_org_subscription +from payments.webhooks import set_workspace_subscription +from workspaces.models import Workspace @app.task def save_stripe_default_payment_method( *, + workspace_id_or_uid: int | str, payment_intent_id: str, - org_id: str, amount: int, charged_amount: int, reason: TransactionReason, @@ -36,16 +36,24 @@ def save_stripe_default_payment_method( invoice_settings=dict(default_payment_method=pm), ) - # if user doesn't already have a active billing/autorecharge info, so we don't need to do anything - # set user's subscription to the free plan - if ( - reason == TransactionReason.ADDON - and not Subscription.objects.filter( - org__org_id=org_id, payment_provider__isnull=False - ).exists() - ): - set_org_subscription( - org_id=org_id, + # if user already has a subscription with payment info, we do nothing + # otherwise, we set the user's subscription to the free plan + if reason == TransactionReason.ADDON: + try: + workspace = Workspace.objects.select_related("subscription").get( + int(workspace_id_or_uid) + ) + except (ValueError, Workspace.DoesNotExist): + workspace, _ = Workspace.objects.get_or_create_from_uid(workspace_id_or_uid) + + if workspace.subscription and ( + workspace.subscription.is_paid() or workspace.subscription.payment_provider + ): + # already has a subscription + return + + set_workspace_subscription( + workspace_id_or_uid=workspace.id, plan=PricingPlan.STARTER, provider=PaymentProvider.STRIPE, external_id=None, diff --git a/bots/migrations/0082_savedrun_billed_org.py b/bots/migrations/0082_savedrun_billed_workspace.py similarity index 74% rename from bots/migrations/0082_savedrun_billed_org.py rename to bots/migrations/0082_savedrun_billed_workspace.py index 208f46dcc..502c15269 100644 --- a/bots/migrations/0082_savedrun_billed_org.py +++ b/bots/migrations/0082_savedrun_billed_workspace.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.7 on 2024-08-30 08:10 +# Generated by Django 4.2.7 on 2024-09-02 14:08 from django.db import migrations, models import django.db.models.deletion @@ -7,14 +7,14 @@ class Migration(migrations.Migration): dependencies = [ - ('orgs', '0005_org_unique_personal_org_per_user'), + ('workspaces', '0001_initial'), ('bots', '0081_remove_conversation_bots_conver_bot_int_73ac7b_idx_and_more'), ] operations = [ migrations.AddField( model_name='savedrun', - name='billed_org', - field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='billed_runs', to='orgs.org'), + name='billed_workspace', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='billed_runs', to='workspaces.workspace'), ), ] diff --git a/bots/models.py b/bots/models.py index a6163ee1c..fcdd345cc 100644 --- a/bots/models.py +++ b/bots/models.py @@ -212,8 +212,8 @@ class SavedRun(models.Model): ) run_id = models.CharField(max_length=128, default=None, null=True, blank=True) uid = models.CharField(max_length=128, default=None, null=True, blank=True) - billed_org = models.ForeignKey( - "orgs.Org", + billed_workspace = models.ForeignKey( + "workspaces.Workspace", on_delete=models.SET_NULL, null=True, blank=True, diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index c3ae75b52..d3a54549b 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -29,6 +29,10 @@ run_auto_recharge_gracefully, ) +if typing.TYPE_CHECKING: + from workspaces.models import Workspace + + DEFAULT_RUN_STATUS = "Running..." @@ -121,15 +125,14 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False @app.task def post_runner_tasks(saved_run_id: int): sr = SavedRun.objects.get(id=saved_run_id) - user = AppUser.objects.get(uid=sr.uid) if not sr.is_api_call: send_email_on_completion(sr) - if should_attempt_auto_recharge(user): - run_auto_recharge_gracefully(user) + if should_attempt_auto_recharge(sr.billed_workspace): + run_auto_recharge_gracefully(sr.billed_workspace) - run_low_balance_email_check(user) + run_low_balance_email_check(sr.billed_workspace) def err_msg_for_exc(e: Exception): @@ -158,15 +161,18 @@ def err_msg_for_exc(e: Exception): return f"{type(e).__name__}: {e}" -def run_low_balance_email_check(user: AppUser): +def run_low_balance_email_check(workspace: Workspace): # don't send email if feature is disabled if not settings.LOW_BALANCE_EMAIL_ENABLED: return # don't send email if user is not paying or has enough balance - if not user.is_paying or user.balance > settings.LOW_BALANCE_EMAIL_CREDITS: + if ( + not workspace.is_paying + or workspace.balance > settings.LOW_BALANCE_EMAIL_CREDITS + ): return last_purchase = ( - AppUserTransaction.objects.filter(user=user, amount__gt=0) + AppUserTransaction.objects.filter(workspace=workspace, amount__gt=0) .order_by("-created_at") .first() ) @@ -176,22 +182,27 @@ def run_low_balance_email_check(user: AppUser): # send email if user has not been sent email in last X days or last purchase was after last email sent if ( # user has not been sent any email - not user.low_balance_email_sent_at + not workspace.low_balance_email_sent_at # user was sent email before X days - or (user.low_balance_email_sent_at < email_date_cutoff) + or (workspace.low_balance_email_sent_at < email_date_cutoff) # user has made a purchase after last email sent - or (last_purchase and last_purchase.created_at > user.low_balance_email_sent_at) + or ( + last_purchase + and last_purchase.created_at > workspace.low_balance_email_sent_at + ) ): # calculate total credits consumed in last X days total_credits_consumed = abs( AppUserTransaction.objects.filter( - user=user, amount__lt=0, created_at__gte=email_date_cutoff + workspace=workspace, amount__lt=0, created_at__gte=email_date_cutoff ).aggregate(Sum("amount"))["amount__sum"] or 0 ) - send_low_balance_email(user=user, total_credits_consumed=total_credits_consumed) - user.low_balance_email_sent_at = timezone.now() - user.save(update_fields=["low_balance_email_sent_at"]) + send_low_balance_email( + workspace=workspace, total_credits_consumed=total_credits_consumed + ) + workspace.low_balance_email_sent_at = timezone.now() + workspace.save(update_fields=["low_balance_email_sent_at"]) def send_email_on_completion(sr: SavedRun): diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index f37a284bb..a7c61367d 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -83,6 +83,10 @@ from routers.account import AccountTabs from routers.root import RecipeTabs +if typing.TYPE_CHECKING: + from workspaces.models import Workspace + + DEFAULT_META_IMG = ( # Small "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ec2100aa-1f6e-11ef-ba0b-02420a000159/Main.jpg" @@ -571,7 +575,7 @@ def _render_publish_modal( gui.radio( "", options=[ - 'Anyone at my org (coming soon)' + 'Anyone at my workspace (coming soon)' ], disabled=True, checked_by_default=False, @@ -1204,6 +1208,12 @@ def create_published_run( visibility=visibility, ) + def get_current_workspace(self) -> "Workspace": + assert self.request.user + + workspace, _ = self.request.user.get_or_create_personal_workspace() + return workspace + def duplicate_published_run( self, published_run: PublishedRun, @@ -1599,7 +1609,9 @@ def submit_and_redirect(self): def on_submit(self): try: - sr = self.create_new_run(enable_rate_limits=True) + sr = self.create_new_run( + enable_rate_limits=True, billed_workspace=self.get_current_workspace() + ) except ValidationError as e: gui.session_state[StateKeys.run_status] = None gui.session_state[StateKeys.error_msg] = str(e) @@ -1612,7 +1624,7 @@ def on_submit(self): return sr def should_submit_after_login(self) -> bool: - return ( + return bool( gui.get_query_params().get(SUBMIT_AFTER_LOGIN_Q) and self.request and self.request.user @@ -2084,18 +2096,18 @@ def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict): assert self.request, "request must be set to check credits" assert self.request.user, "request.user must be set to check credits" - user = self.request.user price = self.get_price_roundoff(state) + workspace, _ = self.request.user.get_or_create_personal_workspace() - if user.balance >= price: + if workspace.balance >= price: return - if should_attempt_auto_recharge(user): + if should_attempt_auto_recharge(workspace): yield "Low balance detected. Recharging..." - run_auto_recharge_gracefully(user) - user.refresh_from_db() + run_auto_recharge_gracefully(workspace) + workspace.refresh_from_db() - if user.balance >= price: + if workspace.balance >= price: return raise InsufficientCredits(self.request.user, sr) @@ -2106,8 +2118,8 @@ def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]: ), "request.user must be set to deduct credits" amount = self.get_price_roundoff(state) - org, _ = self.request.user.get_or_create_personal_org() - txn = org.add_balance(-amount, f"gooey_in_{uuid.uuid1()}") + workspace, _ = self.request.user.get_or_create_personal_workspace() + txn = workspace.add_balance(-amount, f"gooey_in_{uuid.uuid1()}") return txn, amount def get_price_roundoff(self, state: dict) -> int: @@ -2204,7 +2216,7 @@ def get_cost_note(self) -> str | None: @classmethod def is_user_admin(cls, user: AppUser) -> bool: email = user.email - return email and email in settings.ADMIN_EMAILS + return bool(email and email in settings.ADMIN_EMAILS) def is_current_user_admin(self) -> bool: if not self.request or not self.request.user: diff --git a/daras_ai_v2/billing.py b/daras_ai_v2/billing.py index 7c99a53bf..d8b3bd44b 100644 --- a/daras_ai_v2/billing.py +++ b/daras_ai_v2/billing.py @@ -13,41 +13,41 @@ from daras_ai_v2.user_date_widgets import render_local_date_attrs from payments.models import PaymentMethodSummary from payments.plans import PricingPlan -from payments.webhooks import StripeWebhookHandler, set_org_subscription +from payments.webhooks import StripeWebhookHandler, set_workspace_subscription from scripts.migrate_existing_subscriptions import available_subscriptions if typing.TYPE_CHECKING: - from orgs.models import Org + from workspaces.models import Workspace rounded_border = "w-100 border shadow-sm rounded py-4 px-3" -def billing_page(org: "Org"): +def billing_page(workspace: "Workspace"): render_payments_setup() - if org.subscription and org.subscription.is_paid(): - render_current_plan(org) + if workspace.subscription and workspace.subscription.is_paid(): + render_current_plan(workspace) with gui.div(className="my-5"): - render_credit_balance(org) + render_credit_balance(workspace) with gui.div(className="my-5"): - selected_payment_provider = render_all_plans(org) + selected_payment_provider = render_all_plans(workspace) with gui.div(className="my-5"): - render_addon_section(org, selected_payment_provider) + render_addon_section(workspace, selected_payment_provider) - if org.subscription: - if org.subscription.payment_provider == PaymentProvider.STRIPE: + if workspace.subscription: + if workspace.subscription.payment_provider == PaymentProvider.STRIPE: with gui.div(className="my-5"): - render_auto_recharge_section(org) + render_auto_recharge_section(workspace) with gui.div(className="my-5"): - render_payment_information(org) + render_payment_information(workspace) with gui.div(className="my-5"): - render_billing_history(org) + render_billing_history(workspace) def render_payments_setup(): @@ -61,10 +61,10 @@ def render_payments_setup(): ) -def render_current_plan(org: "Org"): - plan = PricingPlan.from_sub(org.subscription) - if org.subscription.payment_provider: - provider = PaymentProvider(org.subscription.payment_provider) +def render_current_plan(workspace: "Workspace"): + plan = PricingPlan.from_sub(workspace.subscription) + if workspace.subscription.payment_provider: + provider = PaymentProvider(workspace.subscription.payment_provider) else: provider = None @@ -82,7 +82,7 @@ def render_current_plan(org: "Org"): with right, gui.div(className="d-flex align-items-center gap-1"): if provider and ( next_invoice_ts := gui.run_in_thread( - org.subscription.get_next_invoice_timestamp, cache=True + workspace.subscription.get_next_invoice_timestamp, cache=True ) ): gui.html("Next invoice on ") @@ -118,17 +118,17 @@ def render_current_plan(org: "Org"): ) -def render_credit_balance(org: "Org"): - gui.write(f"## Credit Balance: {org.balance:,}") +def render_credit_balance(workspace: "Workspace"): + gui.write(f"## Credit Balance: {workspace.balance:,}") gui.caption( "Every time you submit a workflow or make an API call, we deduct credits from your account." ) -def render_all_plans(org: "Org") -> PaymentProvider: +def render_all_plans(workspace: "Workspace") -> PaymentProvider: current_plan = ( - PricingPlan.from_sub(org.subscription) - if org.subscription + PricingPlan.from_sub(workspace.subscription) + if workspace.subscription else PricingPlan.STARTER ) all_plans = [plan for plan in PricingPlan if not plan.deprecated] @@ -136,8 +136,8 @@ def render_all_plans(org: "Org") -> PaymentProvider: gui.write("## All Plans") plans_div = gui.div(className="mb-1") - if org.subscription and org.subscription.payment_provider: - selected_payment_provider = org.subscription.payment_provider + if workspace.subscription and workspace.subscription.payment_provider: + selected_payment_provider = workspace.subscription.payment_provider else: with gui.div(): selected_payment_provider = PaymentProvider[ @@ -155,7 +155,7 @@ def _render_plan(plan: PricingPlan): ): _render_plan_details(plan) _render_plan_action_button( - org=org, + workspace=workspace, plan=plan, current_plan=current_plan, payment_provider=selected_payment_provider, @@ -193,7 +193,7 @@ def _render_plan_details(plan: PricingPlan): def _render_plan_action_button( - org: "Org", + workspace: "Workspace", plan: PricingPlan, current_plan: PricingPlan, payment_provider: PaymentProvider | None, @@ -207,10 +207,13 @@ def _render_plan_action_button( className=btn_classes + " btn btn-theme btn-primary", ): gui.html("Contact Us") - elif org.subscription and org.subscription.plan == PricingPlan.ENTERPRISE.db_value: + elif ( + workspace.subscription + and workspace.subscription.plan == PricingPlan.ENTERPRISE.db_value + ): # don't show upgrade/downgrade buttons for enterprise customers return - elif org.subscription and org.subscription.is_paid(): + elif workspace.subscription and workspace.subscription.is_paid(): # subscription exists, show upgrade/downgrade button if plan.credits > current_plan.credits: modal, confirmed = confirm_modal( @@ -232,7 +235,7 @@ def _render_plan_action_button( modal.open() if confirmed: change_subscription( - org, + workspace, plan, # when upgrading, charge the full new amount today: https://docs.stripe.com/billing/subscriptions/billing-cycle#reset-the-billing-cycle-to-the-current-time billing_cycle_anchor="now", @@ -254,11 +257,11 @@ def _render_plan_action_button( ): modal.open() if confirmed: - change_subscription(org, plan) + change_subscription(workspace, plan) else: assert payment_provider is not None # for sanity _render_create_subscription_button( - org=org, + workspace=workspace, plan=plan, payment_provider=payment_provider, ) @@ -266,13 +269,13 @@ def _render_plan_action_button( def _render_create_subscription_button( *, - org: "Org", + workspace: "Workspace", plan: PricingPlan, payment_provider: PaymentProvider, ): match payment_provider: case PaymentProvider.STRIPE: - render_stripe_subscription_button(org=org, plan=plan) + render_stripe_subscription_button(workspace=workspace, plan=plan) case PaymentProvider.PAYPAL: render_paypal_subscription_button(plan=plan) @@ -284,27 +287,29 @@ def fmt_price(plan: PricingPlan) -> str: return "Free" -def change_subscription(org: "Org", new_plan: PricingPlan, **kwargs): +def change_subscription(workspace: "Workspace", new_plan: PricingPlan, **kwargs): from routers.account import account_route from routers.account import payment_processing_route - current_plan = PricingPlan.from_sub(org.subscription) + current_plan = PricingPlan.from_sub(workspace.subscription) if new_plan == current_plan: raise gui.RedirectException(get_app_route_url(account_route), status_code=303) if new_plan == PricingPlan.STARTER: - org.subscription.cancel() + workspace.subscription.cancel() raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 ) - match org.subscription.payment_provider: + match workspace.subscription.payment_provider: case PaymentProvider.STRIPE: if not new_plan.supports_stripe(): gui.error(f"Stripe subscription not available for {new_plan}") - subscription = stripe.Subscription.retrieve(org.subscription.external_id) + subscription = stripe.Subscription.retrieve( + workspace.subscription.external_id + ) stripe.Subscription.modify( subscription.id, items=[ @@ -325,7 +330,9 @@ def change_subscription(org: "Org", new_plan: PricingPlan, **kwargs): if not new_plan.supports_paypal(): gui.error(f"Paypal subscription not available for {new_plan}") - subscription = paypal.Subscription.retrieve(org.subscription.external_id) + subscription = paypal.Subscription.retrieve( + workspace.subscription.external_id + ) paypal_plan_info = new_plan.get_paypal_plan() approval_url = subscription.update_plan( plan_id=paypal_plan_info["plan_id"], @@ -348,20 +355,22 @@ def payment_provider_radio(**props) -> str | None: ) -def render_addon_section(org: "Org", selected_payment_provider: PaymentProvider): - if org.subscription: +def render_addon_section( + workspace: "Workspace", selected_payment_provider: PaymentProvider +): + if workspace.subscription: gui.write("# Purchase More Credits") else: gui.write("# Purchase Credits") gui.caption(f"Buy more credits. $1 per {settings.ADDON_CREDITS_PER_DOLLAR} credits") - if org.subscription and org.subscription.payment_provider: - provider = PaymentProvider(org.subscription.payment_provider) + if workspace.subscription and workspace.subscription.payment_provider: + provider = PaymentProvider(workspace.subscription.payment_provider) else: provider = selected_payment_provider match provider: case PaymentProvider.STRIPE: - render_stripe_addon_buttons(org) + render_stripe_addon_buttons(workspace) case PaymentProvider.PAYPAL: render_paypal_addon_buttons() @@ -385,8 +394,8 @@ def render_paypal_addon_buttons(): gui.div(id="paypal-result-message") -def render_stripe_addon_buttons(org: "Org"): - if not (org.subscription and org.subscription.payment_provider): +def render_stripe_addon_buttons(workspace: "Workspace"): + if not (workspace.subscription and workspace.subscription.payment_provider): save_pm = gui.checkbox( "Save payment method for future purchases & auto-recharge", value=True ) @@ -394,10 +403,10 @@ def render_stripe_addon_buttons(org: "Org"): save_pm = True for dollat_amt in settings.ADDON_AMOUNT_CHOICES: - render_stripe_addon_button(dollat_amt, org, save_pm) + render_stripe_addon_button(dollat_amt, workspace, save_pm) -def render_stripe_addon_button(dollat_amt: int, org: "Org", save_pm: bool): +def render_stripe_addon_button(dollat_amt: int, workspace: "Workspace", save_pm: bool): modal, confirmed = confirm_modal( title="Purchase Credits", key=f"--addon-modal-{dollat_amt}", @@ -411,14 +420,17 @@ def render_stripe_addon_button(dollat_amt: int, org: "Org", save_pm: bool): ) if gui.button(f"${dollat_amt:,}", type="primary"): - if org.subscription and org.subscription.stripe_get_default_payment_method(): + if ( + workspace.subscription + and workspace.subscription.stripe_get_default_payment_method() + ): modal.open() else: - stripe_addon_checkout_redirect(org, dollat_amt, save_pm) + stripe_addon_checkout_redirect(workspace, dollat_amt, save_pm) if confirmed: success = gui.run_in_thread( - org.subscription.stripe_attempt_addon_purchase, + workspace.subscription.stripe_attempt_addon_purchase, args=[dollat_amt], placeholder="", ) @@ -429,10 +441,12 @@ def render_stripe_addon_button(dollat_amt: int, org: "Org", save_pm: bool): modal.close() else: # fallback to stripe checkout flow if the auto payment failed - stripe_addon_checkout_redirect(org, dollat_amt, save_pm) + stripe_addon_checkout_redirect(workspace, dollat_amt, save_pm) -def stripe_addon_checkout_redirect(org: "Org", dollat_amt: int, save_pm: bool): +def stripe_addon_checkout_redirect( + workspace: "Workspace", dollat_amt: int, save_pm: bool +): from routers.account import account_route from routers.account import payment_processing_route @@ -448,7 +462,7 @@ def stripe_addon_checkout_redirect(org: "Org", dollat_amt: int, save_pm: bool): mode="payment", success_url=get_app_route_url(payment_processing_route), cancel_url=get_app_route_url(account_route), - customer=org.get_or_create_stripe_customer(), + customer=workspace.get_or_create_stripe_customer(), invoice_creation={"enabled": True}, allow_promotion_codes=True, **kwargs, @@ -458,7 +472,7 @@ def stripe_addon_checkout_redirect(org: "Org", dollat_amt: int, save_pm: bool): def render_stripe_subscription_button( *, - org: "Org", + workspace: "Workspace", plan: PricingPlan, ): if not plan.supports_stripe(): @@ -486,30 +500,33 @@ def render_stripe_subscription_button( key=f"--change-sub-{plan.key}", type="primary", ): - if org.subscription and org.subscription.stripe_get_default_payment_method(): + if ( + workspace.subscription + and workspace.subscription.stripe_get_default_payment_method() + ): modal.open() else: - stripe_subscription_create(org=org, plan=plan) + stripe_subscription_create(workspace=workspace, plan=plan) if confirmed: - stripe_subscription_create(org=org, plan=plan) + stripe_subscription_create(workspace=workspace, plan=plan) -def stripe_subscription_create(org: "Org", plan: PricingPlan): +def stripe_subscription_create(workspace: "Workspace", plan: PricingPlan): from routers.account import account_route from routers.account import payment_processing_route - if org.subscription and org.subscription.is_paid(): + if workspace.subscription and workspace.subscription.is_paid(): # sanity check: already subscribed to some plan gui.rerun() # check for existing subscriptions on stripe - customer = org.get_or_create_stripe_customer() + customer = workspace.get_or_create_stripe_customer() for sub in stripe.Subscription.list( customer=customer, status="active", limit=1 ).data: StripeWebhookHandler.handle_subscription_updated( - org_id=org.org_id, stripe_sub=sub + workspace_id_or_uid=workspace.id, stripe_sub=sub ) raise gui.RedirectException( get_app_route_url(payment_processing_route), status_code=303 @@ -517,7 +534,10 @@ def stripe_subscription_create(org: "Org", plan: PricingPlan): # try to directly create the subscription without checkout metadata = {settings.STRIPE_USER_SUBSCRIPTION_METADATA_FIELD: plan.key} - pm = org.subscription and org.subscription.stripe_get_default_payment_method() + pm = ( + workspace.subscription + and workspace.subscription.stripe_get_default_payment_method() + ) line_items = [plan.get_stripe_line_item()] if pm: sub = stripe.Subscription.create( @@ -567,12 +587,12 @@ def render_paypal_subscription_button( ) -def render_payment_information(org: "Org"): - if not org.subscription: +def render_payment_information(workspace: "Workspace"): + if not workspace.subscription: return pm_summary = gui.run_in_thread( - org.subscription.get_payment_method_summary, cache=True + workspace.subscription.get_payment_method_summary, cache=True ) if not pm_summary: return @@ -584,7 +604,7 @@ def render_payment_information(org: "Org"): gui.write("**Pay via**") with col2: provider = PaymentProvider( - org.subscription.payment_provider or PaymentProvider.STRIPE + workspace.subscription.payment_provider or PaymentProvider.STRIPE ) gui.write(provider.label) with col3: @@ -592,7 +612,7 @@ def render_payment_information(org: "Org"): f"{icons.edit} Edit", type="link", key="manage-payment-provider" ): raise gui.RedirectException( - org.subscription.get_external_management_url() + workspace.subscription.get_external_management_url() ) pm_summary = PaymentMethodSummary(*pm_summary) @@ -612,7 +632,7 @@ def render_payment_information(org: "Org"): if gui.button( f"{icons.edit} Edit", type="link", key="edit-payment-method" ): - change_payment_method(org) + change_payment_method(workspace) if pm_summary.billing_email: col1, col2, _ = gui.columns(3, responsive=False) @@ -640,13 +660,16 @@ def render_payment_information(org: "Org"): ): modal.open() if confirmed: - set_org_subscription( - org_id=org.org_id, + set_workspace_subscription( + workspace_id_or_uid=workspace.id, plan=PricingPlan.STARTER, provider=None, external_id=None, ) - pm = org.subscription and org.subscription.stripe_get_default_payment_method() + pm = ( + workspace.subscription + and workspace.subscription.stripe_get_default_payment_method() + ) if pm: pm.detach() raise gui.RedirectException( @@ -654,18 +677,18 @@ def render_payment_information(org: "Org"): ) -def change_payment_method(org: "Org"): +def change_payment_method(workspace: "Workspace"): from routers.account import payment_processing_route from routers.account import account_route - match org.subscription.payment_provider: + match workspace.subscription.payment_provider: case PaymentProvider.STRIPE: session = stripe.checkout.Session.create( mode="setup", currency="usd", - customer=org.get_or_create_stripe_customer(), + customer=workspace.get_or_create_stripe_customer(), setup_intent_data={ - "metadata": {"subscription_id": org.subscription.external_id}, + "metadata": {"subscription_id": workspace.subscription.external_id}, }, success_url=get_app_route_url(payment_processing_route), cancel_url=get_app_route_url(account_route), @@ -679,11 +702,11 @@ def format_card_brand(brand: str) -> str: return icons.card_icons.get(brand.lower(), brand.capitalize()) -def render_billing_history(org: "Org", limit: int = 50): +def render_billing_history(workspace: "Workspace", limit: int = 50): import pandas as pd txns = AppUserTransaction.objects.filter( - org=org, + workspace=workspace, amount__gt=0, ).order_by("-created_at") if not txns: @@ -708,9 +731,9 @@ def render_billing_history(org: "Org", limit: int = 50): gui.caption(f"Showing only the most recent {limit} transactions.") -def render_auto_recharge_section(org: "Org"): - assert org.subscription - subscription = org.subscription +def render_auto_recharge_section(workspace: "Workspace"): + assert workspace.subscription + subscription = workspace.subscription gui.write("## Auto Recharge & Limits") with gui.div(className="h4"): diff --git a/daras_ai_v2/send_email.py b/daras_ai_v2/send_email.py index 2262624e7..11799f86a 100644 --- a/daras_ai_v2/send_email.py +++ b/daras_ai_v2/send_email.py @@ -11,6 +11,7 @@ if typing.TYPE_CHECKING: from app_users.models import AppUser + from workspaces.models import Workspace def send_reported_run_email( @@ -44,25 +45,26 @@ def send_reported_run_email( def send_low_balance_email( *, - user: "AppUser", + workspace: "Workspace", total_credits_consumed: int, ): from routers.account import account_route recipeints = "support@gooey.ai, devs@gooey.ai" - html_body = templates.get_template("low_balance_email.html").render( - user=user, - url=get_app_route_url(account_route), - total_credits_consumed=total_credits_consumed, - settings=settings, - ) - send_email_via_postmark( - from_address=settings.SUPPORT_EMAIL, - to_address=user.email or recipeints, - bcc=recipeints, - subject="Your Gooey.AI credit balance is low", - html_body=html_body, - ) + for owner in workspace.get_owners(): + html_body = templates.get_template("low_balance_email.html").render( + user=owner.user, + url=get_app_route_url(account_route), + total_credits_consumed=total_credits_consumed, + settings=settings, + ) + send_email_via_postmark( + from_address=settings.SUPPORT_EMAIL, + to_address=owner.user.email or recipeints, + bcc=recipeints, + subject="Your Gooey.AI credit balance is low", + html_body=html_body, + ) is_running_pytest = "pytest" in sys.modules diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index c73de7223..9bd1d419d 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -63,7 +63,7 @@ "handles", "payments", "functions", - "orgs", + "workspaces", ] MIDDLEWARE = [ @@ -288,7 +288,7 @@ EMAIL_USER_FREE_CREDITS = config("EMAIL_USER_FREE_CREDITS", 0, cast=int) ANON_USER_FREE_CREDITS = config("ANON_USER_FREE_CREDITS", 25, cast=int) LOGIN_USER_FREE_CREDITS = config("LOGIN_USER_FREE_CREDITS", 500, cast=int) -FIRST_ORG_FREE_CREDITS = config("ORG_FREE_CREDITS", 500, cast=int) +FIRST_WORKSPACE_FREE_CREDITS = config("WORKSPACE_FREE_CREDITS", 500, cast=int) ADDON_CREDITS_PER_DOLLAR = config("ADDON_CREDITS_PER_DOLLAR", 100, cast=int) ADDON_AMOUNT_CHOICES = [10, 30, 50, 100, 300, 500, 1000] # USD @@ -399,9 +399,11 @@ TWILIO_API_KEY_SID = config("TWILIO_API_KEY_SID", "") TWILIO_API_KEY_SECRET = config("TWILIO_API_KEY_SECRET", "") -ORG_INVITATION_EXPIRY_DAYS = config("ORG_INVITATIONS_EXPIRY_IN_DAYS", 10, cast=int) -ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL = config( - "ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL", 60 * 60 * 24, cast=int # 24 hours +WORKSPACE_INVITATION_EXPIRY_DAYS = config( + "WORKSPACE_INVITATIONS_EXPIRY_IN_DAYS", 10, cast=int +) +WORKSPACE_INVITATION_EMAIL_COOLDOWN_INTERVAL = config( + "WORKSPACE_INVITATION_EMAIL_COOLDOWN_INTERVAL", 60 * 60 * 24, cast=int # 24 hours ) SCRAPING_PROXY_HOST = config("SCRAPING_PROXY_HOST", "") diff --git a/orgs/admin.py b/orgs/admin.py deleted file mode 100644 index 370ca4c4e..000000000 --- a/orgs/admin.py +++ /dev/null @@ -1,111 +0,0 @@ -from django.contrib import admin -from safedelete.admin import SafeDeleteAdmin, SafeDeleteAdminFilter - -from bots.admin_links import change_obj_url -from orgs.models import Org, OrgMembership, OrgInvitation - - -class OrgMembershipInline(admin.TabularInline): - model = OrgMembership - extra = 0 - show_change_link = True - fields = ["user", "role", "created_at", "updated_at"] - readonly_fields = ["created_at", "updated_at"] - ordering = ["-created_at"] - can_delete = False - show_change_link = True - - -class OrgInvitationInline(admin.TabularInline): - model = OrgInvitation - extra = 0 - show_change_link = True - fields = [ - "invitee_email", - "inviter", - "status", - "auto_accepted", - "created_at", - "updated_at", - ] - readonly_fields = ["auto_accepted", "created_at", "updated_at"] - ordering = ["status", "-created_at"] - can_delete = False - show_change_link = True - - -@admin.register(Org) -class OrgAdmin(SafeDeleteAdmin): - list_display = [ - "name", - "domain_name", - "created_at", - "updated_at", - ] + list(SafeDeleteAdmin.list_display) - list_filter = [SafeDeleteAdminFilter] + list(SafeDeleteAdmin.list_filter) - fields = [ - "name", - "domain_name", - "created_by", - "is_personal", - "created_at", - "updated_at", - ] - search_fields = ["name", "domain_name"] - readonly_fields = ["is_personal", "created_at", "updated_at"] - inlines = [OrgMembershipInline, OrgInvitationInline] - ordering = ["-created_at"] - - -@admin.register(OrgMembership) -class OrgMembershipAdmin(SafeDeleteAdmin): - list_display = [ - "user", - "org", - "role", - "created_at", - "updated_at", - ] + list(SafeDeleteAdmin.list_display) - list_filter = ["org", "role", SafeDeleteAdminFilter] + list( - SafeDeleteAdmin.list_filter - ) - - def get_readonly_fields( - self, request: "HttpRequest", obj: OrgMembership | None = None - ) -> list[str]: - readonly_fields = list(super().get_readonly_fields(request, obj)) - if obj and obj.org and obj.org.deleted: - return readonly_fields + ["deleted_org"] - else: - return readonly_fields - - @admin.display - def deleted_org(self, obj): - org = Org.deleted_objects.get(pk=obj.org_id) - return change_obj_url(org) - - -@admin.register(OrgInvitation) -class OrgInvitationAdmin(SafeDeleteAdmin): - fields = [ - "org", - "invitee_email", - "inviter", - "role", - "status", - "auto_accepted", - "created_at", - "updated_at", - ] - list_display = [ - "org", - "invitee_email", - "inviter", - "status", - "created_at", - "updated_at", - ] + list(SafeDeleteAdmin.list_display) - list_filter = ["org", "inviter", "role", SafeDeleteAdminFilter] + list( - SafeDeleteAdmin.list_filter - ) - readonly_fields = ["auto_accepted"] diff --git a/orgs/migrations/0002_alter_org_unique_together_and_more.py b/orgs/migrations/0002_alter_org_unique_together_and_more.py deleted file mode 100644 index 2c5384d67..000000000 --- a/orgs/migrations/0002_alter_org_unique_together_and_more.py +++ /dev/null @@ -1,35 +0,0 @@ -# Generated by Django 4.2.7 on 2024-07-22 14:45 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('orgs', '0001_initial'), - ] - - operations = [ - migrations.AlterUniqueTogether( - name='org', - unique_together=set(), - ), - migrations.AlterField( - model_name='orginvitation', - name='last_email_sent_at', - field=models.DateTimeField(blank=True, default=None, null=True), - ), - migrations.AlterField( - model_name='orginvitation', - name='status_changed_at', - field=models.DateTimeField(blank=True, default=None, null=True), - ), - migrations.AddConstraint( - model_name='org', - constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('domain_name',), name='unique_domain_name_when_not_deleted'), - ), - migrations.RemoveField( - model_name='org', - name='members', - ), - ] diff --git a/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py b/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py deleted file mode 100644 index 6047919f1..000000000 --- a/orgs/migrations/0003_remove_org_unique_domain_name_when_not_deleted_and_more.py +++ /dev/null @@ -1,36 +0,0 @@ -# Generated by Django 4.2.7 on 2024-07-23 11:45 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - dependencies = [ - ('app_users', '0019_alter_appusertransaction_reason'), - ('orgs', '0002_alter_org_unique_together_and_more'), - ] - - operations = [ - migrations.RemoveConstraint( - model_name='org', - name='unique_domain_name_when_not_deleted', - ), - migrations.AlterUniqueTogether( - name='orgmembership', - unique_together=set(), - ), - migrations.AlterField( - model_name='orginvitation', - name='status_changed_by', - field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='received_invitations', to='app_users.appuser'), - ), - migrations.AddConstraint( - model_name='org', - constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('domain_name',), name='unique_domain_name_when_not_deleted', violation_error_message='This domain name is already in use by another team. Contact Gooey.AI Support if you think this is a mistake.'), - ), - migrations.AddConstraint( - model_name='orgmembership', - constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('org', 'user'), name='unique_org_user'), - ), - ] diff --git a/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py b/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py deleted file mode 100644 index 9d9fdfc5d..000000000 --- a/orgs/migrations/0004_org_balance_org_is_paying_org_is_personal_and_more.py +++ /dev/null @@ -1,45 +0,0 @@ -# Generated by Django 4.2.7 on 2024-08-12 14:23 - -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - dependencies = [ - ('payments', '0005_alter_subscription_plan'), - ('orgs', '0003_remove_org_unique_domain_name_when_not_deleted_and_more'), - ] - - operations = [ - migrations.AddField( - model_name='org', - name='balance', - field=models.IntegerField(default=0, verbose_name='bal'), - ), - migrations.AddField( - model_name='org', - name='is_paying', - field=models.BooleanField(default=False, verbose_name='paid'), - ), - migrations.AddField( - model_name='org', - name='is_personal', - field=models.BooleanField(default=False), - ), - migrations.AddField( - model_name='org', - name='low_balance_email_sent_at', - field=models.DateTimeField(blank=True, null=True), - ), - migrations.AddField( - model_name='org', - name='stripe_customer_id', - field=models.CharField(blank=True, default='', max_length=255), - ), - migrations.AddField( - model_name='org', - name='subscription', - field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='org', to='payments.subscription'), - ), - ] diff --git a/orgs/migrations/0005_org_unique_personal_org_per_user.py b/orgs/migrations/0005_org_unique_personal_org_per_user.py deleted file mode 100644 index aaaa1cc4d..000000000 --- a/orgs/migrations/0005_org_unique_personal_org_per_user.py +++ /dev/null @@ -1,17 +0,0 @@ -# Generated by Django 4.2.7 on 2024-08-13 14:34 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('orgs', '0004_org_balance_org_is_paying_org_is_personal_and_more'), - ] - - operations = [ - migrations.AddConstraint( - model_name='org', - constraint=models.UniqueConstraint(models.F('created_by'), condition=models.Q(('deleted__isnull', True), ('is_personal', True)), name='unique_personal_org_per_user'), - ), - ] diff --git a/orgs/signals.py b/orgs/signals.py deleted file mode 100644 index bb23b7e06..000000000 --- a/orgs/signals.py +++ /dev/null @@ -1,49 +0,0 @@ -from django.db.models.signals import post_save -from django.dispatch import receiver -from loguru import logger -from safedelete.signals import post_softdelete - -from app_users.models import AppUser -from orgs.models import Org, OrgMembership, OrgRole -from orgs.tasks import send_auto_accepted_email - - -@receiver(post_save, sender=AppUser) -def add_user_existing_org(instance: AppUser, **kwargs): - """ - if the domain name matches - """ - if not instance.email: - return - - email_domain = instance.email.split("@")[1] - org = Org.objects.filter(domain_name=email_domain).first() - if not org: - return - - if instance.received_invitations.exists(): - # user has some existing invitations - return - - org_owner = org.memberships.filter(role=OrgRole.OWNER).first() - if not org_owner: - logger.warning( - f"Org {org} has no owner. Skipping auto-accept for user {instance}" - ) - return - - invitation = org.invite_user( - invitee_email=instance.email, - inviter=org_owner.user, - role=OrgRole.MEMBER, - auto_accept=not instance.org_memberships.exists(), # auto-accept only if user has no existing memberships - ) - - -@receiver(post_softdelete, sender=OrgMembership) -def delete_org_if_no_members_left(instance: OrgMembership, **kwargs): - if instance.org.memberships.exists(): - return - - logger.info(f"Deleting org {instance.org} because it has no members left") - instance.org.delete() diff --git a/payments/auto_recharge.py b/payments/auto_recharge.py index 3d07493b5..bc7934311 100644 --- a/payments/auto_recharge.py +++ b/payments/auto_recharge.py @@ -5,7 +5,7 @@ from app_users.models import PaymentProvider from daras_ai_v2.redis_cache import redis_lock -from orgs.models import Org +from workspaces.models import Workspace from payments.tasks import send_monthly_budget_reached_email @@ -28,18 +28,18 @@ class AutoRechargeCooldownException(AutoRechargeException): pass -def should_attempt_auto_recharge(org: Org): - return ( - org.subscription - and org.subscription.auto_recharge_enabled - and org.subscription.payment_provider - and org.balance < org.subscription.auto_recharge_balance_threshold +def should_attempt_auto_recharge(workspace: Workspace) -> bool: + return bool( + workspace.subscription + and workspace.subscription.auto_recharge_enabled + and workspace.subscription.payment_provider + and workspace.balance < workspace.subscription.auto_recharge_balance_threshold ) -def run_auto_recharge_gracefully(org: Org): +def run_auto_recharge_gracefully(workspace: Workspace): """ - Wrapper over _auto_recharge_org, that handles exceptions so that it can: + Wrapper over _auto_recharge_workspace, that handles exceptions so that it can: - log exceptions - send emails when auto-recharge fails - not retry if this is run as a background task @@ -47,49 +47,49 @@ def run_auto_recharge_gracefully(org: Org): Meant to be used in conjunction with should_attempt_auto_recharge """ try: - with redis_lock(f"gooey/auto_recharge_user/v1/{org.org_id}"): - _auto_recharge_org(org) + with redis_lock(f"gooey/auto_recharge_user/v1/{workspace.id}"): + _auto_recharge_workspace(workspace) except AutoRechargeCooldownException as e: logger.info( - f"Rejected auto-recharge because auto-recharge is in cooldown period for org" - f"{org=}, {e=}" + f"Rejected auto-recharge because auto-recharge is in cooldown period for workspace" + f"{workspace=}, {e=}" ) except MonthlyBudgetReachedException as e: - send_monthly_budget_reached_email(org) + send_monthly_budget_reached_email(workspace) logger.info( f"Rejected auto-recharge because user has reached monthly budget" - f"{org=}, spending=${e.spending}, budget=${e.budget}" + f"{workspace=}, spending=${e.spending}, budget=${e.budget}" ) except Exception as e: traceback.print_exc() sentry_sdk.capture_exception(e) -def _auto_recharge_org(org: Org): +def _auto_recharge_workspace(workspace: Workspace): """ Returns whether a charge was attempted """ from payments.webhooks import StripeWebhookHandler assert ( - org.subscription.payment_provider == PaymentProvider.STRIPE + workspace.subscription.payment_provider == PaymentProvider.STRIPE ), "Auto recharge is only supported with Stripe" # check for monthly budget - dollars_spent = org.get_dollars_spent_this_month() + dollars_spent = workspace.get_dollars_spent_this_month() if ( - dollars_spent + org.subscription.auto_recharge_topup_amount - > org.subscription.monthly_spending_budget + dollars_spent + workspace.subscription.auto_recharge_topup_amount + > workspace.subscription.monthly_spending_budget ): raise MonthlyBudgetReachedException( "Performing this top-up would exceed your monthly recharge budget", - budget=org.subscription.monthly_spending_budget, + budget=workspace.subscription.monthly_spending_budget, spending=dollars_spent, ) try: - invoice = org.subscription.stripe_get_or_create_auto_invoice( - amount_in_dollars=org.subscription.auto_recharge_topup_amount, + invoice = workspace.subscription.stripe_get_or_create_auto_invoice( + amount_in_dollars=workspace.subscription.auto_recharge_topup_amount, metadata_key="auto_recharge", ) except Exception as e: @@ -103,9 +103,11 @@ def _auto_recharge_org(org: Org): # get default payment method and attempt payment assert invoice.status == "open" # sanity check - pm = org.subscription.stripe_get_default_payment_method() + pm = workspace.subscription.stripe_get_default_payment_method() if not pm: - logger.warning(f"{org} has no default payment method, cannot auto-recharge") + logger.warning( + f"{workspace} has no default payment method, cannot auto-recharge" + ) return try: @@ -117,5 +119,5 @@ def _auto_recharge_org(org: Org): else: assert invoice_data.paid StripeWebhookHandler.handle_invoice_paid( - org_id=org.org_id, invoice=invoice_data + workspace_id_or_uid=workspace.id, invoice=invoice_data ) diff --git a/payments/models.py b/payments/models.py index ff5be4f69..cebfeda70 100644 --- a/payments/models.py +++ b/payments/models.py @@ -82,8 +82,8 @@ def __str__(self): ret = f"{self.get_plan_display()} | {self.get_payment_provider_display()}" # if self.has_user: # ret = f"{ret} | {self.user}" - if self.has_org: - ret = f"{ret} | {self.org}" + if self.has_workspace: + ret = f"{ret} | {self.workspace}" if self.auto_recharge_enabled: ret = f"Auto | {ret}" return ret @@ -138,10 +138,10 @@ def is_paid(self) -> bool: return PricingPlan.from_sub(self).monthly_charge > 0 and self.external_id @property - def has_org(self) -> bool: + def has_workspace(self) -> bool: try: - self.org - except Subscription.org.RelatedObjectDoesNotExist: + self.workspace + except Subscription.workspace.RelatedObjectDoesNotExist: return False else: return True @@ -376,12 +376,12 @@ def has_sent_monthly_budget_email_this_month(self) -> bool: ) def should_send_monthly_spending_notification(self) -> bool: - assert self.has_org + assert self.has_workspace return bool( self.monthly_spending_notification_threshold and not self.has_sent_monthly_spending_notification_this_month() - and self.org.get_dollars_spent_this_month() + and self.workspace.get_dollars_spent_this_month() >= self.monthly_spending_notification_threshold ) diff --git a/payments/tasks.py b/payments/tasks.py index c98b8c12e..d84f1c748 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -1,8 +1,7 @@ from django.utils import timezone from loguru import logger -from app_users.models import AppUser -from orgs.models import Org +from workspaces.models import Workspace from celeryapp import app from daras_ai_v2 import settings from daras_ai_v2.fastapi_tricks import get_app_route_url @@ -14,11 +13,11 @@ def send_monthly_spending_notification_email(id: int): from routers.account import account_route - org = Org.objects.get(id=id) - threshold = org.subscription.monthly_spending_notification_threshold - for owner in org.get_owners(): + workspace = Workspace.objects.get(id=id) + threshold = workspace.subscription.monthly_spending_notification_threshold + for owner in workspace.get_owners(): if not owner.user.email: - logger.error(f"Org Owner doesn't have an email: {owner=}") + logger.error(f"Workspace Owner doesn't have an email: {owner=}") return send_email_via_postmark( @@ -29,7 +28,7 @@ def send_monthly_spending_notification_email(id: int): "monthly_spending_notification_threshold_email.html" ).render( user=owner.user, - org=org, + workspace=workspace, account_url=get_app_route_url(account_route), ), ) @@ -37,20 +36,22 @@ def send_monthly_spending_notification_email(id: int): # IMPORTANT: always use update_fields=... / select_for_update when updating # subscription info. We don't want to overwrite other changes made to # subscription during the same time - org.subscription.monthly_spending_notification_sent_at = timezone.now() - org.subscription.save(update_fields=["monthly_spending_notification_sent_at"]) + workspace.subscription.monthly_spending_notification_sent_at = timezone.now() + workspace.subscription.save( + update_fields=["monthly_spending_notification_sent_at"] + ) -def send_monthly_budget_reached_email(org: Org): +def send_monthly_budget_reached_email(workspace: Workspace): from routers.account import account_route - for owner in org.get_owners(): + for owner in workspace.get_owners(): if not owner.user.email: continue email_body = templates.get_template("monthly_budget_reached_email.html").render( user=owner.user, - org=org, + workspace=workspace, account_url=get_app_route_url(account_route), ) send_email_via_postmark( @@ -63,5 +64,5 @@ def send_monthly_budget_reached_email(org: Org): # IMPORTANT: always use update_fields=... when updating subscription # info. We don't want to overwrite other changes made to subscription # during the same time - org.subscription.monthly_budget_email_sent_at = timezone.now() - org.subscription.save(update_fields=["monthly_budget_email_sent_at"]) + workspace.subscription.monthly_budget_email_sent_at = timezone.now() + workspace.subscription.save(update_fields=["monthly_budget_email_sent_at"]) diff --git a/payments/webhooks.py b/payments/webhooks.py index 36f0499c7..3d7c3f202 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -6,7 +6,7 @@ from app_users.models import PaymentProvider, TransactionReason from daras_ai_v2 import paypal -from orgs.models import Org +from workspaces.models import Workspace from .models import Subscription from .plans import PricingPlan from .tasks import send_monthly_spending_notification_email @@ -22,7 +22,7 @@ def handle_sale_completed(cls, sale: paypal.Sale): return pp_sub = paypal.Subscription.retrieve(sale.billing_agreement_id) - assert pp_sub.custom_id, "pp_sub is missing org_id" + assert pp_sub.custom_id, "pp_sub is missing workspace_id" assert pp_sub.plan_id, "pp_sub is missing plan ID" plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) @@ -35,9 +35,8 @@ def handle_sale_completed(cls, sale: paypal.Sale): f"paypal: charged amount ${charged_dollars} does not match plan's monthly charge ${plan.monthly_charge}" ) - org_id = pp_sub.custom_id add_balance_for_payment( - org_id=org_id, + workspace_id_or_uid=pp_sub.custom_id, amount=plan.credits, invoice_id=sale.id, payment_provider=cls.PROVIDER, @@ -50,7 +49,9 @@ def handle_sale_completed(cls, sale: paypal.Sale): def handle_subscription_updated(cls, pp_sub: paypal.Subscription): logger.info(f"Paypal subscription updated {pp_sub.id}") - assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing org_id" + assert ( + pp_sub.custom_id + ), f"PayPal subscription {pp_sub.id} is missing workspace_id" assert pp_sub.plan_id, f"PayPal subscription {pp_sub.id} is missing plan ID" plan = PricingPlan.get_by_paypal_plan_id(pp_sub.plan_id) @@ -62,8 +63,8 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): ) return - set_org_subscription( - org_id=pp_sub.custom_id, + set_workspace_subscription( + workspace_id_or_uid=pp_sub.custom_id, plan=plan, provider=cls.PROVIDER, external_id=pp_sub.id, @@ -72,8 +73,8 @@ def handle_subscription_updated(cls, pp_sub: paypal.Subscription): @classmethod def handle_subscription_cancelled(cls, pp_sub: paypal.Subscription): assert pp_sub.custom_id, f"PayPal subscription {pp_sub.id} is missing uid" - set_org_subscription( - org_id=pp_sub.custom_id, + set_workspace_subscription( + workspace_id_or_uid=pp_sub.custom_id, plan=PricingPlan.STARTER, provider=None, external_id=None, @@ -84,7 +85,9 @@ class StripeWebhookHandler: PROVIDER = PaymentProvider.STRIPE @classmethod - def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice): + def handle_invoice_paid( + cls, workspace_id_or_uid: str | int, invoice: stripe.Invoice + ): from app_users.tasks import save_stripe_default_payment_method kwargs = {} @@ -109,7 +112,7 @@ def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice): amount = invoice.lines.data[0].quantity charged_amount = invoice.lines.data[0].amount add_balance_for_payment( - org_id=org_id, + workspace_id_or_uid=workspace_id_or_uid, amount=amount, invoice_id=invoice.id, payment_provider=cls.PROVIDER, @@ -119,15 +122,15 @@ def handle_invoice_paid(cls, org_id: str, invoice: stripe.Invoice): ) save_stripe_default_payment_method.delay( + workspace_id_or_uid=workspace_id_or_uid, payment_intent_id=invoice.payment_intent, - org_id=org_id, amount=amount, charged_amount=charged_amount, reason=reason, ) @classmethod - def handle_checkout_session_completed(cls, org_id: str, session_data): + def handle_checkout_session_completed(cls, workspace_id_or_uid: str, session_data): setup_intent_id = session_data.get("setup_intent") if not setup_intent_id: # not a setup mode checkout -- do nothing @@ -149,7 +152,9 @@ def handle_checkout_session_completed(cls, org_id: str, session_data): ) @classmethod - def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscription): + def handle_subscription_updated( + cls, workspace_id_or_uid: int | str, stripe_sub: stripe.Subscription + ): logger.info(f"Stripe subscription updated: {stripe_sub.id}") assert stripe_sub.plan, f"Stripe subscription {stripe_sub.id} is missing plan" @@ -170,17 +175,17 @@ def handle_subscription_updated(cls, org_id: str, stripe_sub: stripe.Subscriptio ) return - set_org_subscription( - org_id=org_id, + set_workspace_subscription( + workspace_id_or_uid=workspace_id_or_uid, plan=plan, provider=cls.PROVIDER, external_id=stripe_sub.id, ) @classmethod - def handle_subscription_cancelled(cls, org_id: str): - set_org_subscription( - org_id=org_id, + def handle_subscription_cancelled(cls, workspace_id_or_uid: int | str): + set_workspace_subscription( + workspace_id_or_uid=workspace_id_or_uid, plan=PricingPlan.STARTER, provider=PaymentProvider.STRIPE, external_id=None, @@ -189,15 +194,19 @@ def handle_subscription_cancelled(cls, org_id: str): def add_balance_for_payment( *, - org_id: str, + workspace_id_or_uid: int | str, amount: int, invoice_id: str, payment_provider: PaymentProvider, charged_amount: int, **kwargs, ): - org = Org.objects.get_or_create_from_org_id(org_id)[0] - org.add_balance( + try: + workspace = Workspace.objects.get(id=int(workspace_id_or_uid)) + except (ValueError, Workspace.DoesNotExist): + workspace, _ = Workspace.objects.get_or_create_from_uid(workspace_id_or_uid) + + workspace.add_balance( amount=amount, invoice_id=invoice_id, charged_amount=charged_amount, @@ -205,30 +214,33 @@ def add_balance_for_payment( **kwargs, ) - if not org.is_paying: - org.is_paying = True - org.save(update_fields=["is_paying"]) + if not workspace.is_paying: + workspace.is_paying = True + workspace.save(update_fields=["is_paying"]) if ( - org.subscription - and org.subscription.should_send_monthly_spending_notification() + workspace.subscription + and workspace.subscription.should_send_monthly_spending_notification() ): - send_monthly_spending_notification_email.delay(org.id) + send_monthly_spending_notification_email.delay(workspace.id) -def set_org_subscription( +def set_workspace_subscription( *, - org_id: str, + workspace_id_or_uid: int | str, plan: PricingPlan, provider: PaymentProvider | None, external_id: str | None, amount: int | None = None, charged_amount: int | None = None, ) -> Subscription: - with transaction.atomic(): - org = Org.objects.get_or_create_from_org_id(org_id)[0] + try: + workspace = Workspace.objects.get(id=int(workspace_id_or_uid)) + except (ValueError, Workspace.DoesNotExist): + workspace, _ = Workspace.objects.get_or_create_from_uid(workspace_id_or_uid) - old_sub = org.subscription + with transaction.atomic(): + old_sub = workspace.subscription if old_sub: new_sub = copy(old_sub) else: @@ -242,8 +254,8 @@ def set_org_subscription( new_sub.save() if not old_sub: - org.subscription = new_sub - org.save(update_fields=["subscription"]) + workspace.subscription = new_sub + workspace.save(update_fields=["subscription"]) # cancel previous subscription if it's not the same as the new one if old_sub and old_sub.external_id != external_id: diff --git a/routers/account.py b/routers/account.py index f9194589b..b2612d55c 100644 --- a/routers/account.py +++ b/routers/account.py @@ -18,10 +18,10 @@ from daras_ai_v2.manage_api_keys_widget import manage_api_keys from daras_ai_v2.meta_content import raw_build_meta_tags from daras_ai_v2.profiles import edit_user_profile_page -from orgs.models import OrgInvitation +from workspaces.models import WorkspaceInvitation from payments.webhooks import PaypalWebhookHandler from routers.root import page_wrapper, get_og_url_path -from orgs.views import invitation_page, orgs_page +from workspaces.views import invitation_page, workspaces_page from routers.custom_api_router import CustomAPIRouter @@ -142,10 +142,10 @@ def api_keys_route(request: Request): ) -@gui.route(app, "/orgs/") -def orgs_route(request: Request): - with account_page_wrapper(request, AccountTabs.orgs): - orgs_tab(request) +@gui.route(app, "/workspaces/") +def workspaces_route(request: Request): + with account_page_wrapper(request, AccountTabs.workspaces): + workspaces_tab(request) url = get_og_url_path(request) return dict( @@ -159,8 +159,8 @@ def orgs_route(request: Request): ) -@gui.route(app, "/invitation/{org_slug}/{invite_id}/") -def invitation_route(request: Request, org_slug: str, invite_id: str): +@gui.route(app, "/invitation/{workspace_slug}/{invite_id}/") +def invitation_route(request: Request, workspace_slug: str, invite_id: str): from routers.root import login if not request.user or request.user.is_anonymous: @@ -169,8 +169,8 @@ def invitation_route(request: Request, org_slug: str, invite_id: str): raise RedirectException(redirect_url) try: - invitation = OrgInvitation.objects.get(invite_id=invite_id) - except OrgInvitation.DoesNotExist: + invitation = WorkspaceInvitation.objects.get(invite_id=invite_id) + except WorkspaceInvitation.DoesNotExist: return Response(status_code=404) with page_wrapper(request): @@ -178,8 +178,8 @@ def invitation_route(request: Request, org_slug: str, invite_id: str): return dict( meta=raw_build_meta_tags( url=str(request.url), - title=f"Join {invitation.org.name} • Gooey.AI", - description=f"Invitation to join {invitation.org.name}", + title=f"Join {invitation.workspace.name} • Gooey.AI", + description=f"Invitation to join {invitation.workspace.name}", robots="noindex,nofollow", ) ) @@ -195,7 +195,7 @@ class AccountTabs(TabData, Enum): profile = TabData(title=f"{icons.profile} Profile", route=profile_route) saved = TabData(title=f"{icons.save} Saved", route=saved_route) api_keys = TabData(title=f"{icons.api} API Keys", route=api_keys_route) - orgs = TabData(title=f"{icons.company} Teams", route=orgs_route) + workspaces = TabData(title=f"{icons.company} Teams", route=workspaces_route) @property def url_path(self) -> str: @@ -203,8 +203,8 @@ def url_path(self) -> str: def billing_tab(request: Request): - org, _ = request.user.get_or_create_personal_org() - return billing_page(org) + workspace, _ = request.user.get_or_create_personal_workspace() + return billing_page(workspace) def profile_tab(request: Request): @@ -256,14 +256,14 @@ def api_keys_tab(request: Request): manage_api_keys(request.user) -def orgs_tab(request: Request): +def workspaces_tab(request: Request): """only accessible to admins""" from daras_ai_v2.base import BasePage if not BasePage.is_user_admin(request.user): raise RedirectException(get_route_path(account_route)) - orgs_page(request.user) + workspaces_page(request.user) def get_tabs(request: Request) -> list[AccountTabs]: @@ -276,7 +276,7 @@ def get_tabs(request: Request) -> list[AccountTabs]: AccountTabs.api_keys, ] if BasePage.is_user_admin(request.user): - tab_list.append(AccountTabs.orgs) + tab_list.append(AccountTabs.workspaces) return tab_list diff --git a/routers/api.py b/routers/api.py index 9b795d426..5d2b4e42d 100644 --- a/routers/api.py +++ b/routers/api.py @@ -354,6 +354,7 @@ def submit_api_call( enable_rate_limits=enable_rate_limits, is_api_call=True, retention_policy=retention_policy or RetentionPolicy.keep, + billed_workspace=self.get_current_workspace(), ) except ValidationError as e: raise RequestValidationError(e.raw_errors, body=gui.session_state) from e diff --git a/routers/paypal.py b/routers/paypal.py index 3771481cf..84d8a3b83 100644 --- a/routers/paypal.py +++ b/routers/paypal.py @@ -126,8 +126,8 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json): if plan.deprecated: return JSONResponse({"error": "Deprecated plan"}, status_code=400) - org, _ = request.user.get_or_create_personal_org() - if org.subscription and org.subscription.is_paid(): + workspace, _ = request.user.get_or_create_personal_worksace() + if workspace.subscription and workspace.subscription.is_paid(): return JSONResponse( {"error": "User already has an active subscription"}, status_code=400 ) @@ -135,7 +135,7 @@ def create_subscription(request: Request, payload: dict = fastapi_request_json): paypal_plan_info = plan.get_paypal_plan() pp_subscription = paypal.Subscription.create( plan_id=paypal_plan_info["plan_id"], - custom_id=org.org_id, + custom_id=str(workspace.id), plan=paypal_plan_info.get("plan", {}), application_context={ "brand_name": "Gooey.AI", @@ -177,7 +177,7 @@ def _handle_invoice_paid(order_id: str): purchase_unit = order["purchase_units"][0] payment_capture = purchase_unit["payments"]["captures"][0] add_balance_for_payment( - org_id=payment_capture["custom_id"], + workspace_id_or_uid=payment_capture["custom_id"], amount=int(purchase_unit["items"][0]["quantity"]), invoice_id=payment_capture["id"], payment_provider=PaymentProvider.PAYPAL, diff --git a/scripts/migrate_billed_org_for_saved_runs.py b/scripts/migrate_billed_org_for_saved_runs.py deleted file mode 100644 index 52b86e932..000000000 --- a/scripts/migrate_billed_org_for_saved_runs.py +++ /dev/null @@ -1,18 +0,0 @@ -from django.db.models import F, Subquery, OuterRef -from django.db import transaction - -from bots.models import SavedRun -from orgs.models import Org - - -def run(): - # Start a transaction to ensure atomicity - with transaction.atomic(): - # Perform the update where 'uid' matches a valid 'org_id' in the 'Org' table - SavedRun.objects.filter( - billed_org_id__isnull=True, uid__in=Org.objects.values("org_id") - ).update( - billed_org_id=Subquery( - Org.objects.filter(org_id=OuterRef("uid")).values("id")[:1] - ) - ) diff --git a/scripts/migrate_billed_workspace_for_saved_runs.py b/scripts/migrate_billed_workspace_for_saved_runs.py new file mode 100644 index 000000000..39f19ae4a --- /dev/null +++ b/scripts/migrate_billed_workspace_for_saved_runs.py @@ -0,0 +1,23 @@ +from django.db import connection +from loguru import logger + + +def run(): + with connection.cursor() as cursor: + cursor.execute( + """ + UPDATE bots_savedrun + SET billed_workspace_id = workspaces_workspace.id + FROM + workspaces_workspace INNER JOIN + app_users_appuser ON workspaces_workspace.created_by_id = app_users_appuser.id + WHERE + bots_savedrun.billed_workspace_id IS NULL AND + bots_savedrun.uid IS NOT NULL AND + bots_savedrun.uid = app_users_appuser.uid AND + workspaces_workspace.is_personal = true + """ + ) + rows_updated = cursor.rowcount + + logger.info(f"Updated {rows_updated} saved runs with billed workspace") diff --git a/scripts/migrate_orgs_from_appusers.py b/scripts/migrate_orgs_from_appusers.py deleted file mode 100644 index f4cbc7ec9..000000000 --- a/scripts/migrate_orgs_from_appusers.py +++ /dev/null @@ -1,50 +0,0 @@ -from django.db import IntegrityError, connection -from loguru import logger - -from app_users.models import AppUser -from orgs.models import Org - - -def run(): - migrate_personal_orgs() - migrate_txns() - - -def migrate_personal_orgs(): - users_without_personal_org = AppUser.objects.exclude( - id__in=Org.objects.filter(is_personal=True).values_list("created_by", flat=True) - ) - - done_count = 0 - - logger.info("Creating personal orgs...") - for appuser in users_without_personal_org: - try: - Org.objects.migrate_from_appuser(appuser) - except IntegrityError as e: - logger.warning(f"IntegrityError: {e}") - else: - done_count += 1 - - if done_count % 100 == 0: - logger.info(f"Running... {done_count} migrated") - - logger.info(f"Migrated {done_count} personal orgs...") - - -def migrate_txns(): - with connection.cursor() as cursor: - cursor.execute( - """ - UPDATE app_users_appusertransaction AS txn - SET org_id = orgs_org.id - FROM - app_users_appuser - INNER JOIN orgs_org ON app_users_appuser.id = orgs_org.created_by_id - WHERE - txn.user_id = app_users_appuser.id - AND txn.org_id IS NULL - AND orgs_org.is_personal = true - """ - ) - logger.info(f"Updated {cursor.rowcount} txns with personal orgs") diff --git a/scripts/migrate_workspace_from_appusers.py b/scripts/migrate_workspace_from_appusers.py new file mode 100644 index 000000000..f58c0935d --- /dev/null +++ b/scripts/migrate_workspace_from_appusers.py @@ -0,0 +1,52 @@ +from django.db import IntegrityError, connection +from loguru import logger + +from app_users.models import AppUser +from workspaces.models import Workspace + + +def run(): + migrate_personal_workspaces() + migrate_txns() + + +def migrate_personal_workspaces(): + users_without_personal_workspace = AppUser.objects.exclude( + id__in=Workspace.objects.filter(is_personal=True).values_list( + "created_by", flat=True + ) + ) + + done_count = 0 + + logger.info("Creating personal workspaces...") + for appuser in users_without_personal_workspace: + try: + Workspace.objects.migrate_from_appuser(appuser) + except IntegrityError as e: + logger.warning(f"IntegrityError: {e}") + else: + done_count += 1 + + if done_count % 100 == 0: + logger.info(f"Running... {done_count} migrated") + + logger.info(f"Migrated {done_count} personal workspaces...") + + +def migrate_txns(): + with connection.cursor() as cursor: + cursor.execute( + """ + UPDATE app_users_appusertransaction AS txn + SET workspace_id = workspaces_workspace.id + FROM + app_users_appuser + INNER JOIN workspaces_workspace ON app_users_appuser.id = workspaces_workspace.created_by_id + WHERE + txn.user_id = app_users_appuser.id + AND txn.workspace_id IS NULL + AND workspaces_workspace.is_personal = true + """ + ) + logger.info(f"Updated {cursor.rowcount} txns with personal workspaces") diff --git a/templates/monthly_budget_reached_email.html b/templates/monthly_budget_reached_email.html index 6e467a086..861a3c3c4 100644 --- a/templates/monthly_budget_reached_email.html +++ b/templates/monthly_budget_reached_email.html @@ -1,6 +1,6 @@ -{% set dollars_spent = org.get_dollars_spent_this_month() %} -{% set monthly_budget = org.subscription.monthly_spending_budget %} -{% set threshold = org.subscription.auto_recharge_balance_threshold %} +{% set dollars_spent = workspace.get_dollars_spent_this_month() %} +{% set monthly_budget = workspace.subscription.monthly_spending_budget %} +{% set threshold = workspace.subscription.auto_recharge_balance_threshold %}

Hey, {{ user.first_name() }}! @@ -18,7 +18,7 @@

    -
  • Credit Balance: {{ org.balance }} credits
  • +
  • Credit Balance: {{ workspace.balance }} credits
  • Monthly Budget: ${{ monthly_budget }}
  • Spending this month: ${{ dollars_spent }}
diff --git a/templates/monthly_spending_notification_threshold_email.html b/templates/monthly_spending_notification_threshold_email.html index 13be0fae5..c8a6394d1 100644 --- a/templates/monthly_spending_notification_threshold_email.html +++ b/templates/monthly_spending_notification_threshold_email.html @@ -1,4 +1,4 @@ -{% set dollars_spent = org.get_dollars_spent_this_month() %} +{% set dollars_spent = workspace.get_dollars_spent_this_month() %}

Hi, {{ user.first_name() }}! @@ -6,11 +6,11 @@

Your spend on Gooey.AI so far this month is ${{ dollars_spent }}, exceeding your notification threshold - of ${{ org.subscription.monthly_spending_notification_threshold }}. + of ${{ workspace.subscription.monthly_spending_notification_threshold }}.

- Your monthly budget is ${{ org.subscription.monthly_spending_budget }}, after which auto-recharge will be + Your monthly budget is ${{ workspace.subscription.monthly_spending_budget }}, after which auto-recharge will be paused and all runs / API calls will be rejected.

diff --git a/templates/org_invitation_auto_accepted_email.html b/templates/org_invitation_auto_accepted_email.html index 843fb7426..ab34a71de 100644 --- a/templates/org_invitation_auto_accepted_email.html +++ b/templates/org_invitation_auto_accepted_email.html @@ -3,14 +3,14 @@

- You have been added to the team {{ org.name }} on Gooey.AI. - Visit the teams page to see your team. + You have been added to the team {{ workspace.name }} on Gooey.AI. + Visit the teams page to see your team.

- Your invite was automatically accepted because your email domain matches the organization's configured email domain. - If you think this shouldn't have happened, you can leave this organization from the - teams page. + Your invite was automatically accepted because your email domain matches the workspaceanization's configured email domain. + If you think this shouldn't have happened, you can leave this workspaceanization from the + teams page.

diff --git a/templates/org_invitation_email.html b/templates/org_invitation_email.html index c8e12dc87..dabb1dd40 100644 --- a/templates/org_invitation_email.html +++ b/templates/org_invitation_email.html @@ -4,7 +4,7 @@

{{ invitation.inviter.display_name or invitation.inviter.first_name() }} has invited - you to join their team {{ invitation.org.name }} on Gooey.AI. + you to join their team {{ invitation.workspace.name }} on Gooey.AI.

@@ -14,7 +14,7 @@

- The link will expire in {{ settings.ORG_INVITATION_EXPIRY_DAYS }} days. + The link will expire in {{ settings.WORKSPACE_INVITATION_EXPIRY_DAYS }} days.

diff --git a/orgs/__init__.py b/workspaces/__init__.py similarity index 100% rename from orgs/__init__.py rename to workspaces/__init__.py diff --git a/workspaces/admin.py b/workspaces/admin.py new file mode 100644 index 000000000..3c0e74de7 --- /dev/null +++ b/workspaces/admin.py @@ -0,0 +1,155 @@ +from django.contrib import admin +from django.db.models import Sum +from safedelete.admin import SafeDeleteAdmin, SafeDeleteAdminFilter + +from bots.admin_links import change_obj_url +from usage_costs.models import UsageCost +from .models import Workspace, WorkspaceMembership, WorkspaceInvitation + + +class WorkspaceMembershipInline(admin.TabularInline): + model = WorkspaceMembership + extra = 0 + show_change_link = True + fields = ["user", "role", "created_at", "updated_at"] + readonly_fields = ["created_at", "updated_at"] + ordering = ["-created_at"] + can_delete = False + show_change_link = True + + +class WorkspaceInvitationInline(admin.TabularInline): + model = WorkspaceInvitation + extra = 0 + show_change_link = True + fields = [ + "invitee_email", + "inviter", + "status", + "auto_accepted", + "created_at", + "updated_at", + ] + readonly_fields = ["auto_accepted", "created_at", "updated_at"] + ordering = ["status", "-created_at"] + can_delete = False + show_change_link = True + + +@admin.register(Workspace) +class WorkspaceAdmin(SafeDeleteAdmin): + list_display = [ + "name", + "domain_name", + "created_at", + "updated_at", + ] + list(SafeDeleteAdmin.list_display) + list_filter = [SafeDeleteAdminFilter] + list(SafeDeleteAdmin.list_filter) + fields = [ + "name", + "domain_name", + "created_by", + "is_personal", + "is_paying", + ("balance", "subscription"), + ("total_payments", "total_charged", "total_usage_cost"), + "created_at", + "updated_at", + ] + search_fields = ["name", "domain_name"] + readonly_fields = [ + "is_personal", + "created_at", + "updated_at", + "total_payments", + "total_charged", + "total_usage_cost", + ] + inlines = [WorkspaceMembershipInline, WorkspaceInvitationInline] + ordering = ["-created_at"] + + @admin.display(description="Total Payments") + def total_payments(self, workspace: Workspace): + return "$" + str( + ( + workspace.transactions.aggregate(Sum("charged_amount"))[ + "charged_amount__sum" + ] + or 0 + ) + / 100 + ) + + @admin.display(description="Total Charged") + def total_charged(self, workspace: Workspace): + credits_charged = -1 * ( + workspace.transactions.filter(amount__lt=0).aggregate(Sum("amount"))[ + "amount__sum" + ] + or 0 + ) + return f"{credits_charged} Credits" + + @admin.display(description="Total Usage Cost") + def total_usage_cost(self, workspace: Workspace): + total_cost = ( + UsageCost.objects.filter( + saved_run__billed_workspace_id=workspace.id + ).aggregate(Sum("dollar_amount"))["dollar_amount__sum"] + or 0 + ) + return round(total_cost, 2) + + +@admin.register(WorkspaceMembership) +class WorkspaceMembershipAdmin(SafeDeleteAdmin): + list_display = [ + "user", + "workspace", + "role", + "created_at", + "updated_at", + ] + list(SafeDeleteAdmin.list_display) + list_filter = ["workspace", "role", SafeDeleteAdminFilter] + list( + SafeDeleteAdmin.list_filter + ) + + def get_readonly_fields( + self, request: "HttpRequest", obj: WorkspaceMembership | None = None + ) -> list[str]: + readonly_fields = list(super().get_readonly_fields(request, obj)) + if obj and obj.workspace and obj.workspace.deleted: + return readonly_fields + ["deleted_workspace"] + else: + return readonly_fields + + @admin.display + def deleted_workspace(self, obj): + workspace = Workspace.deleted_objects.get(pk=obj.workspace_id) + return change_obj_url(workspace) + + +@admin.register(WorkspaceInvitation) +class WorkspaceInvitationAdmin(SafeDeleteAdmin): + fields = [ + "workspace", + "invitee_email", + "inviter", + "role", + "status", + "auto_accepted", + "created_at", + "updated_at", + ] + list_display = [ + "workspace", + "invitee_email", + "inviter", + "status", + "created_at", + "updated_at", + ] + list(SafeDeleteAdmin.list_display) + list_filter = ["workspace", "inviter", "role", SafeDeleteAdminFilter] + list( + SafeDeleteAdmin.list_filter + ) + readonly_fields = ["auto_accepted"] diff --git a/orgs/apps.py b/workspaces/apps.py similarity index 74% rename from orgs/apps.py rename to workspaces/apps.py index a75310666..dfc799939 100644 --- a/orgs/apps.py +++ b/workspaces/apps.py @@ -1,9 +1,9 @@ from django.apps import AppConfig -class OrgsConfig(AppConfig): +class WorkspacesConfig(AppConfig): default_auto_field = "django.db.models.BigAutoField" - name = "orgs" + name = "workspaces" def ready(self): from . import signals diff --git a/orgs/migrations/0001_initial.py b/workspaces/migrations/0001_initial.py similarity index 56% rename from orgs/migrations/0001_initial.py rename to workspaces/migrations/0001_initial.py index 7de84737d..b7183be45 100644 --- a/orgs/migrations/0001_initial.py +++ b/workspaces/migrations/0001_initial.py @@ -1,8 +1,8 @@ -# Generated by Django 4.2.7 on 2024-07-18 15:41 +# Generated by Django 4.2.7 on 2024-09-02 14:07 from django.db import migrations, models import django.db.models.deletion -import orgs.models +import workspaces.models class Migration(migrations.Migration): @@ -10,27 +10,34 @@ class Migration(migrations.Migration): initial = True dependencies = [ + ('payments', '0005_alter_subscription_plan'), ('app_users', '0019_alter_appusertransaction_reason'), ] operations = [ migrations.CreateModel( - name='Org', + name='Workspace', fields=[ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('deleted', models.DateTimeField(db_index=True, editable=False, null=True)), ('deleted_by_cascade', models.BooleanField(default=False, editable=False)), - ('org_id', models.CharField(blank=True, max_length=100, null=True, unique=True)), + ('workspace_id', models.CharField(blank=True, max_length=100, null=True, unique=True)), ('name', models.CharField(max_length=100)), ('logo', models.URLField(blank=True, null=True)), - ('domain_name', models.CharField(blank=True, max_length=30, null=True, validators=[orgs.models.validate_org_domain_name])), + ('domain_name', models.CharField(blank=True, max_length=30, null=True, validators=[workspaces.models.validate_workspace_domain_name])), + ('balance', models.IntegerField(default=0, verbose_name='bal')), + ('is_paying', models.BooleanField(default=False, verbose_name='paid')), + ('stripe_customer_id', models.CharField(blank=True, default='', max_length=255)), + ('low_balance_email_sent_at', models.DateTimeField(blank=True, null=True)), + ('is_personal', models.BooleanField(default=False)), ('created_at', models.DateTimeField(auto_now_add=True)), ('updated_at', models.DateTimeField(auto_now=True)), ('created_by', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='app_users.appuser')), + ('subscription', models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='workspace', to='payments.subscription')), ], ), migrations.CreateModel( - name='OrgInvitation', + name='WorkspaceInvitation', fields=[ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('deleted', models.DateTimeField(db_index=True, editable=False, null=True)), @@ -40,20 +47,20 @@ class Migration(migrations.Migration): ('status', models.IntegerField(choices=[(1, 'Pending'), (2, 'Accepted'), (3, 'Rejected'), (4, 'Canceled'), (5, 'Expired')], default=1)), ('auto_accepted', models.BooleanField(default=False)), ('role', models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3)), - ('last_email_sent_at', models.DateTimeField(blank=True, default=False, null=True)), - ('status_changed_at', models.DateTimeField(blank=True, default=False, null=True)), + ('last_email_sent_at', models.DateTimeField(blank=True, default=None, null=True)), + ('status_changed_at', models.DateTimeField(blank=True, default=None, null=True)), ('created_at', models.DateTimeField(auto_now_add=True)), ('updated_at', models.DateTimeField(auto_now=True)), ('inviter', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='sent_invitations', to='app_users.appuser')), - ('org', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='invitations', to='orgs.org')), - ('status_changed_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='app_users.appuser')), + ('status_changed_by', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='received_invitations', to='app_users.appuser')), + ('workspace', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='invitations', to='workspaces.workspace')), ], options={ 'abstract': False, }, ), migrations.CreateModel( - name='OrgMembership', + name='WorkspaceMembership', fields=[ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('deleted', models.DateTimeField(db_index=True, editable=False, null=True)), @@ -61,21 +68,21 @@ class Migration(migrations.Migration): ('role', models.IntegerField(choices=[(1, 'Owner'), (2, 'Admin'), (3, 'Member')], default=3)), ('created_at', models.DateTimeField(auto_now_add=True)), ('updated_at', models.DateTimeField(auto_now=True)), - ('invitation', models.OneToOneField(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='membership', to='orgs.orginvitation')), - ('org', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='memberships', to='orgs.org')), - ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='org_memberships', to='app_users.appuser')), + ('invitation', models.OneToOneField(blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='membership', to='workspaces.workspaceinvitation')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='workspace_memberships', to='app_users.appuser')), + ('workspace', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='memberships', to='workspaces.workspace')), ], - options={ - 'unique_together': {('org', 'user', 'deleted')}, - }, ), - migrations.AddField( - model_name='org', - name='members', - field=models.ManyToManyField(related_name='orgs', through='orgs.OrgMembership', to='app_users.appuser'), + migrations.AddConstraint( + model_name='workspacemembership', + constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('workspace', 'user'), name='unique_workspace_user'), + ), + migrations.AddConstraint( + model_name='workspace', + constraint=models.UniqueConstraint(condition=models.Q(('deleted__isnull', True)), fields=('domain_name',), name='unique_domain_name_when_not_deleted', violation_error_message='This domain name is already in use by another team. Contact Gooey.AI Support if you think this is a mistake.'), ), - migrations.AlterUniqueTogether( - name='org', - unique_together={('domain_name', 'deleted')}, + migrations.AddConstraint( + model_name='workspace', + constraint=models.UniqueConstraint(models.F('created_by'), condition=models.Q(('deleted__isnull', True), ('is_personal', True)), name='unique_personal_workspace_per_user'), ), ] diff --git a/workspaces/migrations/0002_alter_workspace_logo.py b/workspaces/migrations/0002_alter_workspace_logo.py new file mode 100644 index 000000000..c28aba367 --- /dev/null +++ b/workspaces/migrations/0002_alter_workspace_logo.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.7 on 2024-09-03 12:59 + +import bots.custom_fields +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('workspaces', '0001_initial'), + ] + + operations = [ + migrations.AlterField( + model_name='workspace', + name='logo', + field=bots.custom_fields.CustomURLField(blank=True, max_length=2048, null=True), + ), + ] diff --git a/orgs/migrations/__init__.py b/workspaces/migrations/__init__.py similarity index 100% rename from orgs/migrations/__init__.py rename to workspaces/migrations/__init__.py diff --git a/orgs/models.py b/workspaces/models.py similarity index 69% rename from orgs/models.py rename to workspaces/models.py index 4c9b2c8e2..56020aef4 100644 --- a/orgs/models.py +++ b/workspaces/models.py @@ -15,80 +15,80 @@ from safedelete.managers import SafeDeleteManager from safedelete.models import SafeDeleteModel, SOFT_DELETE_CASCADE +from bots.custom_fields import CustomURLField from daras_ai_v2 import settings from daras_ai_v2.fastapi_tricks import get_app_route_url from daras_ai_v2.crypto import get_random_doc_id from gooeysite.bg_db_conn import db_middleware -from orgs.tasks import send_auto_accepted_email, send_invitation_email +from .tasks import send_auto_accepted_email, send_invitation_email if typing.TYPE_CHECKING: from app_users.models import AppUser, AppUserTransaction -ORG_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$") +WORKSPACE_DOMAIN_NAME_RE = re.compile(r"^[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]+$") -def validate_org_domain_name(value): +def validate_workspace_domain_name(value): from handles.models import COMMON_EMAIL_DOMAINS - if not ORG_DOMAIN_NAME_RE.fullmatch(value): + if not WORKSPACE_DOMAIN_NAME_RE.fullmatch(value): raise ValidationError("Invalid domain name") if value in COMMON_EMAIL_DOMAINS: raise ValidationError("This domain name is reserved") -class OrgRole(models.IntegerChoices): +class WorkspaceRole(models.IntegerChoices): OWNER = 1 ADMIN = 2 MEMBER = 3 -class OrgManager(SafeDeleteManager): - def create_org( +class WorkspaceManager(SafeDeleteManager): + def create_workspace( self, *, created_by: "AppUser", - org_id: str | None = None, balance: int | None = None, **kwargs, - ) -> Org: - org = self.model( - org_id=org_id or get_random_doc_id(), + ) -> Workspace: + workspace = self.model( created_by=created_by, balance=balance, **kwargs, ) if ( balance is None - and Org.all_objects.filter(created_by=created_by).count() <= 1 + and Workspace.all_objects.filter(created_by=created_by).count() <= 1 ): # set some balance for first team created by user - # Org.all_objects is important to include deleted orgs - org.balance = settings.FIRST_ORG_FREE_CREDITS + # Workspace.all_objects is important to include deleted workspaces + workspace.balance = settings.FIRST_WORKSPACE_FREE_CREDITS - org.full_clean() - org.save() - org.add_member( + workspace.full_clean() + workspace.save() + workspace.add_member( created_by, - role=OrgRole.OWNER, + role=WorkspaceRole.OWNER, ) - return org + return workspace - def get_or_create_from_org_id(self, org_id: str) -> tuple[Org, bool]: - from app_users.models import AppUser + def get_or_create_from_uid(self, uid: str) -> tuple[Workspace, bool]: + workspace = Workspace.objects.filter( + is_personal=True, created_by__uid=uid + ).first() + if workspace: + return workspace, False - try: - return self.get(org_id=org_id), False - except self.model.DoesNotExist: - user = AppUser.objects.get_or_create_from_uid(org_id)[0] - return self.migrate_from_appuser(user), True + user, _ = AppUser.objects.get_or_create_from_uid(uid) + workspace = self.migrate_from_appuser(user) + return workspace, True - def migrate_from_appuser(self, user: "AppUser") -> Org: - return self.create_org( + def migrate_from_appuser(self, user: "AppUser") -> Workspace: + return self.create_workspace( name=f"{user.first_name()}'s Personal Workspace", - org_id=user.uid or get_random_doc_id(), created_by=user, is_personal=True, balance=user.balance, @@ -108,24 +108,22 @@ def get_dollars_spent_this_month(self) -> float: return (cents_spent or 0) / 100 -class Org(SafeDeleteModel): +class Workspace(SafeDeleteModel): _safedelete_policy = SOFT_DELETE_CASCADE - org_id = models.CharField(max_length=100, null=True, blank=True, unique=True) - name = models.CharField(max_length=100) created_by = models.ForeignKey( "app_users.appuser", on_delete=models.CASCADE, ) - logo = models.URLField(null=True, blank=True) + logo = CustomURLField(null=True, blank=True) domain_name = models.CharField( max_length=30, blank=True, null=True, validators=[ - validate_org_domain_name, + validate_workspace_domain_name, ], ) @@ -136,7 +134,7 @@ class Org(SafeDeleteModel): subscription = models.OneToOneField( "payments.Subscription", on_delete=models.SET_NULL, - related_name="org", + related_name="workspace", null=True, blank=True, ) @@ -147,7 +145,7 @@ class Org(SafeDeleteModel): created_at = models.DateTimeField(auto_now_add=True) updated_at = models.DateTimeField(auto_now=True) - objects = OrgManager() + objects = WorkspaceManager() class Meta: constraints = [ @@ -160,7 +158,7 @@ class Meta: models.UniqueConstraint( "created_by", condition=Q(deleted__isnull=True, is_personal=True), - name="unique_personal_org_per_user", + name="unique_personal_workspace_per_user", ), ] @@ -174,10 +172,13 @@ def get_slug(self): return slugify(self.name) def add_member( - self, user: "AppUser", role: OrgRole, invitation: "OrgInvitation | None" = None + self, + user: "AppUser", + role: WorkspaceRole, + invitation: "WorkspaceInvitation | None" = None, ): - OrgMembership( - org=self, + WorkspaceMembership( + workspace=self, user=user, role=role, invitation=invitation, @@ -188,9 +189,9 @@ def invite_user( *, invitee_email: str, inviter: "AppUser", - role: OrgRole, + role: WorkspaceRole, auto_accept: bool = False, - ) -> "OrgInvitation": + ) -> "WorkspaceInvitation": """ auto_accept: If True, the user will be automatically added if they have an account """ @@ -198,15 +199,17 @@ def invite_user( if member.user.email == invitee_email: raise ValidationError(f"{member.user} is already a member of this team") - for invitation in self.invitations.filter(status=OrgInvitation.Status.PENDING): + for invitation in self.invitations.filter( + status=WorkspaceInvitation.Status.PENDING + ): if invitation.invitee_email == invitee_email: raise ValidationError( f"{invitee_email} was already invited to this team" ) - invitation = OrgInvitation( + invitation = WorkspaceInvitation( invite_id=get_random_doc_id(), - org=self, + workspace=self, invitee_email=invitee_email, inviter=inviter, role=role, @@ -225,8 +228,8 @@ def invite_user( return invitation - def get_owners(self) -> list[OrgMembership]: - return self.memberships.filter(role=OrgRole.OWNER) + def get_owners(self) -> models.QuerySet[WorkspaceMembership]: + return self.memberships.filter(role=WorkspaceRole.OWNER) @db_middleware @transaction.atomic @@ -255,51 +258,58 @@ def add_balance( # It won't lock this row for reads, and multiple threads can update the same row leading incorrect balance # # Also we're not using .update() here because it won't give back the updated end balance - org: Org = Org.objects.select_for_update().get(pk=self.pk) - org.balance += amount - org.save(update_fields=["balance"]) - kwargs.setdefault("plan", org.subscription and org.subscription.plan) + workspace: Workspace = Workspace.objects.select_for_update().get(pk=self.pk) + workspace.balance += amount + workspace.save(update_fields=["balance"]) + kwargs.setdefault( + "plan", workspace.subscription and workspace.subscription.plan + ) return AppUserTransaction.objects.create( - org=org, - user=org.created_by if org.is_personal else None, + workspace=workspace, + user=workspace.created_by if workspace.is_personal else None, invoice_id=invoice_id, amount=amount, - end_balance=org.balance, + end_balance=workspace.balance, **kwargs, ) def get_or_create_stripe_customer(self) -> stripe.Customer: customer = self.search_stripe_customer() if not customer: + metadata = {"workspace_id": self.id} + if self.is_personal: + metadata["uid"] = self.created_by.uid + customer = stripe.Customer.create( name=self.created_by.display_name, email=self.created_by.email, phone=self.created_by.phone_number, - metadata={"uid": self.org_id, "org_id": self.org_id, "id": self.pk}, + metadata=metadata, ) self.stripe_customer_id = customer.id self.save() return customer def search_stripe_customer(self) -> stripe.Customer | None: - if not self.org_id: - return None if self.stripe_customer_id: try: return stripe.Customer.retrieve(self.stripe_customer_id) - except stripe.error.InvalidRequestError as e: + except stripe.InvalidRequestError as e: if e.http_status != 404: raise + try: customer = stripe.Customer.search( - query=f'metadata["uid"]:"{self.org_id}"' + query=f'metadata["workspace_id"]:"{self.id}"' ).data[0] except IndexError: - return None - else: - self.stripe_customer_id = customer.id - self.save() - return customer + customer = self.is_personal and self.created_by.search_stripe_customer() + if not customer: + return None + + self.stripe_customer_id = customer.id + self.save() + return customer def get_dollars_spent_this_month(self) -> float: today = timezone.now() @@ -311,13 +321,17 @@ def get_dollars_spent_this_month(self) -> float: return (cents_spent or 0) / 100 -class OrgMembership(SafeDeleteModel): - org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="memberships") +class WorkspaceMembership(SafeDeleteModel): + workspace = models.ForeignKey( + Workspace, on_delete=models.CASCADE, related_name="memberships" + ) user = models.ForeignKey( - "app_users.AppUser", on_delete=models.CASCADE, related_name="org_memberships" + "app_users.AppUser", + on_delete=models.CASCADE, + related_name="workspace_memberships", ) invitation = models.OneToOneField( - "OrgInvitation", + "WorkspaceInvitation", on_delete=models.SET_NULL, blank=True, null=True, @@ -325,7 +339,9 @@ class OrgMembership(SafeDeleteModel): related_name="membership", ) - role = models.IntegerField(choices=OrgRole.choices, default=OrgRole.MEMBER) + role = models.IntegerField( + choices=WorkspaceRole.choices, default=WorkspaceRole.MEMBER + ) created_at = models.DateTimeField(auto_now_add=True) # same as joining date updated_at = models.DateTimeField(auto_now=True) @@ -335,45 +351,45 @@ class OrgMembership(SafeDeleteModel): class Meta: constraints = [ models.UniqueConstraint( - fields=["org", "user"], + fields=["workspace", "user"], condition=Q(deleted__isnull=True), - name="unique_org_user", + name="unique_workspace_user", ) ] def __str__(self): - return f"{self.get_role_display()} - {self.user} ({self.org})" + return f"{self.get_role_display()} - {self.user} ({self.workspace})" - def can_edit_org_metadata(self): - return self.role in (OrgRole.OWNER, OrgRole.ADMIN) + def can_edit_workspace_metadata(self): + return self.role in (WorkspaceRole.OWNER, WorkspaceRole.ADMIN) - def can_delete_org(self): - return self.role == OrgRole.OWNER + def can_delete_workspace(self): + return self.role == WorkspaceRole.OWNER - def has_higher_role_than(self, other: "OrgMembership"): + def has_higher_role_than(self, other: "WorkspaceMembership"): # creator > owner > admin > member match other.role: - case OrgRole.OWNER: - return self.org.created_by == OrgRole.OWNER - case OrgRole.ADMIN: - return self.role == OrgRole.OWNER - case OrgRole.MEMBER: - return self.role in (OrgRole.OWNER, OrgRole.ADMIN) - - def can_change_role(self, other: "OrgMembership"): + case WorkspaceRole.OWNER: + return self.workspace.created_by == WorkspaceRole.OWNER + case WorkspaceRole.ADMIN: + return self.role == WorkspaceRole.OWNER + case WorkspaceRole.MEMBER: + return self.role in (WorkspaceRole.OWNER, WorkspaceRole.ADMIN) + + def can_change_role(self, other: "WorkspaceMembership"): return self.has_higher_role_than(other) - def can_kick(self, other: "OrgMembership"): + def can_kick(self, other: "WorkspaceMembership"): return self.has_higher_role_than(other) def can_transfer_ownership(self): - return self.role == OrgRole.OWNER + return self.role == WorkspaceRole.OWNER def can_invite(self): - return self.role in (OrgRole.OWNER, OrgRole.ADMIN) + return self.role in (WorkspaceRole.OWNER, WorkspaceRole.ADMIN) -class OrgInvitation(SafeDeleteModel): +class WorkspaceInvitation(SafeDeleteModel): class Status(models.IntegerChoices): PENDING = 1 ACCEPTED = 2 @@ -384,14 +400,18 @@ class Status(models.IntegerChoices): invite_id = models.CharField(max_length=100, unique=True) invitee_email = models.EmailField() - org = models.ForeignKey(Org, on_delete=models.CASCADE, related_name="invitations") + workspace = models.ForeignKey( + Workspace, on_delete=models.CASCADE, related_name="invitations" + ) inviter = models.ForeignKey( "app_users.AppUser", on_delete=models.CASCADE, related_name="sent_invitations" ) status = models.IntegerField(choices=Status.choices, default=Status.PENDING) auto_accepted = models.BooleanField(default=False) - role = models.IntegerField(choices=OrgRole.choices, default=OrgRole.MEMBER) + role = models.IntegerField( + choices=WorkspaceRole.choices, default=WorkspaceRole.MEMBER + ) last_email_sent_at = models.DateTimeField(null=True, blank=True, default=None) status_changed_at = models.DateTimeField(null=True, blank=True, default=None) @@ -407,12 +427,12 @@ class Status(models.IntegerChoices): updated_at = models.DateTimeField(auto_now=True) def __str__(self): - return f"{self.invitee_email} - {self.org} ({self.get_status_display()})" + return f"{self.invitee_email} - {self.workspace} ({self.get_status_display()})" def has_expired(self): return self.status == self.Status.EXPIRED or ( timezone.now() - (self.last_email_sent_at or self.created_at) - > timedelta(days=settings.ORG_INVITATION_EXPIRY_DAYS) + > timedelta(days=settings.WORKSPACE_INVITATION_EXPIRY_DAYS) ) def auto_accept(self): @@ -431,7 +451,9 @@ def auto_accept(self): self.accept(invitee, auto_accepted=True) if self.auto_accepted: - logger.info(f"User {invitee} auto-accepted invitation to org {self.org}") + logger.info( + f"User {invitee} auto-accepted invitation to workspace {self.workspace}" + ) send_auto_accepted_email.delay(self.pk) def get_url(self): @@ -439,7 +461,10 @@ def get_url(self): return get_app_route_url( invitation_route, - path_params={"invite_id": self.invite_id, "org_slug": self.org.get_slug()}, + path_params={ + "invite_id": self.invite_id, + "workspace_slug": self.workspace.get_slug(), + }, ) def send_email(self): @@ -469,7 +494,7 @@ def accept(self, user: "AppUser", *, auto_accepted: bool = False): "This invitation has expired. Please ask your team admin to send a new one." ) - if self.org.memberships.filter(user_id=user.pk).exists(): + if self.workspace.memberships.filter(user_id=user.pk).exists(): raise ValidationError(f"User is already a member of this team.") self.status = self.Status.ACCEPTED @@ -480,8 +505,8 @@ def accept(self, user: "AppUser", *, auto_accepted: bool = False): self.full_clean() with transaction.atomic(): - user.org_memberships.all().delete() # delete current memberships - self.org.add_member( + user.workspace_memberships.all().delete() # delete current memberships + self.workspace.add_member( user, role=self.role, invitation=self, @@ -505,5 +530,5 @@ def can_resend_email(self): return True return timezone.now() - self.last_email_sent_at > timedelta( - seconds=settings.ORG_INVITATION_EMAIL_COOLDOWN_INTERVAL + seconds=settings.WORKSPACE_INVITATION_EMAIL_COOLDOWN_INTERVAL ) diff --git a/workspaces/signals.py b/workspaces/signals.py new file mode 100644 index 000000000..962ffe794 --- /dev/null +++ b/workspaces/signals.py @@ -0,0 +1,50 @@ +from django.db.models.signals import post_save +from django.dispatch import receiver +from loguru import logger +from safedelete.signals import post_softdelete + +from app_users.models import AppUser +from .models import Workspace, WorkspaceMembership, WorkspaceRole + + +@receiver(post_save, sender=AppUser) +def add_user_existing_workspace(instance: AppUser, **kwargs): + """ + if the domain name matches + """ + if not instance.email: + return + + email_domain = instance.email.split("@")[1] + workspace = Workspace.objects.filter(domain_name=email_domain).first() + if not workspace: + return + + if instance.received_invitations.exists(): + # user has some existing invitations + return + + workspace_owner = workspace.memberships.filter(role=WorkspaceRole.OWNER).first() + if not workspace_owner: + logger.warning( + f"Workspace {workspace} has no owner. Skipping auto-accept for user {instance}" + ) + return + + workspace.invite_user( + invitee_email=instance.email, + inviter=workspace_owner.user, + role=WorkspaceRole.MEMBER, + auto_accept=not instance.workspace_memberships.exists(), # auto-accept only if user has no existing memberships + ) + + +@receiver(post_softdelete, sender=WorkspaceMembership) +def delete_workspace_if_no_members_left(instance: WorkspaceMembership, **kwargs): + if instance.workspace.memberships.exists(): + return + + logger.info( + f"Deleting workspace {instance.workspace} because it has no members left" + ) + instance.workspace.delete() diff --git a/orgs/tasks.py b/workspaces/tasks.py similarity index 67% rename from orgs/tasks.py rename to workspaces/tasks.py index 09258c9ec..bdfd416ed 100644 --- a/orgs/tasks.py +++ b/workspaces/tasks.py @@ -10,20 +10,20 @@ @app.task def send_invitation_email(invitation_pk: int): - from orgs.models import OrgInvitation + from workspaces.models import WorkspaceInvitation - invitation = OrgInvitation.objects.get(pk=invitation_pk) + invitation = WorkspaceInvitation.objects.get(pk=invitation_pk) assert invitation.status == invitation.Status.PENDING logger.info( - f"Sending inviation email to {invitation.invitee_email} for org {invitation.org}..." + f"Sending inviation email to {invitation.invitee_email} for workspace {invitation.workspace}..." ) send_email_via_postmark( to_address=invitation.invitee_email, from_address=settings.SUPPORT_EMAIL, - subject=f"[Gooey.AI] Invitation to join {invitation.org.name}", - html_body=templates.get_template("org_invitation_email.html").render( + subject=f"[Gooey.AI] Invitation to join {invitation.workspace.name}", + html_body=templates.get_template("workspace_invitation_email.html").render( settings=settings, invitation=invitation, ), @@ -37,10 +37,10 @@ def send_invitation_email(invitation_pk: int): @app.task def send_auto_accepted_email(invitation_pk: int): - from orgs.models import OrgInvitation - from routers.account import orgs_route + from workspaces.models import WorkspaceInvitation + from routers.account import workspaces_route - invitation = OrgInvitation.objects.get(pk=invitation_pk) + invitation = WorkspaceInvitation.objects.get(pk=invitation_pk) assert invitation.auto_accepted and invitation.status == invitation.Status.ACCEPTED assert invitation.status_changed_by @@ -50,19 +50,19 @@ def send_auto_accepted_email(invitation_pk: int): return logger.info( - f"Sending auto-accepted email to {user.email} for org {invitation.org}..." + f"Sending auto-accepted email to {user.email} for workspace {invitation.workspace}..." ) send_email_via_postmark( to_address=user.email, from_address=settings.SUPPORT_EMAIL, subject=f"[Gooey.AI] You've been added to a new team!", html_body=templates.get_template( - "org_invitation_auto_accepted_email.html" + "workspace_invitation_auto_accepted_email.html" ).render( settings=settings, user=user, - org=invitation.org, - orgs_url=get_app_route_url(orgs_route), + workspace=invitation.workspace, + workspaces_url=get_app_route_url(workspaces_route), ), message_stream="outbound", ) diff --git a/orgs/tests.py b/workspaces/tests.py similarity index 100% rename from orgs/tests.py rename to workspaces/tests.py diff --git a/orgs/views.py b/workspaces/views.py similarity index 64% rename from orgs/views.py rename to workspaces/views.py index 494bac72a..75121d1cd 100644 --- a/orgs/views.py +++ b/workspaces/views.py @@ -5,43 +5,44 @@ import gooey_gui as gui from django.core.exceptions import ValidationError +from .models import Workspace, WorkspaceInvitation, WorkspaceMembership, WorkspaceRole from app_users.models import AppUser -from orgs.models import Org, OrgInvitation, OrgMembership, OrgRole from daras_ai_v2 import icons from daras_ai_v2.fastapi_tricks import get_route_path -DEFAULT_ORG_LOGO = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/74a37c52-8260-11ee-a297-02420a0001ee/gooey.ai%20-%20A%20pop%20art%20illustration%20of%20robots%20taki...y%20Liechtenstein%20mint%20colour%20is%20main%20city%20Seattle.png" +DEFAULT_WORKSPACE_LOGO = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/74a37c52-8260-11ee-a297-02420a0001ee/gooey.ai%20-%20A%20pop%20art%20illustration%20of%20robots%20taki...y%20Liechtenstein%20mint%20colour%20is%20main%20city%20Seattle.png" rounded_border = "w-100 border shadow-sm rounded py-4 px-3" -def invitation_page(user: AppUser, invitation: OrgInvitation): - from routers.account import orgs_route +def invitation_page(user: AppUser, invitation: WorkspaceInvitation): + from routers.account import workspaces_route - orgs_page_path = get_route_path(orgs_route) + workspaces_page_path = get_route_path(workspaces_route) with gui.div(className="text-center my-5"): gui.write( - f"# Invitation to join {invitation.org.name}", className="d-block mb-5" + f"# Invitation to join {invitation.workspace.name}", + className="d-block mb-5", ) - if invitation.org.memberships.filter(user=user).exists(): - # redirect to org page - raise gui.RedirectException(orgs_page_path) + if invitation.workspace.memberships.filter(user=user).exists(): + # redirect to workspace page + raise gui.RedirectException(workspaces_page_path) - if invitation.status != OrgInvitation.Status.PENDING: + if invitation.status != WorkspaceInvitation.Status.PENDING: gui.write(f"This invitation has been {invitation.get_status_display()}.") return gui.write( - f"**{format_user_name(invitation.inviter)}** has invited you to join **{invitation.org.name}**." + f"**{format_user_name(invitation.inviter)}** has invited you to join **{invitation.workspace.name}**." ) - if other_m := user.org_memberships.first(): + if other_m := user.workspace_memberships.first(): gui.caption( - f"You are currently a member of [{other_m.org.name}]({orgs_page_path}). You will be removed from that team if you accept this invitation." + f"You are currently a member of [{other_m.workspace.name}]({workspaces_page_path}). You will be removed from that team if you accept this invitation." ) accept_label = "Leave and Accept" else: @@ -56,58 +57,61 @@ def invitation_page(user: AppUser, invitation: OrgInvitation): if accept_button: invitation.accept(user=user) - raise gui.RedirectException(orgs_page_path) + raise gui.RedirectException(workspaces_page_path) if reject_button: invitation.reject(user=user) -def orgs_page(user: AppUser): - memberships = user.org_memberships.filter() +def workspaces_page(user: AppUser): + memberships = user.workspace_memberships.filter() if not memberships: - gui.write("*You're not part of an organization yet... Create one?*") + gui.write("*You're not part of an workspaceanization yet... Create one?*") - render_org_creation_view(user) + render_workspace_creation_view(user) else: - # only support one org for now - render_org_by_membership(memberships.first()) + # only support one workspace for now + render_workspace_by_membership(memberships.first()) -def render_org_by_membership(membership: OrgMembership): +def render_workspace_by_membership(membership: WorkspaceMembership): """ membership object has all the information we need: - - org + - workspace - current user - - current user's role in the org (and other metadata) + - current user's role in the workspace (and other metadata) """ - org = membership.org + workspace = membership.workspace current_user = membership.user with gui.div( className="d-xs-block d-sm-flex flex-row-reverse justify-content-between" ): with gui.div(className="d-flex justify-content-center align-items-center"): - if membership.can_edit_org_metadata(): - org_edit_modal = gui.Modal("Edit Org", key="edit-org-modal") - if org_edit_modal.is_open(): - with org_edit_modal.container(): - render_org_edit_view_by_membership( - membership, modal=org_edit_modal + if membership.can_edit_workspace_metadata(): + workspace_edit_modal = gui.Modal( + "Edit Workspace", key="edit-workspace-modal" + ) + if workspace_edit_modal.is_open(): + with workspace_edit_modal.container(): + render_workspace_edit_view_by_membership( + membership, modal=workspace_edit_modal ) if gui.button(f"{icons.edit} Edit", type="secondary"): - org_edit_modal.open() + workspace_edit_modal.open() with gui.div(className="d-flex align-items-center"): gui.image( - org.logo or DEFAULT_ORG_LOGO, + workspace.logo or DEFAULT_WORKSPACE_LOGO, className="my-0 me-4 rounded", style={"width": "128px", "height": "128px", "object-fit": "contain"}, ) with gui.div(className="d-flex flex-column justify-content-center"): - gui.write(f"# {org.name}") - if org.domain_name: + gui.write(f"# {workspace.name}") + if workspace.domain_name: gui.write( - f"Org Domain: `@{org.domain_name}`", className="text-muted" + f"Workspace Domain: `@{workspace.domain_name}`", + className="text-muted", ) with gui.div(className="mt-4"): @@ -122,38 +126,44 @@ def render_org_by_membership(membership: OrgMembership): if invite_modal.is_open(): with invite_modal.container(): render_invite_creation_view( - org=org, inviter=current_user, modal=invite_modal + workspace=workspace, + inviter=current_user, + modal=invite_modal, ) - render_members_list(org=org, current_member=membership) + render_members_list(workspace=workspace, current_member=membership) with gui.div(className="mt-4"): - render_pending_invitations_list(org=org, current_member=membership) + render_pending_invitations_list(workspace=workspace, current_member=membership) with gui.div(className="mt-4"): - org_leave_modal = gui.Modal("Leave Org", key="leave-org-modal") - if org_leave_modal.is_open(): - with org_leave_modal.container(): - render_org_leave_view_by_membership(membership, modal=org_leave_modal) + workspace_leave_modal = gui.Modal( + "Leave Workspace", key="leave-workspace-modal" + ) + if workspace_leave_modal.is_open(): + with workspace_leave_modal.container(): + render_workspace_leave_view_by_membership( + membership, modal=workspace_leave_modal + ) with gui.div(className="text-end"): - leave_org = gui.button( + leave_workspace = gui.button( "Leave", className="btn btn-theme bg-danger border-danger text-white", ) - if leave_org: - org_leave_modal.open() + if leave_workspace: + workspace_leave_modal.open() -def render_org_creation_view(user: AppUser): - gui.write(f"# {icons.company} Create an Org", unsafe_allow_html=True) - org_fields = render_org_create_or_edit_form() +def render_workspace_creation_view(user: AppUser): + gui.write(f"# {icons.company} Create an Workspace", unsafe_allow_html=True) + workspace_fields = render_workspace_create_or_edit_form() if gui.button("Create"): try: - Org.objects.create_org( + Workspace.objects.create_workspace( created_by=user, - **org_fields, + **workspace_fields, ) except ValidationError as e: gui.write(", ".join(e.messages), className="text-danger") @@ -161,50 +171,54 @@ def render_org_creation_view(user: AppUser): gui.rerun() -def render_org_edit_view_by_membership(membership: OrgMembership, *, modal: gui.Modal): - org = membership.org - render_org_create_or_edit_form(org=org) +def render_workspace_edit_view_by_membership( + membership: WorkspaceMembership, *, modal: gui.Modal +): + workspace = membership.workspace + render_workspace_create_or_edit_form(workspace=workspace) if gui.button("Save", className="w-100", type="primary"): try: - org.full_clean() + workspace.full_clean() except ValidationError as e: # newlines in markdown gui.write(" \n".join(e.messages), className="text-danger") else: - org.save() + workspace.save() modal.close() - if membership.can_delete_org() or membership.can_transfer_ownership(): + if membership.can_delete_workspace() or membership.can_transfer_ownership(): gui.write("---") render_danger_zone_by_membership(membership) -def render_danger_zone_by_membership(membership: OrgMembership): +def render_danger_zone_by_membership(membership: WorkspaceMembership): gui.write("### Danger Zone", className="d-block my-2") - if membership.can_delete_org(): - org_deletion_modal = gui.Modal("Delete Organization", key="delete-org-modal") - if org_deletion_modal.is_open(): - with org_deletion_modal.container(): - render_org_deletion_view_by_membership( - membership, modal=org_deletion_modal + if membership.can_delete_workspace(): + workspace_deletion_modal = gui.Modal( + "Delete Workspaceanization", key="delete-workspace-modal" + ) + if workspace_deletion_modal.is_open(): + with workspace_deletion_modal.container(): + render_workspace_deletion_view_by_membership( + membership, modal=workspace_deletion_modal ) with gui.div(className="d-flex justify-content-between align-items-center"): - gui.write("Delete Organization") + gui.write("Delete Workspaceanization") if gui.button( f"{icons.delete} Delete", className="btn btn-theme py-2 bg-danger border-danger text-white", ): - org_deletion_modal.open() + workspace_deletion_modal.open() -def render_org_deletion_view_by_membership( - membership: OrgMembership, *, modal: gui.Modal +def render_workspace_deletion_view_by_membership( + membership: WorkspaceMembership, *, modal: gui.Modal ): gui.write( - f"Are you sure you want to delete **{membership.org.name}**? This action is irreversible." + f"Are you sure you want to delete **{membership.workspace.name}**? This action is irreversible." ) with gui.div(className="d-flex"): @@ -216,34 +230,37 @@ def render_org_deletion_view_by_membership( if gui.button( "Delete", className="btn btn-theme bg-danger border-danger text-light w-50" ): - membership.org.delete() + membership.workspace.delete() modal.close() -def render_org_leave_view_by_membership( - current_member: OrgMembership, *, modal: gui.Modal +def render_workspace_leave_view_by_membership( + current_member: WorkspaceMembership, *, modal: gui.Modal ): - org = current_member.org + workspace = current_member.workspace - gui.write("Are you sure you want to leave this organization?") + gui.write("Are you sure you want to leave this workspaceanization?") new_owner = None - if current_member.role == OrgRole.OWNER and org.memberships.count() == 1: + if ( + current_member.role == WorkspaceRole.OWNER + and workspace.memberships.count() == 1 + ): gui.caption( "You are the only member. You will lose access to this team if you leave." ) elif ( - current_member.role == OrgRole.OWNER - and org.memberships.filter(role=OrgRole.OWNER).count() == 1 + current_member.role == WorkspaceRole.OWNER + and workspace.memberships.filter(role=WorkspaceRole.OWNER).count() == 1 ): members_by_uid = { m.user.uid: m - for m in org.memberships.all().select_related("user") + for m in workspace.memberships.all().select_related("user") if m != current_member } gui.caption( - "You are the only owner of this organization. Please choose another member to promote to owner." + "You are the only owner of this workspaceanization. Please choose another member to promote to owner." ) new_owner_uid = gui.selectbox( "New Owner", @@ -262,13 +279,13 @@ def render_org_leave_view_by_membership( "Leave", className="btn btn-theme bg-danger border-danger text-light w-50" ): if new_owner: - new_owner.role = OrgRole.OWNER + new_owner.role = WorkspaceRole.OWNER new_owner.save() current_member.delete() modal.close() -def render_members_list(org: Org, current_member: OrgMembership): +def render_members_list(workspace: Workspace, current_member: WorkspaceMembership): with gui.tag("table", className="table table-responsive"): with gui.tag("thead"), gui.tag("tr"): with gui.tag("th", scope="col"): @@ -281,7 +298,7 @@ def render_members_list(org: Org, current_member: OrgMembership): gui.html("") with gui.tag("tbody"): - for m in org.memberships.all().order_by("created_at"): + for m in workspace.memberships.all().order_by("created_at"): with gui.tag("tr"): with gui.tag("td"): name = format_user_name( @@ -300,9 +317,11 @@ def render_members_list(org: Org, current_member: OrgMembership): render_membership_actions(m, current_member=current_member) -def render_membership_actions(m: OrgMembership, current_member: OrgMembership): +def render_membership_actions( + m: WorkspaceMembership, current_member: WorkspaceMembership +): if current_member.can_change_role(m): - if m.role == OrgRole.MEMBER: + if m.role == WorkspaceRole.MEMBER: modal, confirmed = button_with_confirmation_modal( f"{icons.admin} Make Admin", key=f"promote-member-{m.pk}", @@ -312,10 +331,10 @@ def render_membership_actions(m: OrgMembership, current_member: OrgMembership): modal_key=f"promote-member-{m.pk}-modal", ) if confirmed: - m.role = OrgRole.ADMIN + m.role = WorkspaceRole.ADMIN m.save() modal.close() - elif m.role == OrgRole.ADMIN: + elif m.role == WorkspaceRole.ADMIN: modal, confirmed = button_with_confirmation_modal( f"{icons.remove_user} Revoke Admin", key=f"demote-member-{m.pk}", @@ -325,7 +344,7 @@ def render_membership_actions(m: OrgMembership, current_member: OrgMembership): modal_key=f"demote-member-{m.pk}-modal", ) if confirmed: - m.role = OrgRole.MEMBER + m.role = WorkspaceRole.MEMBER m.save() modal.close() @@ -334,7 +353,7 @@ def render_membership_actions(m: OrgMembership, current_member: OrgMembership): f"{icons.remove_user} Remove", key=f"remove-member-{m.pk}", unsafe_allow_html=True, - confirmation_text=f"Are you sure you want to remove **{format_user_name(m.user)}** from **{m.org.name}**?", + confirmation_text=f"Are you sure you want to remove **{format_user_name(m.user)}** from **{m.workspace.name}**?", modal_title="Remove Member", modal_key=f"remove-member-{m.pk}-modal", className="bg-danger border-danger text-light", @@ -382,8 +401,12 @@ def button_with_confirmation_modal( return modal, False -def render_pending_invitations_list(org: Org, *, current_member: OrgMembership): - pending_invitations = org.invitations.filter(status=OrgInvitation.Status.PENDING) +def render_pending_invitations_list( + workspace: Workspace, *, current_member: WorkspaceMembership +): + pending_invitations = workspace.invitations.filter( + status=WorkspaceInvitation.Status.PENDING + ) if not pending_invitations: return @@ -419,7 +442,9 @@ def render_pending_invitations_list(org: Org, *, current_member: OrgMembership): render_invitation_actions(invite, current_member=current_member) -def render_invitation_actions(invitation: OrgInvitation, current_member: OrgMembership): +def render_invitation_actions( + invitation: WorkspaceInvitation, current_member: WorkspaceMembership +): if current_member.can_invite() and invitation.can_resend_email(): modal, confirmed = button_with_confirmation_modal( f"{icons.email} Resend", @@ -453,20 +478,23 @@ def render_invitation_actions(invitation: OrgInvitation, current_member: OrgMemb modal.close() -def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal): +def render_invite_creation_view( + workspace: Workspace, inviter: AppUser, modal: gui.Modal +): email = gui.text_input("Email") - if org.domain_name: + if workspace.domain_name: gui.caption( - f"Users with `@{org.domain_name}` email will be added automatically." + f"Users with `@{workspace.domain_name}` email will be added automatically." ) if gui.button(f"{icons.add_user} Invite", type="primary", unsafe_allow_html=True): try: - org.invite_user( + workspace.invite_user( invitee_email=email, inviter=inviter, - role=OrgRole.MEMBER, - auto_accept=org.domain_name.lower() == email.split("@")[1].lower(), + role=WorkspaceRole.MEMBER, + auto_accept=workspace.domain_name.lower() + == email.split("@")[1].lower(), ) except ValidationError as e: gui.write(", ".join(e.messages), className="text-danger") @@ -474,24 +502,28 @@ def render_invite_creation_view(org: Org, inviter: AppUser, modal: gui.Modal): modal.close() -def render_org_create_or_edit_form(org: Org | None = None) -> AttrDict | Org: - org_proxy = org or AttrDict() +def render_workspace_create_or_edit_form( + workspace: Workspace | None = None, +) -> AttrDict | Workspace: + workspace_proxy = workspace or AttrDict() - org_proxy.name = gui.text_input("Team Name", value=org and org.name or "") - org_proxy.logo = gui.file_uploader( - "Logo", accept=["image/*"], value=org and org.logo or "" + workspace_proxy.name = gui.text_input( + "Team Name", value=workspace and workspace.name or "" + ) + workspace_proxy.logo = gui.file_uploader( + "Logo", accept=["image/*"], value=workspace and workspace.logo or "" ) - org_proxy.domain_name = gui.text_input( + workspace_proxy.domain_name = gui.text_input( "Domain Name (Optional)", placeholder="e.g. gooey.ai", - value=org and org.domain_name or "", + value=workspace and workspace.domain_name or "", ) - if org_proxy.domain_name: + if workspace_proxy.domain_name: gui.caption( - f"Invite any user with `@{org_proxy.domain_name}` email to this organization." + f"Invite any user with `@{workspace_proxy.domain_name}` email to this workspaceanization." ) - return org_proxy + return workspace_proxy def format_user_name(user: AppUser, current_user: AppUser | None = None): From 555cfafc29cf2e06378b695a0f07b9345cb3f7cc Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 4 Sep 2024 18:43:37 +0530 Subject: [PATCH 070/110] fix: /v1/balance API should return balance from personal workspace --- routers/api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/routers/api.py b/routers/api.py index 5d2b4e42d..59d3e95b7 100644 --- a/routers/api.py +++ b/routers/api.py @@ -434,7 +434,8 @@ class BalanceResponse(BaseModel): @app.get("/v1/balance/", response_model=BalanceResponse, tags=["Misc"]) def get_balance(user: AppUser = Depends(api_auth_header)): - return BalanceResponse(balance=user.balance) + workspace, _ = user.get_or_create_personal_workspace() + return BalanceResponse(balance=workspace.balance) @app.get("/status") From 2edea041e84993328c95bb055aba5a7db367cd94 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:01:41 +0530 Subject: [PATCH 071/110] remove useless debug logging --- payments/webhooks.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/payments/webhooks.py b/payments/webhooks.py index 2c1820065..bff97390f 100644 --- a/payments/webhooks.py +++ b/payments/webhooks.py @@ -194,15 +194,12 @@ def handle_subscription_cancelled(cls, uid: str): @classmethod def handle_invoice_failed(cls, uid: str, data: dict): - logger.info(f"Invoice failed: {data}") - if stripe.Charge.list(payment_intent=data["payment_intent"], limit=1).has_more: # we must have already sent an invoice for this to the user. so we should just ignore this event logger.info("Charge already exists for this payment intent") return if data.get("metadata", {}).get("auto_recharge"): - logger.info("auto recharge failed... sending invoice email") send_payment_failed_email_with_invoice.delay( uid=uid, invoice_url=data["hosted_invoice_url"], @@ -210,17 +207,12 @@ def handle_invoice_failed(cls, uid: str, data: dict): subject="Payment failure on your Gooey.AI auto-recharge", ) elif data.get("subscription_details", {}): - print("subscription failed") send_payment_failed_email_with_invoice.delay( uid=uid, invoice_url=data["hosted_invoice_url"], dollar_amt=data["amount_due"] / 100, subject="Payment failure on your Gooey.AI subscription", ) - else: - print("not auto recharge or subscription") - print(f"{data.get('metadata')=}") - return def add_balance_for_payment( From 8d29766dd969df4b0b837b1de4d910667f6651f3 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:05:37 +0530 Subject: [PATCH 072/110] cleanup: remove base_email.html template --- templates/base_email.html | 22 -------- .../off_session_payment_failed_email.html | 51 ++++++++++++------- 2 files changed, 32 insertions(+), 41 deletions(-) delete mode 100644 templates/base_email.html diff --git a/templates/base_email.html b/templates/base_email.html deleted file mode 100644 index 63ab8b012..000000000 --- a/templates/base_email.html +++ /dev/null @@ -1,22 +0,0 @@ - - - - - - - - {% block title %}{% endblock title %} - - {% block head %}{% endblock head %} - - - - - -

- {% block content %}{% endblock content %} -
- - - - diff --git a/templates/off_session_payment_failed_email.html b/templates/off_session_payment_failed_email.html index c8624fec5..6eaaba7d2 100644 --- a/templates/off_session_payment_failed_email.html +++ b/templates/off_session_payment_failed_email.html @@ -1,25 +1,38 @@ -{% extends 'base_email.html' %} + + -{% block title %}Payment failed{% endblock title %} + + + -{% block content %} -

Hi {{ user.first_name() }},

+ + -

We attempted to process your payment for ${{ dollar_amt }} but your payment method was declined.

+ +
-

- Please make a payment on Gooey.AI for continued service or update - your payment method on your account. -

+

Hi {{ user.first_name() }},

-

- - - -

+

We attempted to process your payment for ${{ dollar_amt }} but your payment method was declined.

-

- Cheers,
- The Gooey.AI team -

-{% endblock content %} +

+ Please make a payment on Gooey.AI for continued service or update + your payment method on your account. +

+ +

+ + + +

+ +

+ Cheers,
+ The Gooey.AI team +

+ +
+ + + + From fbb022ee21b08f091a672ad8661a58f4ebdeeb2f Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:06:30 +0530 Subject: [PATCH 073/110] cleanup: rename off_session_payment_failed_email.html -> auto_payment_failed_email.html more obvious --- payments/tasks.py | 4 +--- ...yment_failed_email.html => auto_payment_failed_email.html} | 0 2 files changed, 1 insertion(+), 3 deletions(-) rename templates/{off_session_payment_failed_email.html => auto_payment_failed_email.html} (100%) diff --git a/payments/tasks.py b/payments/tasks.py index f0ee87fce..4a602c3ba 100644 --- a/payments/tasks.py +++ b/payments/tasks.py @@ -57,9 +57,7 @@ def send_payment_failed_email_with_invoice( from_address=settings.PAYMENT_EMAIL, to_address=user.email, subject=subject, - html_body=templates.get_template( - "off_session_payment_failed_email.html" - ).render( + html_body=templates.get_template("auto_payment_failed_email.html").render( user=user, dollar_amt=f"{dollar_amt:.2f}", invoice_url=invoice_url, diff --git a/templates/off_session_payment_failed_email.html b/templates/auto_payment_failed_email.html similarity index 100% rename from templates/off_session_payment_failed_email.html rename to templates/auto_payment_failed_email.html From 6cab6e1b76c9eadb481402e5ee4f64a175bf29b6 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:09:10 +0530 Subject: [PATCH 074/110] fix type error with import --- celeryapp/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index d3a54549b..e9a843b0e 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -161,7 +161,7 @@ def err_msg_for_exc(e: Exception): return f"{type(e).__name__}: {e}" -def run_low_balance_email_check(workspace: Workspace): +def run_low_balance_email_check(workspace: "Workspace"): # don't send email if feature is disabled if not settings.LOW_BALANCE_EMAIL_ENABLED: return From 1250368e487c91244763310526fac2901519e128 Mon Sep 17 00:00:00 2001 From: Kaustubh Maske Patil <37668193+nikochiko@users.noreply.github.com> Date: Wed, 4 Sep 2024 20:12:23 +0530 Subject: [PATCH 075/110] cleanup: remove unused GoogleLLM.py --- recipes/GoogleLLM.py | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 recipes/GoogleLLM.py diff --git a/recipes/GoogleLLM.py b/recipes/GoogleLLM.py deleted file mode 100644 index 5779f8b5a..000000000 --- a/recipes/GoogleLLM.py +++ /dev/null @@ -1,2 +0,0 @@ -class GoogleLLM: - pass From b9e0db149d54724b8d164747a28ae54ce717a2d8 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 4 Sep 2024 20:17:08 +0530 Subject: [PATCH 076/110] Refactor Lipsync cost calculation and duration handling - Show saved price from db instead of calculating price everytime - Rename `seconds` to `duration_sec` to match gpu api - Remove `truncated` from api, always show the truncating note in the input side - add sadtalker model pricing - migrate from `truncate_to_seconds` -> `max_frames` for lipsync --- daras_ai_v2/base.py | 6 +- daras_ai_v2/lipsync_api.py | 18 +++-- recipes/Lipsync.py | 68 +++++++++---------- scripts/init_llm_pricing.py | 2 +- scripts/init_self_hosted_pricing.py | 8 ++- .../0019_alter_modelpricing_model_name.py | 18 +++++ usage_costs/models.py | 1 + 7 files changed, 73 insertions(+), 48 deletions(-) create mode 100644 usage_costs/migrations/0019_alter_modelpricing_model_name.py diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 2233a0803..3ab715ea2 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1328,7 +1328,11 @@ def render_submit_button(self, key="--submit-1"): def render_run_cost(self): url = self.get_credits_click_url() - run_cost = self.get_price_roundoff(gui.session_state) + sr = self.get_current_sr() + if sr.price: + run_cost = sr.price + else: + run_cost = self.get_price_roundoff(gui.session_state) ret = f'Run cost = {run_cost} credits' cost_note = self.get_cost_note() diff --git a/daras_ai_v2/lipsync_api.py b/daras_ai_v2/lipsync_api.py index 9b724d41e..7aa81cc12 100644 --- a/daras_ai_v2/lipsync_api.py +++ b/daras_ai_v2/lipsync_api.py @@ -10,8 +10,8 @@ class LipsyncModel(Enum): - Wav2Lip = "SD: Fast but low-res" - SadTalker = "HD (SadTalker): Hi-res but slow" + Wav2Lip = "SD, Low-res (~480p), Fast (Rudrabha/Wav2Lip)" + SadTalker = "HD, Hi-res (max 1080p), Slow (OpenTalker/SadTalker)" class SadTalkerSettings(BaseModel): @@ -72,7 +72,7 @@ def run_sadtalker( settings: SadTalkerSettings, face: str, audio: str, - truncate_to_seconds: float | None = None, + max_frames: int | None = None, ) -> tuple[str, float]: links, metadata = call_celery_task_outfile_with_ret( "lipsync.sadtalker", @@ -80,11 +80,9 @@ def run_sadtalker( model_id="SadTalker_V0.0.2_512.safetensors", preprocess=settings.preprocess, ), - inputs=settings.dict() - | dict( - source_image=face, - driven_audio=audio, - truncate_to_seconds=truncate_to_seconds, + inputs=( + settings.dict() + | dict(source_image=face, driven_audio=audio, max_frames=max_frames) ), content_type="video/mp4", filename=f"gooey.ai lipsync.mp4", @@ -98,7 +96,7 @@ def run_wav2lip( face: str, audio: str, pads: tuple[int, int, int, int], - truncate_to_seconds: float | None = None, + max_frames: int | None = None, ) -> tuple[str, float]: try: links, metadata = call_celery_task_outfile_with_ret( @@ -114,7 +112,7 @@ def run_wav2lip( # "out_height": 480, # "smooth": True, # "fps": 25, - truncate_to_seconds=truncate_to_seconds, + max_frames=max_frames, ), content_type="video/mp4", filename=f"gooey.ai lipsync.mp4", diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index e737e1d4d..742dcaad1 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -1,9 +1,9 @@ import typing +from math import ceil -import requests +import gooey_gui as gui from pydantic import BaseModel -import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_selector @@ -12,9 +12,18 @@ from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.pydantic_validation import FieldHttpUrl +DEFAULT_LIPSYNC_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7fc4d302-9402-11ee-98dc-02420a0001ca/Lip%20Sync.jpg.png" + + CREDITS_PER_MINUTE = 36 -DEFAULT_LIPSYNC_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7fc4d302-9402-11ee-98dc-02420a0001ca/Lip%20Sync.jpg.png" + +def price_for_model(selected_model: str | None) -> float: + if selected_model == LipsyncModel.SadTalker.name: + multiplier = 2 + else: + multiplier = 1 + return CREDITS_PER_MINUTE * multiplier class LipsyncPage(BasePage): @@ -31,8 +40,7 @@ class RequestModel(LipsyncSettings, BasePage.RequestModel): class ResponseModel(BaseModel): output_video: FieldHttpUrl - seconds: float = 0 - truncated: bool = False + duration_sec: float | None def preview_image(self, state: dict) -> str | None: return DEFAULT_LIPSYNC_META_IMG @@ -55,6 +63,12 @@ def render_form_v2(self): """, key="input_audio", ) + if not (self.is_current_user_paying() or self.is_current_user_admin()): + gui.error( + "Input Audio longer than 10 seconds will be truncated for free users. Please [upgrade](/account) to generate long videos.", + icon="⚠️", + color="#ffe8b2", + ) enum_selector( LipsyncModel, @@ -74,17 +88,15 @@ def run(self, state: dict) -> typing.Iterator[str | None]: request = self.RequestModel.parse_obj(state) if self.is_current_user_paying() or self.is_current_user_admin(): - truncate_to_seconds = None - state["truncated"] = False + max_frames = None else: - truncate_to_seconds = 10 - state["truncated"] = True + max_frames = 250 model = LipsyncModel[request.selected_model] yield f"Running {model.value}..." match model: case LipsyncModel.Wav2Lip: - state["output_video"], state["seconds"] = run_wav2lip( + state["output_video"], state["duration_sec"] = run_wav2lip( face=request.input_face, audio=request.input_audio, pads=( @@ -93,14 +105,14 @@ def run(self, state: dict) -> typing.Iterator[str | None]: request.face_padding_left or 0, request.face_padding_right or 0, ), - truncate_to_seconds=truncate_to_seconds, + max_frames=max_frames, ) case LipsyncModel.SadTalker: - state["output_video"], state["seconds"] = run_sadtalker( + state["output_video"], state["duration_sec"] = run_sadtalker( request.sadtalker_settings, face=request.input_face, audio=request.input_audio, - truncate_to_seconds=truncate_to_seconds, + max_frames=max_frames, ) def render_example(self, state: dict): @@ -110,12 +122,6 @@ def render_example(self, state: dict): gui.video(output_video, autoplay=True, show_download_button=True) else: gui.div() - if state.get("truncated"): - st.error( - "Audio durations greater than 10 seconds will be truncated for free users. Please upgrade to process longer audio files.", - icon="⚠️", - color="orange", - ) def render_output(self): self.render_example(gui.session_state) @@ -135,22 +141,16 @@ def preview_description(self, state: dict) -> str: return "Create high-quality, realistic Lipsync animations from any audio file. Input a sample face gif/video + audio and we will automatically generate a lipsync animation that matches your audio." def get_cost_note(self) -> str | None: - multiplier = ( - 2 - if gui.session_state.get("selected_model") == LipsyncModel.SadTalker.name - else 1 - ) - return f"{CREDITS_PER_MINUTE * multiplier}/minute" + selected_model = gui.session_state.get("selected_model") + return f"{price_for_model(selected_model)}/minute" def get_raw_price(self, state: dict) -> float: - from math import ceil + try: + duration_sec = state["duration_sec"] + except KeyError: + return 1 + duration_sec = ceil(duration_sec / 5) * 5 # round up to nearest 5 seconds - seconds = self.get_duration(state) - seconds = ceil(seconds / 5) * 5 # round up to nearest 5 seconds - multiplier = ( - 2 if state.get("selected_model") == LipsyncModel.SadTalker.name else 1 - ) - return seconds * CREDITS_PER_MINUTE * multiplier / 60 + price = price_for_model(state.get("selected_model")) - def get_duration(self, state: dict) -> float: - return state.get("seconds", 0.0) + return duration_sec / 60 * price diff --git a/scripts/init_llm_pricing.py b/scripts/init_llm_pricing.py index a8be03a63..067cd0d2b 100644 --- a/scripts/init_llm_pricing.py +++ b/scripts/init_llm_pricing.py @@ -719,7 +719,7 @@ def llm_pricing_create( ), ) if created: - print(f"created {obj}") + print("created", obj) obj, created = ModelPricing.objects.get_or_create( model_id=model_id, sku=ModelSku.llm_completion, diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index d1df62a58..68bd63512 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -31,11 +31,13 @@ def run(): add_model(model_ids[m], m.name) except KeyError: pass + add_model("wav2lip_gan.pth", "wav2lip") + add_model("SadTalker_V0.0.2_512.safetensors", "sadtalker") -def add_model(model_id, model_name): - ModelPricing.objects.get_or_create( +def add_model(model_id: str, model_name: str): + obj, created = ModelPricing.objects.get_or_create( model_id=build_queue_name("gooey-gpu", model_id), sku=ModelSku.gpu_ms, defaults=dict( @@ -48,3 +50,5 @@ def add_model(model_id, model_name): pricing_url="https://azure.microsoft.com/en-in/pricing/details/virtual-machines/linux/#pricing", ), ) + if created: + print("created", obj) diff --git a/usage_costs/migrations/0019_alter_modelpricing_model_name.py b/usage_costs/migrations/0019_alter_modelpricing_model_name.py new file mode 100644 index 000000000..15de455fe --- /dev/null +++ b/usage_costs/migrations/0019_alter_modelpricing_model_name.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2024-09-05 08:19 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('usage_costs', '0018_alter_modelpricing_model_name'), + ] + + operations = [ + migrations.AlterField( + model_name='modelpricing', + name='model_name', + field=models.CharField(choices=[('gpt_4_o', 'GPT-4o (openai)'), ('gpt_4_o_mini', 'GPT-4o-mini (openai)'), ('chatgpt_4_o', 'ChatGPT-4o (openai) 🧪'), ('gpt_4_turbo_vision', 'GPT-4 Turbo with Vision (openai)'), ('gpt_4_vision', 'GPT-4 Vision (openai) 🔻'), ('gpt_4_turbo', 'GPT-4 Turbo (openai)'), ('gpt_4', 'GPT-4 (openai)'), ('gpt_4_32k', 'GPT-4 32K (openai) 🔻'), ('gpt_3_5_turbo', 'ChatGPT (openai)'), ('gpt_3_5_turbo_16k', 'ChatGPT 16k (openai)'), ('gpt_3_5_turbo_instruct', 'GPT-3.5 Instruct (openai) 🔻'), ('llama3_70b', 'Llama 3 70b (Meta AI)'), ('llama_3_groq_70b_tool_use', 'Llama 3 Groq 70b Tool Use'), ('llama3_8b', 'Llama 3 8b (Meta AI)'), ('llama_3_groq_8b_tool_use', 'Llama 3 Groq 8b Tool Use'), ('llama2_70b_chat', 'Llama 2 70b Chat [Deprecated] (Meta AI)'), ('mixtral_8x7b_instruct_0_1', 'Mixtral 8x7b Instruct v0.1 (Mistral)'), ('gemma_2_9b_it', 'Gemma 2 9B (Google)'), ('gemma_7b_it', 'Gemma 7B (Google)'), ('gemini_1_5_flash', 'Gemini 1.5 Flash (Google)'), ('gemini_1_5_pro', 'Gemini 1.5 Pro (Google)'), ('gemini_1_pro_vision', 'Gemini 1.0 Pro Vision (Google)'), ('gemini_1_pro', 'Gemini 1.0 Pro (Google)'), ('palm2_chat', 'PaLM 2 Chat (Google)'), ('palm2_text', 'PaLM 2 Text (Google)'), ('claude_3_5_sonnet', 'Claude 3.5 Sonnet (Anthropic)'), ('claude_3_opus', 'Claude 3 Opus [L] (Anthropic)'), ('claude_3_sonnet', 'Claude 3 Sonnet [M] (Anthropic)'), ('claude_3_haiku', 'Claude 3 Haiku [S] (Anthropic)'), ('sea_lion_7b_instruct', 'SEA-LION-7B-Instruct [Deprecated] (aisingapore)'), ('llama3_8b_cpt_sea_lion_v2_instruct', 'Llama3 8B CPT SEA-LIONv2 Instruct (aisingapore)'), ('sarvam_2b', 'Sarvam 2B (sarvamai)'), ('text_davinci_003', 'GPT-3.5 Davinci-3 [Deprecated] (openai)'), ('text_davinci_002', 'GPT-3.5 Davinci-2 [Deprecated] (openai)'), ('code_davinci_002', 'Codex [Deprecated] (openai)'), ('text_curie_001', 'Curie [Deprecated] (openai)'), ('text_babbage_001', 'Babbage [Deprecated] (openai)'), ('text_ada_001', 'Ada [Deprecated] (openai)'), ('protogen_2_2', 'Protogen V2.2 (darkstorm2150)'), ('epicdream', 'epiCDream (epinikion)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'DALL·E 2 (OpenAI)'), ('dall_e_3', 'DALL·E 3 (OpenAI)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero)'), ('openjourney', 'Open Journey (PromptHero)'), ('analog_diffusion', 'Analog Diffusion (wavymulder)'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('deepfloyd_if', 'DeepFloyd IF [Deprecated] (stability.ai)'), ('dream_shaper', 'DreamShaper (Lykon)'), ('dreamlike_2', 'Dreamlike Photoreal 2.0 (dreamlike.art)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('sd_1_5', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('instruct_pix2pix', '✨ InstructPix2Pix (Tim Brooks)'), ('openjourney_2', 'Open Journey v2 beta (PromptHero) 🐢'), ('openjourney', 'Open Journey (PromptHero) 🐢'), ('analog_diffusion', 'Analog Diffusion (wavymulder) 🐢'), ('protogen_5_3', 'Protogen v5.3 (darkstorm2150) 🐢'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('rodent_diffusion_1_5', 'Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)'), ('sd_2', 'Stable Diffusion v2.1 (stability.ai)'), ('runway_ml', 'Stable Diffusion v1.5 (RunwayML)'), ('dall_e', 'Dall-E (OpenAI)'), ('jack_qiao', 'Stable Diffusion v1.4 [Deprecated] (Jack Qiao)'), ('wav2lip', 'LipSync (wav2lip)'), ('sadtalker', 'LipSync (sadtalker)')], help_text='The name of the model. Only used for Display purposes.', max_length=255), + ), + ] diff --git a/usage_costs/models.py b/usage_costs/models.py index bfea9ec61..8941d1cd7 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -67,6 +67,7 @@ def get_model_choices(): + [(model.name, model.value) for model in Img2ImgModels] + [(model.name, model.value) for model in InpaintingModels] + [("wav2lip", "LipSync (wav2lip)")] + + [("sadtalker", "LipSync (sadtalker)")] ) From 8e82a01de7fb8b2776ae82ba323f1f8481fca25b Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 4 Sep 2024 17:46:30 +0530 Subject: [PATCH 077/110] Add Dockerfile and Captain Definition for Cloudflared deployment --- scripts/deployment/cloudflared.Dockerfile | 2 ++ scripts/deployment/cloudflared.captain-defintion | 4 ++++ 2 files changed, 6 insertions(+) create mode 100644 scripts/deployment/cloudflared.Dockerfile create mode 100644 scripts/deployment/cloudflared.captain-defintion diff --git a/scripts/deployment/cloudflared.Dockerfile b/scripts/deployment/cloudflared.Dockerfile new file mode 100644 index 000000000..2ee7b60b4 --- /dev/null +++ b/scripts/deployment/cloudflared.Dockerfile @@ -0,0 +1,2 @@ +FROM cloudflare/cloudflared +CMD ["tunnel", "run"] diff --git a/scripts/deployment/cloudflared.captain-defintion b/scripts/deployment/cloudflared.captain-defintion new file mode 100644 index 000000000..01e809ea9 --- /dev/null +++ b/scripts/deployment/cloudflared.captain-defintion @@ -0,0 +1,4 @@ +{ + "schemaVersion": 2, + "dockerfilePath": "./cloudflared.Dockerfile" +} From df0ef6eef7db9ad04163c440c3dcbc217348fefe Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 5 Sep 2024 15:13:57 +0530 Subject: [PATCH 078/110] Add new hashed entries to .gitleaksignore for explore.py --- .gitleaksignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitleaksignore b/.gitleaksignore index 022ac599f..abe66ae69 100644 --- a/.gitleaksignore +++ b/.gitleaksignore @@ -1,6 +1,9 @@ 4749e3ef005e8ddc6562d1bd82a00e752a7e94e3:explore.py:aws-access-token:16 +ede582d5a859726ff04fe371e9066f9dffb393f7:explore.py:aws-access-token:16 4749e3ef005e8ddc6562d1bd82a00e752a7e94e3:explore.py:private-key:23 +ede582d5a859726ff04fe371e9066f9dffb393f7:explore.py:private-key:23 4749e3ef005e8ddc6562d1bd82a00e752a7e94e3:explore.py:generic-api-key:32 +ede582d5a859726ff04fe371e9066f9dffb393f7:explore.py:generic-api-key:32 b0c80dac8e22faafa319d5466947df8723dfaa4a:daras_ai_v2/img_model_settings_widgets.py:generic-api-key:372 8670036e722f40530dbff3e0e7573e9b5aac85c9:routers/slack.py:slack-webhook-url:73 b6ad1fc0168832711adcff07287907660f3305fb:bots/location.py:generic-api-key:12 From 878a88066a33991f0d341cc4dd8fdc06ff3b9d59 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 5 Sep 2024 15:23:43 +0530 Subject: [PATCH 079/110] Update free user video length warning to 250 frames in Lipsync and LipsyncTTS --- recipes/Lipsync.py | 2 +- recipes/LipsyncTTS.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 742dcaad1..6dae9320e 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -65,7 +65,7 @@ def render_form_v2(self): ) if not (self.is_current_user_paying() or self.is_current_user_admin()): gui.error( - "Input Audio longer than 10 seconds will be truncated for free users. Please [upgrade](/account) to generate long videos.", + "Output videos will be truncated to 250 frames for free users. Please [upgrade](/account) to generate long videos.", icon="⚠️", color="#ffe8b2", ) diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 79983f36e..339cee3dd 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -67,6 +67,12 @@ def render_form_v2(self): """, key="text_prompt", ) + if not (self.is_current_user_paying() or self.is_current_user_admin()): + gui.error( + "Output videos will be truncated to 250 frames for free users. Please [upgrade](/account) to generate long videos.", + icon="⚠️", + color="#ffe8b2", + ) enum_selector( LipsyncModel, From 12c4e7b03c89d3c5e3133e79de5863afd0f91ffa Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 5 Sep 2024 15:28:59 +0530 Subject: [PATCH 080/110] Add duration_sec field to LipsyncTTSPage model --- recipes/LipsyncTTS.py | 1 + 1 file changed, 1 insertion(+) diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 339cee3dd..995f5d43f 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -39,6 +39,7 @@ class ResponseModel(BaseModel): audio_url: str | None output_video: FieldHttpUrl + duration_sec: float | None def related_workflows(self) -> list: from recipes.VideoBots import VideoBotsPage From e000a4d86b43a294366269248e541794fb9aee47 Mon Sep 17 00:00:00 2001 From: anish-work Date: Thu, 5 Sep 2024 16:20:22 +0530 Subject: [PATCH 081/110] add enableConversations config --- daras_ai_v2/bot_integration_widgets.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 15d89b4d6..36cca09ec 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -426,6 +426,7 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None): enablePhotoUpload=False, autoPlayResponses=True, enableAudioMessage=True, + enableConversations=True, branding=( dict(showPoweredByGooey=True) | bi.web_config_extras.get("branding", {}) @@ -442,6 +443,9 @@ def web_widget_config(bi: BotIntegration, user: AppUser | None): config["enablePhotoUpload"] = gui.checkbox( "Allow Photo Upload", value=config["enablePhotoUpload"] ) + config["enableConversations"] = gui.checkbox( + 'Show "New Chat"', value=config["enableConversations"] + ) with scol2: config["enableAudioMessage"] = gui.checkbox( "Enable Audio Message", value=config["enableAudioMessage"] From 6c9bf391098e1fa9c03b5708cbc748d7320579a0 Mon Sep 17 00:00:00 2001 From: anish-work Date: Thu, 5 Sep 2024 18:04:28 +0530 Subject: [PATCH 082/110] fix integrations tab ui jump --- routers/root.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/routers/root.py b/routers/root.py index b3a636207..e234a5443 100644 --- a/routers/root.py +++ b/routers/root.py @@ -761,7 +761,7 @@ class RecipeTabs(TabData, Enum): route=history_route, ) integrations = TabData( - title=f'Facebook, Whatsapp, Slack, Instagram Icons Integrations', + title=f'Facebook, Whatsapp, Slack, Instagram Icons Integrations', label="Integrations", route=integrations_route, ) From 39197548d5e2169f1f89939a5168e279759ea864 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 5 Sep 2024 21:01:30 +0530 Subject: [PATCH 083/110] Add SEA-LIONv2.1 to language model and pricing scripts, deprecate SEA-LIONv2 --- daras_ai_v2/language_model.py | 10 +++++++++- scripts/init_llm_pricing.py | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index d6eb526d2..fbb744d7f 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -338,11 +338,19 @@ class LargeLanguageModels(Enum): is_deprecated=True, ) llama3_8b_cpt_sea_lion_v2_instruct = LLMSpec( - label="Llama3 8B CPT SEA-LIONv2 Instruct (aisingapore)", + label="Llama3 8B CPT SEA-LIONv2 Instruct [Deprecated] (aisingapore)", model_id="aisingapore/llama3-8b-cpt-sea-lionv2-instruct", llm_api=LLMApis.self_hosted, context_window=8192, price=1, + is_deprecated=True, + ) + llama3_8b_cpt_sea_lion_v2_1_instruct = LLMSpec( + label="Llama3 8B CPT SEA-LIONv2.1 Instruct (aisingapore)", + model_id="aisingapore/llama3-8b-cpt-sea-lionv2.1-instruct", + llm_api=LLMApis.self_hosted, + context_window=8192, + price=1, ) sarvam_2b = LLMSpec( label="Sarvam 2B (sarvamai)", diff --git a/scripts/init_llm_pricing.py b/scripts/init_llm_pricing.py index 067cd0d2b..18f69f0a4 100644 --- a/scripts/init_llm_pricing.py +++ b/scripts/init_llm_pricing.py @@ -673,7 +673,6 @@ def run(): provider=ModelProvider.aks, notes="Same as GPT-4o. Note that the actual cost of this model is in GPU Milliseconds", ) - llm_pricing_create( model_id="aisingapore/llama3-8b-cpt-sea-lionv2-instruct", model_name=LargeLanguageModels.llama3_8b_cpt_sea_lion_v2_instruct.name, @@ -683,6 +682,15 @@ def run(): provider=ModelProvider.aks, notes="Same as GPT-4o. Note that the actual cost of this model is in GPU Milliseconds", ) + llm_pricing_create( + model_id="aisingapore/llama3-8b-cpt-sea-lionv2.1-instruct", + model_name=LargeLanguageModels.llama3_8b_cpt_sea_lion_v2_1_instruct.name, + unit_cost_input=5, + unit_cost_output=15, + unit_quantity=10**6, + provider=ModelProvider.aks, + notes="Same as GPT-4o. Note that the actual cost of this model is in GPU Milliseconds", + ) llm_pricing_create( model_id="sarvamai/sarvam-2b-v0.5", From 670e64ab98ea4f47656368cf58cbf3c57c1e8b8d Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Thu, 29 Aug 2024 18:53:17 +0530 Subject: [PATCH 084/110] Refactor `BasePage` methods to consolidate `SavedRun` and `PublishedRun` retrieval logic into `get_sr_pr` Remove usage of global gui.get_query_params --- bots/admin.py | 5 +- bots/models.py | 12 +- celeryapp/tasks.py | 23 +- conftest.py | 4 +- daras_ai_v2/base.py | 344 ++++++++------------- daras_ai_v2/bot_integration_widgets.py | 10 +- daras_ai_v2/bots.py | 2 +- daras_ai_v2/doc_search_settings_widgets.py | 2 +- daras_ai_v2/meta_content.py | 5 +- daras_ai_v2/safety_checker.py | 2 +- daras_ai_v2/workflow_url_input.py | 4 +- explore.py | 2 +- recipes/DocSearch.py | 4 +- recipes/Functions.py | 8 +- recipes/GoogleGPT.py | 2 +- recipes/VideoBots.py | 11 +- recipes/VideoBotsStats.py | 7 +- routers/api.py | 21 +- routers/bots_api.py | 2 +- routers/root.py | 21 +- routers/twilio_api.py | 2 +- tests/test_apis.py | 4 +- tests/test_pricing.py | 26 +- url_shortener/models.py | 13 +- usage_costs/cost_utils.py | 13 +- 25 files changed, 239 insertions(+), 310 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index eaf200b04..b7e696dfd 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -439,7 +439,10 @@ def rerun_tasks(self, request, queryset): sr: SavedRun for sr in queryset.all(): page = Workflow(sr.workflow).page_cls( - request=SimpleNamespace(user=AppUser.objects.get(uid=sr.uid)) + request=SimpleNamespace( + user=AppUser.objects.get(uid=sr.uid), + query_params=dict(run_id=sr.run_id, uid=sr.uid), + ) ) page.call_runner_task(sr, deduct_credits=False) self.message_user( diff --git a/bots/models.py b/bots/models.py index e997e8f8a..3ace6bc33 100644 --- a/bots/models.py +++ b/bots/models.py @@ -127,16 +127,12 @@ def get_or_create_metadata(self) -> "WorkflowMetadata": workflow=self, create=lambda **kwargs: WorkflowMetadata.objects.create( **kwargs, - short_title=( - self.page_cls.get_root_published_run().title or self.page_cls.title - ), + short_title=(self.page_cls.get_root_pr().title or self.page_cls.title), default_image=self.page_cls.explore_image or "", - meta_title=( - self.page_cls.get_root_published_run().title or self.page_cls.title - ), + meta_title=(self.page_cls.get_root_pr().title or self.page_cls.title), meta_description=( self.page_cls().preview_description(state={}) - or self.page_cls.get_root_published_run().notes + or self.page_cls.get_root_pr().notes ), meta_image=self.page_cls.explore_image or "", ), @@ -389,7 +385,7 @@ def submit_api_call( ), ) - return result, page.run_doc_sr(run_id, uid) + return result, page.current_sr def get_creator(self) -> AppUser | None: if self.uid: diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index c3ae75b52..b2e257365 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -1,5 +1,6 @@ import datetime import html +import threading import traceback import typing from time import time @@ -31,6 +32,15 @@ DEFAULT_RUN_STATUS = "Running..." +threadlocal = threading.local() + + +def get_running_saved_run() -> SavedRun | None: + try: + return threadlocal.saved_run + except AttributeError: + return None + @app.task def runner_task( @@ -81,12 +91,16 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # save to db page.dump_state_to_sr(gui.session_state | output, sr) - user = AppUser.objects.get(id=user_id) - page = page_cls(request=SimpleNamespace(user=user)) + page = page_cls( + request=SimpleNamespace( + user=AppUser.objects.get(id=user_id), + query_params=dict(run_id=run_id, uid=uid), + ), + ) page.setup_sentry() - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr + threadlocal.saved_run = sr gui.set_session_state(sr.to_dict() | (unsaved_state or {})) - gui.set_query_params(dict(run_id=run_id, uid=uid)) try: save_on_step() @@ -114,6 +128,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False # save everything, mark run as completed finally: save_on_step(done=True) + threadlocal.saved_run = None post_runner_tasks.delay(sr.id) diff --git a/conftest.py b/conftest.py index 57d5d5cca..539ea848e 100644 --- a/conftest.py +++ b/conftest.py @@ -64,10 +64,10 @@ def mock_celery_tasks(): def _mock_runner_task( *, page_cls: typing.Type[BasePage], run_id: str, uid: str, **kwargs ): - sr = page_cls.run_doc_sr(run_id, uid) + sr = page_cls.get_sr_from_ids(run_id, uid) sr.set(sr.parent.to_dict()) sr.save() - channel = page_cls().realtime_channel_name(run_id, uid) + channel = page_cls.realtime_channel_name(run_id, uid) _mock_realtime_push(channel, sr.to_dict()) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 3ab715ea2..d1f3715c1 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -8,6 +8,7 @@ import uuid from copy import deepcopy, copy from enum import Enum +from functools import cached_property from itertools import pairwise from random import Random from time import sleep @@ -151,10 +152,13 @@ class RequestModel(BaseModel): def __init__( self, - tab: RecipeTabs = "", - request: Request | SimpleNamespace = None, - run_user: AppUser = None, + *, + tab: RecipeTabs = RecipeTabs.run, + request: Request | SimpleNamespace | None = None, + run_user: AppUser | None = None, ): + if request is None: + request = SimpleNamespace(user=None, query_params={}) self.tab = tab self.request = request self.run_user = run_user @@ -164,9 +168,8 @@ def __init__( def endpoint(cls) -> str: return f"/v2/{cls.slug_versions[0]}" - @classmethod def current_app_url( - cls, + self, tab: RecipeTabs = RecipeTabs.run, *, query_params: dict = None, @@ -174,8 +177,8 @@ def current_app_url( ) -> str: if query_params is None: query_params = {} - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - return cls.app_url( + example_id, run_id, uid = extract_query_params(self.request.query_params) + return self.app_url( tab=tab, example_id=example_id, run_id=run_id, @@ -209,7 +212,7 @@ def app_url( run_slug = None if example_id: try: - pr = cls.get_published_run(published_run_id=example_id) + pr = cls.get_pr_from_example_id(example_id=example_id) except PublishedRun.DoesNotExist: pr = None if pr and pr.title: @@ -225,11 +228,6 @@ def app_url( ) ) - @classmethod - def current_api_url(cls) -> furl | None: - pr = cls.get_current_published_run() - return cls.api_url(example_id=pr and pr.published_run_id) - @classmethod def api_url( cls, @@ -276,12 +274,12 @@ def sentry_event_set_request(self, event, hint): ) else: request["url"] = self.app_url( - tab=self.tab, query_params=gui.get_query_params() + tab=self.tab, query_params=dict(self.request.query_params) ) return event def sentry_event_set_user(self, event, hint): - if user := self.request and self.request.user: + if user := self.request.user: event["user"] = { "id": user.id, "name": user.display_name, @@ -305,7 +303,7 @@ def sentry_event_set_user(self, event, hint): return event def refresh_state(self): - sr = self.get_current_sr() + sr = self.current_sr channel = self.realtime_channel_name(sr.run_id, sr.uid) output = gui.realtime_pull([channel])[0] if output: @@ -341,14 +339,11 @@ def render(self): self._render_header() def _render_header(self): - current_run = self.get_current_sr() - published_run = self.get_current_published_run() - is_example = published_run.saved_run == current_run - is_root_example = is_example and published_run.is_root() - tbreadcrumbs = get_title_breadcrumbs( - self, current_run, published_run, tab=self.tab - ) - can_save = self.can_user_save_run(current_run, published_run) + sr, pr = self.current_sr_pr + is_example = pr.saved_run == sr + is_root_example = is_example and pr.is_root() + tbreadcrumbs = get_title_breadcrumbs(self, sr, pr, tab=self.tab) + can_save = self.can_user_save_run(sr, pr) request_changed = self._has_request_changed() with gui.div(className="d-flex justify-content-between mt-4"): @@ -360,15 +355,13 @@ def _render_header(self): with gui.tag("div", className="me-3 mb-1 mb-lg-0 py-2 py-lg-0"): render_breadcrumbs( tbreadcrumbs, - is_api_call=( - current_run.is_api_call and self.tab == RecipeTabs.run - ), + is_api_call=(sr.is_api_call and self.tab == RecipeTabs.run), ) if is_example: - author = published_run.created_by + author = pr.created_by else: - author = self.run_user or current_run.get_creator() + author = self.run_user or sr.get_creator() if not is_root_example: self.render_author(author) @@ -389,10 +382,7 @@ def _render_header(self): show_save_buttons = request_changed or can_save if show_save_buttons: - self._render_published_run_save_buttons( - current_run=current_run, - published_run=published_run, - ) + self._render_published_run_save_buttons(sr=sr, pr=pr) self._render_social_buttons(show_button_text=not show_save_buttons) if tbreadcrumbs.has_breadcrumbs() or self.run_user: @@ -401,15 +391,15 @@ def _render_header(self): if self.tab != RecipeTabs.run: return - if published_run and published_run.notes: - gui.write(published_run.notes, line_clamp=2) + if pr and pr.notes: + gui.write(pr.notes, line_clamp=2) elif is_root_example and self.tab != RecipeTabs.integrations: - gui.write(self.preview_description(current_run.to_dict()), line_clamp=2) + gui.write(self.preview_description(sr.to_dict()), line_clamp=2) def can_user_save_run( self, current_run: SavedRun, - published_run: PublishedRun | None, + published_run: PublishedRun, ) -> bool: return ( self.is_current_user_admin() @@ -425,13 +415,9 @@ def can_user_save_run( ) ) - def can_user_edit_published_run( - self, published_run: PublishedRun | None = None - ) -> bool: - published_run = published_run or self.get_current_published_run() + def can_user_edit_published_run(self, published_run: PublishedRun) -> bool: return self.is_current_user_admin() or bool( - published_run - and self.request + self.request and self.request.user and published_run.created_by_id and published_run.created_by_id == self.request.user.id @@ -460,13 +446,8 @@ def _render_social_buttons(self, show_button_text: bool = False): className="mb-0 ms-lg-2", ) - def _render_published_run_save_buttons( - self, - *, - current_run: SavedRun, - published_run: PublishedRun, - ): - can_edit = self.can_user_edit_published_run(published_run) + def _render_published_run_save_buttons(self, *, sr: SavedRun, pr: PublishedRun): + can_edit = self.can_user_edit_published_run(pr) with gui.div(className="d-flex justify-content-end"): gui.html( @@ -497,8 +478,8 @@ def _render_published_run_save_buttons( if options_modal.is_open(): with options_modal.container(style={"minWidth": "min(300px, 100vw)"}): self._render_options_modal( - current_run=current_run, - published_run=published_run, + current_run=sr, + published_run=pr, modal=options_modal, ) @@ -518,8 +499,8 @@ def _render_published_run_save_buttons( if publish_modal.is_open(): with publish_modal.container(style={"minWidth": "min(500px, 100vw)"}): self._render_publish_modal( - current_run=current_run, - published_run=published_run, + sr=sr, + pr=pr, modal=publish_modal, is_update_mode=can_edit, ) @@ -527,12 +508,12 @@ def _render_published_run_save_buttons( def _render_publish_modal( self, *, - current_run: SavedRun, - published_run: PublishedRun, + sr: SavedRun, + pr: PublishedRun, modal: gui.Modal, is_update_mode: bool = False, ): - if published_run.is_root() and self.is_current_user_admin(): + if pr.is_root() and self.is_current_user_admin(): with gui.div(className="text-danger"): gui.write( "###### You're about to update the root workflow as an admin. " @@ -564,7 +545,7 @@ def _render_publish_modal( "", options=options, format_func=options.__getitem__, - value=str(published_run.visibility), + value=str(pr.visibility), ) ) ) @@ -579,9 +560,9 @@ def _render_publish_modal( with gui.div(className="mt-4"): if is_update_mode: - title = published_run.title or self.title + title = pr.title or self.title else: - recipe_title = self.get_root_published_run().title or self.title + recipe_title = self.get_root_pr().title or self.title title = f"{self.request.user.first_name_possesive()} {recipe_title}" published_run_title = gui.text_input( "##### Title", @@ -591,11 +572,7 @@ def _render_publish_modal( published_run_notes = gui.text_area( "##### Notes", key="published_run_notes", - value=( - published_run.notes - or self.preview_description(gui.session_state) - or "" - ), + value=(pr.notes or self.preview_description(gui.session_state) or ""), ) with gui.div(className="mt-4 d-flex justify-content-center"): @@ -605,12 +582,12 @@ def _render_publish_modal( type="primary", ) - self._render_admin_options(current_run, published_run) + self._render_admin_options(sr, pr) if not pressed_save: return - is_root_published_run = is_update_mode and published_run.is_root() + is_root_published_run = is_update_mode and pr.is_root() if not is_root_published_run: try: self._validate_published_run_title(published_run_title) @@ -619,33 +596,31 @@ def _render_publish_modal( return if self._has_request_changed(): - current_run = self.on_submit() - if not current_run: + sr = self.on_submit() + if not sr: modal.close() if is_update_mode: updates = dict( - saved_run=current_run, + saved_run=sr, title=published_run_title.strip(), notes=published_run_notes.strip(), visibility=published_run_visibility, ) - if not self._has_published_run_changed( - published_run=published_run, **updates - ): + if not self._has_published_run_changed(published_run=pr, **updates): gui.error("No changes to publish", icon="⚠️") return - published_run.add_version(user=self.request.user, **updates) + pr.add_version(user=self.request.user, **updates) else: - published_run = self.create_published_run( + pr = self.create_published_run( published_run_id=get_random_doc_id(), - saved_run=current_run, + saved_run=sr, user=self.request.user, title=published_run_title.strip(), notes=published_run_notes.strip(), visibility=published_run_visibility, ) - raise gui.RedirectException(published_run.get_app_url()) + raise gui.RedirectException(pr.get_app_url()) def _validate_published_run_title(self, title: str): if slugify(title) in settings.DISALLOWED_TITLE_SLUGS: @@ -813,7 +788,7 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR className="text-danger", ) if gui.button("👌 Yes, Update the Root Workflow"): - root_run = self.get_root_published_run() + root_run = self.get_root_pr() root_run.add_version( user=self.request.user, title=published_run.title, @@ -825,7 +800,7 @@ def _render_admin_options(self, current_run: SavedRun, published_run: PublishedR @classmethod def get_recipe_title(cls) -> str: - return cls.get_root_published_run().title or cls.title or cls.workflow.label + return cls.get_root_pr().title or cls.title or cls.workflow.label def get_explore_image(self) -> str: meta = self.workflow.get_or_create_metadata() @@ -853,7 +828,7 @@ def get_tabs(self): def render_selected_tab(self): match self.tab: case RecipeTabs.run: - if self.get_current_sr().retention_policy == RetentionPolicy.delete: + if self.current_sr.retention_policy == RetentionPolicy.delete: self.render_deleted_output() return @@ -884,15 +859,12 @@ def render_selected_tab(self): self._saved_tab() def _render_version_history(self): - published_run = self.get_current_published_run() - - if published_run: - versions = published_run.versions.all() - first_version = versions[0] - for version, older_version in pairwise(versions): - first_version = older_version - self._render_version_row(version, older_version) - self._render_version_row(first_version, None) + versions = self.current_pr.versions.all() + first_version = versions[0] + for version, older_version in pairwise(versions): + first_version = older_version + self._render_version_row(version, older_version) + self._render_version_row(first_version, None) def _render_version_row( self, @@ -957,7 +929,7 @@ def render_related_workflows(self): def _render(page_cls: typing.Type[BasePage]): page = page_cls() - root_run = page.get_root_published_run() + root_run = page.get_root_pr() state = root_run.saved_run.to_dict() preview_image = page.get_explore_image() @@ -1034,11 +1006,9 @@ def render_report_form(self): gui.error("Reason for report cannot be empty") return - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - send_reported_run_email( user=self.request.user, - run_uid=uid, + run_uid=str(self.run_user.uid), url=self.current_app_url(), recipe_name=self.title, report_type=report_type, @@ -1047,7 +1017,7 @@ def render_report_form(self): ) if report_type == inappropriate_radio_text: - self.update_flag_for_run(run_id=run_id, uid=uid, is_flagged=True) + self.update_flag_for_run(is_flagged=True) # gui.success("Reported.") gui.session_state["show_report_workflow"] = False @@ -1064,10 +1034,8 @@ def _check_if_flagged(self): if not unflag_pressed: return with gui.spinner("Removing flag..."): - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - if run_id and uid: - self.update_flag_for_run(run_id=run_id, uid=uid, is_flagged=False) - gui.success("Removed flag.", icon="✅") + self.update_flag_for_run(is_flagged=False) + gui.success("Removed flag.") sleep(2) gui.rerun() else: @@ -1077,87 +1045,47 @@ def _check_if_flagged(self): # Return and Don't render the run any further gui.stop() - @classmethod - def get_runs_from_query_params( - cls, example_id: str, run_id: str, uid: str - ) -> tuple[SavedRun, PublishedRun | None]: - if run_id and uid: - sr = cls.run_doc_sr(run_id, uid) - pr = sr.parent_published_run() - else: - pr = cls.get_published_run(published_run_id=example_id or "") - sr = pr.saved_run - return sr, pr + def update_flag_for_run(self, is_flagged: bool): + sr = self.current_sr + sr.is_flagged = is_flagged + sr.save(update_fields=["is_flagged"]) + gui.session_state["is_flagged"] = is_flagged - @classmethod - def get_current_published_run(cls) -> PublishedRun | None: - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - return cls.get_pr_from_query_params(example_id, run_id, uid) + @property + def current_sr(self) -> SavedRun: + return self.current_sr_pr[0] - @classmethod - def get_pr_from_query_params( - cls, example_id: str, run_id: str, uid: str - ) -> PublishedRun | None: - if run_id and uid: - sr = cls.get_sr_from_query_params(example_id, run_id, uid) - return sr.parent_published_run() or cls.get_root_published_run() - elif example_id: - return cls.get_published_run(published_run_id=example_id) - else: - return cls.get_root_published_run() + @property + def current_pr(self) -> PublishedRun: + return self.current_sr_pr[1] - @classmethod - def get_published_run(cls, *, published_run_id: str): - return PublishedRun.objects.get( - workflow=cls.workflow, - published_run_id=published_run_id, + @cached_property + def current_sr_pr(self) -> tuple[SavedRun, PublishedRun]: + return self.get_sr_pr_from_query_params( + *extract_query_params(self.request.query_params) ) @classmethod - def get_current_sr(cls) -> SavedRun: - return cls.get_sr_from_query_params_dict(gui.get_query_params()) - - @classmethod - def get_sr_from_query_params_dict(cls, query_params) -> SavedRun: - example_id, run_id, uid = extract_query_params(query_params) - return cls.get_sr_from_query_params(example_id, run_id, uid) - - @classmethod - def get_sr_from_query_params( - cls, example_id: str | None, run_id: str | None, uid: str | None - ) -> SavedRun: - try: - if run_id and uid: - sr = cls.run_doc_sr(run_id, uid) - elif example_id: - pr = cls.get_published_run(published_run_id=example_id) - assert ( - pr.saved_run is not None - ), "invalid published run: without a saved run" - sr = pr.saved_run - else: - sr = cls.recipe_doc_sr() - return sr - except (SavedRun.DoesNotExist, PublishedRun.DoesNotExist): - raise HTTPException(status_code=404) - - @classmethod - def get_total_runs(cls) -> int: - # TODO: fix to also handle published run case - return SavedRun.objects.filter(workflow=cls.workflow).count() - - @classmethod - def recipe_doc_sr(cls, create: bool = True) -> SavedRun: - if create: - return cls.get_root_published_run().saved_run + def get_sr_pr_from_query_params( + cls, example_id: str, run_id: str, uid: str + ) -> tuple[SavedRun, PublishedRun]: + if run_id and uid: + sr = cls.get_sr_from_ids(run_id, uid) + pr = sr.parent_published_run() or cls.get_root_pr() else: - return cls.get_root_published_run().saved_run + if example_id: + pr = cls.get_pr_from_example_id(example_id=example_id) + else: + pr = cls.get_root_pr() + sr = pr.saved_run + return sr, pr @classmethod - def run_doc_sr( + def get_sr_from_ids( cls, run_id: str, uid: str, + *, create: bool = False, defaults: dict = None, ) -> SavedRun: @@ -1168,7 +1096,14 @@ def run_doc_sr( return SavedRun.objects.get(**config) @classmethod - def get_root_published_run(cls) -> PublishedRun: + def get_pr_from_example_id(cls, *, example_id: str): + return PublishedRun.objects.get( + workflow=cls.workflow, + published_run_id=example_id, + ) + + @classmethod + def get_root_pr(cls) -> PublishedRun: return PublishedRun.objects.get_or_create_with_version( workflow=cls.workflow, published_run_id="", @@ -1219,6 +1154,11 @@ def duplicate_published_run( visibility=visibility, ) + @classmethod + def get_total_runs(cls) -> int: + # TODO: fix to also handle published run case + return SavedRun.objects.filter(workflow=cls.workflow).count() + def render_description(self): pass @@ -1328,9 +1268,8 @@ def render_submit_button(self, key="--submit-1"): def render_run_cost(self): url = self.get_credits_click_url() - sr = self.get_current_sr() - if sr.price: - run_cost = sr.price + if self.current_sr.price: + run_cost = self.current_sr.price else: run_cost = self.get_price_roundoff(gui.session_state) ret = f'Run cost = {run_cost} credits' @@ -1356,7 +1295,7 @@ def _render_step_row(self): with col2: placeholder = gui.div() render_called_functions( - saved_run=self.get_current_sr(), trigger=FunctionTrigger.pre + saved_run=self.current_sr, trigger=FunctionTrigger.pre ) try: self.render_steps() @@ -1366,7 +1305,7 @@ def _render_step_row(self): with placeholder: gui.write("##### 👣 Steps") render_called_functions( - saved_run=self.get_current_sr(), trigger=FunctionTrigger.post + saved_run=self.current_sr, trigger=FunctionTrigger.post ) def _render_help(self): @@ -1447,9 +1386,10 @@ def run_v2( raise NotImplementedError def _render_report_button(self): - example_id, run_id, uid = extract_query_params(gui.get_query_params()) - # only logged in users can report a run (but not examples/default runs) - if not (self.request.user and run_id and uid): + sr, pr = self.current_sr_pr + is_example = pr.saved_run_id == sr.id + # only logged in users can report a run (but not examples/root runs) + if not self.request.user or is_example: return reported = gui.button( @@ -1461,12 +1401,6 @@ def _render_report_button(self): gui.session_state["show_report_workflow"] = reported gui.rerun() - def update_flag_for_run(self, run_id: str, uid: str, is_flagged: bool): - ref = self.run_doc_sr(uid=uid, run_id=run_id) - ref.is_flagged = is_flagged - ref.save(update_fields=["is_flagged"]) - gui.session_state["is_flagged"] = is_flagged - # Functions in every recipe feels like overkill for now, hide it in settings functions_in_settings = True show_settings = True @@ -1617,8 +1551,7 @@ def on_submit(self): def should_submit_after_login(self) -> bool: return ( - gui.get_query_params().get(SUBMIT_AFTER_LOGIN_Q) - and self.request + self.request.query_params.get(SUBMIT_AFTER_LOGIN_Q) and self.request.user and not self.request.user.is_anonymous ) @@ -1647,19 +1580,13 @@ def create_new_run( run_id = get_random_doc_id() - parent_example_id, parent_run_id, parent_uid = extract_query_params( - gui.get_query_params() - ) - parent = self.get_sr_from_query_params( - parent_example_id, parent_run_id, parent_uid - ) - published_run = self.get_current_published_run() + parent, pr = self.current_sr_pr try: - parent_version = published_run and published_run.versions.latest() + parent_version = pr.versions.latest() except PublishedRunVersion.DoesNotExist: parent_version = None - sr = self.run_doc_sr( + sr = self.get_sr_from_ids( run_id, uid, create=True, @@ -1697,7 +1624,7 @@ def call_runner_task(self, sr: SavedRun, deduct_credits: bool = True): ) @classmethod - def realtime_channel_name(cls, run_id, uid): + def realtime_channel_name(cls, run_id: str, uid: str) -> str: return f"gooey-outputs/{cls.slug_versions[0]}/{uid}/{run_id}" def generate_credit_error_message(self, run_id, uid) -> str: @@ -1849,7 +1776,7 @@ def _history_tab(self): if self.is_current_user_admin(): uid = self.request.query_params.get("uid", uid) - before = gui.get_query_params().get("updated_at__lt", None) + before = self.request.query_params.get("updated_at__lt", None) if before: before = datetime.datetime.fromisoformat(before) else: @@ -2051,11 +1978,10 @@ def run_as_api_tab(self): as_async = gui.checkbox("##### Run Async") as_form_data = gui.checkbox("##### Upload Files via Form Data") - pr = self.get_current_published_run() api_url, request_body = self.get_example_request( gui.session_state, include_all=include_all, - pr=pr, + pr=self.current_pr, ) response_body = self.get_example_response_body( gui.session_state, as_async=as_async, include_all=include_all @@ -2105,9 +2031,7 @@ def ensure_credits_and_auto_recharge(self, sr: SavedRun, state: dict): raise InsufficientCredits(self.request.user, sr) def deduct_credits(self, state: dict) -> tuple[AppUserTransaction, int]: - assert ( - self.request and self.request.user - ), "request.user must be set to deduct credits" + assert self.request.user, "request.user must be set to deduct credits" amount = self.get_price_roundoff(state) txn = self.request.user.add_balance(-amount, f"gooey_in_{uuid.uuid1()}") @@ -2124,7 +2048,7 @@ def get_raw_price(self, state: dict) -> float: def get_total_linked_usage_cost_in_credits(self, default=1): """Return the sum of the linked usage costs in gooey credits.""" - sr = self.get_current_sr() + sr = self.current_sr total = sr.usage_costs.aggregate(total=Sum("dollar_amount"))["total"] if not total: return default @@ -2132,10 +2056,8 @@ def get_total_linked_usage_cost_in_credits(self, default=1): def get_grouped_linked_usage_cost_in_credits(self): """Return the linked usage costs grouped by model name in gooey credits.""" - qs = ( - self.get_current_sr() - .usage_costs.values("pricing__model_name") - .annotate(total=Sum("dollar_amount") * settings.ADDON_CREDITS_PER_DOLLAR) + qs = self.current_sr.usage_costs.values("pricing__model_name").annotate( + total=Sum("dollar_amount") * settings.ADDON_CREDITS_PER_DOLLAR ) return {item["pricing__model_name"]: item["total"] for item in qs} @@ -2179,7 +2101,7 @@ def get_example_response_body( run_id=run_id, uid=self.request.user and self.request.user.uid, ) - sr = self.get_current_sr() + sr = self.current_sr output = sr.api_output(extract_model_fields(self.ResponseModel, state)) if as_async: return dict( @@ -2210,17 +2132,13 @@ def is_user_admin(cls, user: AppUser) -> bool: return email and email in settings.ADMIN_EMAILS def is_current_user_admin(self) -> bool: - if not self.request or not self.request.user: - return False - return self.is_user_admin(self.request.user) + return self.request.user and self.is_user_admin(self.request.user) def is_current_user_paying(self) -> bool: - return bool(self.request and self.request.user and self.request.user.is_paying) + return bool(self.request.user and self.request.user.is_paying) def is_current_user_owner(self) -> bool: - return bool( - self.request and self.request.user and self.run_user == self.request.user - ) + return bool(self.request.user and self.run_user == self.request.user) def started_at_text(dt: datetime.datetime): diff --git a/daras_ai_v2/bot_integration_widgets.py b/daras_ai_v2/bot_integration_widgets.py index 36cca09ec..8e767b351 100644 --- a/daras_ai_v2/bot_integration_widgets.py +++ b/daras_ai_v2/bot_integration_widgets.py @@ -7,6 +7,7 @@ from django.db import transaction from django.utils.text import slugify from furl import furl +from starlette.requests import Request from app_users.models import AppUser from bots.models import BotIntegration, BotIntegrationAnalysisRun, Platform @@ -54,7 +55,7 @@ def integrations_welcome_screen(title: str): gui.caption("Analyze your usage. Update your Saved Run to test changes.") -def general_integration_settings(bi: BotIntegration, current_user: AppUser): +def general_integration_settings(bi: BotIntegration, request: Request): if gui.session_state.get(f"_bi_reset_{bi.id}"): gui.session_state[f"_bi_streaming_enabled_{bi.id}"] = ( BotIntegration._meta.get_field("streaming_enabled").default @@ -101,9 +102,10 @@ def general_integration_settings(bi: BotIntegration, current_user: AppUser): "📊 View Results", str( furl( - VideoBotsPage.current_app_url( - RecipeTabs.integrations, + VideoBotsPage.app_url( + tab=RecipeTabs.integrations, path_params=dict(integration_id=bi.api_integration_id()), + query_params=dict(request.query_params), ) ) / "analysis/" @@ -119,7 +121,7 @@ def render_workflow_url_input(key: str, del_key: str | None, d: dict): key=key, internal_state=d, del_key=del_key, - current_user=current_user, + current_user=request.user, ) if not ret: return diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index 90fef04ee..eea020936 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -440,7 +440,7 @@ def _process_and_send_msg( # wait for the celery task to finish get_celery_result_db_safe(result) # get the final state from db - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr state = sr.to_dict() bot.recipe_run_state = page.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" diff --git a/daras_ai_v2/doc_search_settings_widgets.py b/daras_ai_v2/doc_search_settings_widgets.py index 9b46c9cdf..186075646 100644 --- a/daras_ai_v2/doc_search_settings_widgets.py +++ b/daras_ai_v2/doc_search_settings_widgets.py @@ -129,7 +129,7 @@ def doc_extract_selector(current_user: AppUser | None): gui.write("###### Create Synthetic Data") gui.caption( f""" - To improve answer quality, pick a [synthetic data maker workflow]({DocExtractPage.get_root_published_run().get_app_url()}) to scan & OCR any images in your documents or transcribe & translate any videos. It also can synthesize a helpful FAQ. Adds ~2 minutes of one-time processing per file. + To improve answer quality, pick a [synthetic data maker workflow]({DocExtractPage.get_root_pr().get_app_url()}) to scan & OCR any images in your documents or transcribe & translate any videos. It also can synthesize a helpful FAQ. Adds ~2 minutes of one-time processing per file. """ ) workflow_url_input( diff --git a/daras_ai_v2/meta_content.py b/daras_ai_v2/meta_content.py index c3335434b..d2c255a62 100644 --- a/daras_ai_v2/meta_content.py +++ b/daras_ai_v2/meta_content.py @@ -17,11 +17,8 @@ def build_meta_tags( url: str, page: "BasePage", state: dict, - run_id: str, - uid: str, - example_id: str, ) -> list[dict]: - sr, pr = page.get_runs_from_query_params(example_id, run_id, uid) + sr, pr = page.current_sr_pr metadata = page.workflow.get_or_create_metadata() title = meta_title_for_page( diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 84c14baf5..8a8d67336 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -29,7 +29,7 @@ def safety_checker_text(text_input: str): # run in a thread to avoid messing up threadlocals result, sr = ( CompareLLMPage() - .get_published_run(published_run_id=settings.SAFETY_CHECKER_EXAMPLE_ID) + .get_pr_from_example_id(example_id=settings.SAFETY_CHECKER_EXAMPLE_ID) .submit_api_call( current_user=billing_account, request_body=dict(variables=dict(input=text_input)), diff --git a/daras_ai_v2/workflow_url_input.py b/daras_ai_v2/workflow_url_input.py index 0b82cd2a8..b9f23cc56 100644 --- a/daras_ai_v2/workflow_url_input.py +++ b/daras_ai_v2/workflow_url_input.py @@ -136,7 +136,7 @@ def url_to_runs( assert match, "Not a valid Gooey.AI URL" page_cls = page_slug_map[normalize_slug(match.matched_params["page_slug"])] example_id, run_id, uid = extract_query_params(furl(url).query.params) - sr, pr = page_cls.get_runs_from_query_params( + sr, pr = page_cls.get_sr_pr_from_query_params( example_id or match.matched_params.get("example_id"), run_id, uid ) return page_cls, sr, pr @@ -177,7 +177,7 @@ def get_published_run_options( if include_root: # include root recipe if requested options_dict = { - page_cls.get_root_published_run().get_app_url(): "Default", + page_cls.get_root_pr().get_app_url(): "Default", } | options_dict return options_dict diff --git a/explore.py b/explore.py index cd4ec3d00..5bbe380fc 100644 --- a/explore.py +++ b/explore.py @@ -85,7 +85,7 @@ def render_description(page: BasePage): with gui.link(to=page.app_url()): gui.markdown(f"#### {page.get_recipe_title()}") - root_pr = page.get_root_published_run() + root_pr = page.get_root_pr() notes = root_pr.notes or page.preview_description(state=page.sane_defaults) with gui.tag("p", style={"marginBottom": "25px"}): gui.write(notes, line_clamp=4) diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index 04e034dde..d0c49f3c7 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -136,7 +136,7 @@ def render_settings(self): gui.write("---") gui.write("##### 🔎 Document Search Settings") citation_style_selector() - doc_extract_selector(self.request and self.request.user) + doc_extract_selector(self.request.user) query_instructions_widget() gui.write("---") doc_search_advanced_settings() @@ -175,7 +175,7 @@ def run_v2( "search_query": response.final_search_query, }, ), - current_user=self.request and self.request.user, + current_user=self.request.user, ) # empty search result, abort! diff --git a/recipes/Functions.py b/recipes/Functions.py index 356381343..81b99c946 100644 --- a/recipes/Functions.py +++ b/recipes/Functions.py @@ -59,10 +59,8 @@ def run_v2( request: "FunctionsPage.RequestModel", response: "FunctionsPage.ResponseModel", ) -> typing.Iterator[str | None]: - query_params = gui.get_query_params() - run_id = query_params.get("run_id") - uid = query_params.get("uid") - tag = f"run_id={run_id}&uid={uid}" + sr = self.current_sr + tag = f"run_id={sr.run_id}&uid={sr.uid}" yield "Running your code..." # this will run functions/executor.js in deno deploy @@ -86,7 +84,7 @@ def render_form_v2(self): ) def get_price_roundoff(self, state: dict) -> float: - if CalledFunction.objects.filter(function_run=self.get_current_sr()).exists(): + if CalledFunction.objects.filter(function_run=self.current_sr).exists(): return 0 return super().get_price_roundoff(state) diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 6a9eaf999..611483287 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -254,7 +254,7 @@ def run_v2( }, ), is_user_url=False, - current_user=self.request and self.request.user, + current_user=self.request.user, ) # add pretty titles to references for ref in response.references: diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 41e7a33a8..26b7e785e 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -85,7 +85,6 @@ from daras_ai_v2.prompt_vars import render_prompt_vars from daras_ai_v2.pydantic_validation import FieldHttpUrl from daras_ai_v2.query_generator import generate_final_search_query -from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.search_ref import ( parse_refs, CitationStyles, @@ -521,7 +520,7 @@ def render_settings(self): citation_style_selector() gui.checkbox("🔗 Shorten Citation URLs", key="use_url_shortener") - doc_extract_selector(self.request and self.request.user) + doc_extract_selector(self.request.user) gui.write("---") @@ -886,7 +885,7 @@ def run_v2( "keyword_query": response.final_keyword_query, }, ), - current_user=self.request and self.request.user, + current_user=self.request.user, ) if request.use_url_shortener: for reference in response.references: @@ -1061,9 +1060,7 @@ def render_integrations_tab(self): gui.anchor("Get Started", href=self.get_auth_url(), type="primary") return - sr, pr = self.get_runs_from_query_params( - *extract_query_params(gui.get_query_params()) - ) + sr, pr = self.current_sr_pr # make user the user knows that they are on a saved run not the published run if pr and pr.saved_run_id != sr.id: @@ -1377,7 +1374,7 @@ def render_integrations_settings( slack_specific_settings(bi, run_title) if bi.platform == Platform.TWILIO: twilio_specific_settings(bi) - general_integration_settings(bi, self.request.user) + general_integration_settings(bi, self.request) if bi.platform in [Platform.SLACK, Platform.WHATSAPP, Platform.TWILIO]: gui.newline() diff --git a/recipes/VideoBotsStats.py b/recipes/VideoBotsStats.py index 648a2894f..3f3bbacb6 100644 --- a/recipes/VideoBotsStats.py +++ b/recipes/VideoBotsStats.py @@ -84,8 +84,9 @@ def show_title_breadcrumb_share( ) gui.breadcrumb_item( "Integrations", - link_to=VideoBotsPage.current_app_url( - RecipeTabs.integrations, + link_to=VideoBotsPage.app_url( + tab=RecipeTabs.integrations, + query_params=dict(self.request.query_params), path_params=dict( integration_id=bi.api_integration_id() ), @@ -152,7 +153,7 @@ def render(self): ) ) - run_url = VideoBotsPage.current_app_url() + run_url = VideoBotsPage.app_url(query_params=dict(self.request.query_params)) if bi.published_run_id: run_title = bi.published_run.title else: diff --git a/routers/api.py b/routers/api.py index 9b795d426..dd74b5a00 100644 --- a/routers/api.py +++ b/routers/api.py @@ -258,8 +258,13 @@ def get_run_status( run_id: str, user: AppUser = Depends(api_auth_header), ): - self = page_cls() - sr = self.get_sr_from_query_params(example_id=None, run_id=run_id, uid=user.uid) + # init a new page for every request + self = page_cls( + request=SimpleNamespace( + user=user, query_params=dict(run_id=run_id, uid=user.uid) + ) + ) + sr = self.current_sr web_url = str(furl(self.app_url(run_id=run_id, uid=user.uid))) ret = { "run_id": run_id, @@ -335,18 +340,17 @@ def submit_api_call( deduct_credits: bool = True, ) -> tuple[BasePage, "celery.result.AsyncResult", str, str]: # init a new page for every request - self = page_cls(request=SimpleNamespace(user=user)) + query_params.setdefault("uid", user.uid) + self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) # get saved state from db - query_params.setdefault("uid", user.uid) - sr = self.get_sr_from_query_params_dict(query_params) + sr = self.current_sr state = self.load_state_from_sr(sr) # load request data state.update(request_body) # set streamlit session state gui.set_session_state(state) - gui.set_query_params(query_params) # create a new run try: @@ -369,7 +373,7 @@ def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: / page.endpoint.replace("v2", "v3") / "status/" ) - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr return dict( run_id=run_id, web_url=web_url, @@ -388,7 +392,8 @@ def build_sync_api_response( web_url = page.app_url(run_id=run_id, uid=uid) # wait for the result get_celery_result_db_safe(result) - sr = page.run_doc_sr(run_id, uid) + sr = page.current_sr + sr.refresh_from_db() if sr.retention_policy == RetentionPolicy.delete: sr.state = {} sr.save(update_fields=["state"]) diff --git a/routers/bots_api.py b/routers/bots_api.py index 780a7b918..64285cbd2 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -302,7 +302,7 @@ def runner(self): msg_handler(self) # raise ValueError("Stream ended") if self.run_id and self.uid: - sr = self.page_cls.run_doc_sr(run_id=self.run_id, uid=self.uid) + sr = self.page_cls.get_sr_from_ids(run_id=self.run_id, uid=self.uid) state = sr.to_dict() self.queue.put( FinalResponse( diff --git a/routers/root.py b/routers/root.py index e234a5443..b2375439d 100644 --- a/routers/root.py +++ b/routers/root.py @@ -39,7 +39,6 @@ from daras_ai_v2.meta_content import build_meta_tags, raw_build_meta_tags from daras_ai_v2.meta_preview_url import meta_preview_url from daras_ai_v2.profiles import user_profile_page, get_meta_tags_for_profile -from daras_ai_v2.query_params_util import extract_query_params from daras_ai_v2.settings import templates from handles.models import Handle from routers.custom_api_router import CustomAPIRouter @@ -314,7 +313,7 @@ def _api_docs_page(request): as_form_data = gui.checkbox("Upload Files via Form Data") page = workflow.page_cls(request=request) - state = page.get_root_published_run().saved_run.to_dict() + state = page.get_root_pr().saved_run.to_dict() api_url, request_body = page.get_example_request(state, include_all=include_all) response_body = page.get_example_response_body( state, as_async=as_async, include_all=include_all @@ -669,12 +668,13 @@ def render_recipe_page( return RedirectResponse(str(new_url.set(origin=None)), status_code=301) # this is because the code still expects example_id to be in the query params - gui.set_query_params(dict(request.query_params) | dict(example_id=example_id)) - _, run_id, uid = extract_query_params(request.query_params) + request._query_params = dict(request.query_params) | dict(example_id=example_id) + + page = page_cls(tab=tab, request=request) + sr = page.current_sr + page.run_user = get_run_user(request, sr.uid) - page = page_cls(tab=tab, request=request, run_user=get_run_user(request, uid)) if not gui.session_state: - sr = page.get_sr_from_query_params(example_id, run_id, uid) gui.session_state.update(page.load_state_from_sr(sr)) with page_wrapper(request): @@ -682,12 +682,7 @@ def render_recipe_page( return dict( meta=build_meta_tags( - url=get_og_url_path(request), - page=page, - state=gui.session_state, - run_id=run_id, - uid=uid, - example_id=example_id, + url=get_og_url_path(request), page=page, state=gui.session_state ), ) @@ -698,7 +693,7 @@ def get_og_url_path(request) -> str: ) -def get_run_user(request, uid) -> AppUser | None: +def get_run_user(request: Request, uid: str) -> AppUser | None: if not uid: return if request.user and request.user.uid == uid: diff --git a/routers/twilio_api.py b/routers/twilio_api.py index 7893c15d1..0223636a9 100644 --- a/routers/twilio_api.py +++ b/routers/twilio_api.py @@ -262,7 +262,7 @@ def resp_say_or_tts_play( tts_state = TextToSpeechPage.RequestModel.parse_obj( {**bot.saved_run.state, "text_prompt": text} ).dict() - result, sr = TextToSpeechPage.get_root_published_run().submit_api_call( + result, sr = TextToSpeechPage.get_root_pr().submit_api_call( current_user=AppUser.objects.get(uid=bot.billing_account_uid), request_body=tts_state, ) diff --git a/tests/test_apis.py b/tests/test_apis.py index fa897eb83..7220798e3 100644 --- a/tests/test_apis.py +++ b/tests/test_apis.py @@ -21,7 +21,7 @@ def test_apis_sync(mock_celery_tasks, force_authentication, threadpool_subtest): def _test_api_sync(page_cls: typing.Type[BasePage]): - state = page_cls.recipe_doc_sr().state + state = page_cls.get_root_pr().saved_run.state r = client.post( f"/v2/{page_cls.slug_versions[0]}/", json=page_cls.get_example_request(state)[1], @@ -38,7 +38,7 @@ def test_apis_async(mock_celery_tasks, force_authentication, threadpool_subtest) def _test_api_async(page_cls: typing.Type[BasePage]): - state = page_cls.recipe_doc_sr().state + state = page_cls.get_root_pr().saved_run.state r = client.post( f"/v3/{page_cls.slug_versions[0]}/async/", diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 6ac6e0591..73e05e4a9 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import gooey_gui as gui import pytest from starlette.testclient import TestClient @@ -46,8 +48,12 @@ def test_copilot_get_raw_price_round_up(): unit_quantity=model_pricing.unit_quantity, dollar_amount=model_pricing.unit_cost * 1 / model_pricing.unit_quantity, ) - copilot_page = VideoBotsPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + copilot_page = VideoBotsPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ), + ) assert ( copilot_page.get_price_roundoff(state=state) == 210 + copilot_page.PROFIT_CREDITS @@ -107,8 +113,12 @@ def test_multiple_llm_sums_usage_cost(): dollar_amount=model_pricing2.unit_cost * 1 / model_pricing2.unit_quantity, ) - llm_page = CompareLLMPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + llm_page = CompareLLMPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ) + ) assert llm_page.get_price_roundoff(state=state) == (310 + llm_page.PROFIT_CREDITS) @@ -152,8 +162,12 @@ def test_workflowmetadata_2x_multiplier(): metadata.price_multiplier = 2 metadata.save() - llm_page = CompareLLMPage(run_user=user) - gui.set_query_params({"run_id": bot_saved_run.run_id or "", "uid": user.uid or ""}) + llm_page = CompareLLMPage( + request=SimpleNamespace( + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), + ) + ) assert ( llm_page.get_price_roundoff(state=state) == (210 + llm_page.PROFIT_CREDITS) * 2 ) diff --git a/url_shortener/models.py b/url_shortener/models.py index 21b0864d7..515b293a2 100644 --- a/url_shortener/models.py +++ b/url_shortener/models.py @@ -6,10 +6,9 @@ from app_users.models import AppUser from bots.custom_fields import CustomURLField from bots.models import Workflow, SavedRun +from celeryapp.tasks import get_running_saved_run from daras_ai.image_input import truncate_filename from daras_ai_v2 import settings -from daras_ai_v2.query_params_util import extract_query_params -import gooey_gui as gui class ShortenedURLQuerySet(models.QuerySet): @@ -17,14 +16,8 @@ def get_or_create_for_workflow( self, *, user: AppUser, workflow: Workflow, **kwargs ) -> tuple["ShortenedURL", bool]: surl, created = self.filter_first_or_create(user=user, **kwargs) - _, run_id, uid = extract_query_params(gui.get_query_params()) - surl.saved_runs.add( - SavedRun.objects.get_or_create( - workflow=workflow, - run_id=run_id, - uid=uid, - )[0], - ) + sr = get_running_saved_run() + surl.saved_runs.add(sr) return surl, created def filter_first_or_create(self, defaults=None, **kwargs): diff --git a/usage_costs/cost_utils.py b/usage_costs/cost_utils.py index 6596c0d22..b380ee63f 100644 --- a/usage_costs/cost_utils.py +++ b/usage_costs/cost_utils.py @@ -1,19 +1,16 @@ from loguru import logger -from daras_ai_v2.query_params_util import extract_query_params +from celeryapp.tasks import get_running_saved_run from usage_costs.models import ( UsageCost, ModelSku, ModelPricing, ) -import gooey_gui as gui def record_cost_auto(model: str, sku: ModelSku, quantity: int): - from bots.models import SavedRun - - _, run_id, uid = extract_query_params(gui.get_query_params()) - if not run_id or not uid: + sr = get_running_saved_run() + if not sr: return try: @@ -22,10 +19,8 @@ def record_cost_auto(model: str, sku: ModelSku, quantity: int): logger.warning(f"Cant find pricing for {model=} {sku=}: {e=}") return - saved_run = SavedRun.objects.get(run_id=run_id, uid=uid) - UsageCost.objects.create( - saved_run=saved_run, + saved_run=sr, pricing=pricing, quantity=quantity, unit_cost=pricing.unit_cost, From 9db79affd2448089deb40e7cd80071d3591bc8c2 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 30 Aug 2024 20:17:19 +0530 Subject: [PATCH 085/110] Refactor load_state_from_sr method to current_sr_to_session_state across the codebase --- daras_ai_v2/base.py | 7 +++---- recipes/asr_page.py | 4 ++-- routers/api.py | 3 +-- routers/root.py | 5 ++--- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index d1f3715c1..e2f05a66f 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -1683,12 +1683,11 @@ def _render_after_output(self): gui.session_state[StateKeys.pressed_randomize] = True gui.rerun() - @classmethod - def load_state_from_sr(cls, sr: SavedRun) -> dict: - state = sr.to_dict() + def current_sr_to_session_state(self) -> dict: + state = self.current_sr.to_dict() if state is None: raise HTTPException(status_code=404) - return cls.load_state_defaults(state) + return self.load_state_defaults(state) @classmethod def load_state_defaults(cls, state: dict): diff --git a/recipes/asr_page.py b/recipes/asr_page.py index 58d49ffa4..f68be1a08 100644 --- a/recipes/asr_page.py +++ b/recipes/asr_page.py @@ -64,8 +64,8 @@ class ResponseModel(BaseModel): raw_output_text: list[str] | None output_text: list[str | AsrOutputJson] - def load_state_from_sr(self, sr: SavedRun) -> dict: - state = super().load_state_from_sr(sr) + def current_sr_to_session_state(self) -> dict: + state = super().current_sr_to_session_state() google_translate_target = state.pop("google_translate_target", None) translation_model = state.get("translation_model") if google_translate_target and not translation_model: diff --git a/routers/api.py b/routers/api.py index dd74b5a00..d1878bd53 100644 --- a/routers/api.py +++ b/routers/api.py @@ -344,8 +344,7 @@ def submit_api_call( self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) # get saved state from db - sr = self.current_sr - state = self.load_state_from_sr(sr) + state = self.current_sr_to_session_state() # load request data state.update(request_body) diff --git a/routers/root.py b/routers/root.py index b2375439d..9f678416b 100644 --- a/routers/root.py +++ b/routers/root.py @@ -671,11 +671,10 @@ def render_recipe_page( request._query_params = dict(request.query_params) | dict(example_id=example_id) page = page_cls(tab=tab, request=request) - sr = page.current_sr - page.run_user = get_run_user(request, sr.uid) + page.run_user = get_run_user(request, page.current_sr.uid) if not gui.session_state: - gui.session_state.update(page.load_state_from_sr(sr)) + gui.session_state.update(page.current_sr_to_session_state()) with page_wrapper(request): page.render() From 5524cc0b3bd1f68a8d8b8d20d90ffb24967e8c10 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 6 Sep 2024 14:00:12 +0530 Subject: [PATCH 086/110] Refactor consistent usage of submit_api_call() - Introduce `SavedRun.wait_for_celery_result` to encapsulate common logic. - Change BasePage `endpoint` to method `api_endpoint`. - Update `submit_api_call`, `build_sync_api_response`, and `build_async_api_response` signatures. --- bots/models.py | 13 ++++--- bots/tasks.py | 4 +- daras_ai_v2/base.py | 7 ++-- daras_ai_v2/bots.py | 22 +++++------ daras_ai_v2/safety_checker.py | 3 +- functions/recipe_functions.py | 3 +- recipes/BulkRunner.py | 12 +++--- routers/api.py | 72 +++++++++++++++++------------------ routers/bots_api.py | 14 +++---- routers/twilio_api.py | 3 +- 10 files changed, 72 insertions(+), 81 deletions(-) diff --git a/bots/models.py b/bots/models.py index 3ace6bc33..6c3c1d92b 100644 --- a/bots/models.py +++ b/bots/models.py @@ -18,6 +18,7 @@ from daras_ai_v2.crypto import get_random_doc_id from daras_ai_v2.language_model import format_chat_entry from functions.models import CalledFunction, CalledFunctionResponse +from gooeysite.bg_db_conn import get_celery_result_db_safe from gooeysite.custom_create import get_or_create_lazy if typing.TYPE_CHECKING: @@ -358,8 +359,8 @@ def submit_api_call( current_user: AppUser, request_body: dict, enable_rate_limits: bool = False, - parent_pr: "PublishedRun" = None, deduct_credits: bool = True, + parent_pr: "PublishedRun" = None, ) -> tuple["celery.result.AsyncResult", "SavedRun"]: from routers.api import submit_api_call @@ -373,19 +374,21 @@ def submit_api_call( query_params = page_cls.clean_query_params( example_id=self.example_id, run_id=self.run_id, uid=self.uid ) - page, result, run_id, uid = pool.apply( + return pool.apply( submit_api_call, kwds=dict( page_cls=page_cls, query_params=query_params, - user=current_user, + current_user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, deduct_credits=deduct_credits, ), ) - return result, page.current_sr + def wait_for_celery_result(self, result: "celery.result.AsyncResult"): + get_celery_result_db_safe(result) + self.refresh_from_db() def get_creator(self) -> AppUser | None: if self.uid: @@ -1839,8 +1842,8 @@ def submit_api_call( current_user=current_user, request_body=request_body, enable_rate_limits=enable_rate_limits, - parent_pr=self, deduct_credits=deduct_credits, + parent_pr=self, ) diff --git a/bots/tasks.py b/bots/tasks.py index 70f381bf1..0a76e42bc 100644 --- a/bots/tasks.py +++ b/bots/tasks.py @@ -97,9 +97,7 @@ def msg_analysis(self, msg_id: int, anal_id: int, countdown: int | None): # save the run before the result is ready Message.objects.filter(id=msg_id).update(analysis_run=sr) - # wait for the result - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # if failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index e2f05a66f..90432553c 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -164,8 +164,7 @@ def __init__( self.run_user = run_user @classmethod - @property - def endpoint(cls) -> str: + def api_endpoint(cls) -> str: return f"/v2/{cls.slug_versions[0]}" def current_app_url( @@ -241,7 +240,9 @@ def api_url( query_params = dict(run_id=run_id, uid=uid) elif example_id: query_params = dict(example_id=example_id) - return furl(settings.API_BASE_URL, query_params=query_params) / cls.endpoint + return ( + furl(settings.API_BASE_URL, query_params=query_params) / cls.api_endpoint() + ) @classmethod def clean_query_params(cls, *, example_id, run_id, uid) -> dict: diff --git a/daras_ai_v2/bots.py b/daras_ai_v2/bots.py index eea020936..e8fb61c6b 100644 --- a/daras_ai_v2/bots.py +++ b/daras_ai_v2/bots.py @@ -1,6 +1,7 @@ import mimetypes import typing from datetime import datetime +from types import SimpleNamespace import gooey_gui as gui from django.db import transaction @@ -199,9 +200,7 @@ def get_input_documents(self) -> list[str] | None: def get_interactive_msg_info(self) -> ButtonPressed: raise NotImplementedError("This bot does not support interactive messages.") - def on_run_created( - self, page: BasePage, result: "celery.result.AsyncResult", run_id: str, uid: str - ): + def on_run_created(self, sr: "SavedRun"): pass def send_run_status(self, update_msg_id: str | None) -> str | None: @@ -376,13 +375,13 @@ def _process_and_send_msg( variables.update(bot.request_overrides["variables"]) except KeyError: pass - page, result, run_id, uid = submit_api_call( + result, sr = submit_api_call( page_cls=bot.page_cls, - user=billing_account_user, - request_body=body, query_params=bot.query_params, + current_user=billing_account_user, + request_body=body, ) - bot.on_run_created(page, result, run_id, uid) + bot.on_run_created(sr) if bot.show_feedback_buttons: buttons = _feedback_start_buttons() @@ -394,10 +393,10 @@ def _process_and_send_msg( last_idx = 0 # this is the last index of the text sent to the user if bot.streaming_enabled: # subscribe to the realtime channel for updates - channel = page.realtime_channel_name(run_id, uid) + channel = bot.page_cls.realtime_channel_name(sr.run_id, sr.uid) with gui.realtime_subscribe(channel) as realtime_gen: for state in realtime_gen: - bot.recipe_run_state = page.get_run_state(state) + bot.recipe_run_state = bot.page_cls.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" # check for errors if bot.recipe_run_state == RecipeRunState.failed: @@ -438,11 +437,10 @@ def _process_and_send_msg( break # we're done streaming, stop the loop # wait for the celery task to finish - get_celery_result_db_safe(result) + sr.wait_for_celery_result(result) # get the final state from db - sr = page.current_sr state = sr.to_dict() - bot.recipe_run_state = page.get_run_state(state) + bot.recipe_run_state = bot.page_cls.get_run_state(state) bot.run_status = state.get(StateKeys.run_status) or "" # check for errors err_msg = state.get(StateKeys.error_msg) diff --git a/daras_ai_v2/safety_checker.py b/daras_ai_v2/safety_checker.py index 8a8d67336..7faf6b66b 100644 --- a/daras_ai_v2/safety_checker.py +++ b/daras_ai_v2/safety_checker.py @@ -38,8 +38,7 @@ def safety_checker_text(text_input: str): ) # wait for checker - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # if checker failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) diff --git a/functions/recipe_functions.py b/functions/recipe_functions.py index b7fd36fdb..21d3fa185 100644 --- a/functions/recipe_functions.py +++ b/functions/recipe_functions.py @@ -63,8 +63,7 @@ def call_recipe_functions( # wait for the result if its a pre request function if trigger == FunctionTrigger.post: continue - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # if failed, raise error if sr.error_msg: raise RuntimeError(sr.error_msg) diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index f700effac..7ad9e67e9 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -2,10 +2,10 @@ import typing import uuid +import gooey_gui as gui from furl import furl from pydantic import BaseModel, Field -import gooey_gui as gui from bots.models import Workflow, SavedRun from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import icons @@ -322,8 +322,7 @@ def run_v2( request_body=request_body, parent_pr=pr, ) - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) run_time = datetime.timedelta( seconds=int(sr.run_time.total_seconds()) @@ -390,10 +389,11 @@ def run_v2( documents=response.output_documents ).dict(exclude_unset=True) result, sr = sr.submit_api_call( - current_user=self.request.user, request_body=request_body, parent_pr=pr + current_user=self.request.user, + request_body=request_body, + parent_pr=pr, ) - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) response.eval_runs.append(sr.get_app_url()) def preview_description(self, state: dict) -> str: diff --git a/routers/api.py b/routers/api.py index d1878bd53..68ad4c2b6 100644 --- a/routers/api.py +++ b/routers/api.py @@ -31,7 +31,7 @@ from app_users.models import AppUser from auth.token_authentication import api_auth_header -from bots.models import RetentionPolicy +from bots.models import RetentionPolicy, Workflow from daras_ai.image_input import upload_file_from_bytes from daras_ai_v2 import settings from daras_ai_v2.all_pages import all_api_pages @@ -41,9 +41,12 @@ ) from daras_ai_v2.fastapi_tricks import fastapi_request_form from functions.models import CalledFunctionResponse -from gooeysite.bg_db_conn import get_celery_result_db_safe from routers.custom_api_router import CustomAPIRouter +if typing.TYPE_CHECKING: + from bots.models import SavedRun + import celery.result + app = CustomAPIRouter() @@ -117,7 +120,7 @@ class RunSettings(BaseModel): def script_to_api(page_cls: typing.Type[BasePage]): - endpoint = page_cls().endpoint.rstrip("/") + endpoint = page_cls.api_endpoint().rstrip("/") # add the common settings to the request model request_model = create_model( page_cls.__name__ + "Request", @@ -156,15 +159,15 @@ def run_api_json( page_request: request_model, user: AppUser = Depends(api_auth_header), ): - page, result, run_id, uid = submit_api_call( + result, sr = submit_api_call( page_cls=page_cls, - user=user, - request_body=page_request.dict(exclude_unset=True), query_params=dict(request.query_params), retention_policy=RetentionPolicy[page_request.settings.retention_policy], + current_user=user, + request_body=page_request.dict(exclude_unset=True), enable_rate_limits=True, ) - return build_sync_api_response(page=page, result=result, run_id=run_id, uid=uid) + return build_sync_api_response(result, sr) @app.post( os.path.join(endpoint, "form"), @@ -205,15 +208,15 @@ def run_api_json_async( page_request: request_model, user: AppUser = Depends(api_auth_header), ): - page, _, run_id, uid = submit_api_call( + result, sr = submit_api_call( page_cls=page_cls, - user=user, - request_body=page_request.dict(exclude_unset=True), query_params=dict(request.query_params), retention_policy=RetentionPolicy[page_request.settings.retention_policy], + current_user=user, + request_body=page_request.dict(exclude_unset=True), enable_rate_limits=True, ) - ret = build_async_api_response(page=page, run_id=run_id, uid=uid) + ret = build_async_api_response(sr) response.headers["Location"] = ret["status_url"] response.headers["Access-Control-Expose-Headers"] = "Location" return ret @@ -332,19 +335,21 @@ def _parse_form_data( def submit_api_call( *, page_cls: typing.Type[BasePage], - request_body: dict, - user: AppUser, query_params: dict, retention_policy: RetentionPolicy = None, + current_user: AppUser, + request_body: dict, enable_rate_limits: bool = False, deduct_credits: bool = True, -) -> tuple[BasePage, "celery.result.AsyncResult", str, str]: +) -> tuple["celery.result.AsyncResult", "SavedRun"]: # init a new page for every request - query_params.setdefault("uid", user.uid) - self = page_cls(request=SimpleNamespace(user=user, query_params=query_params)) + query_params.setdefault("uid", current_user.uid) + page = page_cls( + request=SimpleNamespace(user=current_user, query_params=query_params) + ) # get saved state from db - state = self.current_sr_to_session_state() + state = page.current_sr_to_session_state() # load request data state.update(request_body) @@ -353,7 +358,7 @@ def submit_api_call( # create a new run try: - sr = self.create_new_run( + sr = page.create_new_run( enable_rate_limits=enable_rate_limits, is_api_call=True, retention_policy=retention_policy or RetentionPolicy.keep, @@ -361,20 +366,19 @@ def submit_api_call( except ValidationError as e: raise RequestValidationError(e.raw_errors, body=gui.session_state) from e # submit the task - result = self.call_runner_task(sr, deduct_credits=deduct_credits) - return self, result, sr.run_id, sr.uid + result = page.call_runner_task(sr, deduct_credits=deduct_credits) + return result, sr -def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: - web_url = page.app_url(run_id=run_id, uid=uid) +def build_async_api_response(sr: "SavedRun") -> dict: + web_url = sr.get_app_url() status_url = str( - furl(settings.API_BASE_URL, query_params=dict(run_id=run_id)) - / page.endpoint.replace("v2", "v3") + furl(settings.API_BASE_URL, query_params=dict(run_id=sr.run_id)) + / Workflow(sr.workflow).page_cls.api_endpoint().replace("v2", "v3") / "status/" ) - sr = page.current_sr return dict( - run_id=run_id, + run_id=sr.run_id, web_url=web_url, created_at=sr.created_at.isoformat(), status_url=status_url, @@ -382,17 +386,11 @@ def build_async_api_response(*, page: BasePage, run_id: str, uid: str) -> dict: def build_sync_api_response( - *, - page: BasePage, - result: "celery.result.AsyncResult", - run_id: str, - uid: str, + result: "celery.result.AsyncResult", sr: "SavedRun" ) -> JSONResponse: - web_url = page.app_url(run_id=run_id, uid=uid) + web_url = sr.get_app_url() # wait for the result - get_celery_result_db_safe(result) - sr = page.current_sr - sr.refresh_from_db() + sr.wait_for_celery_result(result) if sr.retention_policy == RetentionPolicy.delete: sr.state = {} sr.save(update_fields=["state"]) @@ -401,7 +399,7 @@ def build_sync_api_response( return JSONResponse( dict( detail=dict( - id=run_id, + id=sr.run_id, url=web_url, created_at=sr.created_at.isoformat(), error=sr.error_msg, @@ -414,7 +412,7 @@ def build_sync_api_response( return JSONResponse( jsonable_encoder( dict( - id=run_id, + id=sr.run_id, url=web_url, created_at=sr.created_at.isoformat(), output=sr.api_output(), diff --git a/routers/bots_api.py b/routers/bots_api.py index 64285cbd2..021caae4b 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from starlette.responses import StreamingResponse, Response -from bots.models import Platform, Conversation, BotIntegration, Message +from bots.models import Platform, Conversation, BotIntegration, Message, SavedRun from celeryapp.tasks import err_msg_for_exc from daras_ai_v2 import settings from daras_ai_v2.base import RecipeRunState, BasePage, StateKeys @@ -320,14 +320,10 @@ def runner(self): finally: self.queue.put(None) - def on_run_created( - self, page: BasePage, result: "celery.result.AsyncResult", run_id: str, uid: str - ): - self.run_id = run_id - self.uid = uid - self.queue.put( - RunStart(**build_async_api_response(page=page, run_id=run_id, uid=uid)) - ) + def on_run_created(self, sr: SavedRun): + self.run_id = sr.run_id + self.uid = sr.uid + self.queue.put(RunStart(**build_async_api_response(sr))) def send_run_status(self, update_msg_id: str | None) -> str | None: self.queue.put( diff --git a/routers/twilio_api.py b/routers/twilio_api.py index 0223636a9..108f3e47b 100644 --- a/routers/twilio_api.py +++ b/routers/twilio_api.py @@ -267,8 +267,7 @@ def resp_say_or_tts_play( request_body=tts_state, ) # wait for the TTS to finish - get_celery_result_db_safe(result) - sr.refresh_from_db() + sr.wait_for_celery_result(result) # check for errors if sr.error_msg: raise RuntimeError(sr.error_msg) From 76b205198c4a72588138d94621224b500067e525 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Fri, 6 Sep 2024 15:19:15 +0530 Subject: [PATCH 087/110] Remove usage of `SimpleNamespace` for request handling in BasePage Move BasePage.run_user -> cached property current_sr_user --- bots/admin.py | 7 ++---- celeryapp/tasks.py | 6 +---- daras_ai_v2/base.py | 54 +++++++++++++++++++++++++++++-------------- recipes/LipsyncTTS.py | 6 ++--- recipes/VideoBots.py | 8 ++----- routers/api.py | 11 ++------- routers/root.py | 29 ++++++++--------------- tests/test_pricing.py | 21 +++++------------ 8 files changed, 62 insertions(+), 80 deletions(-) diff --git a/bots/admin.py b/bots/admin.py index b7e696dfd..8b52e233f 100644 --- a/bots/admin.py +++ b/bots/admin.py @@ -1,6 +1,5 @@ import datetime import json -from types import SimpleNamespace import django.db.models from django import forms @@ -439,10 +438,8 @@ def rerun_tasks(self, request, queryset): sr: SavedRun for sr in queryset.all(): page = Workflow(sr.workflow).page_cls( - request=SimpleNamespace( - user=AppUser.objects.get(uid=sr.uid), - query_params=dict(run_id=sr.run_id, uid=sr.uid), - ) + user=AppUser.objects.get(uid=sr.uid), + query_params=dict(run_id=sr.run_id, uid=sr.uid), ) page.call_runner_task(sr, deduct_credits=False) self.message_user( diff --git a/celeryapp/tasks.py b/celeryapp/tasks.py index b2e257365..3fed1442f 100644 --- a/celeryapp/tasks.py +++ b/celeryapp/tasks.py @@ -4,7 +4,6 @@ import traceback import typing from time import time -from types import SimpleNamespace import gooey_gui as gui import requests @@ -92,10 +91,7 @@ def save_on_step(yield_val: str | tuple[str, dict] = None, *, done: bool = False page.dump_state_to_sr(gui.session_state | output, sr) page = page_cls( - request=SimpleNamespace( - user=AppUser.objects.get(id=user_id), - query_params=dict(run_id=run_id, uid=uid), - ), + user=AppUser.objects.get(id=user_id), query_params=dict(run_id=run_id, uid=uid) ) page.setup_sentry() sr = page.current_sr diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 90432553c..870debdb1 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -12,7 +12,6 @@ from itertools import pairwise from random import Random from time import sleep -from types import SimpleNamespace import gooey_gui as gui import sentry_sdk @@ -26,7 +25,6 @@ from sentry_sdk.tracing import ( TRANSACTION_SOURCE_ROUTE, ) -from starlette.requests import Request from app_users.models import AppUser, AppUserTransaction from bots.models import ( @@ -94,7 +92,6 @@ MAX_SEED = 4294967294 gooey_rng = Random() - SUBMIT_AFTER_LOGIN_Q = "submitafterlogin" @@ -117,6 +114,12 @@ class StateKeys: hidden = "__hidden" +class BasePageRequest: + user: AppUser | None + session: dict + query_params: dict + + class BasePage: title: str workflow: Workflow @@ -154,14 +157,20 @@ def __init__( self, *, tab: RecipeTabs = RecipeTabs.run, - request: Request | SimpleNamespace | None = None, - run_user: AppUser | None = None, + request: BasePageRequest | None = None, + user: AppUser | None = None, + request_session: dict | None = None, + query_params: dict | None = None, ): - if request is None: - request = SimpleNamespace(user=None, query_params={}) self.tab = tab + + if not request: + request = BasePageRequest() + request.user = user + request.session = request_session or {} + request.query_params = query_params or {} + self.request = request - self.run_user = run_user @classmethod def api_endpoint(cls) -> str: @@ -349,7 +358,7 @@ def _render_header(self): with gui.div(className="d-flex justify-content-between mt-4"): with gui.div(className="d-lg-flex d-block align-items-center"): - if not tbreadcrumbs.has_breadcrumbs() and not self.run_user: + if not tbreadcrumbs.has_breadcrumbs() and not self.current_sr_user: self._render_title(tbreadcrumbs.h1_title) if tbreadcrumbs: @@ -362,7 +371,7 @@ def _render_header(self): if is_example: author = pr.created_by else: - author = self.run_user or sr.get_creator() + author = self.current_sr_user or sr.get_creator() if not is_root_example: self.render_author(author) @@ -386,7 +395,7 @@ def _render_header(self): self._render_published_run_save_buttons(sr=sr, pr=pr) self._render_social_buttons(show_button_text=not show_save_buttons) - if tbreadcrumbs.has_breadcrumbs() or self.run_user: + if tbreadcrumbs.has_breadcrumbs() or self.current_sr_user: # only render title here if the above row was not empty self._render_title(tbreadcrumbs.h1_title) @@ -810,7 +819,7 @@ def get_explore_image(self) -> str: return meta_preview_url(img, fallback_img) def _user_disabled_check(self): - if self.run_user and self.run_user.is_disabled: + if self.current_sr_user and self.current_sr_user.is_disabled: msg = ( "This Gooey.AI account has been disabled for violating our [Terms of Service](/terms). " "Contact us at support@gooey.ai if you think this is a mistake." @@ -1009,7 +1018,7 @@ def render_report_form(self): send_reported_run_email( user=self.request.user, - run_uid=str(self.run_user.uid), + run_uid=str(self.current_sr_user.uid), url=self.current_app_url(), recipe_name=self.title, report_type=report_type, @@ -1052,11 +1061,22 @@ def update_flag_for_run(self, is_flagged: bool): sr.save(update_fields=["is_flagged"]) gui.session_state["is_flagged"] = is_flagged - @property + @cached_property + def current_sr_user(self) -> AppUser | None: + if not self.current_sr.uid: + return None + if self.request.user and self.request.user.uid == self.current_sr.uid: + return self.request.user + try: + return AppUser.objects.get(uid=self.current_sr.uid) + except AppUser.DoesNotExist: + return None + + @cached_property def current_sr(self) -> SavedRun: return self.current_sr_pr[0] - @property + @cached_property def current_pr(self) -> PublishedRun: return self.current_sr_pr[1] @@ -1571,7 +1591,7 @@ def create_new_run( uid = self.request.user.uid else: uid = auth.create_user().uid - self.request.scope["user"] = AppUser.objects.create( + self.request.user = AppUser.objects.create( uid=uid, is_anonymous=True, balance=settings.ANON_USER_FREE_CREDITS ) self.request.session[ANONYMOUS_USER_COOKIE] = dict(uid=uid) @@ -2138,7 +2158,7 @@ def is_current_user_paying(self) -> bool: return bool(self.request.user and self.request.user.is_paying) def is_current_user_owner(self) -> bool: - return bool(self.request.user and self.run_user == self.request.user) + return bool(self.request.user and self.current_sr_user == self.request.user) def started_at_text(dt: datetime.datetime): diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index 995f5d43f..d557cd663 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -122,12 +122,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]: if not self.request.user.disable_safety_checker: safety_checker(text=state["text_prompt"]) - yield from TextToSpeechPage(request=self.request, run_user=self.run_user).run( - state - ) + yield from TextToSpeechPage(request=self.request).run(state) # IMP: Copy output of TextToSpeechPage "audio_url" to Lipsync as "input_audio" state["input_audio"] = state["audio_url"] - yield from LipsyncPage(request=self.request, run_user=self.run_user).run(state) + yield from LipsyncPage(request=self.request).run(state) def render_example(self, state: dict): output_video = state.get("output_video") diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 26b7e785e..77b266c26 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -1015,9 +1015,7 @@ def run_v2( tts_state = TextToSpeechPage.RequestModel.parse_obj( {**gui.session_state, "text_prompt": text} ).dict() - yield from TextToSpeechPage( - request=self.request, run_user=self.run_user - ).run(tts_state) + yield from TextToSpeechPage(request=self.request).run(tts_state) response.output_audio.append(tts_state["audio_url"]) if not request.input_face: @@ -1031,9 +1029,7 @@ def run_v2( "selected_model": request.lipsync_model, } ).dict() - yield from LipsyncPage(request=self.request, run_user=self.run_user).run( - lip_state - ) + yield from LipsyncPage(request=self.request).run(lip_state) response.output_video.append(lip_state["output_video"]) def get_tabs(self): diff --git a/routers/api.py b/routers/api.py index 68ad4c2b6..73c5ec7b9 100644 --- a/routers/api.py +++ b/routers/api.py @@ -3,7 +3,6 @@ import os.path import os.path import typing -from types import SimpleNamespace import gooey_gui as gui from fastapi import Depends @@ -262,11 +261,7 @@ def get_run_status( user: AppUser = Depends(api_auth_header), ): # init a new page for every request - self = page_cls( - request=SimpleNamespace( - user=user, query_params=dict(run_id=run_id, uid=user.uid) - ) - ) + self = page_cls(user=user, query_params=dict(run_id=run_id, uid=user.uid)) sr = self.current_sr web_url = str(furl(self.app_url(run_id=run_id, uid=user.uid))) ret = { @@ -344,9 +339,7 @@ def submit_api_call( ) -> tuple["celery.result.AsyncResult", "SavedRun"]: # init a new page for every request query_params.setdefault("uid", current_user.uid) - page = page_cls( - request=SimpleNamespace(user=current_user, query_params=query_params) - ) + page = page_cls(user=current_user, query_params=query_params) # get saved state from db state = page.current_sr_to_session_state() diff --git a/routers/root.py b/routers/root.py index 9f678416b..40884fd55 100644 --- a/routers/root.py +++ b/routers/root.py @@ -242,7 +242,7 @@ def explore_page(request: Request): @gui.route(app, "/api/") def api_docs_page(request: Request): with page_wrapper(request): - _api_docs_page(request) + _api_docs_page() return dict( meta=raw_build_meta_tags( url=get_og_url_path(request), @@ -255,7 +255,7 @@ def api_docs_page(request: Request): ) -def _api_docs_page(request): +def _api_docs_page(): from daras_ai_v2.all_pages import all_api_pages api_docs_url = str(furl(settings.API_BASE_URL) / "docs") @@ -312,7 +312,7 @@ def _api_docs_page(request): as_async = gui.checkbox("Run Async") as_form_data = gui.checkbox("Upload Files via Form Data") - page = workflow.page_cls(request=request) + page = workflow.page_cls() state = page.get_root_pr().saved_run.to_dict() api_url, request_body = page.get_example_request(state, include_all=include_all) response_body = page.get_example_response_body( @@ -667,11 +667,13 @@ def render_recipe_page( ) return RedirectResponse(str(new_url.set(origin=None)), status_code=301) - # this is because the code still expects example_id to be in the query params - request._query_params = dict(request.query_params) | dict(example_id=example_id) - - page = page_cls(tab=tab, request=request) - page.run_user = get_run_user(request, page.current_sr.uid) + page = page_cls( + tab=tab, + user=request.user, + request_session=request.session, + # this is because the code still expects example_id to be in the query params + query_params=dict(request.query_params) | dict(example_id=example_id), + ) if not gui.session_state: gui.session_state.update(page.current_sr_to_session_state()) @@ -692,17 +694,6 @@ def get_og_url_path(request) -> str: ) -def get_run_user(request: Request, uid: str) -> AppUser | None: - if not uid: - return - if request.user and request.user.uid == uid: - return request.user - try: - return AppUser.objects.get(uid=uid) - except AppUser.DoesNotExist: - pass - - @contextmanager def page_wrapper(request: Request, className=""): context = { diff --git a/tests/test_pricing.py b/tests/test_pricing.py index 73e05e4a9..c78bf5376 100644 --- a/tests/test_pricing.py +++ b/tests/test_pricing.py @@ -1,6 +1,3 @@ -from types import SimpleNamespace - -import gooey_gui as gui import pytest from starlette.testclient import TestClient @@ -49,10 +46,8 @@ def test_copilot_get_raw_price_round_up(): dollar_amount=model_pricing.unit_cost * 1 / model_pricing.unit_quantity, ) copilot_page = VideoBotsPage( - request=SimpleNamespace( - user=user, - query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), - ), + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), ) assert ( copilot_page.get_price_roundoff(state=state) @@ -114,10 +109,8 @@ def test_multiple_llm_sums_usage_cost(): ) llm_page = CompareLLMPage( - request=SimpleNamespace( - user=user, - query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), - ) + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), ) assert llm_page.get_price_roundoff(state=state) == (310 + llm_page.PROFIT_CREDITS) @@ -163,10 +156,8 @@ def test_workflowmetadata_2x_multiplier(): metadata.save() llm_page = CompareLLMPage( - request=SimpleNamespace( - user=user, - query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), - ) + user=user, + query_params=dict(run_id=bot_saved_run.run_id or "", uid=user.uid or ""), ) assert ( llm_page.get_price_roundoff(state=state) == (210 + llm_page.PROFIT_CREDITS) * 2 From 2ce8187bb703ed913b0e179267596743541b6ce5 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 10 Sep 2024 12:58:25 +0530 Subject: [PATCH 088/110] add github icon to all workflow pages --- daras_ai_v2/base.py | 11 +++++++++++ daras_ai_v2/github_tools.py | 25 +++++++++++++++++++++++++ server.py | 19 +------------------ 3 files changed, 37 insertions(+), 18 deletions(-) create mode 100644 daras_ai_v2/github_tools.py diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 870debdb1..77571d235 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -51,6 +51,7 @@ ) from daras_ai_v2.exceptions import InsufficientCredits from daras_ai_v2.fastapi_tricks import get_route_path +from daras_ai_v2.github_tools import github_url_for_file from daras_ai_v2.grid_layout_widget import grid_layout from daras_ai_v2.gui_confirm import confirm_modal from daras_ai_v2.html_spinner_widget import html_spinner @@ -348,6 +349,16 @@ def render(self): with header_placeholder: self._render_header() + github_url = github_url_for_file(inspect.getfile(self.__class__)) + gui.html( + f""" + + + Fork me on GitHub + + """ + ) + def _render_header(self): sr, pr = self.current_sr_pr is_example = pr.saved_run == sr diff --git a/daras_ai_v2/github_tools.py b/daras_ai_v2/github_tools.py new file mode 100644 index 000000000..939848b08 --- /dev/null +++ b/daras_ai_v2/github_tools.py @@ -0,0 +1,25 @@ +import os +import traceback + +from furl import furl + +from daras_ai_v2 import settings + +GITHUB_REPO = "https://github.com/GooeyAI/gooey-server/" +_base_dir = str(settings.BASE_DIR) + + +def github_url_for_exc(exc: Exception) -> str | None: + for frame in reversed(traceback.extract_tb(exc.__traceback__)): + if not frame.filename.startswith(_base_dir): + continue + return github_url_for_file(frame.filename, frame.lineno) + return GITHUB_REPO + + +def github_url_for_file(filename: str, lineno: str | None = None) -> str: + ref = (os.environ.get("CAPROVER_GIT_COMMIT_SHA") or "master").strip() + path = os.path.relpath(filename, _base_dir) + return str( + furl(GITHUB_REPO, fragment_path=lineno and f"L{lineno}") / "blob" / ref / path + ) diff --git a/server.py b/server.py index a1d4ec2e5..e42d53221 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,3 @@ -import os import traceback from fastapi.exception_handlers import ( @@ -6,7 +5,6 @@ request_validation_exception_handler, ) from fastapi.exceptions import RequestValidationError -from furl import furl from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse @@ -15,6 +13,7 @@ HTTP_405_METHOD_NOT_ALLOWED, ) +from daras_ai_v2.github_tools import github_url_for_exc from daras_ai_v2.pydantic_validation import convert_errors from daras_ai_v2.settings import templates from gooeysite import wsgi @@ -158,19 +157,3 @@ async def _exc_handler(request: Request, exc: Exception, template_name: str): ), status_code=500, ) - - -GITHUB_REPO = "https://github.com/GooeyAI/gooey-server/" - - -def github_url_for_exc(exc: Exception) -> str | None: - base_dir = str(settings.BASE_DIR) - ref = (os.environ.get("CAPROVER_GIT_COMMIT_SHA") or "master").strip() - for frame in reversed(traceback.extract_tb(exc.__traceback__)): - if not frame.filename.startswith(base_dir): - continue - path = os.path.relpath(frame.filename, base_dir) - return str( - furl(GITHUB_REPO, fragment_path=f"L{frame.lineno}") / "blob" / ref / path - ) - return GITHUB_REPO From f00402158acafa220a26e263b9bd6385306a2b58 Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Tue, 10 Sep 2024 15:20:35 +0530 Subject: [PATCH 089/110] fix overlapping github fork button with account link --- daras_ai_v2/base.py | 19 +++++++++---------- routers/root.py | 2 +- templates/header.html | 2 +- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 77571d235..13e0a4ae6 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -349,16 +349,6 @@ def render(self): with header_placeholder: self._render_header() - github_url = github_url_for_file(inspect.getfile(self.__class__)) - gui.html( - f""" - - - Fork me on GitHub - - """ - ) - def _render_header(self): sr, pr = self.current_sr_pr is_example = pr.saved_run == sr @@ -456,6 +446,15 @@ def _render_unpublished_changes_indicator(self): def _render_social_buttons(self, show_button_text: bool = False): if show_button_text: + github_url = github_url_for_file(inspect.getfile(self.__class__)) + gui.anchor( + ' GitHub', + href=github_url, + unsafe_allow_html=True, + target="_blank", + type="tertiary", + ) + button_text = ' Copy Link' else: button_text = "" diff --git a/routers/root.py b/routers/root.py index 40884fd55..15aa5d0bd 100644 --- a/routers/root.py +++ b/routers/root.py @@ -711,7 +711,7 @@ def page_wrapper(request: Request, className=""): gui.html(templates.get_template("header.html").render(**context)) gui.html(copy_to_clipboard_scripts) - with gui.div(id="main-content", className="container " + className): + with gui.div(id="main-content", className="container-xxl " + className): yield gui.html(templates.get_template("footer.html").render(**context)) diff --git a/templates/header.html b/templates/header.html index bb5208be5..ea21fa876 100644 --- a/templates/header.html +++ b/templates/header.html @@ -1,6 +1,6 @@