Skip to content

Commit 9b44ec4

Browse files
committed
ref: generation to use strategy design pattern (in-progress)
1 parent 75a67d3 commit 9b44ec4

File tree

11 files changed

+302
-36
lines changed

11 files changed

+302
-36
lines changed

backend/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"helmet": "^8.1.0",
5454
"jsonwebtoken": "^9.0.2",
5555
"stripe": "^18.1.1",
56+
"together-ai": "^0.16.0",
5657
"winston": "^3.17.0",
5758
"zod": "^3.24.4"
5859
}

backend/src/modules/ai/config/models.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,16 @@ export const SUMMARIZATION_MODEL_CONFIGS: Record<AIModel, ModelConfig> = {
9292
};
9393

9494
export const IMAGE_GENERATION_MODEL_CONFIGS: Record<AIGenerationModel, GenerationModelConfig> = {
95-
[AIGenerationModel.Gemini20FlashImageGenPreview]: {},
96-
[AIGenerationModel.TogetherFlux1SchnellFree]: {},
97-
[AIGenerationModel.TogetherFlux1Schnell]: {},
98-
[AIGenerationModel.TogetherFlux1Dev]: {}
95+
[AIGenerationModel.Gemini20FlashImageGenPreview]: {
96+
97+
},
98+
[AIGenerationModel.TogetherFlux1SchnellFree]: {
99+
100+
},
101+
[AIGenerationModel.TogetherFlux1Schnell]: {
102+
103+
},
104+
[AIGenerationModel.TogetherFlux1Dev]: {
105+
106+
}
99107
};
Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,55 @@
1-
import { GenerativeModel, GoogleGenerativeAI } from "@google/generative-ai";
2-
import { PromptTemplate } from '@langchain/core/prompts';
3-
import { AIGenerationModel, HigherTierImageGenerationRequest, LowerTierImageGenerationResponse, SummarizationRequest, SummarizationResponse } from '../types';
4-
import { IMAGE_GENERATION_USER_PROMPT } from '../config/prompts';
5-
import { IMAGE_GENERATION_MODEL_CONFIGS } from '../config/models';
61
import { AppError } from '../../../shared/errors';
7-
import { logger } from '../../../shared/utils/logger';
8-
import { env } from '../../../shared/config/environment';
9-
import { getEncoding } from 'js-tiktoken';
2+
3+
import { HigherTierImageGenerationRequest, LowerTierImageGenerationRequest, LowerTierImageGenerationResponse } from '../types';
4+
import { AIProvider, checkEnvironmentVariables } from '../utils/ai.utils';
5+
import { ImageGenerationStrategy } from '../strategy/generation.strategy';
106

117
const APP_ERROR_SOURCE = 'image.generation.service';
128

