Skip to content

Commit

Permalink
fix(agent): response waiting time (#340)
Browse files Browse the repository at this point in the history
Upgrade agent response waiting time based on message length and random
ranges:
* Add inference time calculation
* Calculate waiting time based on inference time & generated response
length
  • Loading branch information
ramidecodes authored Feb 19, 2024
1 parent c3ec99d commit cfb6f6e
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 31 deletions.
80 changes: 51 additions & 29 deletions src/server/service/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export class Agent {
private _characterId: CharacterId;
private _characterName: CharacterName;
private _match: Match;
private _agentExtraversion: number;
private _systemPrompt: string;

private _triggeredAt = Date.now();
Expand All @@ -44,13 +45,8 @@ export class Agent {
this._match = match;
this._characterId = characterId;
this._characterName = this.getCharacterName(characterId);
this._systemPrompt = [
`Your name is ${this._characterName}.`,
"You should never say that you're a bot, nor a computer program, nor an AI assistant.",
"Addopt a random personality with its own traits and ways of talking, mantain it through out the entire conversation.",
"Always reply with short sentences that don't excede 150 characters.",
// `First decide if you would reply, otherwise say: ${this._silenceToken}.`,
].join(" ");
this._agentExtraversion = getRandomInt({ min: 4, max: 6 }); // 1-10 Threshold used to decide if agent will reply to last message
this._systemPrompt = this.generateSystemPrompt();

ee.on(matchEvent(match.id), this.handleMessageEvent);
}
Expand All @@ -71,9 +67,9 @@ export class Agent {
};

private async triggerResponse() {
// TODO: Add time based trigger if player hasen't replyed in a while
// TODO: Add time based trigger if player hasn't replyed in a while
const shouldTrigger =
getRandomInt({ min: 1, max: 10 }) < getRandomInt({ min: 4, max: 6 });
getRandomInt({ min: 1, max: 10 }) >= this._agentExtraversion;
if (!shouldTrigger) return;

this._triggeredAt = Date.now();
Expand All @@ -83,22 +79,40 @@ export class Agent {
// If inference failed or bot decided not to reply, let the agent be silent
if (!response || response.includes(this._silenceToken)) return;

const cleanResponse = this.parseResponse(response);
const cleanResponse = this.parseLLMResponse(response);
if (!cleanResponse) return; // Stay silent if something went wrong with parsing

const waitTime = this.calculateWaitingTime(cleanResponse);
await wait(waitTime);

const payload: ChatMessagePayload = {
sender: this.id,
message: cleanResponse,
sentAt: Date.now(),
};
this._match.addMessage(payload);
}

const waitTime =
this._match.messages.length === 1
? getRandomInt({ min: 10500, max: 15000 }) // First reply would be longer in response to host prompt
: getRandomInt({ min: 6500, max: 12000 }); // Otherwise replying to ongoing conversation
private calculateWaitingTime(response: string) {
const inferenceTime = Date.now() - this._triggeredAt;

await wait(waitTime);
const hostPromptOffsetTime = 6000; // Waiting for host prompt to be rendered & players to read it
const typingTime = response.length * 275; // Average typing time per character in a word is 0.3s
const minTypingTime = typingTime * 0.8;
const maxTypingTime = typingTime * 1;

this._match.addMessage(payload);
const waitTime =
this._match.messages.length === 1 // Check if it's the start of the match with one message from host
? getRandomInt({
min: minTypingTime + hostPromptOffsetTime,
max: maxTypingTime + hostPromptOffsetTime,
}) // First reply would be longer in response to host prompt
: getRandomInt({
min: minTypingTime,
max: maxTypingTime,
}); // Otherwise reply to ongoing conversation

return waitTime - inferenceTime;
}

private async requestMessageFromLLM() {
Expand All @@ -121,15 +135,14 @@ export class Agent {
return promptMessage;
});

// TODO Limit amount of messages sent for inference
const prompt = this.generatePrompt(promptDialog);

const body = JSON.stringify({
inputs: prompt,
parameters: {
max_new_tokens: 58, // amount of words generated
top_p: 0.9, // higher value = more varied answers
temperature: 1, // higher value = more creative answers
top_p: 1, // 0-1 higher value = more varied words in answers
temperature: 1, // 0-1 higher value = more creative answers
},
});

Expand All @@ -147,7 +160,6 @@ export class Agent {
});

const textRes = await response.text();

if (!textRes) return this._silenceToken;

const result = JSON.parse(textRes) as { body: string };
Expand Down Expand Up @@ -186,14 +198,6 @@ export class Agent {
};
}

private getMessageRole(senderID: string): SenderRole {
return senderID === this._id ? "assistant" : "user";
}

private getCharacterName(characterId: CharacterId): CharacterName {
return CHARACTERS[characterId].name;
}

generatePrompt(messages: PromptMessage[]): string {
// First message is always from Host
const hostMessage = messages.shift();
Expand All @@ -216,11 +220,29 @@ export class Agent {
return prompt;
}

parseResponse(input: string): string {
parseLLMResponse(input: string): string {
// Removes //ufffd || </s> || *some expresion* || [INST] || (words in parenthesis)
return input.replace(
/(\ufffd|<\/s>|(\*[^*]*\*)|\[INST\]|\[\/INST\]|\[INST(?:\])?|\([^)]*\))/g,
"",
);
}

private getMessageRole(senderID: string): SenderRole {
return senderID === this._id ? "assistant" : "user";
}

private getCharacterName(characterId: CharacterId): CharacterName {
return CHARACTERS[characterId].name;
}

private generateSystemPrompt(): string {
return [
`Your name is ${this._characterName}.`,
"Never say that you're AI nor a bot.",
"Always reply with short sentences. Use quick replies to keep up with a conversatio or full sentences that don't excede 40 words.",
"Use slang language sometimes.",
// `First decide if you would reply, otherwise say: ${this._silenceToken}.`,
].join(" ");
}
}
4 changes: 2 additions & 2 deletions src/server/service/match.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ export class Match {

this._players = lodash.shuffle([...botPlayers, ...humanPlayers]);

this.addPrompt();
this.addHostPrompt();

const userIds = users.map((u) => u.id);

Expand Down Expand Up @@ -153,7 +153,7 @@ export class Match {
}
}

private addPrompt() {
private addHostPrompt() {
const randomPrompt =
matchPrompts[getRandomInt({ max: matchPrompts.length })];
if (!randomPrompt) throw new Error("No random prompt found");
Expand Down

0 comments on commit cfb6f6e

Please sign in to comment.