diff --git a/packages/mermaid/src/diagrams/common/common.spec.ts b/packages/mermaid/src/diagrams/common/common.spec.ts index 9a78482f6c..4dac5b33c1 100644 --- a/packages/mermaid/src/diagrams/common/common.spec.ts +++ b/packages/mermaid/src/diagrams/common/common.spec.ts @@ -1,4 +1,4 @@ -import { sanitizeText, removeScript, parseGenericTypes } from './common.js'; +import { sanitizeText, removeScript, parseGenericTypes, countOccurrence } from './common.js'; describe('when securityLevel is antiscript, all script must be removed', () => { /** @@ -59,15 +59,29 @@ describe('Sanitize text', () => { }); describe('generic parser', () => { - it('should parse generic types', () => { - expect(parseGenericTypes('test~T~')).toEqual('test'); - expect(parseGenericTypes('test~Array~Array~string~~~')).toEqual('test>>'); - expect(parseGenericTypes('test~Array~Array~string[]~~~')).toEqual( - 'test>>' - ); - expect(parseGenericTypes('test ~Array~Array~string[]~~~')).toEqual( - 'test >>' - ); - expect(parseGenericTypes('~test')).toEqual('~test'); + it.each([ + ['test~T~', 'test'], + ['test~Array~Array~string~~~', 'test>>'], + ['test~Array~Array~string[]~~~', 'test>>'], + ['test ~Array~Array~string[]~~~', 'test >>'], + ['~test', '~test'], + ['~test~T~', '~test'], + ])('should parse generic types: %s to %s', (input: string, expected: string) => { + expect(parseGenericTypes(input)).toEqual(expected); }); }); + +it.each([ + ['', '', 0], + ['', 'x', 0], + ['test', 'x', 0], + ['test', 't', 2], + ['test', 'te', 1], + ['test~T~', '~', 2], + ['test~Array~Array~string~~~', '~', 6], +])( + 'should count `%s` to contain occurrences of `%s` to be `%i`', + (str: string, substring: string, count: number) => { + expect(countOccurrence(str, substring)).toEqual(count); + } +); diff --git a/packages/mermaid/src/diagrams/common/common.ts b/packages/mermaid/src/diagrams/common/common.ts index bb9c6b649d..e0ca2929db 100644 --- a/packages/mermaid/src/diagrams/common/common.ts +++ b/packages/mermaid/src/diagrams/common/common.ts @@ -208,21 +208,33 @@ export const parseGenericTypes = function (input: string): string { return output.join(''); }; +export const countOccurrence = (string: string, substring: string): number => { + return Math.max(0, string.split(substring).length - 1); +}; + const shouldCombineSets = (previousSet: string, nextSet: string): boolean => { - const prevCount = [...previousSet].reduce((count, char) => (char === '~' ? count + 1 : count), 0); - const nextCount = [...nextSet].reduce((count, char) => (char === '~' ? count + 1 : count), 0); + const prevCount = countOccurrence(previousSet, '~'); + const nextCount = countOccurrence(nextSet, '~'); return prevCount === 1 && nextCount === 1; }; const processSet = (input: string): string => { - const chars = [...input]; - const tildeCount = chars.reduce((count, char) => (char === '~' ? count + 1 : count), 0); + const tildeCount = countOccurrence(input, '~'); + let hasStartingTilde = false; if (tildeCount <= 1) { return input; } + // If there is an odd number of tildes, and the input starts with a tilde, we need to remove it and add it back in later + if (tildeCount % 2 !== 0 && input.startsWith('~')) { + input = input.substring(1); + hasStartingTilde = true; + } + + const chars = [...input]; + let first = chars.indexOf('~'); let last = chars.lastIndexOf('~'); @@ -234,6 +246,11 @@ const processSet = (input: string): string => { last = chars.lastIndexOf('~'); } + // Add the starting tilde back in if we removed it + if (hasStartingTilde) { + chars.unshift('~'); + } + return chars.join(''); };