Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add options to auto generate embeddings and summaries #11833

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/main/java/org/jabref/gui/JabRefGUI.java
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,15 @@ public void initialize() {
Injector.setModelOrService(ClipBoardManager.class, clipBoardManager);

JabRefGUI.aiService = new AiService(
stateManager,
preferences.getAiPreferences(),
preferences.getFilePreferences(),
dialogService,
taskExecutor);
Injector.setModelOrService(AiService.class, aiService);

JabRefGUI.chatHistoryService = new ChatHistoryService(
stateManager,
preferences.getCitationKeyPatternPreferences(),
dialogService);
Injector.setModelOrService(ChatHistoryService.class, chatHistoryService);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.jabref.model.groups.AbstractGroup;
import org.jabref.model.groups.GroupTreeNode;

import com.airhacks.afterburner.injection.Injector;
import com.google.common.eventbus.Subscribe;
import dev.langchain4j.data.message.ChatMessage;
import org.slf4j.Logger;
Expand Down Expand Up @@ -58,7 +57,7 @@ public class ChatHistoryService implements AutoCloseable {

private static final String CHAT_HISTORY_FILE_NAME = "chat-histories.mv";

private final StateManager stateManager = Injector.instantiateModelOrService(StateManager.class);
private final StateManager stateManager;

private final CitationKeyPatternPreferences citationKeyPatternPreferences;

Expand All @@ -79,31 +78,25 @@ private record ChatHistoryManagementRecord(Optional<BibDatabaseContext> bibDatab
return o1 == o2 ? 0 : o1.getGroup().getName().compareTo(o2.getGroup().getName());
});

public ChatHistoryService(CitationKeyPatternPreferences citationKeyPatternPreferences, NotificationService notificationService) {
public ChatHistoryService(StateManager stateManager, CitationKeyPatternPreferences citationKeyPatternPreferences, NotificationService notificationService) {
this.stateManager = stateManager;
this.citationKeyPatternPreferences = citationKeyPatternPreferences;
this.implementation = new MVStoreChatHistoryStorage(Directories.getAiFilesDirectory().resolve(CHAT_HISTORY_FILE_NAME), notificationService);
configureHistoryTransfer();
}

public ChatHistoryService(CitationKeyPatternPreferences citationKeyPatternPreferences,
ChatHistoryStorage implementation) {
this.citationKeyPatternPreferences = citationKeyPatternPreferences;
this.implementation = implementation;

configureHistoryTransfer();
configureDatabaseListeners();
}

private void configureHistoryTransfer() {
private void configureDatabaseListeners() {
stateManager.getOpenDatabases().addListener((ListChangeListener<BibDatabaseContext>) change -> {
while (change.next()) {
if (change.wasAdded()) {
change.getAddedSubList().forEach(this::configureHistoryTransfer);
change.getAddedSubList().forEach(this::configureDatabaseListeners);
}
}
});
}

private void configureHistoryTransfer(BibDatabaseContext bibDatabaseContext) {
private void configureDatabaseListeners(BibDatabaseContext bibDatabaseContext) {
bibDatabaseContext.getMetaData().getGroups().ifPresent(rootGroupTreeNode -> {
rootGroupTreeNode.iterateOverTree().forEach(groupNode -> {
groupNode.getGroup().nameProperty().addListener((observable, oldValue, newValue) -> {
Expand Down Expand Up @@ -264,6 +257,7 @@ public void close() {
new HashSet<>(groupsChatHistory.keySet()).forEach(this::closeChatHistoryForGroup);

implementation.commit();
implementation.close();
}

private void transferGroupHistory(BibDatabaseContext bibDatabaseContext, GroupTreeNode groupTreeNode, String oldName, String newName) {
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/org/jabref/gui/preferences/ai/AiTab.fxml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,21 @@
</children>
</HBox>

<CheckBox fx:id="autoGenerateEmbeddings"
mnemonicParsing="false"
text="%Automatically generate embeddings for new entries"
HBox.hgrow="ALWAYS"
maxWidth="Infinity"/>

<CheckBox fx:id="autoGenerateSummaries"
mnemonicParsing="false"
text="%Automatically generate summaries for new entries"
HBox.hgrow="ALWAYS"
maxWidth="Infinity"/>

<Label styleClass="sectionHeader"
text="%Connection"/>

<HBox alignment="CENTER_LEFT"
layoutX="10.0"
layoutY="306.0"
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/jabref/gui/preferences/ai/AiTab.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public class AiTab extends AbstractPreferenceTabView<AiTabViewModel> implements
private static final String HUGGING_FACE_CHAT_MODEL_PROMPT = "TinyLlama/TinyLlama_v1.1 (or any other model name)";

@FXML private CheckBox enableAi;
@FXML private CheckBox autoGenerateEmbeddings;
@FXML private CheckBox autoGenerateSummaries;

@FXML private ComboBox<AiProvider> aiProviderComboBox;
@FXML private ComboBox<String> chatModelComboBox;
Expand Down Expand Up @@ -72,6 +74,10 @@ public void initialize() {
this.viewModel = new AiTabViewModel(preferences);

enableAi.selectedProperty().bindBidirectional(viewModel.enableAi());
autoGenerateSummaries.selectedProperty().bindBidirectional(viewModel.autoGenerateSummaries());
autoGenerateSummaries.disableProperty().bind(viewModel.disableAutoGenerateSummaries());
autoGenerateEmbeddings.selectedProperty().bindBidirectional(viewModel.autoGenerateEmbeddings());
autoGenerateEmbeddings.disableProperty().bind(viewModel.disableAutoGenerateEmbeddings());

new ViewModelListCellFactory<AiProvider>()
.withText(AiProvider::toString)
Expand Down
24 changes: 24 additions & 0 deletions src/main/java/org/jabref/gui/preferences/ai/AiTabViewModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ public class AiTabViewModel implements PreferenceTabViewModel {
private final Locale oldLocale;

private final BooleanProperty enableAi = new SimpleBooleanProperty();
private final BooleanProperty autoGenerateEmbeddings = new SimpleBooleanProperty();
private final BooleanProperty disableAutoGenerateEmbeddings = new SimpleBooleanProperty();
private final BooleanProperty autoGenerateSummaries = new SimpleBooleanProperty();
private final BooleanProperty disableAutoGenerateSummaries = new SimpleBooleanProperty();

private final ListProperty<AiProvider> aiProvidersList =
new SimpleListProperty<>(FXCollections.observableArrayList(AiProvider.values()));
Expand Down Expand Up @@ -295,6 +299,8 @@ public void setValues() {
huggingFaceChatModel.setValue(aiPreferences.getHuggingFaceChatModel());

enableAi.setValue(aiPreferences.getEnableAi());
autoGenerateSummaries.setValue(aiPreferences.getAutoGenerateSummaries());
autoGenerateEmbeddings.setValue(aiPreferences.getAutoGenerateEmbeddings());

selectedAiProvider.setValue(aiPreferences.getAiProvider());

Expand All @@ -313,6 +319,8 @@ public void setValues() {
@Override
public void storeSettings() {
aiPreferences.setEnableAi(enableAi.get());
aiPreferences.setAutoGenerateEmbeddings(autoGenerateEmbeddings.get());
aiPreferences.setAutoGenerateSummaries(autoGenerateSummaries.get());

aiPreferences.setAiProvider(selectedAiProvider.get());

Expand Down Expand Up @@ -407,6 +415,22 @@ public BooleanProperty enableAi() {
return enableAi;
}

public BooleanProperty autoGenerateEmbeddings() {
return autoGenerateEmbeddings;
}

public BooleanProperty disableAutoGenerateEmbeddings() {
return disableAutoGenerateEmbeddings;
}

public BooleanProperty autoGenerateSummaries() {
return autoGenerateSummaries;
}

public BooleanProperty disableAutoGenerateSummaries() {
return disableAutoGenerateSummaries;
}

public ReadOnlyListProperty<AiProvider> aiProvidersProperty() {
return aiProvidersList;
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/org/jabref/logic/ai/AiDefaultPreferences.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ public class AiDefaultPreferences {
);

public static final boolean ENABLE_CHAT = false;
public static final boolean AUTO_GENERATE_EMBEDDINGS = false;
public static final boolean AUTO_GENERATE_SUMMARIES = false;

public static final AiProvider PROVIDER = AiProvider.OPEN_AI;

Expand Down
30 changes: 30 additions & 0 deletions src/main/java/org/jabref/logic/ai/AiPreferences.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public class AiPreferences {
private static final String KEYRING_AI_SERVICE_ACCOUNT = "apiKey";

private final BooleanProperty enableAi;
private final BooleanProperty autoGenerateEmbeddings;
private final BooleanProperty autoGenerateSummaries;

private final ObjectProperty<AiProvider> aiProvider;

Expand Down Expand Up @@ -58,6 +60,8 @@ public class AiPreferences {
private Runnable apiKeyChangeListener;

public AiPreferences(boolean enableAi,
boolean autoGenerateEmbeddings,
boolean autoGenerateSummaries,
AiProvider aiProvider,
String openAiChatModel,
String mistralAiChatModel,
Expand All @@ -78,6 +82,8 @@ public AiPreferences(boolean enableAi,
double ragMinScore
) {
this.enableAi = new SimpleBooleanProperty(enableAi);
this.autoGenerateEmbeddings = new SimpleBooleanProperty(autoGenerateEmbeddings);
this.autoGenerateSummaries = new SimpleBooleanProperty(autoGenerateSummaries);

this.aiProvider = new SimpleObjectProperty<>(aiProvider);

Expand Down Expand Up @@ -143,6 +149,30 @@ public void setEnableAi(boolean enableAi) {
this.enableAi.set(enableAi);
}

public BooleanProperty autoGenerateEmbeddingsProperty() {
return autoGenerateEmbeddings;
}

public boolean getAutoGenerateEmbeddings() {
return autoGenerateEmbeddings.get();
}

public void setAutoGenerateEmbeddings(boolean autoGenerateEmbeddings) {
this.autoGenerateEmbeddings.set(autoGenerateEmbeddings);
}

public BooleanProperty autoGenerateSummariesProperty() {
return autoGenerateSummaries;
}

public boolean getAutoGenerateSummaries() {
return autoGenerateSummaries.get();
}

public void setAutoGenerateSummaries(boolean autoGenerateSummaries) {
this.autoGenerateSummaries.set(autoGenerateSummaries);
}

public ObjectProperty<AiProvider> aiProviderProperty() {
return aiProvider;
}
Expand Down
18 changes: 16 additions & 2 deletions src/main/java/org/jabref/logic/ai/AiService.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import javafx.beans.property.BooleanProperty;
import javafx.beans.property.SimpleBooleanProperty;

import org.jabref.gui.StateManager;
import org.jabref.logic.FilePreferences;
import org.jabref.logic.ai.chatting.AiChatService;
import org.jabref.logic.ai.chatting.model.JabRefChatLanguageModel;
Expand Down Expand Up @@ -52,7 +53,8 @@ public class AiService implements AutoCloseable {
private final IngestionService ingestionService;
private final SummariesService summariesService;

public AiService(AiPreferences aiPreferences,
public AiService(StateManager stateManager,
AiPreferences aiPreferences,
FilePreferences filePreferences,
NotificationService notificationService,
TaskExecutor taskExecutor
Expand All @@ -64,8 +66,11 @@ public AiService(AiPreferences aiPreferences,
this.mvStoreSummariesStorage = new MVStoreSummariesStorage(Directories.getAiFilesDirectory().resolve(SUMMARIES_FILE_NAME), notificationService);

this.jabRefEmbeddingModel = new JabRefEmbeddingModel(aiPreferences, notificationService, taskExecutor);

this.aiChatService = new AiChatService(aiPreferences, jabRefChatLanguageModel, jabRefEmbeddingModel, mvStoreEmbeddingStore, cachedThreadPool);

this.ingestionService = new IngestionService(
stateManager,
aiPreferences,
shutdownSignal,
jabRefEmbeddingModel,
Expand All @@ -74,7 +79,16 @@ public AiService(AiPreferences aiPreferences,
filePreferences,
taskExecutor
);
this.summariesService = new SummariesService(aiPreferences, mvStoreSummariesStorage, jabRefChatLanguageModel, shutdownSignal, filePreferences, taskExecutor);

this.summariesService = new SummariesService(
stateManager,
aiPreferences,
mvStoreSummariesStorage,
jabRefChatLanguageModel,
shutdownSignal,
filePreferences,
taskExecutor
);
}

public JabRefChatLanguageModel getChatLanguageModel() {
Expand Down
60 changes: 56 additions & 4 deletions src/main/java/org/jabref/logic/ai/ingestion/IngestionService.java
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
package org.jabref.logic.ai.ingestion;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

import javafx.beans.property.ReadOnlyBooleanProperty;
import javafx.beans.property.StringProperty;
import javafx.collections.ListChangeListener;

import org.jabref.gui.StateManager;
import org.jabref.logic.FilePreferences;
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.processingstatus.ProcessingInfo;
import org.jabref.logic.ai.processingstatus.ProcessingState;
import org.jabref.logic.util.TaskExecutor;
import org.jabref.model.database.BibDatabaseContext;
import org.jabref.model.database.event.EntriesAddedEvent;
import org.jabref.model.entry.LinkedFile;
import org.jabref.model.entry.event.FieldChangedEvent;
import org.jabref.model.entry.field.StandardField;

import com.google.common.eventbus.Subscribe;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
Expand All @@ -25,25 +30,29 @@
* Use this class in the logic and UI.
*/
public class IngestionService {
private final Map<LinkedFile, ProcessingInfo<LinkedFile, Void>> ingestionStatusMap = new HashMap<>();
// We use a {@link TreeMap} here for the same reasons we use it in {@link ChatHistoryService}.
private final TreeMap<LinkedFile, ProcessingInfo<LinkedFile, Void>> ingestionStatusMap = new TreeMap<>((o1, o2) -> o1 == o2 ? 0 : o1.getLink().compareTo(o2.getLink()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the more easier way with Compartor.comparing( ... )
https://www.baeldung.com/java-8-comparator-comparing


private final List<List<LinkedFile>> listsUnderIngestion = new ArrayList<>();

private final AiPreferences aiPreferences;
private final FilePreferences filePreferences;
private final TaskExecutor taskExecutor;

private final FileEmbeddingsManager fileEmbeddingsManager;

private final ReadOnlyBooleanProperty shutdownSignal;

public IngestionService(AiPreferences aiPreferences,
public IngestionService(StateManager stateManager,
AiPreferences aiPreferences,
ReadOnlyBooleanProperty shutdownSignal,
EmbeddingModel embeddingModel,
EmbeddingStore<TextSegment> embeddingStore,
FullyIngestedDocumentsTracker fullyIngestedDocumentsTracker,
FilePreferences filePreferences,
TaskExecutor taskExecutor
) {
this.aiPreferences = aiPreferences;
this.filePreferences = filePreferences;
this.taskExecutor = taskExecutor;

Expand All @@ -56,6 +65,49 @@ public IngestionService(AiPreferences aiPreferences,
);

this.shutdownSignal = shutdownSignal;

configureDatabaseListeners(stateManager);
}

private void configureDatabaseListeners(StateManager stateManager) {
stateManager.getOpenDatabases().addListener((ListChangeListener<BibDatabaseContext>) change -> {
while (change.next()) {
if (change.wasAdded()) {
change.getAddedSubList().forEach(this::configureDatabaseListeners);
}
}
});
}

private void configureDatabaseListeners(BibDatabaseContext bibDatabaseContext) {
// GC was eating the listeners, so we have to fall back to the event bus.
bibDatabaseContext.getDatabase().registerListener(new EntriesChangedListener(bibDatabaseContext));
}

private class EntriesChangedListener {
private final BibDatabaseContext bibDatabaseContext;

public EntriesChangedListener(BibDatabaseContext bibDatabaseContext) {
this.bibDatabaseContext = bibDatabaseContext;
}

@Subscribe
public void listen(EntriesAddedEvent e) {
e.getBibEntries().forEach(entry -> {
if (aiPreferences.getAutoGenerateEmbeddings()) {
entry.getFiles().forEach(linkedFile -> ingest(linkedFile, bibDatabaseContext));
}

entry.registerListener(this);
});
}

@Subscribe
public void listen(FieldChangedEvent e) {
if (e.getField() == StandardField.FILE && aiPreferences.getAutoGenerateEmbeddings()) {
e.getBibEntry().getFiles().forEach(linkedFile -> ingest(linkedFile, bibDatabaseContext));
}
}
}

/**
Expand Down
Loading
Loading