Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Node post methods #25

Merged
merged 25 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ repos:
- passlib
- pytest~=3.6,<5.0.0
- sphinx<4
- importlib_metadata~=4.3
exclude: >
(?x)^(
docs/.*|
Expand Down
3 changes: 2 additions & 1 deletion aiida_restapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from fastapi import FastAPI

from aiida_restapi.graphql import main
from aiida_restapi.routers import auth, computers, groups, users
from aiida_restapi.routers import auth, computers, groups, nodes, users

app = FastAPI()
app.include_router(auth.router)
app.include_router(computers.router)
app.include_router(nodes.router)
app.include_router(groups.router)
app.include_router(users.router)
app.add_route("/graphql", main.app, name="graphql")
137 changes: 137 additions & 0 deletions aiida_restapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,47 @@
"""
# pylint: disable=too-few-public-methods

import inspect
import io
from datetime import datetime
from typing import ClassVar, Dict, List, Optional, Type, TypeVar
from uuid import UUID

from aiida import orm
from aiida.restapi.common.identifiers import load_entry_point_from_full_type
from fastapi import Form
from pydantic import BaseModel, Field

# Template type for subclasses of `AiidaModel`
ModelType = TypeVar("ModelType", bound="AiidaModel")


def as_form(cls: Type[BaseModel]) -> Type[BaseModel]:
"""
Adds an as_form class method to decorated models. The as_form class method
can be used with FastAPI endpoints

Note: Taken from https://github.com/tiangolo/fastapi/issues/2387
"""
new_params = [
inspect.Parameter(
field.alias,
inspect.Parameter.POSITIONAL_ONLY,
default=(Form(field.default) if not field.required else Form(...)),
)
for field in cls.__fields__.values()
]

async def _as_form(**data: Dict) -> BaseModel:
return cls(**data)

sig = inspect.signature(_as_form)
sig = sig.replace(parameters=new_params)
_as_form.__signature__ = sig # type: ignore
setattr(cls, "as_form", _as_form)
return cls


class AiidaModel(BaseModel):
"""A mapping of an AiiDA entity to a pydantic model."""

Expand All @@ -26,6 +56,7 @@ class Config:
"""The models configuration."""

orm_mode = True
extra = "forbid"

@classmethod
def get_projectable_properties(cls) -> List[str]:
Expand Down Expand Up @@ -88,6 +119,11 @@ class User(AiidaModel):

_orm_entity = orm.User

class Config:
"""The models configuration."""

extra = "allow"

id: Optional[int] = Field(description="Unique user id (pk)")
email: str = Field(description="Email address of the user")
first_name: Optional[str] = Field(description="First name of the user")
Expand Down Expand Up @@ -119,6 +155,107 @@ class Computer(AiidaModel):
description="General settings for these communication and management protocols"
)

description: Optional[str] = Field(description="Description of node")


class Node(AiidaModel):
"""AiiDA Node Model."""

_orm_entity = orm.Node

id: Optional[int] = Field(description="Unique id (pk)")
uuid: Optional[UUID] = Field(description="Unique uuid")
node_type: Optional[str] = Field(description="Node type")
process_type: Optional[str] = Field(description="Process type")
label: str = Field(description="Label of node")
description: Optional[str] = Field(description="Description of node")
ctime: Optional[datetime] = Field(description="Creation time")
mtime: Optional[datetime] = Field(description="Last modification time")
user_id: Optional[int] = Field(description="Created by user id (pk)")
dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)")
attributes: Optional[Dict] = Field(
description="Variable attributes of the node",
)
extras: Optional[Dict] = Field(
description="Variable extras (unsealed) of the node",
)


@as_form
class Node_Post(AiidaModel):
"""AiiDA model for posting Nodes."""

node_type: Optional[str] = Field(description="Node type")
process_type: Optional[str] = Field(description="Process type")
label: str = Field(description="Label of node")
description: Optional[str] = Field(description="Description of node")
user_id: Optional[int] = Field(description="Created by user id (pk)")
dbcomputer_id: Optional[int] = Field(description="Associated computer id (pk)")
attributes: Optional[Dict] = Field(
description="Variable attributes of the node",
)
extras: Optional[Dict] = Field(
description="Variable extras (unsealed) of the node",
)

@classmethod
def create_new_node(
cls: Type[ModelType],
node_type: str,
node_dict: dict,
) -> orm.Node:
"Create and Store new Node"

orm_class = load_entry_point_from_full_type(node_type)
attributes = node_dict.pop("attributes", {})
extras = node_dict.pop("extras", {})

if issubclass(orm_class, orm.BaseType):
orm_object = orm_class(
attributes["value"],
**node_dict,
)
elif issubclass(orm_class, orm.Dict):
orm_object = orm_class(
dict=attributes,
**node_dict,
)
elif issubclass(orm_class, orm.Code):
orm_object = orm_class()
orm_object.set_remote_computer_exec(
(
orm.Computer.get(id=node_dict.get("dbcomputer_id")),
attributes["remote_exec_path"],
)
)
orm_object.label = node_dict.get("label")
else:
orm_object = load_entry_point_from_full_type(node_type)(**node_dict)
orm_object.set_attribute_many(attributes)

orm_object.set_extra_many(extras)
orm_object.store()
return orm_object

@classmethod
def create_new_node_with_file(
cls: Type[ModelType],
node_type: str,
node_dict: dict,
file: bytes,
) -> orm.Node:
"Create and Store new Node with file"
attributes = node_dict.pop("attributes", {})
extras = node_dict.pop("extras", {})

orm_object = load_entry_point_from_full_type(node_type)(
file=io.BytesIO(file), **node_dict, **attributes
)

orm_object.set_extra_many(extras)
orm_object.store()
return orm_object


class Group(AiidaModel):
"""AiiDA Group model."""
Expand Down
1 change: 1 addition & 0 deletions aiida_restapi/routers/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async def read_group(group_id: int) -> Optional[Group]:


@router.post("/groups", response_model=Group)
@with_dbenv()
async def create_group(
group: Group_Post,
current_user: User = Depends(
Expand Down
97 changes: 97 additions & 0 deletions aiida_restapi/routers/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
"""Declaration of FastAPI application."""
from typing import List, Optional

from aiida import orm
from aiida.cmdline.utils.decorators import with_dbenv
from fastapi import APIRouter, Depends, File, HTTPException
from importlib_metadata import entry_points

from aiida_restapi import models

from .auth import get_current_active_user

router = APIRouter()

ENTRY_POINTS = entry_points()


@router.get("/nodes", response_model=List[models.Node])
@with_dbenv()
async def read_nodes() -> List[models.Node]:
"""Get list of all nodes"""
return models.Node.get_entities()


@router.get("/nodes/projectable_properties", response_model=List[str])
async def get_nodes_projectable_properties() -> List[str]:
"""Get projectable properties for nodes endpoint"""

return models.Node.get_projectable_properties()


@router.get("/nodes/{nodes_id}", response_model=models.Node)
@with_dbenv()
async def read_node(nodes_id: int) -> Optional[models.Node]:
"""Get nodes by id."""
qbobj = orm.QueryBuilder()

qbobj.append(orm.Node, filters={"id": nodes_id}, project=["**"], tag="node").limit(
1
)
return qbobj.dict()[0]["node"]


@router.post("/nodes", response_model=models.Node)
@with_dbenv()
async def create_node(
node: models.Node_Post,
current_user: models.User = Depends(
get_current_active_user
), # pylint: disable=unused-argument
) -> models.Node:
"""Create new AiiDA node."""

node_dict = node.dict(exclude_unset=True)
node_type = node_dict.pop("node_type", None)

try:
(entry_point_node,) = ENTRY_POINTS.select(
group="aiida.rest.post", name=node_type
)
except ValueError as exc:
raise HTTPException(
status_code=404, detail="Entry point '{}' not recognized.".format(node_type)
) from exc

try:
orm_object = entry_point_node.load().create_new_node(node_type, node_dict)
except (TypeError, ValueError, KeyError) as err:
raise HTTPException(status_code=400, detail="Error: {0}".format(err)) from err

return models.Node.from_orm(orm_object)


@router.post("/nodes/singlefile", response_model=models.Node)
@with_dbenv()
async def create_upload_file(
upload_file: bytes = File(...),
params: models.Node_Post = Depends(models.Node_Post.as_form), # type: ignore # pylint: disable=maybe-no-member
current_user: models.User = Depends(
get_current_active_user
), # pylint: disable=unused-argument
) -> models.Node:
"""Endpoint for uploading file data"""
node_dict = params.dict(exclude_unset=True, exclude_none=True)
node_type = node_dict.pop("node_type", None)

try:
(entry_point_node,) = entry_points(group="aiida.rest.post", name=node_type)
except KeyError as exc:
raise KeyError("Entry point '{}' not recognized.".format(node_type)) from exc

orm_object = entry_point_node.load().create_new_node_with_file(
node_type, node_dict, upload_file
)

return models.Node.from_orm(orm_object)
1 change: 1 addition & 0 deletions aiida_restapi/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ async def read_user(user_id: int) -> Optional[User]:


@router.post("/users", response_model=User)
@with_dbenv()
async def create_user(
user: User,
current_user: User = Depends(
Expand Down
11 changes: 11 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ def default_groups():
return [group_1.id, group_2.id]


@pytest.fixture(scope="function")
def default_nodes():
"""Populate database with some nodes."""
node_1 = orm.Int(1).store()
node_2 = orm.Float(1.1).store()
node_3 = orm.Str("test_string").store()
node_4 = orm.Bool(False).store()

return [node_1.id, node_2.id, node_3.id, node_4.id]


@pytest.fixture(scope="function")
def authenticate():
"""Authenticate user.
Expand Down
15 changes: 14 additions & 1 deletion setup.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@
],
"aiida.cmdline.data": [
"restapi = aiida_restapi.cli:data_cli"
],
"aiida.rest.post": [
"data.str.Str.| = aiida_restapi.models:Node_Post",
"data.float.Float.| = aiida_restapi.models:Node_Post",
"data.int.Int.| = aiida_restapi.models:Node_Post",
"data.bool.Bool.| = aiida_restapi.models:Node_Post",
"data.structure.StructureData.| = aiida_restapi.models:Node_Post",
"data.orbital.OrbitalData.| = aiida_restapi.models:Node_Post",
"data.list.List.| = aiida_restapi.models:Node_Post",
"data.dict.Dict.| = aiida_restapi.models:Node_Post",
"data.singlefile.SingleFileData.| = aiida_restapi.models:Node_Post",
"data.code.Code.| = aiida_restapi.models:Node_Post"
]
},
"include_package_data": true,
Expand All @@ -36,7 +48,8 @@
"pydantic~=1.8.2",
"graphene~=2.0",
"python-dateutil~=2.0",
"lark~=0.11.0"
"lark~=0.11.0",
"importlib_metadata~=4.3"
],
"extras_require": {
"testing": [
Expand Down
1 change: 1 addition & 0 deletions tests/test_computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_get_computers_projectable(client):
"scheduler_type",
"transport_type",
"metadata",
"description",
]


Expand Down
2 changes: 1 addition & 1 deletion tests/test_graphql/test_full/test_full.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
data:
aiidaVersion: 1.6.4
aiidaVersion: 1.6.5
node:
label: node 1
3 changes: 2 additions & 1 deletion tests/test_models/test_computer_get_entities.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
- hostname: localhost_1
- description: ''
hostname: localhost_1
id: int
metadata: {}
name: test_comp_1
Expand Down
Loading