Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 73 additions & 63 deletions invokeai/frontend/web/src/common/util/promptAST.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,72 +7,76 @@ describe('promptAST', () => {
it('should tokenize basic text', () => {
const tokens = tokenize('a cat');
expect(tokens).toEqual([
{ type: 'word', value: 'a' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', value: 'cat' },
{ type: 'word', value: 'a', start: 0, end: 1 },
{ type: 'whitespace', value: ' ', start: 1, end: 2 },
{ type: 'word', value: 'cat', start: 2, end: 5 },
]);
});

it('should tokenize groups with parentheses', () => {
const tokens = tokenize('(a cat)');
expect(tokens).toEqual([
{ type: 'lparen' },
{ type: 'word', value: 'a' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', value: 'cat' },
{ type: 'rparen' },
{ type: 'lparen', start: 0, end: 1 },
{ type: 'word', value: 'a', start: 1, end: 2 },
{ type: 'whitespace', value: ' ', start: 2, end: 3 },
{ type: 'word', value: 'cat', start: 3, end: 6 },
{ type: 'rparen', start: 6, end: 7 },
]);
});

it('should tokenize escaped parentheses', () => {
const tokens = tokenize('\\(medium\\)');
expect(tokens).toEqual([
{ type: 'escaped_paren', value: '(' },
{ type: 'word', value: 'medium' },
{ type: 'escaped_paren', value: ')' },
{ type: 'escaped_paren', value: '(', start: 0, end: 2 },
{ type: 'word', value: 'medium', start: 2, end: 8 },
{ type: 'escaped_paren', value: ')', start: 8, end: 10 },
]);
});

it('should tokenize mixed escaped and unescaped parentheses', () => {
const tokens = tokenize('colored pencil \\(medium\\) (enhanced)');
expect(tokens).toEqual([
{ type: 'word', value: 'colored' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', value: 'pencil' },
{ type: 'whitespace', value: ' ' },
{ type: 'escaped_paren', value: '(' },
{ type: 'word', value: 'medium' },
{ type: 'escaped_paren', value: ')' },
{ type: 'whitespace', value: ' ' },
{ type: 'lparen' },
{ type: 'word', value: 'enhanced' },
{ type: 'rparen' },
{ type: 'word', value: 'colored', start: 0, end: 7 },
{ type: 'whitespace', value: ' ', start: 7, end: 8 },
{ type: 'word', value: 'pencil', start: 8, end: 14 },
{ type: 'whitespace', value: ' ', start: 14, end: 15 },
{ type: 'escaped_paren', value: '(', start: 15, end: 17 },
{ type: 'word', value: 'medium', start: 17, end: 23 },
{ type: 'escaped_paren', value: ')', start: 23, end: 25 },
{ type: 'whitespace', value: ' ', start: 25, end: 26 },
{ type: 'lparen', start: 26, end: 27 },
{ type: 'word', value: 'enhanced', start: 27, end: 35 },
{ type: 'rparen', start: 35, end: 36 },
]);
});

it('should tokenize groups with weights', () => {
const tokens = tokenize('(a cat)1.2');
expect(tokens).toEqual([
{ type: 'lparen' },
{ type: 'word', value: 'a' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', value: 'cat' },
{ type: 'rparen' },
{ type: 'weight', value: 1.2 },
{ type: 'lparen', start: 0, end: 1 },
{ type: 'word', value: 'a', start: 1, end: 2 },
{ type: 'whitespace', value: ' ', start: 2, end: 3 },
{ type: 'word', value: 'cat', start: 3, end: 6 },
{ type: 'rparen', start: 6, end: 7 },
{ type: 'weight', value: 1.2, start: 7, end: 10 },
]);
});

it('should tokenize words with weights', () => {
const tokens = tokenize('cat+');
expect(tokens).toEqual([
{ type: 'word', value: 'cat' },
{ type: 'weight', value: '+' },
{ type: 'word', value: 'cat', start: 0, end: 3 },
{ type: 'weight', value: '+', start: 3, end: 4 },
]);
});

it('should tokenize embeddings', () => {
const tokens = tokenize('<embedding_name>');
expect(tokens).toEqual([{ type: 'lembed' }, { type: 'word', value: 'embedding_name' }, { type: 'rembed' }]);
expect(tokens).toEqual([
{ type: 'lembed', start: 0, end: 1 },
{ type: 'word', value: 'embedding_name', start: 1, end: 15 },
{ type: 'rembed', start: 15, end: 16 },
]);
});
});

Expand All @@ -81,9 +85,9 @@ describe('promptAST', () => {
const tokens = tokenize('a cat');
const ast = parseTokens(tokens);
expect(ast).toEqual([
{ type: 'word', text: 'a' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', text: 'cat' },
{ type: 'word', text: 'a', range: { start: 0, end: 1 }, attention: undefined },
{ type: 'whitespace', value: ' ', range: { start: 1, end: 2 } },
{ type: 'word', text: 'cat', range: { start: 2, end: 5 }, attention: undefined },
]);
});

Expand All @@ -93,10 +97,12 @@ describe('promptAST', () => {
expect(ast).toEqual([
{
type: 'group',
range: { start: 0, end: 7 },
attention: undefined,
children: [
{ type: 'word', text: 'a' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', text: 'cat' },
{ type: 'word', text: 'a', range: { start: 1, end: 2 }, attention: undefined },
{ type: 'whitespace', value: ' ', range: { start: 2, end: 3 } },
{ type: 'word', text: 'cat', range: { start: 3, end: 6 }, attention: undefined },
],
},
]);
Expand All @@ -106,27 +112,29 @@ describe('promptAST', () => {
const tokens = tokenize('\\(medium\\)');
const ast = parseTokens(tokens);
expect(ast).toEqual([
{ type: 'escaped_paren', value: '(' },
{ type: 'word', text: 'medium' },
{ type: 'escaped_paren', value: ')' },
{ type: 'escaped_paren', value: '(', range: { start: 0, end: 2 } },
{ type: 'word', text: 'medium', range: { start: 2, end: 8 }, attention: undefined },
{ type: 'escaped_paren', value: ')', range: { start: 8, end: 10 } },
]);
});

it('should parse mixed escaped and unescaped parentheses', () => {
const tokens = tokenize('colored pencil \\(medium\\) (enhanced)');
const ast = parseTokens(tokens);
expect(ast).toEqual([
{ type: 'word', text: 'colored' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', text: 'pencil' },
{ type: 'whitespace', value: ' ' },
{ type: 'escaped_paren', value: '(' },
{ type: 'word', text: 'medium' },
{ type: 'escaped_paren', value: ')' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', text: 'colored', range: { start: 0, end: 7 }, attention: undefined },
{ type: 'whitespace', value: ' ', range: { start: 7, end: 8 } },
{ type: 'word', text: 'pencil', range: { start: 8, end: 14 }, attention: undefined },
{ type: 'whitespace', value: ' ', range: { start: 14, end: 15 } },
{ type: 'escaped_paren', value: '(', range: { start: 15, end: 17 } },
{ type: 'word', text: 'medium', range: { start: 17, end: 23 }, attention: undefined },
{ type: 'escaped_paren', value: ')', range: { start: 23, end: 25 } },
{ type: 'whitespace', value: ' ', range: { start: 25, end: 26 } },
{
type: 'group',
children: [{ type: 'word', text: 'enhanced' }],
range: { start: 26, end: 36 },
attention: undefined,
children: [{ type: 'word', text: 'enhanced', range: { start: 27, end: 35 }, attention: undefined }],
},
]);
});
Expand All @@ -138,10 +146,11 @@ describe('promptAST', () => {
{
type: 'group',
attention: 1.2,
range: { start: 0, end: 10 },
children: [
{ type: 'word', text: 'a' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', text: 'cat' },
{ type: 'word', text: 'a', range: { start: 1, end: 2 }, attention: undefined },
{ type: 'whitespace', value: ' ', range: { start: 2, end: 3 } },
{ type: 'word', text: 'cat', range: { start: 3, end: 6 }, attention: undefined },
],
},
]);
Expand All @@ -150,13 +159,13 @@ describe('promptAST', () => {
it('should parse words with attention', () => {
const tokens = tokenize('cat+');
const ast = parseTokens(tokens);
expect(ast).toEqual([{ type: 'word', text: 'cat', attention: '+' }]);
expect(ast).toEqual([{ type: 'word', text: 'cat', attention: '+', range: { start: 0, end: 4 } }]);
});

it('should parse embeddings', () => {
const tokens = tokenize('<embedding_name>');
const ast = parseTokens(tokens);
expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name' }]);
expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name', range: { start: 0, end: 16 } }]);
});
});

Expand Down Expand Up @@ -243,19 +252,20 @@ describe('promptAST', () => {

// Should have escaped parens as nodes and a group with attention
expect(ast).toEqual([
{ type: 'word', text: 'portrait' },
{ type: 'whitespace', value: ' ' },
{ type: 'escaped_paren', value: '(' },
{ type: 'word', text: 'realistic' },
{ type: 'escaped_paren', value: ')' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', text: 'portrait', range: { start: 0, end: 8 }, attention: undefined },
{ type: 'whitespace', value: ' ', range: { start: 8, end: 9 } },
{ type: 'escaped_paren', value: '(', range: { start: 9, end: 11 } },
{ type: 'word', text: 'realistic', range: { start: 11, end: 20 }, attention: undefined },
{ type: 'escaped_paren', value: ')', range: { start: 20, end: 22 } },
{ type: 'whitespace', value: ' ', range: { start: 22, end: 23 } },
{
type: 'group',
attention: 1.2,
range: { start: 23, end: 40 },
children: [
{ type: 'word', text: 'high' },
{ type: 'whitespace', value: ' ' },
{ type: 'word', text: 'quality' },
{ type: 'word', text: 'high', range: { start: 24, end: 28 }, attention: undefined },
{ type: 'whitespace', value: ' ', range: { start: 28, end: 29 } },
{ type: 'word', text: 'quality', range: { start: 29, end: 36 }, attention: undefined },
],
},
]);
Expand Down
Loading