Skip to content

Adding Tex inline support #441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.events.interaction.component.ButtonInteractionEvent;
import net.dv8tion.jda.api.interactions.callbacks.IDeferrableCallback;
import net.dv8tion.jda.api.interactions.commands.OptionType;
import net.dv8tion.jda.api.interactions.components.buttons.Button;
import net.dv8tion.jda.api.interactions.components.buttons.ButtonStyle;
import org.jetbrains.annotations.NotNull;
import org.scilab.forge.jlatexmath.ParseException;
import org.scilab.forge.jlatexmath.TeXConstants;
Expand All @@ -21,6 +23,8 @@
import java.io.IOException;
import java.util.List;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* Implementation of a tex command which takes a string and renders an image corresponding to the
Expand All @@ -31,11 +35,18 @@
* message.
*/

public class TeXCommand extends SlashCommandAdapter {

private static final String LATEX_OPTION = "latex";
public final class TeXCommand extends SlashCommandAdapter {
static final String LATEX_OPTION = "latex";
// Matches regions between two dollars, like '$foo$'.
private static final String MATH_REGION = "(\\$[^$]+\\$)";
private static final String TEXT_REGION = "([^$]+)";
private static final Pattern INLINE_LATEX_REPLACEMENT =
Pattern.compile(MATH_REGION + "|" + TEXT_REGION);
private static final String RENDERING_ERROR = "There was an error generating the image";
private static final float DEFAULT_IMAGE_SIZE = 40F;
static final String BAD_LATEX_ERROR_PREFIX = "That is an invalid latex: ";
static final String INVALID_INLINE_FORMAT_ERROR_MESSAGE =
"The amount of $-symbols must be divisible by two. Did you forget to close an expression?";
private static final float DEFAULT_IMAGE_SIZE = 40.0F;
private static final Color BACKGROUND_COLOR = Color.decode("#36393F");
private static final Color FOREGROUND_COLOR = Color.decode("#FFFFFF");
private static final Logger logger = LoggerFactory.getLogger(TeXCommand.class);
Expand All @@ -44,8 +55,7 @@ public class TeXCommand extends SlashCommandAdapter {
* Creates a new Instance.
*/
public TeXCommand() {
super("tex",
"This command accepts a latex expression and generates an image corresponding to it.",
super("tex", "Renders LaTeX, also supports inline $-regions like 'see this $\frac{x}{2}$'.",
SlashCommandVisibility.GUILD);
getData().addOption(OptionType.STRING, LATEX_OPTION,
"The latex which is rendered as an image", true);
Expand All @@ -56,42 +66,102 @@ public void onSlashCommand(@NotNull final SlashCommandInteractionEvent event) {
String latex = Objects.requireNonNull(event.getOption(LATEX_OPTION)).getAsString();
String userID = (Objects.requireNonNull(event.getMember()).getId());
TeXFormula formula;

try {
if (latex.contains("$")) {
latex = convertInlineLatexToFull(latex);
}
formula = new TeXFormula(latex);
} catch (ParseException e) {
event.reply("That is an invalid latex: " + e.getMessage()).setEphemeral(true).queue();
event.reply(BAD_LATEX_ERROR_PREFIX + e.getMessage()).setEphemeral(true).queue();
return;
}

event.deferReply().queue();
Image image = formula.createBufferedImage(TeXConstants.STYLE_DISPLAY, DEFAULT_IMAGE_SIZE,
FOREGROUND_COLOR, BACKGROUND_COLOR);
if (image.getWidth(null) == -1 || image.getHeight(null) == -1) {
event.getHook().setEphemeral(true).editOriginal(RENDERING_ERROR).queue();
logger.warn(
"Unable to render latex, image does not have an accessible width or height. Formula was {}",
latex);
return;
}
BufferedImage renderedTextImage = new BufferedImage(image.getWidth(null),
image.getHeight(null), BufferedImage.TYPE_4BYTE_ABGR);
renderedTextImage.getGraphics().drawImage(image, 0, 0, null);
ByteArrayOutputStream renderedTextImageStream = new ByteArrayOutputStream();

try {
ImageIO.write(renderedTextImage, "png", renderedTextImageStream);
Image image = renderImage(formula);
sendImage(event, userID, image);
} catch (IOException e) {
event.getHook().setEphemeral(true).editOriginal(RENDERING_ERROR).queue();
event.getHook().editOriginal(RENDERING_ERROR).queue();
logger.warn(
"Unable to render latex, could not convert the image into an attachable form. Formula was {}",
latex, e);
return;

} catch (IllegalStateException e) {
event.getHook().editOriginal(RENDERING_ERROR).queue();

logger.warn(
"Unable to render latex, image does not have an accessible width or height. Formula was {}",
latex, e);
}
}

private @NotNull Image renderImage(@NotNull TeXFormula formula) {
Image image = formula.createBufferedImage(TeXConstants.STYLE_DISPLAY, DEFAULT_IMAGE_SIZE,
FOREGROUND_COLOR, BACKGROUND_COLOR);

if (image.getWidth(null) == -1 || image.getHeight(null) == -1) {
throw new IllegalStateException("Image has no height or width");
}
return image;
}

private void sendImage(@NotNull IDeferrableCallback event, @NotNull String userID,
@NotNull Image image) throws IOException {
ByteArrayOutputStream renderedTextImageStream = getRenderedTextImageStream(image);
event.getHook()
.editOriginal(renderedTextImageStream.toByteArray(), "tex.png")
.setActionRow(Button.danger(generateComponentId(userID), "Delete"))
.setActionRow(Button.of(ButtonStyle.DANGER, generateComponentId(userID), "Delete"))
.queue();
}

@NotNull
private ByteArrayOutputStream getRenderedTextImageStream(@NotNull Image image)
throws IOException {
BufferedImage renderedTextImage = new BufferedImage(image.getWidth(null),
image.getHeight(null), BufferedImage.TYPE_4BYTE_ABGR);

renderedTextImage.getGraphics().drawImage(image, 0, 0, null);
ByteArrayOutputStream renderedTextImageStream = new ByteArrayOutputStream();

ImageIO.write(renderedTextImage, "png", renderedTextImageStream);

return renderedTextImageStream;
}

/**
* Converts inline latex like: {@code hello $\frac{x}{2}$ world} to full latex
* {@code \text{hello}\frac{x}{2}\text{ world}}.
*
* @param latex the latex to convert
* @return the converted latex
*/
@NotNull
private String convertInlineLatexToFull(@NotNull String latex) {
if (isInvalidInlineFormat(latex)) {
throw new ParseException(INVALID_INLINE_FORMAT_ERROR_MESSAGE);
}

Matcher matcher = INLINE_LATEX_REPLACEMENT.matcher(latex);
StringBuilder sb = new StringBuilder(latex.length());

while (matcher.find()) {
boolean isInsideMathRegion = matcher.group(1) != null;
if (isInsideMathRegion) {
sb.append(matcher.group(1).replace("$", ""));
} else {
sb.append("\\text{").append(matcher.group(2)).append("}");
}
}

return sb.toString();
}

private boolean isInvalidInlineFormat(@NotNull String latex) {
return latex.chars().filter(charAsInt -> charAsInt == '$').count() % 2 == 1;
}

@Override
public void onButtonClick(@NotNull final ButtonInteractionEvent event,
@NotNull final List<String> args) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package org.togetherjava.tjbot.commands.mathcommands;

import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.togetherjava.tjbot.commands.SlashCommand;
import org.togetherjava.tjbot.jda.JdaTester;

import java.util.ArrayList;
import java.util.List;

import static org.mockito.ArgumentMatchers.*;
import static org.mockito.Mockito.description;
import static org.mockito.Mockito.verify;

final class TeXCommandTest {
private JdaTester jdaTester;
private SlashCommand command;

@BeforeEach
void setUp() {
jdaTester = new JdaTester();
command = jdaTester.spySlashCommand(new TeXCommand());
}

private @NotNull SlashCommandInteractionEvent triggerSlashCommand(@NotNull String latex) {
SlashCommandInteractionEvent event = jdaTester.createSlashCommandInteractionEvent(command)
.setOption(TeXCommand.LATEX_OPTION, latex)
.build();

command.onSlashCommand(event);
return event;
}

private void verifySuccessfulResponse(@NotNull SlashCommandInteractionEvent event,
@NotNull String query) {
verify(jdaTester.getInteractionHookMock(), description("Testing query: " + query))
.editOriginal(any(byte[].class), eq("tex.png"));
}

private static List<String> provideSupportedQueries() {
List<String> fullLatex = List.of("\\frac{x}{2}", "f \\in \\mathcal{O}(n^2)",
"a^{\\varphi(n)} \\equiv 1\\ (\\textrm{mod}\\ n)", "\\textrm{I like } \\xi");

List<String> inlineLatex = List.of("$\\frac{x}{2}$", "$x$ hello", "hello $x$",
"hello $x$ world $y$", "$x$$y$$z$", "$x \\cdot y$");

List<String> edgeCases = List.of("", " ", " \n ");

List<String> allQueries = new ArrayList<>();
allQueries.addAll(fullLatex);
allQueries.addAll(inlineLatex);
allQueries.addAll(edgeCases);

return allQueries;
}

@ParameterizedTest
@MethodSource("provideSupportedQueries")
@DisplayName("The command supports and renders all supported latex queries")
void canRenderSupportedQuery(@NotNull String supportedQuery) {
// GIVEN a supported latex query

// WHEN triggering the command
SlashCommandInteractionEvent event = triggerSlashCommand(supportedQuery);

// THEN the command send a successful response
verifySuccessfulResponse(event, supportedQuery);
}

private static List<String> provideBadInlineQueries() {
return List.of("hello $x world", "$", " $ ", "hello $x$ world$", "$$$$$", "$x$$y$$z");
}

@ParameterizedTest
@MethodSource("provideBadInlineQueries")
@DisplayName("The command does not support bad inline latex queries, for example with missing dollars")
void failsOnBadInlineQuery(@NotNull String badInlineQuery) {
// GIVEN a bad inline latex query

// WHEN triggering the command
SlashCommandInteractionEvent event = triggerSlashCommand(badInlineQuery);

// THEN the command send a failure response
verify(event, description("Testing query: " + badInlineQuery))
.reply(contains(TeXCommand.INVALID_INLINE_FORMAT_ERROR_MESSAGE));
}

private static List<String> provideBadQueries() {
return List.of("__", "\\foo", "\\left(x + y)");
}

@ParameterizedTest
@MethodSource("provideBadQueries")
@DisplayName("The command does not support bad latex queries, for example with unknown symbols or incomplete braces")
void failsOnBadQuery(@NotNull String badQuery) {
// GIVEN a bad inline latex query

// WHEN triggering the command
SlashCommandInteractionEvent event = triggerSlashCommand(badQuery);

// THEN the command send a failure response
verify(event, description("Testing query: " + badQuery))
.reply(startsWith(TeXCommand.BAD_LATEX_ERROR_PREFIX));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import net.dv8tion.jda.api.events.interaction.command.SlashCommandInteractionEvent;
import net.dv8tion.jda.api.events.interaction.component.ButtonInteractionEvent;
import net.dv8tion.jda.api.exceptions.ErrorResponseException;
import net.dv8tion.jda.api.interactions.InteractionHook;
import net.dv8tion.jda.api.interactions.callbacks.IReplyCallback;
import net.dv8tion.jda.api.interactions.components.ItemComponent;
import net.dv8tion.jda.api.requests.ErrorResponse;
Expand All @@ -20,6 +21,7 @@
import net.dv8tion.jda.internal.requests.Requester;
import net.dv8tion.jda.internal.requests.restaction.AuditableRestActionImpl;
import net.dv8tion.jda.internal.requests.restaction.MessageActionImpl;
import net.dv8tion.jda.internal.requests.restaction.WebhookMessageUpdateActionImpl;
import net.dv8tion.jda.internal.requests.restaction.interactions.ReplyCallbackActionImpl;
import net.dv8tion.jda.internal.utils.config.AuthorizationConfig;
import org.jetbrains.annotations.NotNull;
Expand Down Expand Up @@ -83,8 +85,11 @@ public final class JdaTester {
private final ReplyCallbackActionImpl replyAction;
private final AuditableRestActionImpl<Void> auditableRestAction;
private final MessageActionImpl messageAction;
private final WebhookMessageUpdateActionImpl webhookMessageUpdateAction;
private final TextChannelImpl textChannel;
private final PrivateChannelImpl privateChannel;
private final InteractionHook interactionHook;
private final ReplyCallbackAction replyCallbackAction;

/**
* Creates a new instance. The instance uses a fresh and isolated mocked JDA setup.
Expand All @@ -108,6 +113,8 @@ public JdaTester() {
textChannel = spy(new TextChannelImpl(TEXT_CHANNEL_ID, guild));
privateChannel = spy(new PrivateChannelImpl(jda, PRIVATE_CHANNEL_ID, user));
messageAction = mock(MessageActionImpl.class);
webhookMessageUpdateAction = mock(WebhookMessageUpdateActionImpl.class);
replyCallbackAction = mock(ReplyCallbackAction.class);
EntityBuilder entityBuilder = mock(EntityBuilder.class);
Role everyoneRole = new RoleImpl(GUILD_ID, guild);

Expand Down Expand Up @@ -149,12 +156,22 @@ public JdaTester() {
doNothing().when(auditableRestAction).queue();

doNothing().when(messageAction).queue();
doNothing().when(webhookMessageUpdateAction).queue();
doReturn(webhookMessageUpdateAction).when(webhookMessageUpdateAction)
.setActionRow(any(ItemComponent.class));

doReturn(everyoneRole).when(guild).getPublicRole();
doReturn(selfMember).when(guild).getMember(selfUser);
doReturn(member).when(guild).getMember(not(eq(selfUser)));

doReturn(null).when(textChannel).retrieveMessageById(any());

interactionHook = mock(InteractionHook.class);
when(interactionHook.editOriginal(anyString())).thenReturn(webhookMessageUpdateAction);
when(interactionHook.editOriginal(any(Message.class)))
.thenReturn(webhookMessageUpdateAction);
when(interactionHook.editOriginal(any(byte[].class), any(), any()))
.thenReturn(webhookMessageUpdateAction);
}

/**
Expand Down Expand Up @@ -251,6 +268,19 @@ public JdaTester() {
return replyAction;
}

/**
* Gets the Mockito mock used as universal interaction hook by all mocks created by this tester
* instance.
* <p>
* For example the events created by {@link #createSlashCommandInteractionEvent(SlashCommand)}
* will return this mock on several of their methods.
*
* @return the interaction hook mock used by this tester
*/
public @NotNull InteractionHook getInteractionHookMock() {
return interactionHook;
}

/**
* Gets the text channel spy used as universal text channel by all mocks created by this tester
* instance.
Expand Down Expand Up @@ -371,6 +401,10 @@ private void mockInteraction(@NotNull IReplyCallback interaction) {
doReturn(textChannel).when(interaction).getTextChannel();
doReturn(textChannel).when(interaction).getGuildChannel();
doReturn(privateChannel).when(interaction).getPrivateChannel();

doReturn(interactionHook).when(interaction).getHook();
doReturn(replyCallbackAction).when(interaction).deferReply();
doReturn(replyCallbackAction).when(interaction).deferReply(anyBoolean());
}

private void mockButtonClickEvent(@NotNull ButtonInteractionEvent event) {
Expand Down