Skip to content

Commit

Permalink
feat: refactor and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthieu-OD committed Nov 21, 2024
1 parent 3dab603 commit f1902e7
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 69 deletions.
75 changes: 9 additions & 66 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { ReadStream } from 'fs';
import { v4 as uuidv4 } from 'uuid';

import { LiteralClient } from '.';
import { PromptCacheManager } from './cache/prompt-cache-manager';
import { sharedCache } from './cache/sharedcache';
import {
Dataset,
DatasetExperiment,
Expand Down Expand Up @@ -325,63 +327,6 @@ type CreateAttachmentParams = {
metadata?: Maybe<Record<string, any>>;
};

export class SharedCache {
private static instance: SharedCache | null = null;
private cache: Map<string, any>;

private constructor() {
this.cache = new Map();
}

static getInstance(): SharedCache {
if (!SharedCache.instance) {
SharedCache.instance = new SharedCache();
}
return SharedCache.instance;
}

public getPromptCacheKey(
id?: string,
name?: string,
version?: number
): string {
if (id) {
return id;
} else if (name && (version || version === 0)) {
return `${name}:${version}`;
} else if (name) {
return name;
}
throw new Error('Either id or name must be provided');
}

public getPrompt(key: string): Prompt {
return this.get(key);
}

public putPrompt(prompt: Prompt): void {
this.put(prompt.id, prompt);
this.put(prompt.name, prompt);
this.put(`${prompt.name}:${prompt.version}`, prompt);
}

public getCache(): Map<string, any> {
return this.cache;
}

public get(key: string): any {
return this.cache.get(key);
}

public put(key: string, value: any): void {
this.cache.set(key, value);
}

public clear(): void {
this.cache.clear();
}
}

/**
* Represents the API client for interacting with the Literal service.
* This class handles API requests, authentication, and provides methods
Expand All @@ -398,7 +343,7 @@ export class SharedCache {
*/
export class API {
/** @ignore */
public cache: SharedCache;
private cache: typeof sharedCache;
/** @ignore */
public client: LiteralClient;
/** @ignore */
Expand Down Expand Up @@ -431,7 +376,7 @@ export class API {
throw new Error('LITERAL_API_URL not set');
}

this.cache = SharedCache.getInstance();
this.cache = sharedCache;

this.apiKey = apiKey;
this.url = url;
Expand Down Expand Up @@ -2199,7 +2144,7 @@ export class API {
}
`;

return this.getPromptWithQuery(query, { id });
return await this.getPromptWithQuery(query, { id });
}

/**
Expand All @@ -2210,8 +2155,8 @@ export class API {
variables: Record<string, any>
) {
const { id, name, version } = variables;
const cachedPrompt = this.cache.getPrompt(
this.cache.getPromptCacheKey(id, name, version)
const cachedPrompt = sharedCache.get(
PromptCacheManager.getPromptCacheKey({ id, name, version })
);
const timeout = cachedPrompt ? 1000 : undefined;

Expand All @@ -2231,11 +2176,9 @@ export class API {
}

const prompt = new Prompt(this, promptData);
this.cache.putPrompt(prompt);
PromptCacheManager.putPrompt(prompt);
return prompt;
} catch (error) {
console.log('key: ', this.cache.getPromptCacheKey(id, name, version));
console.log('cachedPrompt: ', cachedPrompt);
return cachedPrompt;
}
}
Expand Down Expand Up @@ -2268,7 +2211,7 @@ export class API {
}
}
`;
return this.getPromptWithQuery(query, { name, version });
return await this.getPromptWithQuery(query, { name, version });
}

/**
Expand Down
29 changes: 29 additions & 0 deletions src/cache/prompt-cache-manager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { Prompt } from '../prompt-engineering/prompt';
import { sharedCache } from './sharedcache';

export class PromptCacheManager {
public static getPromptCacheKey({
id,
name,
version
}: {
id?: string;
name?: string;
version?: number;
}): string {
if (id) {
return id;
} else if (name && typeof version === 'number') {
return `${name}:${version}`;
} else if (name) {
return name;
}
throw new Error('Either id or name must be provided');
}

public static putPrompt(prompt: Prompt): void {
sharedCache.put(prompt.id, prompt);
sharedCache.put(prompt.name, prompt);
sharedCache.put(`${prompt.name}:${prompt.version}`, prompt);
}
}
36 changes: 36 additions & 0 deletions src/cache/sharedcache.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
const cache: Map<string, any> = new Map();

class SharedCache {
private static instance: SharedCache;

public constructor() {
if (SharedCache.instance) {
throw new Error('SharedCache can only be created once');
}
SharedCache.instance = this;
}

public getInstance(): SharedCache {
return this;
}

public getCache(): Map<string, any> {
return cache;
}

public get(key: string): any {
return cache.get(key);
}

public put(key: string, value: any): void {
cache.set(key, value);
}

public clear(): void {
cache.clear();
}
}

export const sharedCache = new SharedCache();

export default sharedCache;
8 changes: 5 additions & 3 deletions tests/api.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import 'dotenv/config';
import { v4 as uuidv4 } from 'uuid';

import { ChatGeneration, IGenerationMessage, LiteralClient } from '../src';
import { PromptCacheManager } from '../src/cache/prompt-cache-manager';
import { sharedCache } from '../src/cache/sharedcache';
import { Dataset } from '../src/evaluation/dataset';
import { Score } from '../src/evaluation/score';
import { Prompt, PromptConstructor } from '../src/prompt-engineering/prompt';
Expand Down Expand Up @@ -685,7 +687,7 @@ is a templated list.`;

it('should fallback to cache when getPromptById DB call fails', async () => {
const prompt = new Prompt(client.api, mockPromptData);
client.api.cache.putPrompt(prompt);
PromptCacheManager.putPrompt(prompt);

jest
.spyOn(client.api as any, 'makeGqlCall')
Expand All @@ -697,7 +699,7 @@ is a templated list.`;

it('should fallback to cache when getPrompt DB call fails', async () => {
const prompt = new Prompt(client.api, mockPromptData);
client.api.cache.putPrompt(prompt);
PromptCacheManager.putPrompt(prompt);
jest.spyOn(axios, 'post').mockRejectedValueOnce(new Error('DB Error'));

const result = await client.api.getPrompt(prompt.id);
Expand All @@ -713,7 +715,7 @@ is a templated list.`;

await client.api.getPromptById(prompt.id);

const cachedPrompt = await client.api.cache.get(prompt.id);
const cachedPrompt = sharedCache.get(prompt.id);
expect(cachedPrompt).toBeDefined();
expect(cachedPrompt?.id).toBe(prompt.id);
});
Expand Down
121 changes: 121 additions & 0 deletions tests/cache.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import { API } from '../src/api';
import { PromptCacheManager } from '../src/cache/prompt-cache-manager';
import { sharedCache } from '../src/cache/sharedcache';
import { Prompt, PromptConstructor } from '../src/prompt-engineering/prompt';

describe('Cache', () => {
let api: API;
let mockPrompt: Prompt;

beforeAll(() => {
api = {} as API;
});

beforeEach(() => {
sharedCache.clear();

const mockPromptData: PromptConstructor = {
id: 'test-id',
type: 'CHAT',
createdAt: '2023-01-01T00:00:00Z',
name: 'test-name',
version: 1,
metadata: {},
items: [],
templateMessages: [{ role: 'user', content: 'Hello', uuid: '123' }],
provider: 'test-provider',
settings: {
provider: 'test-provider',
model: 'test-model',
frequency_penalty: 0,
max_tokens: 100,
presence_penalty: 0,
temperature: 0.7,
top_p: 1
},
variables: []
};
mockPrompt = new Prompt(api, mockPromptData);
});

describe('PromptCacheManager', () => {
describe('getPromptCacheKey', () => {
it('should return id when provided', () => {
const key = PromptCacheManager.getPromptCacheKey({
id: 'test-id',
name: 'test-name',
version: 1
});
expect(key).toBe('test-id');
});

it('should return name:version when id not provided but name and version are', () => {
const key = PromptCacheManager.getPromptCacheKey({
name: 'test-name',
version: 1
});
expect(key).toBe('test-name:1');
});

it('should return name when only name provided', () => {
const key = PromptCacheManager.getPromptCacheKey({ name: 'test-name' });
expect(key).toBe('test-name');
});

it('should throw error when neither id nor name provided', () => {
expect(() =>
PromptCacheManager.getPromptCacheKey({ version: 0 })
).toThrow('Either id or name must be provided');
});
});

describe('putPrompt', () => {
it('should store prompt with multiple keys', () => {
PromptCacheManager.putPrompt(mockPrompt);

expect(sharedCache.get('test-id')).toEqual(mockPrompt);
expect(sharedCache.get('test-name')).toEqual(mockPrompt);
expect(sharedCache.get('test-name:1')).toEqual(mockPrompt);
});
});
});

describe('SharedCache', () => {
it('should return undefined for non-existent key', () => {
const value = sharedCache.get('non-existent');
expect(value).toBeUndefined();
});

it('should store and retrieve values', () => {
sharedCache.put('test-key', 'test-value');
expect(sharedCache.get('test-key')).toBe('test-value');
});

it('should clear all values', () => {
sharedCache.put('key1', 'value1');
sharedCache.put('key2', 'value2');

sharedCache.clear();

expect(sharedCache.get('key1')).toBeUndefined();
expect(sharedCache.get('key2')).toBeUndefined();
});

it('should maintain singleton behavior', () => {
const instance1 = sharedCache;
const instance2 = sharedCache;

instance1.put('test', 'value');
expect(instance2.get('test')).toBe('value');
expect(instance1).toBe(instance2);
});

it('should expose cache map', () => {
sharedCache.put('test', 'value');
const cacheMap = sharedCache.getCache();

expect(cacheMap instanceof Map).toBe(true);
expect(cacheMap.get('test')).toBe('value');
});
});
});

0 comments on commit f1902e7

Please sign in to comment.