Skip to content

Commit

Permalink
Modifying logic to generate passages from search result document (#33)
Browse files Browse the repository at this point in the history
Signed-off-by: Mahita Mahesh <mahitam@amazon.com>
  • Loading branch information
mahitamahesh authored and msfroh committed Nov 23, 2022
1 parent 414673a commit 1cbc376
Show file tree
Hide file tree
Showing 10 changed files with 615 additions and 377 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,21 @@
import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreResult;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreResultItem;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.BM25Scorer;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.PassageGenerator;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.QueryParser;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.QueryParser.QueryParserResult;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.SlidingWindowTextSplitter;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.TextTokenizer;
import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings;

public class KendraIntelligentRanker implements ResultTransformer {

private static final int PASSAGE_SIZE_LIMIT = 600;
private static final int SLIDING_WINDOW_STEP = PASSAGE_SIZE_LIMIT - 50;
private static final int MAXIMUM_PASSAGES = 10;
private static final double BM25_B_VALUE = 0.75;
private static final double BM25_K1_VALUE = 1.6;
private static final int TOP_K_PASSAGES = 3;
private static final int MAX_SENTENCE_LENGTH_IN_TOKENS = 35;
private static final int MIN_PASSAGE_LENGTH_IN_TOKENS = 100;
private static final int MAX_PASSAGE_COUNT = 10;
private static final int TITLE_TOKENS_TRIMMED = 15;
private static final int BODY_PASSAGE_TRIMMED = 200;
private static final double BM25_B_VALUE = 0.75;
private static final double BM25_K1_VALUE = 1.6;
private static final int TOP_K_PASSAGES = 3;

private static final Logger logger = LogManager.getLogger(KendraIntelligentRanker.class);

Expand Down Expand Up @@ -93,8 +94,7 @@ public boolean shouldTransform(final SearchRequest request, final ResultTransfor
}

@Override
public SearchRequest preprocessRequest(final SearchRequest request,
final ResultTransformerConfiguration configuration) {
public SearchRequest preprocessRequest(final SearchRequest request, final ResultTransformerConfiguration configuration) {
// Source is returned in response hits by default. If disabled by the user, overwrite and enable
// in order to access document contents for reranking, then suppress at response time.
if (request.source() != null && request.source().fetchSource() != null &&
Expand Down Expand Up @@ -142,7 +142,7 @@ public SearchHits transform(final SearchHits hits,
Map<String, SearchHit> idToSearchHitMap = new HashMap<>();
for (int j = 0; j < numberOfHitsToRerank; ++j) {
Map<String, Object> docSourceMap = originalHits.get(j).getSourceAsMap();
SlidingWindowTextSplitter textSplitter = new SlidingWindowTextSplitter(PASSAGE_SIZE_LIMIT, SLIDING_WINDOW_STEP, MAXIMUM_PASSAGES);
PassageGenerator passageGenerator = new PassageGenerator();
String bodyFieldName = queryParserResult.getBodyFieldName();
String titleFieldName = queryParserResult.getTitleFieldName();
if (docSourceMap.get(bodyFieldName) == null) {
Expand All @@ -152,20 +152,30 @@ public SearchHits transform(final SearchHits hits,
logger.error(errorMessage);
throw new KendraIntelligentRankingException(errorMessage);
}
List<String> splitPassages = textSplitter.split(docSourceMap.get(bodyFieldName).toString());
List<List<String>> topPassages = getTopPassages(queryParserResult.getQueryText(), splitPassages);
List<List<String>> passages = passageGenerator.generatePassages(docSourceMap.get(bodyFieldName).toString(),
MAX_SENTENCE_LENGTH_IN_TOKENS, MIN_PASSAGE_LENGTH_IN_TOKENS, MAX_PASSAGE_COUNT);
List<List<String>> topPassages = getTopPassages(queryParserResult.getQueryText(), passages);
List<String> tokenizedTitle = null;
if (titleFieldName != null && docSourceMap.get(titleFieldName) != null) {
tokenizedTitle = textTokenizer.tokenize(docSourceMap.get(queryParserResult.getTitleFieldName()).toString());
// If tokens list is empty, use null
if (tokenizedTitle.isEmpty()) {
tokenizedTitle = null;
} else if (tokenizedTitle.size() > TITLE_TOKENS_TRIMMED) {
tokenizedTitle = tokenizedTitle.subList(0, TITLE_TOKENS_TRIMMED);
}
}
for (int i = 0; i < topPassages.size(); ++i) {
originalHitsAsDocuments.add(
new Document(originalHits.get(j).getId() + "@" + (i + 1), originalHits.get(j).getId(), tokenizedTitle, topPassages.get(i), originalHits.get(j).getScore())
);
List<String> passageTokens = topPassages.get(i);
if (passageTokens != null && !passageTokens.isEmpty() && passageTokens.size() > BODY_PASSAGE_TRIMMED) {
passageTokens = passageTokens.subList(0, BODY_PASSAGE_TRIMMED);
}
originalHitsAsDocuments.add(new Document(
originalHits.get(j).getId() + "@" + (i + 1),
originalHits.get(j).getId(),
tokenizedTitle,
passageTokens,
originalHits.get(j).getScore()));
}
// Map search hits by their ID in order to map Kendra response documents back to hits later
idToSearchHitMap.put(originalHits.get(j).getId(), originalHits.get(j));
Expand All @@ -184,14 +194,14 @@ public SearchHits transform(final SearchHits hits,
rescoreResultItem.getDocumentId());
logger.error(errorMessage);
throw new KendraIntelligentRankingException(errorMessage);
}
searchHit.score(rescoreResultItem.getScore());
maxScore = Math.max(maxScore, rescoreResultItem.getScore());
newSearchHits.add(searchHit);
}
searchHit.score(rescoreResultItem.getScore());
maxScore = Math.max(maxScore, rescoreResultItem.getScore());
newSearchHits.add(searchHit);
}
// Add remaining hits to response, which are already sorted by OpenSearch score
for (int i = numberOfHitsToRerank; i < originalHits.size(); ++i) {
newSearchHits.add(originalHits.get(i));
newSearchHits.add(originalHits.get(i));
}
return new SearchHits(newSearchHits.toArray(new SearchHit[newSearchHits.size()]), hits.getTotalHits(), maxScore);
} catch (Exception ex) {
Expand All @@ -200,9 +210,8 @@ public SearchHits transform(final SearchHits hits,
}
}

private List<List<String>> getTopPassages(final String queryText, final List<String> splitPassages) {
private List<List<String>> getTopPassages(final String queryText, final List<List<String>> passages) {
List<String> query = textTokenizer.tokenize(queryText);
List<List<String>> passages = textTokenizer.tokenize(splitPassages);
BM25Scorer bm25Scorer = new BM25Scorer(BM25_B_VALUE, BM25_K1_VALUE, passages);
PriorityQueue<PassageScore> pq = new PriorityQueue<>(Comparator.comparingDouble(x -> x.getScore()));

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class PassageGenerator {
private static final int MAX_SENTENCE_LENGTH_IN_TOKENS = 35;
private static final int MIN_PASSAGE_LENGTH_IN_TOKENS = 100;
private static final int MAX_PASSAGE_COUNT = 10;

private SentenceSplitter sentenceSplitter;
private TextTokenizer textTokenizer;

public PassageGenerator() {
this.sentenceSplitter = new SentenceSplitter();
this.textTokenizer = new TextTokenizer();
}

public List<List<String>> generatePassages(final String document, final int maxSentenceLengthInTokens,
final int minPassageLengthInTokens, final int maxPassageCount) {
if (document == null || document.isBlank()) {
return new ArrayList<>();
}

List<List<String>> tokenizedSentences = generateTokenizedSentences(document, maxSentenceLengthInTokens);

// To generate N passages with overlap, generate N/2 + 1 passages first, then overlap and exclude last passage
List<List<List<String>>> passages = combineSentencesIntoPassages(tokenizedSentences,
minPassageLengthInTokens, (maxPassageCount / 2 + 1));

return generatePassagesWithOverlap(passages);
}


List<List<String>> generatePassagesWithOverlap(final List<List<List<String>>> passages) {
final int passageCount = passages.size();
final List<Integer> passageSentenceCounts = passages.stream()
.map(p -> p.size())
.collect(Collectors.toList());

// Generate list of passages, with each passage being a list of tokens, by combining sentences in each passage
List<List<String>> passagesWithOverlap = new ArrayList<>();

if (passageCount == 0) {
return passagesWithOverlap;
}

if (passageCount == 1) {
passagesWithOverlap.add(combineSentencesIntoSinglePassage(passages.get(0)));
return passagesWithOverlap;
}

for (int i = 0; i < (passageCount - 1); ++i) {
// Start at the middle sentence of the first passage
final int firstPassageMidSentenceIndex = (int) Math.floor(passageSentenceCounts.get(i) / 2.0);

// Stop at the middle sentence of the next passage. If there is only one sentence, take it
final int nextPassageMidSentenceIndex = (int) Math.max(1, Math.floor(passageSentenceCounts.get(i + 1) / 2.0));

// Add first passage to overall list, combining tokenized sentences into a single list of tokens
passagesWithOverlap.add(combineSentencesIntoSinglePassage(passages.get(i)));

// Generate the passage with overlap
final List<String> newPassage = new ArrayList<>();
// Use final integer values for stream operation
newPassage.addAll(combineSentencesIntoSinglePassage(
passages.get(i).subList(firstPassageMidSentenceIndex, passageSentenceCounts.get(i))));
newPassage.addAll(combineSentencesIntoSinglePassage(
passages.get(i + 1).subList(0, nextPassageMidSentenceIndex)));

// Add passage with overlap to overall list
passagesWithOverlap.add(newPassage);
}

// Do not add the last passage, in order to limit the overall
return passagesWithOverlap;
}

List<String> combineSentencesIntoSinglePassage(final List<List<String>> tokenizedSentences) {
return tokenizedSentences.stream()
.flatMap(List::stream)
.collect(Collectors.toList());
}

/**
* Combine sentences into passages with minimum length {@code MIN_PASSAGE_LENGTH_IN_TOKENS},
* without splitting up sentences, upto a maximum number of passages
* @param tokenizedSentences list of tokenized sentences
* @return List of passages, where each passage is a list of tokenized sentences
*/
List<List<List<String>>> combineSentencesIntoPassages(final List<List<String>> tokenizedSentences,
int minPassageLengthInTokens,
int maxPassageCount) {
final int sentenceCount = tokenizedSentences.size();
// Maintain list of lengths of each sentence
final List<Integer> sentenceLengthsInTokens = tokenizedSentences.stream()
.map(s -> s.size())
.collect(Collectors.toList());

int currentPassageLengthInTokens = 0;
// Each passage is a list of tokenized sentences, tokens from all sentences are not collapsed
// into a single string because sentences are required for further processing
List<List<String>> currentPassage = new ArrayList<>();
List<List<List<String>>> passages = new ArrayList<>();

for (int i = 0; i < sentenceCount; ++i) {
// Add the sentence to the current passage
currentPassage.add(tokenizedSentences.get(i));
currentPassageLengthInTokens += sentenceLengthsInTokens.get(i);

// If the token count from all remaining sentences is less than half the minimum passage size,
// append all remaining sentence to current passage and end
if (i < (sentenceCount - 1)) {
final int tokenCountFromRemainingSentences = sentenceLengthsInTokens.subList(i + 1, sentenceCount).stream()
.reduce(0, Integer::sum);
if (tokenCountFromRemainingSentences <= (minPassageLengthInTokens / 2)) {
currentPassage.addAll(tokenizedSentences.subList(i + 1, sentenceCount));
passages.add(currentPassage);
break;
}
}

// If min passage length is reached, or this is the last sentence, add current passage to list of passages
if (currentPassageLengthInTokens >= minPassageLengthInTokens || i == (sentenceCount - 1)) {
passages.add(currentPassage);
// Reset current passage and it length
currentPassage = new ArrayList<>();
currentPassageLengthInTokens = 0;
}

// If max number of passages is reached, end
if (passages.size() == maxPassageCount) {
break;
}
}

return passages;
}

/**
* Split a text document into tokenized sentences, while breaking up large sentences
* @param document input document
* @return List, where each member of the list is a list of tokens
*/
List<List<String>> generateTokenizedSentences(final String document, final int maxSentenceLengthInTokens) {
List<List<String>> tokenizedSentences = new ArrayList<>();

List<String> sentences = sentenceSplitter.split(document);
for (String sentence: sentences) {
List<String> currentSentence = textTokenizer.tokenize(sentence);
if (currentSentence.isEmpty()) {
continue;
}
// Break up long sentences
if (currentSentence.size() <= maxSentenceLengthInTokens) {
tokenizedSentences.add(currentSentence);
} else {
final int sentenceLengthInTokens = currentSentence.size();
for (int i = 0; i < sentenceLengthInTokens; i += maxSentenceLengthInTokens) {
final int tokensRemainingInSentence =
sentenceLengthInTokens - (i + maxSentenceLengthInTokens);
// If the remaining text is too short, add it to the current sentence and end
if (tokensRemainingInSentence <= (maxSentenceLengthInTokens / 2)) {
tokenizedSentences.add(currentSentence.subList(i, sentenceLengthInTokens));
break;
}
tokenizedSentences.add(currentSentence.subList(i,
Math.min(sentenceLengthInTokens, i + maxSentenceLengthInTokens)));
}
}
}
return tokenizedSentences;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess;

import com.ibm.icu.text.BreakIterator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public class SentenceSplitter {

/**
* Split the input text into sentences
* @param text input text
* @return list of strings, each a sentence
*/
public List<String> split(final String text) {
if (text == null) {
return new ArrayList<>();
}

final BreakIterator breakIterator = BreakIterator.getSentenceInstance(Locale.ENGLISH);
breakIterator.setText(text);

List<String> sentences = new ArrayList();
int start = breakIterator.first();
String currentSentence;

for (int end = breakIterator.next(); end != BreakIterator.DONE; start = end, end = breakIterator.next()) {
currentSentence = text.substring(start, end).stripTrailing();
if (!currentSentence.isEmpty()) {
sentences.add(currentSentence);
}
}
return sentences;
}
}
Loading

0 comments on commit 1cbc376

Please sign in to comment.