139
export class ImageGenerationService {
14-
private readonly defaultModelType: AIGenerationModel;
15-
private prompt: PromptTemplate;
10+
private strategy: ImageGenerationStrategy;
1611

17-
constructor() {
18-
if (!env.ai.geminiKey) {
19-
throw new AppError(500, 'Gemini API key not found in environment variables. Please set GEMINI_API_KEY.', 'image.generation.service');
20-
}
12+
constructor(strategy?: ImageGenerationStrategy) {
13+
checkEnvironmentVariables(APP_ERROR_SOURCE, AIProvider.GEMINI);
14+
checkEnvironmentVariables(APP_ERROR_SOURCE, AIProvider.TOGETHER);
2115

22-
this.defaultModelType = AIGenerationModel.Gemini20FlashImageGenPreview;
23-
this.prompt = PromptTemplate.fromTemplate(IMAGE_GENERATION_USER_PROMPT);
16+
if (!strategy) {
17+
throw new AppError(
18+
400,
19+
'Invalid strategy provided. Please provide a valid ImageGenerationStrategy instance.',
20+
APP_ERROR_SOURCE
21+
);
22+
}
23+
this.strategy = strategy;
2424
}
2525

26-
async generateImageSingle(request: LowerTierImageGenerationResponse | HigherTierImageGenerationRequest): Promise<LowerTierImageGenerationResponse | HigherTierImageGenerationRequest> {
27-
return [];
26+
getStrategy = (): ImageGenerationStrategy => {
27+
return this.strategy;
2828
}
2929

30-
async generateImageMultiple(request: LowerTierImageGenerationResponse | HigherTierImageGenerationRequest): Promise<LowerTierImageGenerationResponse | HigherTierImageGenerationRequest> {
31-
return [];
30+
setStrategy = (strategy: ImageGenerationStrategy): boolean => {
31+
if (!strategy) {
32+
throw new AppError(
33+
400,
34+
'Invalid strategy provided. Please provide a valid ImageGenerationStrategy instance.',
35+
APP_ERROR_SOURCE
36+
);
37+
}
38+
39+
this.strategy = strategy;
40+
return true;
3241
}
3342

34-
async generateImageStream(request: LowerTierImageGenerationResponse | HigherTierImageGenerationRequest): Promise<LowerTierImageGenerationResponse | HigherTierImageGenerationRequest> {
35-
return [];
36-
}
37-
}
43+
async generate(
44+
request: LowerTierImageGenerationRequest | HigherTierImageGenerationRequest
45+
): Promise<LowerTierImageGenerationResponse | HigherTierImageGenerationRequest> {
46+
if (!request) {
47+
throw new AppError(
48+
400,
49+
'Invalid request provided. Please provide a valid image generation request.',
50+
APP_ERROR_SOURCE
51+
);
52+
}
53+
return this.strategy.generate(request);
54+
}
55+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import { getEncoding, Tiktoken } from 'js-tiktoken';
2+
3+
import { GoogleGenAI } from '@google/genai';
4+
import { PromptTemplate } from '@langchain/core/prompts';
5+
import Together from 'together-ai';
6+
7+
import { AppError } from '../../../shared/errors';
8+
import { env } from '../../../shared/config/environment';
9+
import { logger } from '../../../shared/utils/logger';
10+
11+
import { AIGenerationModel, HigherTierImageGenerationRequest, LowerTierImageGenerationResponse } from '../types';
12+
import { IMAGE_GENERATION_USER_PROMPT } from '../config/prompts';
13+
import { IMAGE_GENERATION_MODEL_CONFIGS } from '../config/models';
14+
import { AIProvider, checkEnvironmentVariables } from '../utils/ai.utils';
15+
import { ImageGenerationStrategy } from '../strategy/generation.strategy';
16+
17+
const APP_ERROR_SOURCE = 'image.google.generation.service';
18+
19+
export class GeminiImageGenerationService implements ImageGenerationStrategy {
20+
private readonly defaultModelType: AIGenerationModel;
21+
private prompt: PromptTemplate;
22+
private encoding: Tiktoken;
23+
24+
constructor() {
25+
checkEnvironmentVariables(APP_ERROR_SOURCE, AIProvider.GOOGLE);
26+
27+
this.defaultModelType = AIGenerationModel.Gemini20FlashImageGenPreview;
28+
this.prompt = PromptTemplate.fromTemplate(IMAGE_GENERATION_USER_PROMPT);
29+
this.encoding = getEncoding('cl100k_base');
30+
}
31+
32+
setupGoogleGenAIClient(): GoogleGenAI {
33+
const genAI = new GoogleGenAI({
34+
apiKey: process.env.GEMINI_API_KEY!,
35+
vertexai: process.env.GOOGLE_USE_VERTEX_AI === 'true'
36+
});
37+
38+
if (!genAI) {
39+
throw new AppError(
40+
500,
41+
'Failed to initialize Google GenAI client. Please check your environment variables.',
42+
APP_ERROR_SOURCE
43+
);
44+
}
45+
46+
return genAI;
47+
}
48+
49+
async generate(
50+
request: LowerTierImageGenerationResponse | HigherTierImageGenerationRequest
51+
): Promise<LowerTierImageGenerationResponse | HigherTierImageGenerationRequest> {
52+
const startTime = Date.now();
53+
const modelType = request.options?.model || this.defaultModelType;
54+
55+
try {
56+
const genAI = this.setupGoogleGenAIClient();
57+
const formattedPrompt = await this.prompt.format({
58+
text: request.text,
59+
});
60+
61+
const result = await genAI.models.generateImages({
62+
model: this.defaultModelType,
63+
prompt: formattedPrompt,
64+
config: {
65+
...IMAGE_GENERATION_MODEL_CONFIGS[modelType],
66+
}
67+
});
68+
69+
const imageUrl = result?.generatedImages?.[0]?.image?.imageBytes;
70+
if (!imageUrl) {
71+
throw new AppError(
72+
500,
73+
'Image generation was unsuccessful, there was no image returned for the request.',
74+
APP_ERROR_SOURCE
75+
);
76+
}
77+
78+
const endTime = Date.now();
79+
const tokenCount = this.encoding.encode(formattedPrompt).length;
80+
const response: LowerTierImageGenerationResponse = {};
81+
82+
// TODO: Finish implementation for generate method in GeminiImageGenerationService
83+
84+
} catch (error) {
85+
if (error instanceof AppError) {
86+
throw error;
87+
} else {
88+
logger.error(
89+
APP_ERROR_SOURCE,
90+
'An unexpected error occurred during image generation.',
91+
error
92+
);
93+
throw new AppError(
94+
500,
95+
'An unexpected error occurred during image generation. Please try again later.',
96+
APP_ERROR_SOURCE
97+
);
98+
}
99+
}
100+
}
101+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export * from './google.service';
2+
export * from './together.service';
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import { getEncoding, Tiktoken } from 'js-tiktoken';
2+
3+
import { GoogleGenAI } from '@google/genai';
4+
import { PromptTemplate } from '@langchain/core/prompts';
5+
import Together from 'together-ai';
6+
7+
import { AppError } from '../../../shared/errors';
8+
import { env } from '../../../shared/config/environment';
9+
import { logger } from '../../../shared/utils/logger';
10+
11+
import { AIGenerationModel, HigherTierImageGenerationRequest, LowerTierImageGenerationResponse } from '../types';
12+
import { IMAGE_GENERATION_USER_PROMPT } from '../config/prompts';
13+
import { IMAGE_GENERATION_MODEL_CONFIGS } from '../config/models';
14+
import { AIProvider, checkEnvironmentVariables } from '../utils/ai.utils';
15+
import { ImageGenerationStrategy } from '../strategy/generation.strategy';
16+
import { LowerTierImageGenerationRequest } from '../../types';
17+
18+
const APP_ERROR_SOURCE = 'image.together.generation.service';
19+
20+
export class TogetherImageGenerationService implements ImageGenerationStrategy {
21+
private readonly defaultModelType: AIGenerationModel;
22+
private prompt: PromptTemplate;
23+
private encoding: Tiktoken;
24+
25+
constructor() {
26+
checkEnvironmentVariables(APP_ERROR_SOURCE, AIProvider.TOGETHER);
27+
28+
this.defaultModelType = AIGenerationModel.TogetherVQGAN;
29+
this.prompt = PromptTemplate.fromTemplate(IMAGE_GENERATION_USER_PROMPT);
30+
this.encoding = getEncoding('cl100k_base');
31+
}
32+
33+
setupTogetherAIClient(): Together {
34+
const togetherAI = new Together({
35+
apiKey: env.ai.togetherKey,
36+
baseURL: env.ai.togetherBaseUrl,
37+
});
38+
39+
if (!togetherAI) {
40+
throw new AppError(
41+
500,
42+
'Failed to initialize Together AI client. Please check your environment variables.',
43+
APP_ERROR_SOURCE
44+
);
45+
}
46+
47+
return togetherAI;
48+
}
49+
50+
async generate(
51+
request: LowerTierImageGenerationRequest | HigherTierImageGenerationRequest
52+
): Promise<LowerTierImageGenerationResponse | HigherTierImageGenerationRequest> {
53+
const startTime = Date.now();
54+
const modelType = request.options?.model || this.defaultModelType;
55+
56+
try {
57+
/**
58+
* TODO: Implement Together AI Image Generation Logc
59+
*/
60+
return {};
61+
62+
} catch (error) {
63+
logger.error(`Error generating image using model ${modelType}`, { error });
64+
if (error instanceof AppError) {
65+
throw error;
66+
}
67+
68+
throw new AppError(
69+
500,
70+
'Failed to generate image. Please try again later.',
71+
APP_ERROR_SOURCE
72+
);
73+
}
74+
}
75+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
export * from './summarization.service';
2+
export * from './generation.service';
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import {
2+
HigherTierImageGenerationRequest,
3+
HigherTierImageGenerationResponse,
4+
LowerTierImageGenerationRequest,
5+
LowerTierImageGenerationResponse
6+
} from "../types";
7+
8+
9+
export interface ImageGenerationStrategy {
10+
generate(
11+
request: LowerTierImageGenerationRequest | HigherTierImageGenerationRequest
12+
): Promise<LowerTierImageGenerationResponse | HigherTierImageGenerationResponse>;
13+
}

backend/src/modules/ai/types/generation.types.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,18 @@ export const HigherTierImageGenerationRequestSchema = z.object({
3333
});
3434

3535
// NOTE: Lower tier image generation using Imagen 3/Gemini 2.0 Flash (Image Generation) models
36+
// TODO: Implement interface LowerTierImageGenerationRequest
3637
export interface LowerTierImageGenerationRequest {
37-
38+
prompt: string;
3839
}
3940

41+
// TODO: Implement interface LowerTierImageGenerationResponse
4042
export interface LowerTierImageGenerationResponse {
41-
43+
prompt: string;
4244
}
4345

4446
// NOTE: Higher tier image generation using FLUX.1-dev/FLUX.1-schnell-free models
47+
// TODO: Refactor implementation HigherTierImageGenerationRequest
4548
export interface HigherTierImageGenerationRequest {
4649
prompt: string;
4750
options?: {
@@ -53,6 +56,7 @@ export interface HigherTierImageGenerationRequest {
5356
};
5457
}
5558

59+
// TODO: Refactor implementation HigherTierImageGenerationResponse
5660
export interface HigherTierImageGenerationResponse {
5761
imageUrl: string;
5862
metadata: {

0 commit comments

Comments
 (0)