diff --git a/application/src/main/java/org/togetherjava/tjbot/commands/mathcommands/TeXCommand.java b/application/src/main/java/org/togetherjava/tjbot/commands/mathcommands/TeXCommand.java index 1f5be8d3ea..00d1037c57 100644 --- a/application/src/main/java/org/togetherjava/tjbot/commands/mathcommands/TeXCommand.java +++ b/application/src/main/java/org/togetherjava/tjbot/commands/mathcommands/TeXCommand.java @@ -2,8 +2,10 @@ import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent; import net.dv8tion.jda.api.events.interaction.component.ButtonInteractionEvent; +import net.dv8tion.jda.api.interactions.callbacks.IDeferrableCallback; import net.dv8tion.jda.api.interactions.commands.OptionType; import net.dv8tion.jda.api.interactions.components.buttons.Button; +import net.dv8tion.jda.api.interactions.components.buttons.ButtonStyle; import org.jetbrains.annotations.NotNull; import org.scilab.forge.jlatexmath.ParseException; import org.scilab.forge.jlatexmath.TeXConstants; @@ -21,6 +23,8 @@ import java.io.IOException; import java.util.List; import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; /** * Implementation of a tex command which takes a string and renders an image corresponding to the @@ -31,11 +35,18 @@ * message. */ -public class TeXCommand extends SlashCommandAdapter { - - private static final String LATEX_OPTION = "latex"; +public final class TeXCommand extends SlashCommandAdapter { + static final String LATEX_OPTION = "latex"; + // Matches regions between two dollars, like '$foo$'. + private static final String MATH_REGION = "(\\$[^$]+\\$)"; + private static final String TEXT_REGION = "([^$]+)"; + private static final Pattern INLINE_LATEX_REPLACEMENT = + Pattern.compile(MATH_REGION + "|" + TEXT_REGION); private static final String RENDERING_ERROR = "There was an error generating the image"; - private static final float DEFAULT_IMAGE_SIZE = 40F; + static final String BAD_LATEX_ERROR_PREFIX = "That is an invalid latex: "; + static final String INVALID_INLINE_FORMAT_ERROR_MESSAGE = + "The amount of $-symbols must be divisible by two. Did you forget to close an expression?"; + private static final float DEFAULT_IMAGE_SIZE = 40.0F; private static final Color BACKGROUND_COLOR = Color.decode("#36393F"); private static final Color FOREGROUND_COLOR = Color.decode("#FFFFFF"); private static final Logger logger = LoggerFactory.getLogger(TeXCommand.class); @@ -44,8 +55,7 @@ public class TeXCommand extends SlashCommandAdapter { * Creates a new Instance. */ public TeXCommand() { - super("tex", - "This command accepts a latex expression and generates an image corresponding to it.", + super("tex", "Renders LaTeX, also supports inline $-regions like 'see this $\frac{x}{2}$'.", SlashCommandVisibility.GUILD); getData().addOption(OptionType.STRING, LATEX_OPTION, "The latex which is rendered as an image", true); @@ -56,42 +66,102 @@ public void onSlashCommand(@NotNull final SlashCommandInteractionEvent event) { String latex = Objects.requireNonNull(event.getOption(LATEX_OPTION)).getAsString(); String userID = (Objects.requireNonNull(event.getMember()).getId()); TeXFormula formula; + try { + if (latex.contains("$")) { + latex = convertInlineLatexToFull(latex); + } formula = new TeXFormula(latex); } catch (ParseException e) { - event.reply("That is an invalid latex: " + e.getMessage()).setEphemeral(true).queue(); + event.reply(BAD_LATEX_ERROR_PREFIX + e.getMessage()).setEphemeral(true).queue(); return; } + event.deferReply().queue(); - Image image = formula.createBufferedImage(TeXConstants.STYLE_DISPLAY, DEFAULT_IMAGE_SIZE, - FOREGROUND_COLOR, BACKGROUND_COLOR); - if (image.getWidth(null) == -1 || image.getHeight(null) == -1) { - event.getHook().setEphemeral(true).editOriginal(RENDERING_ERROR).queue(); - logger.warn( - "Unable to render latex, image does not have an accessible width or height. Formula was {}", - latex); - return; - } - BufferedImage renderedTextImage = new BufferedImage(image.getWidth(null), - image.getHeight(null), BufferedImage.TYPE_4BYTE_ABGR); - renderedTextImage.getGraphics().drawImage(image, 0, 0, null); - ByteArrayOutputStream renderedTextImageStream = new ByteArrayOutputStream(); try { - ImageIO.write(renderedTextImage, "png", renderedTextImageStream); + Image image = renderImage(formula); + sendImage(event, userID, image); } catch (IOException e) { - event.getHook().setEphemeral(true).editOriginal(RENDERING_ERROR).queue(); + event.getHook().editOriginal(RENDERING_ERROR).queue(); logger.warn( "Unable to render latex, could not convert the image into an attachable form. Formula was {}", latex, e); - return; + + } catch (IllegalStateException e) { + event.getHook().editOriginal(RENDERING_ERROR).queue(); + + logger.warn( + "Unable to render latex, image does not have an accessible width or height. Formula was {}", + latex, e); } + } + + private @NotNull Image renderImage(@NotNull TeXFormula formula) { + Image image = formula.createBufferedImage(TeXConstants.STYLE_DISPLAY, DEFAULT_IMAGE_SIZE, + FOREGROUND_COLOR, BACKGROUND_COLOR); + + if (image.getWidth(null) == -1 || image.getHeight(null) == -1) { + throw new IllegalStateException("Image has no height or width"); + } + return image; + } + + private void sendImage(@NotNull IDeferrableCallback event, @NotNull String userID, + @NotNull Image image) throws IOException { + ByteArrayOutputStream renderedTextImageStream = getRenderedTextImageStream(image); event.getHook() .editOriginal(renderedTextImageStream.toByteArray(), "tex.png") - .setActionRow(Button.danger(generateComponentId(userID), "Delete")) + .setActionRow(Button.of(ButtonStyle.DANGER, generateComponentId(userID), "Delete")) .queue(); } + @NotNull + private ByteArrayOutputStream getRenderedTextImageStream(@NotNull Image image) + throws IOException { + BufferedImage renderedTextImage = new BufferedImage(image.getWidth(null), + image.getHeight(null), BufferedImage.TYPE_4BYTE_ABGR); + + renderedTextImage.getGraphics().drawImage(image, 0, 0, null); + ByteArrayOutputStream renderedTextImageStream = new ByteArrayOutputStream(); + + ImageIO.write(renderedTextImage, "png", renderedTextImageStream); + + return renderedTextImageStream; + } + + /** + * Converts inline latex like: {@code hello $\frac{x}{2}$ world} to full latex + * {@code \text{hello}\frac{x}{2}\text{ world}}. + * + * @param latex the latex to convert + * @return the converted latex + */ + @NotNull + private String convertInlineLatexToFull(@NotNull String latex) { + if (isInvalidInlineFormat(latex)) { + throw new ParseException(INVALID_INLINE_FORMAT_ERROR_MESSAGE); + } + + Matcher matcher = INLINE_LATEX_REPLACEMENT.matcher(latex); + StringBuilder sb = new StringBuilder(latex.length()); + + while (matcher.find()) { + boolean isInsideMathRegion = matcher.group(1) != null; + if (isInsideMathRegion) { + sb.append(matcher.group(1).replace("$", "")); + } else { + sb.append("\\text{").append(matcher.group(2)).append("}"); + } + } + + return sb.toString(); + } + + private boolean isInvalidInlineFormat(@NotNull String latex) { + return latex.chars().filter(charAsInt -> charAsInt == '$').count() % 2 == 1; + } + @Override public void onButtonClick(@NotNull final ButtonInteractionEvent event, @NotNull final List args) { diff --git a/application/src/test/java/org/togetherjava/tjbot/commands/mathcommands/TeXCommandTest.java b/application/src/test/java/org/togetherjava/tjbot/commands/mathcommands/TeXCommandTest.java new file mode 100644 index 0000000000..be2dd18c0e --- /dev/null +++ b/application/src/test/java/org/togetherjava/tjbot/commands/mathcommands/TeXCommandTest.java @@ -0,0 +1,109 @@ +package org.togetherjava.tjbot.commands.mathcommands; + +import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.togetherjava.tjbot.commands.SlashCommand; +import org.togetherjava.tjbot.jda.JdaTester; + +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.description; +import static org.mockito.Mockito.verify; + +final class TeXCommandTest { + private JdaTester jdaTester; + private SlashCommand command; + + @BeforeEach + void setUp() { + jdaTester = new JdaTester(); + command = jdaTester.spySlashCommand(new TeXCommand()); + } + + private @NotNull SlashCommandInteractionEvent triggerSlashCommand(@NotNull String latex) { + SlashCommandInteractionEvent event = jdaTester.createSlashCommandInteractionEvent(command) + .setOption(TeXCommand.LATEX_OPTION, latex) + .build(); + + command.onSlashCommand(event); + return event; + } + + private void verifySuccessfulResponse(@NotNull SlashCommandInteractionEvent event, + @NotNull String query) { + verify(jdaTester.getInteractionHookMock(), description("Testing query: " + query)) + .editOriginal(any(byte[].class), eq("tex.png")); + } + + private static List provideSupportedQueries() { + List fullLatex = List.of("\\frac{x}{2}", "f \\in \\mathcal{O}(n^2)", + "a^{\\varphi(n)} \\equiv 1\\ (\\textrm{mod}\\ n)", "\\textrm{I like } \\xi"); + + List inlineLatex = List.of("$\\frac{x}{2}$", "$x$ hello", "hello $x$", + "hello $x$ world $y$", "$x$$y$$z$", "$x \\cdot y$"); + + List edgeCases = List.of("", " ", " \n "); + + List allQueries = new ArrayList<>(); + allQueries.addAll(fullLatex); + allQueries.addAll(inlineLatex); + allQueries.addAll(edgeCases); + + return allQueries; + } + + @ParameterizedTest + @MethodSource("provideSupportedQueries") + @DisplayName("The command supports and renders all supported latex queries") + void canRenderSupportedQuery(@NotNull String supportedQuery) { + // GIVEN a supported latex query + + // WHEN triggering the command + SlashCommandInteractionEvent event = triggerSlashCommand(supportedQuery); + + // THEN the command send a successful response + verifySuccessfulResponse(event, supportedQuery); + } + + private static List provideBadInlineQueries() { + return List.of("hello $x world", "$", " $ ", "hello $x$ world$", "$$$$$", "$x$$y$$z"); + } + + @ParameterizedTest + @MethodSource("provideBadInlineQueries") + @DisplayName("The command does not support bad inline latex queries, for example with missing dollars") + void failsOnBadInlineQuery(@NotNull String badInlineQuery) { + // GIVEN a bad inline latex query + + // WHEN triggering the command + SlashCommandInteractionEvent event = triggerSlashCommand(badInlineQuery); + + // THEN the command send a failure response + verify(event, description("Testing query: " + badInlineQuery)) + .reply(contains(TeXCommand.INVALID_INLINE_FORMAT_ERROR_MESSAGE)); + } + + private static List provideBadQueries() { + return List.of("__", "\\foo", "\\left(x + y)"); + } + + @ParameterizedTest + @MethodSource("provideBadQueries") + @DisplayName("The command does not support bad latex queries, for example with unknown symbols or incomplete braces") + void failsOnBadQuery(@NotNull String badQuery) { + // GIVEN a bad inline latex query + + // WHEN triggering the command + SlashCommandInteractionEvent event = triggerSlashCommand(badQuery); + + // THEN the command send a failure response + verify(event, description("Testing query: " + badQuery)) + .reply(startsWith(TeXCommand.BAD_LATEX_ERROR_PREFIX)); + } +} diff --git a/application/src/test/java/org/togetherjava/tjbot/jda/JdaTester.java b/application/src/test/java/org/togetherjava/tjbot/jda/JdaTester.java index 4c96a216ae..2b07ba9c50 100644 --- a/application/src/test/java/org/togetherjava/tjbot/jda/JdaTester.java +++ b/application/src/test/java/org/togetherjava/tjbot/jda/JdaTester.java @@ -6,6 +6,7 @@ import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent; import net.dv8tion.jda.api.events.interaction.component.ButtonInteractionEvent; import net.dv8tion.jda.api.exceptions.ErrorResponseException; +import net.dv8tion.jda.api.interactions.InteractionHook; import net.dv8tion.jda.api.interactions.callbacks.IReplyCallback; import net.dv8tion.jda.api.interactions.components.ItemComponent; import net.dv8tion.jda.api.requests.ErrorResponse; @@ -20,6 +21,7 @@ import net.dv8tion.jda.internal.requests.Requester; import net.dv8tion.jda.internal.requests.restaction.AuditableRestActionImpl; import net.dv8tion.jda.internal.requests.restaction.MessageActionImpl; +import net.dv8tion.jda.internal.requests.restaction.WebhookMessageUpdateActionImpl; import net.dv8tion.jda.internal.requests.restaction.interactions.ReplyCallbackActionImpl; import net.dv8tion.jda.internal.utils.config.AuthorizationConfig; import org.jetbrains.annotations.NotNull; @@ -83,8 +85,11 @@ public final class JdaTester { private final ReplyCallbackActionImpl replyAction; private final AuditableRestActionImpl auditableRestAction; private final MessageActionImpl messageAction; + private final WebhookMessageUpdateActionImpl webhookMessageUpdateAction; private final TextChannelImpl textChannel; private final PrivateChannelImpl privateChannel; + private final InteractionHook interactionHook; + private final ReplyCallbackAction replyCallbackAction; /** * Creates a new instance. The instance uses a fresh and isolated mocked JDA setup. @@ -108,6 +113,8 @@ public JdaTester() { textChannel = spy(new TextChannelImpl(TEXT_CHANNEL_ID, guild)); privateChannel = spy(new PrivateChannelImpl(jda, PRIVATE_CHANNEL_ID, user)); messageAction = mock(MessageActionImpl.class); + webhookMessageUpdateAction = mock(WebhookMessageUpdateActionImpl.class); + replyCallbackAction = mock(ReplyCallbackAction.class); EntityBuilder entityBuilder = mock(EntityBuilder.class); Role everyoneRole = new RoleImpl(GUILD_ID, guild); @@ -149,12 +156,22 @@ public JdaTester() { doNothing().when(auditableRestAction).queue(); doNothing().when(messageAction).queue(); + doNothing().when(webhookMessageUpdateAction).queue(); + doReturn(webhookMessageUpdateAction).when(webhookMessageUpdateAction) + .setActionRow(any(ItemComponent.class)); doReturn(everyoneRole).when(guild).getPublicRole(); doReturn(selfMember).when(guild).getMember(selfUser); doReturn(member).when(guild).getMember(not(eq(selfUser))); doReturn(null).when(textChannel).retrieveMessageById(any()); + + interactionHook = mock(InteractionHook.class); + when(interactionHook.editOriginal(anyString())).thenReturn(webhookMessageUpdateAction); + when(interactionHook.editOriginal(any(Message.class))) + .thenReturn(webhookMessageUpdateAction); + when(interactionHook.editOriginal(any(byte[].class), any(), any())) + .thenReturn(webhookMessageUpdateAction); } /** @@ -251,6 +268,19 @@ public JdaTester() { return replyAction; } + /** + * Gets the Mockito mock used as universal interaction hook by all mocks created by this tester + * instance. + *

+ * For example the events created by {@link #createSlashCommandInteractionEvent(SlashCommand)} + * will return this mock on several of their methods. + * + * @return the interaction hook mock used by this tester + */ + public @NotNull InteractionHook getInteractionHookMock() { + return interactionHook; + } + /** * Gets the text channel spy used as universal text channel by all mocks created by this tester * instance. @@ -371,6 +401,10 @@ private void mockInteraction(@NotNull IReplyCallback interaction) { doReturn(textChannel).when(interaction).getTextChannel(); doReturn(textChannel).when(interaction).getGuildChannel(); doReturn(privateChannel).when(interaction).getPrivateChannel(); + + doReturn(interactionHook).when(interaction).getHook(); + doReturn(replyCallbackAction).when(interaction).deferReply(); + doReturn(replyCallbackAction).when(interaction).deferReply(anyBoolean()); } private void mockButtonClickEvent(@NotNull ButtonInteractionEvent event) {