Skip to content

Commit

Permalink
version 1.0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
vishesh9131 committed Aug 3, 2024
1 parent a3215e0 commit 6e94d59
Show file tree
Hide file tree
Showing 11 changed files with 1,223 additions and 126 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/.DS_Store
vercel
Binary file added VISH.pth
Binary file not shown.
72 changes: 72 additions & 0 deletions api/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Text Generation with Transformer</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 0;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
height: 100vh;
background-color: #f0f0f0;
}
.container {
background: white;
padding: 20px;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}
input, button, select {
margin: 10px 0;
padding: 10px;
width: 100%;
box-sizing: border-box;
}
button {
background-color: #007bff;
color: white;
border: none;
cursor: pointer;
}
button:hover {
background-color: #0056b3;
}
</style>
</head>
<body>
<div class="container">
<h1>Text Generation with Transformer</h1>
<input type="text" id="seedText" placeholder="Enter seed text" value="Once upon a time">
<input type="number" id="numWords" placeholder="Number of words to generate" value="50">
<input type="number" id="temperature" placeholder="Temperature" step="0.1" value="1.0">
<input type="number" id="topP" placeholder="Top-p (nucleus) sampling" step="0.1" value="0.9">
<button onclick="generateText()">Generate Text</button>
<p id="generatedText"></p>
</div>
<script>
async function generateText() {
const seedText = document.getElementById('seedText').value;
const numWords = parseInt(document.getElementById('numWords').value);
const temperature = parseFloat(document.getElementById('temperature').value);
const topP = parseFloat(document.getElementById('topP').value);

const response = await fetch('https://vish-ih1tg4m9j-vishesh9131s-projects.vercel.app/generate_text/', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ seed_text: seedText, num_words: numWords, temperature: temperature, top_p: topP })
});

const data = await response.json();
document.getElementById('generatedText').innerText = data.generated_text;
}
</script>
</body>
</html>
169 changes: 169 additions & 0 deletions api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch.nn as nn
import json
import torch
import numpy as np

# Define the TransformerEncoder and other necessary classes
class TransformerEncoder(nn.Module):
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
super(TransformerEncoder, self).__init__()
self.att = nn.MultiheadAttention(embed_dim, num_heads)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, embed_dim)
)
self.layernorm1 = nn.LayerNorm(embed_dim)
self.layernorm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(rate)
self.dropout2 = nn.Dropout(rate)

def forward(self, x):
attn_output, _ = self.att(x, x, x)
attn_output = self.dropout1(attn_output)
out1 = self.layernorm1(x + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output)
return self.layernorm2(out1 + ffn_output)

class TransformerModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, max_length):
super(TransformerModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, max_length)
self.transformer_encoder = TransformerEncoder(embed_dim=embedding_dim, num_heads=8, ff_dim=512)
self.global_avg_pool = nn.AdaptiveAvgPool1d(1)
self.fc = nn.Linear(embedding_dim, vocab_size)
self.softmax = nn.Softmax(dim=-1)

def forward(self, x):
x = self.embedding(x)
x = x.permute(1, 0, 2) # (batch_size, seq_len, embed_dim) -> (seq_len, batch_size, embed_dim)
x = self.transformer_encoder(x)
x = x.permute(1, 2, 0) # (seq_len, batch_size, embed_dim) -> (batch_size, embed_dim, seq_len)
x = self.global_avg_pool(x).squeeze(-1)
x = self.fc(x)
return self.softmax(x)

class Tokenizer:
def __init__(self):
self.word_index = {}
self.index_word = {}
self.num_words = 0

def fit_on_texts(self, texts):
word_freq = {}
for text in texts:
words = text.split()
for word in words:
if word not in word_freq:
word_freq[word] = 1
else:
word_freq[word] += 1
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
self.word_index = {word: idx + 1 for idx, (word, _) in enumerate(sorted_words)}
self.index_word = {idx: word for word, idx in self.word_index.items()}
self.num_words = len(self.word_index) + 1

def texts_to_sequences(self, texts):
sequences = []
for text in texts:
sequences.append([self.word_index.get(word, 0) for word in text.split()])
return sequences

def pad_sequences(sequences, maxlen, padding='post'):
padded_sequences = np.zeros((len(sequences), maxlen), dtype=int)
for i, seq in enumerate(sequences):
if len(seq) > maxlen:
padded_sequences[i] = seq[:maxlen]
else:
if padding == 'post':
padded_sequences[i, :len(seq)] = seq
elif padding == 'pre':
padded_sequences[i, -len(seq):] = seq
return padded_sequences

def generate_text(model, tokenizer, seed_text, max_length, num_words, device, temperature=1.0, top_p=0.9):
model.eval()
seed_sequence = tokenizer.texts_to_sequences([seed_text])[0]
generated_text = seed_text

for _ in range(num_words):
padded_sequence = pad_sequences([seed_sequence], maxlen=max_length, padding='post')
padded_sequence = torch.tensor(padded_sequence, dtype=torch.long).to(device)
with torch.no_grad():
predicted_probs = model(padded_sequence).cpu().numpy()[0]

# Apply temperature
predicted_probs = np.log(predicted_probs + 1e-9) / temperature
predicted_probs = np.exp(predicted_probs) / np.sum(np.exp(predicted_probs))

# Top-p (nucleus) sampling
sorted_indices = np.argsort(predicted_probs)[::-1]
cumulative_probs = np.cumsum(predicted_probs[sorted_indices])
top_p_indices = sorted_indices[cumulative_probs <= top_p]
if len(top_p_indices) == 0:
top_p_indices = sorted_indices[:1]
top_p_probs = predicted_probs[top_p_indices]
top_p_probs = top_p_probs / np.sum(top_p_probs)
predicted_word_index = np.random.choice(top_p_indices, p=top_p_probs)

predicted_word = tokenizer.index_word.get(predicted_word_index, '')

if predicted_word == '':
break

seed_sequence.append(predicted_word_index)
seed_sequence = seed_sequence[1:]
generated_text += ' ' + predicted_word

return generated_text

# Load the tokenizer
with open('tokenizer.json', 'r') as f:
word_index = json.load(f)
tokenizer = Tokenizer()
tokenizer.word_index = word_index
tokenizer.index_word = {v: k for k, v in word_index.items()}

# Model parameters
vocab_size = len(tokenizer.word_index) + 1
embedding_dim = 96
max_length = 100

# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransformerModel(vocab_size, embedding_dim, max_length).to(device)
model.load_state_dict(torch.load('VISH.pth', map_location=device))
model.eval()

app = FastAPI()

# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allows all origins
allow_credentials=True,
allow_methods=["*"], # Allows all methods
allow_headers=["*"], # Allows all headers
)

class TextGenerationRequest(BaseModel):
seed_text: str
num_words: int
temperature: float
top_p: float

@app.post("/generate_text/")
def generate_text_endpoint(request: TextGenerationRequest):
try:
generated_text = generate_text(model, tokenizer, request.seed_text, max_length, request.num_words, device, request.temperature, request.top_p)
return {"generated_text": generated_text}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
15 changes: 15 additions & 0 deletions api/vercel.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"version": 2,
"builds": [
{
"src": "main.py",
"use": "@vercel/python"
}
],
"routes": [
{
"src": "/(.*)",
"dest": "main.py"
}
]
}
Loading

0 comments on commit 6e94d59

Please sign in to comment.