-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a3215e0
commit 6e94d59
Showing
11 changed files
with
1,223 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
/.DS_Store | ||
vercel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
<!DOCTYPE html> | ||
<html lang="en"> | ||
<head> | ||
<meta charset="UTF-8"> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||
<title>Text Generation with Transformer</title> | ||
<style> | ||
body { | ||
font-family: Arial, sans-serif; | ||
margin: 0; | ||
padding: 0; | ||
display: flex; | ||
flex-direction: column; | ||
align-items: center; | ||
justify-content: center; | ||
height: 100vh; | ||
background-color: #f0f0f0; | ||
} | ||
.container { | ||
background: white; | ||
padding: 20px; | ||
border-radius: 8px; | ||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | ||
} | ||
input, button, select { | ||
margin: 10px 0; | ||
padding: 10px; | ||
width: 100%; | ||
box-sizing: border-box; | ||
} | ||
button { | ||
background-color: #007bff; | ||
color: white; | ||
border: none; | ||
cursor: pointer; | ||
} | ||
button:hover { | ||
background-color: #0056b3; | ||
} | ||
</style> | ||
</head> | ||
<body> | ||
<div class="container"> | ||
<h1>Text Generation with Transformer</h1> | ||
<input type="text" id="seedText" placeholder="Enter seed text" value="Once upon a time"> | ||
<input type="number" id="numWords" placeholder="Number of words to generate" value="50"> | ||
<input type="number" id="temperature" placeholder="Temperature" step="0.1" value="1.0"> | ||
<input type="number" id="topP" placeholder="Top-p (nucleus) sampling" step="0.1" value="0.9"> | ||
<button onclick="generateText()">Generate Text</button> | ||
<p id="generatedText"></p> | ||
</div> | ||
<script> | ||
async function generateText() { | ||
const seedText = document.getElementById('seedText').value; | ||
const numWords = parseInt(document.getElementById('numWords').value); | ||
const temperature = parseFloat(document.getElementById('temperature').value); | ||
const topP = parseFloat(document.getElementById('topP').value); | ||
|
||
const response = await fetch('https://vish-ih1tg4m9j-vishesh9131s-projects.vercel.app/generate_text/', { | ||
method: 'POST', | ||
headers: { | ||
'Content-Type': 'application/json' | ||
}, | ||
body: JSON.stringify({ seed_text: seedText, num_words: numWords, temperature: temperature, top_p: topP }) | ||
}); | ||
|
||
const data = await response.json(); | ||
document.getElementById('generatedText').innerText = data.generated_text; | ||
} | ||
</script> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
from fastapi import FastAPI, HTTPException | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from pydantic import BaseModel | ||
import torch.nn as nn | ||
import json | ||
import torch | ||
import numpy as np | ||
|
||
# Define the TransformerEncoder and other necessary classes | ||
class TransformerEncoder(nn.Module): | ||
def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): | ||
super(TransformerEncoder, self).__init__() | ||
self.att = nn.MultiheadAttention(embed_dim, num_heads) | ||
self.ffn = nn.Sequential( | ||
nn.Linear(embed_dim, ff_dim), | ||
nn.ReLU(), | ||
nn.Linear(ff_dim, embed_dim) | ||
) | ||
self.layernorm1 = nn.LayerNorm(embed_dim) | ||
self.layernorm2 = nn.LayerNorm(embed_dim) | ||
self.dropout1 = nn.Dropout(rate) | ||
self.dropout2 = nn.Dropout(rate) | ||
|
||
def forward(self, x): | ||
attn_output, _ = self.att(x, x, x) | ||
attn_output = self.dropout1(attn_output) | ||
out1 = self.layernorm1(x + attn_output) | ||
ffn_output = self.ffn(out1) | ||
ffn_output = self.dropout2(ffn_output) | ||
return self.layernorm2(out1 + ffn_output) | ||
|
||
class TransformerModel(nn.Module): | ||
def __init__(self, vocab_size, embedding_dim, max_length): | ||
super(TransformerModel, self).__init__() | ||
self.embedding = nn.Embedding(vocab_size, embedding_dim, max_length) | ||
self.transformer_encoder = TransformerEncoder(embed_dim=embedding_dim, num_heads=8, ff_dim=512) | ||
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) | ||
self.fc = nn.Linear(embedding_dim, vocab_size) | ||
self.softmax = nn.Softmax(dim=-1) | ||
|
||
def forward(self, x): | ||
x = self.embedding(x) | ||
x = x.permute(1, 0, 2) # (batch_size, seq_len, embed_dim) -> (seq_len, batch_size, embed_dim) | ||
x = self.transformer_encoder(x) | ||
x = x.permute(1, 2, 0) # (seq_len, batch_size, embed_dim) -> (batch_size, embed_dim, seq_len) | ||
x = self.global_avg_pool(x).squeeze(-1) | ||
x = self.fc(x) | ||
return self.softmax(x) | ||
|
||
class Tokenizer: | ||
def __init__(self): | ||
self.word_index = {} | ||
self.index_word = {} | ||
self.num_words = 0 | ||
|
||
def fit_on_texts(self, texts): | ||
word_freq = {} | ||
for text in texts: | ||
words = text.split() | ||
for word in words: | ||
if word not in word_freq: | ||
word_freq[word] = 1 | ||
else: | ||
word_freq[word] += 1 | ||
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True) | ||
self.word_index = {word: idx + 1 for idx, (word, _) in enumerate(sorted_words)} | ||
self.index_word = {idx: word for word, idx in self.word_index.items()} | ||
self.num_words = len(self.word_index) + 1 | ||
|
||
def texts_to_sequences(self, texts): | ||
sequences = [] | ||
for text in texts: | ||
sequences.append([self.word_index.get(word, 0) for word in text.split()]) | ||
return sequences | ||
|
||
def pad_sequences(sequences, maxlen, padding='post'): | ||
padded_sequences = np.zeros((len(sequences), maxlen), dtype=int) | ||
for i, seq in enumerate(sequences): | ||
if len(seq) > maxlen: | ||
padded_sequences[i] = seq[:maxlen] | ||
else: | ||
if padding == 'post': | ||
padded_sequences[i, :len(seq)] = seq | ||
elif padding == 'pre': | ||
padded_sequences[i, -len(seq):] = seq | ||
return padded_sequences | ||
|
||
def generate_text(model, tokenizer, seed_text, max_length, num_words, device, temperature=1.0, top_p=0.9): | ||
model.eval() | ||
seed_sequence = tokenizer.texts_to_sequences([seed_text])[0] | ||
generated_text = seed_text | ||
|
||
for _ in range(num_words): | ||
padded_sequence = pad_sequences([seed_sequence], maxlen=max_length, padding='post') | ||
padded_sequence = torch.tensor(padded_sequence, dtype=torch.long).to(device) | ||
with torch.no_grad(): | ||
predicted_probs = model(padded_sequence).cpu().numpy()[0] | ||
|
||
# Apply temperature | ||
predicted_probs = np.log(predicted_probs + 1e-9) / temperature | ||
predicted_probs = np.exp(predicted_probs) / np.sum(np.exp(predicted_probs)) | ||
|
||
# Top-p (nucleus) sampling | ||
sorted_indices = np.argsort(predicted_probs)[::-1] | ||
cumulative_probs = np.cumsum(predicted_probs[sorted_indices]) | ||
top_p_indices = sorted_indices[cumulative_probs <= top_p] | ||
if len(top_p_indices) == 0: | ||
top_p_indices = sorted_indices[:1] | ||
top_p_probs = predicted_probs[top_p_indices] | ||
top_p_probs = top_p_probs / np.sum(top_p_probs) | ||
predicted_word_index = np.random.choice(top_p_indices, p=top_p_probs) | ||
|
||
predicted_word = tokenizer.index_word.get(predicted_word_index, '') | ||
|
||
if predicted_word == '': | ||
break | ||
|
||
seed_sequence.append(predicted_word_index) | ||
seed_sequence = seed_sequence[1:] | ||
generated_text += ' ' + predicted_word | ||
|
||
return generated_text | ||
|
||
# Load the tokenizer | ||
with open('tokenizer.json', 'r') as f: | ||
word_index = json.load(f) | ||
tokenizer = Tokenizer() | ||
tokenizer.word_index = word_index | ||
tokenizer.index_word = {v: k for k, v in word_index.items()} | ||
|
||
# Model parameters | ||
vocab_size = len(tokenizer.word_index) + 1 | ||
embedding_dim = 96 | ||
max_length = 100 | ||
|
||
# Load the model | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
model = TransformerModel(vocab_size, embedding_dim, max_length).to(device) | ||
model.load_state_dict(torch.load('VISH.pth', map_location=device)) | ||
model.eval() | ||
|
||
app = FastAPI() | ||
|
||
# Enable CORS | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], # Allows all origins | ||
allow_credentials=True, | ||
allow_methods=["*"], # Allows all methods | ||
allow_headers=["*"], # Allows all headers | ||
) | ||
|
||
class TextGenerationRequest(BaseModel): | ||
seed_text: str | ||
num_words: int | ||
temperature: float | ||
top_p: float | ||
|
||
@app.post("/generate_text/") | ||
def generate_text_endpoint(request: TextGenerationRequest): | ||
try: | ||
generated_text = generate_text(model, tokenizer, request.seed_text, max_length, request.num_words, device, request.temperature, request.top_p) | ||
return {"generated_text": generated_text} | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail=str(e)) | ||
|
||
if __name__ == "__main__": | ||
import uvicorn | ||
uvicorn.run(app, host="0.0.0.0", port=8000) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"version": 2, | ||
"builds": [ | ||
{ | ||
"src": "main.py", | ||
"use": "@vercel/python" | ||
} | ||
], | ||
"routes": [ | ||
{ | ||
"src": "/(.*)", | ||
"dest": "main.py" | ||
} | ||
] | ||
} |
Oops, something went wrong.