from fastapi import FastAPI
from pydantic import BaseModel
import model

app = FastAPI()

class Input(BaseModel):
    generate_tokens_limit: int = 100
    top_p: float = 0.7
    top_k: float = 0
    temperature: float = 1.0
    text: str


@app.post("/generate/")
async def generate(input: Input):
    # we intentionally make non-await call to model, on GPU implementation it can't be paralelized
    # for parallel generation please check running GPT-J on Google TPU https://github.com/kingoflolz/mesh-transformer-jax
    try:
        output = model.eval(input)
        return {"completion": output}
    except Exception as e:
        return {"error": str(e)}