Skip to content

Commit

Permalink
Update abstra-lib
Browse files Browse the repository at this point in the history
  • Loading branch information
abstra-bot committed Jul 5, 2024
1 parent 3edfa3b commit 8bf5560
Show file tree
Hide file tree
Showing 225 changed files with 564 additions and 455 deletions.
6 changes: 6 additions & 0 deletions abstra_internals/cloud_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def get_ai_messages(messages, stage, thread_id, headers: dict):
)


def generate_project(prompt: str, abstra_json_version: str, headers: dict):
url = f"{CLOUD_API_ENDPOINT}/cli/ai/generate"
body = {"prompt": prompt, "version": abstra_json_version}
return requests.post(url, headers=headers, json=body).json()


def create_thread(headers: dict):
url = f"{CLOUD_API_ENDPOINT}/cli/ai/thread"
return requests.post(url, headers=headers).json()
Expand Down
20 changes: 13 additions & 7 deletions abstra_internals/repositories/project/json_migrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def create_backup(data: dict, location: Path, filename: str):
return f"{CONFIG_FILE_BACKUPS}/{filename}"


def migrate(data: dict, path: Path):
def migrate(data: dict, path: Path, verbose=True):
if "version" not in data:
data["version"] = "0.1"

Expand All @@ -77,14 +77,17 @@ def migrate(data: dict, path: Path):
if not next_migration:
return data

print("Your abstra.json is outdated, running migrations...")
if verbose:
print("Your abstra.json is outdated, running migrations...")

try:
filename = f"abstra_{datetime.now().strftime('%Y%m%d%H%M%S')}.json.backup"
backup_path = create_backup(data, path, filename)
print(f"Backup file created: {backup_path}")
if verbose:
print(f"Backup file created: {backup_path}")
except Exception as e:
print(f"Failed to create backup file: {e}")
if verbose:
print(f"Failed to create backup file: {e}")
raise

while next_migration:
Expand All @@ -100,15 +103,18 @@ def migrate(data: dict, path: Path):

except Exception as e:
AbstraLogger.capture_exception(e)
print(f"An error occurred during migration from {current_version}: {e}")
if verbose:
print(f"An error occurred during migration from {current_version}: {e}")
raise

print(f"Upgrade to version {data['version']} complete.")
if verbose:
print(f"Upgrade to version {data['version']} complete.")

next_migration = next_migration = next(
(m for m in MIGRATIONS if m.source_version() == data["version"]), None
)

print("Your abstra.json is up to date ✅")
if verbose:
print("Your abstra.json is up to date ✅")

return data
66 changes: 39 additions & 27 deletions abstra_internals/repositories/project/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ def delete_stage(self, id: str, remove_file: bool = False):
self.conditions = [c for c in self.conditions if c.id != id]

@staticmethod
def from_dict(data: dict):
def __from_dict(data: dict):
target_stages = set()
nodes = []
edges = []
Expand All @@ -1283,31 +1283,42 @@ def from_dict(data: dict):
else:
data[key][index]["is_initial"] = True

scripts = [ScriptStage.from_dict(script) for script in data["scripts"]]
forms = [FormStage.from_dict(form) for form in data["forms"]]
hooks = [HookStage.from_dict(hook) for hook in data["hooks"]]
jobs = [JobStage.from_dict(job) for job in data["jobs"]]
iterators = [IteratorStage.from_dict(i) for i in data["iterators"]]
conditions = [ConditionStage.from_dict(c) for c in data["conditions"]]

workspace = StyleSettings.from_dict(data["workspace"])
kanban = KanbanView.from_dict(data.get("kanban", {}))
home = Home.from_dict(data.get("home", {}))

return Project(
workspace=workspace,
scripts=scripts,
forms=forms,
hooks=hooks,
jobs=jobs,
iterators=iterators,
conditions=conditions,
kanban=kanban,
home=home,
_graph=Graph.from_primitives(nodes=nodes, edges=edges),
)

@staticmethod
def validate(data: dict):
try:
scripts = [ScriptStage.from_dict(script) for script in data["scripts"]]
forms = [FormStage.from_dict(form) for form in data["forms"]]
hooks = [HookStage.from_dict(hook) for hook in data["hooks"]]
jobs = [JobStage.from_dict(job) for job in data["jobs"]]
iterators = [IteratorStage.from_dict(i) for i in data["iterators"]]
conditions = [ConditionStage.from_dict(c) for c in data["conditions"]]

workspace = StyleSettings.from_dict(data["workspace"])
kanban = KanbanView.from_dict(data.get("kanban", {}))
home = Home.from_dict(data.get("home", {}))

return Project(
workspace=workspace,
scripts=scripts,
forms=forms,
hooks=hooks,
jobs=jobs,
iterators=iterators,
conditions=conditions,
kanban=kanban,
home=home,
_graph=Graph.from_primitives(nodes=nodes, edges=edges),
)
Project.__from_dict(data)
return True
except Exception:
return False

@staticmethod
def from_dict(data: dict):
try:
return Project.__from_dict(data)
except Exception as e:
print("Error: incompatible abstra.json file.")
AbstraLogger.capture_exception(e)
Expand Down Expand Up @@ -1372,7 +1383,7 @@ def save(cls, project: Project):
Path.rmdir(temp_file.parent)

@classmethod
def migrate_config_file(cls):
def migrate_config_file(cls, verbose=True):
if not cls.exists():
return
data = json.loads(cls.get_file_path().read_text(encoding="utf-8"))
Expand All @@ -1381,6 +1392,7 @@ def migrate_config_file(cls):
migrated_data = json_migrations.migrate(
data,
Settings.root_path,
verbose=verbose,
)

if migrated_data["version"] != initial_version:
Expand All @@ -1393,8 +1405,8 @@ def load(cls) -> Project:
return Project.from_dict(data)

@classmethod
def initialize_or_migrate(cls):
def initialize_or_migrate(cls, verbose=True):
if not cls.exists():
cls.initialize()
else:
cls.migrate_config_file()
cls.migrate_config_file(verbose=verbose)
93 changes: 92 additions & 1 deletion abstra_internals/server/controller/ai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,88 @@
import json
from dataclasses import dataclass
from typing import Any, Dict, List

import flask

from abstra_internals.cloud_api import create_thread, generate_project, get_ai_messages
from abstra_internals.credentials import resolve_headers
from abstra_internals.repositories.project.json_migrations import get_latest_version
from abstra_internals.repositories.project.project import Project, ProjectRepository
from abstra_internals.server.controller.linters import fix_all_linters
from abstra_internals.server.controller.main import MainController
from abstra_internals.settings import Settings
from abstra_internals.usage import usage


def get_editor_bp(controller: MainController):
@dataclass
class PythonFile:
filename: str
content: str
stage: str

@staticmethod
def from_dict(data: Dict[str, Any]) -> "PythonFile":
return PythonFile(data["filename"], data["content"], data["stage"])


class AiController:
def __init__(self, controller: MainController):
self.controller = controller

def send_ai_message(self, messages, stage, thread_id):
headers = resolve_headers() or {}
yield from get_ai_messages(messages, stage, thread_id, headers)

def create_thread(self):
headers = resolve_headers()
if headers is None:
return None
return create_thread(headers)

def generate_project(self, prompt: str, tries: int = 0):
headers = resolve_headers() or {}
abstra_json_version = get_latest_version()
try:
res = generate_project(prompt, abstra_json_version, headers)
is_abstra_json_valid = Project.validate(res["abstra_json"])
if not is_abstra_json_valid:
raise Exception("Generated abstra.json is not valid")

generated_abstra_json = json.dumps(res["abstra_json"], indent=4)
python_files = [PythonFile.from_dict(file) for file in res["python_files"]]

self.init_stages(python_files)
Settings.root_path.joinpath("abstra.json").write_text(
generated_abstra_json, encoding="utf-8"
)

ProjectRepository.initialize_or_migrate(verbose=False)
fix_all_linters()
except Exception as e:
if tries >= 3:
raise e
return self.generate_project(prompt, tries + 1)

def init_stages(self, python_files: List[PythonFile]):
for file in python_files:
if file.stage == "form":
script = self.controller.create_form(file.filename[:-3], file.filename)
elif file.stage == "hook":
script = self.controller.create_hook(file.filename[:-3], file.filename)
elif file.stage == "script":
script = self.controller.create_script(
file.filename[:-3], file.filename
)
elif file.stage == "job":
script = self.controller.create_job(file.filename[:-3], file.filename)
else:
raise Exception(f"Invalid stage {file.stage}")
script.file_path.write_text(file.content, encoding="utf-8")


def get_editor_bp(main_controller: MainController):
bp = flask.Blueprint("editor_ai", __name__)
controller = AiController(main_controller)

@bp.post("/message")
@usage
Expand All @@ -31,4 +108,18 @@ def _create_thread():
flask.abort(403)
return thread

@bp.post("/generate")
@usage
def _generate():
body = flask.request.json
if not body:
flask.abort(400)

prompt = body.get("prompt")
if not prompt:
flask.abort(400)

controller.generate_project(prompt)
return {"success": True}

return bp
8 changes: 8 additions & 0 deletions abstra_internals/server/controller/linters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def fix_linter(rule_name: str, fix_name: str):
raise Exception(f"Could not find fix {fix_name} for rule {rule_name}")


def fix_all_linters():
for rule in rules:
if rule.type != "info":
for issue in rule.find_issues():
for fix in issue.fixes:
fix.fix()


def get_editor_bp():
bp = flask.Blueprint("editor_linters", __name__)

Expand Down
18 changes: 1 addition & 17 deletions abstra_internals/server/controller/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
import pkg_resources
from werkzeug.datastructures import FileStorage

from abstra_internals.cloud_api import (
create_thread,
get_ai_messages,
get_api_key_info,
get_project_info,
)
from abstra_internals.cloud_api import get_api_key_info, get_project_info
from abstra_internals.controllers.workflow import WorkflowEngine
from abstra_internals.credentials import (
delete_credentials,
Expand Down Expand Up @@ -460,17 +455,6 @@ def get_project_info(self):
flask.abort(401)
return get_project_info(headers)

# AI
def send_ai_message(self, messages, stage, thread_id):
headers = resolve_headers() or {}
yield from get_ai_messages(messages, stage, thread_id, headers)

def create_thread(self):
headers = resolve_headers()
if headers is None:
return None
return create_thread(headers)

# access_control
def list_access_controls(self):
project = ProjectRepository.load()
Expand Down
Loading

0 comments on commit 8bf5560

Please sign in to comment.