From 1cbc3766333af3edad8388f9a258f443d5929c13 Mon Sep 17 00:00:00 2001 From: mahitamahesh Date: Wed, 23 Nov 2022 13:36:38 -0500 Subject: [PATCH] Modifying logic to generate passages from search result document (#33) Signed-off-by: Mahita Mahesh --- .../KendraIntelligentRanker.java | 55 +++-- .../preprocess/PassageGenerator.java | 182 ++++++++++++++++ .../preprocess/SentenceSplitter.java | 47 ++++ .../preprocess/SlidingWindowTextSplitter.java | 159 -------------- .../preprocess/TextTokenizer.java | 79 ++++++- .../preprocess/PassageGeneratorTests.java | 200 ++++++++++++++++++ .../preprocess/SentenceSplitterTests.java | 39 ++++ .../SlidingWindowTextSplitterTests.java | 157 -------------- .../preprocess/TextTokenizerTests.java | 59 ++++-- src/test/resources/splitter/input.txt | 15 -- 10 files changed, 615 insertions(+), 377 deletions(-) create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/PassageGenerator.java create mode 100644 src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SentenceSplitter.java delete mode 100644 src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SlidingWindowTextSplitter.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/PassageGeneratorTests.java create mode 100644 src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SentenceSplitterTests.java delete mode 100644 src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SlidingWindowTextSplitterTests.java delete mode 100644 src/test/resources/splitter/input.txt diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRanker.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRanker.java index 6c5597c..458dc55 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRanker.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRanker.java @@ -38,20 +38,21 @@ import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreResult; import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreResultItem; import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.BM25Scorer; +import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.PassageGenerator; import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.QueryParser; import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.QueryParser.QueryParserResult; -import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.SlidingWindowTextSplitter; import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.TextTokenizer; import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings; public class KendraIntelligentRanker implements ResultTransformer { - - private static final int PASSAGE_SIZE_LIMIT = 600; - private static final int SLIDING_WINDOW_STEP = PASSAGE_SIZE_LIMIT - 50; - private static final int MAXIMUM_PASSAGES = 10; - private static final double BM25_B_VALUE = 0.75; - private static final double BM25_K1_VALUE = 1.6; - private static final int TOP_K_PASSAGES = 3; + private static final int MAX_SENTENCE_LENGTH_IN_TOKENS = 35; + private static final int MIN_PASSAGE_LENGTH_IN_TOKENS = 100; + private static final int MAX_PASSAGE_COUNT = 10; + private static final int TITLE_TOKENS_TRIMMED = 15; + private static final int BODY_PASSAGE_TRIMMED = 200; + private static final double BM25_B_VALUE = 0.75; + private static final double BM25_K1_VALUE = 1.6; + private static final int TOP_K_PASSAGES = 3; private static final Logger logger = LogManager.getLogger(KendraIntelligentRanker.class); @@ -93,8 +94,7 @@ public boolean shouldTransform(final SearchRequest request, final ResultTransfor } @Override - public SearchRequest preprocessRequest(final SearchRequest request, - final ResultTransformerConfiguration configuration) { + public SearchRequest preprocessRequest(final SearchRequest request, final ResultTransformerConfiguration configuration) { // Source is returned in response hits by default. If disabled by the user, overwrite and enable // in order to access document contents for reranking, then suppress at response time. if (request.source() != null && request.source().fetchSource() != null && @@ -142,7 +142,7 @@ public SearchHits transform(final SearchHits hits, Map idToSearchHitMap = new HashMap<>(); for (int j = 0; j < numberOfHitsToRerank; ++j) { Map docSourceMap = originalHits.get(j).getSourceAsMap(); - SlidingWindowTextSplitter textSplitter = new SlidingWindowTextSplitter(PASSAGE_SIZE_LIMIT, SLIDING_WINDOW_STEP, MAXIMUM_PASSAGES); + PassageGenerator passageGenerator = new PassageGenerator(); String bodyFieldName = queryParserResult.getBodyFieldName(); String titleFieldName = queryParserResult.getTitleFieldName(); if (docSourceMap.get(bodyFieldName) == null) { @@ -152,20 +152,30 @@ public SearchHits transform(final SearchHits hits, logger.error(errorMessage); throw new KendraIntelligentRankingException(errorMessage); } - List splitPassages = textSplitter.split(docSourceMap.get(bodyFieldName).toString()); - List> topPassages = getTopPassages(queryParserResult.getQueryText(), splitPassages); + List> passages = passageGenerator.generatePassages(docSourceMap.get(bodyFieldName).toString(), + MAX_SENTENCE_LENGTH_IN_TOKENS, MIN_PASSAGE_LENGTH_IN_TOKENS, MAX_PASSAGE_COUNT); + List> topPassages = getTopPassages(queryParserResult.getQueryText(), passages); List tokenizedTitle = null; if (titleFieldName != null && docSourceMap.get(titleFieldName) != null) { tokenizedTitle = textTokenizer.tokenize(docSourceMap.get(queryParserResult.getTitleFieldName()).toString()); // If tokens list is empty, use null if (tokenizedTitle.isEmpty()) { tokenizedTitle = null; + } else if (tokenizedTitle.size() > TITLE_TOKENS_TRIMMED) { + tokenizedTitle = tokenizedTitle.subList(0, TITLE_TOKENS_TRIMMED); } } for (int i = 0; i < topPassages.size(); ++i) { - originalHitsAsDocuments.add( - new Document(originalHits.get(j).getId() + "@" + (i + 1), originalHits.get(j).getId(), tokenizedTitle, topPassages.get(i), originalHits.get(j).getScore()) - ); + List passageTokens = topPassages.get(i); + if (passageTokens != null && !passageTokens.isEmpty() && passageTokens.size() > BODY_PASSAGE_TRIMMED) { + passageTokens = passageTokens.subList(0, BODY_PASSAGE_TRIMMED); + } + originalHitsAsDocuments.add(new Document( + originalHits.get(j).getId() + "@" + (i + 1), + originalHits.get(j).getId(), + tokenizedTitle, + passageTokens, + originalHits.get(j).getScore())); } // Map search hits by their ID in order to map Kendra response documents back to hits later idToSearchHitMap.put(originalHits.get(j).getId(), originalHits.get(j)); @@ -184,14 +194,14 @@ public SearchHits transform(final SearchHits hits, rescoreResultItem.getDocumentId()); logger.error(errorMessage); throw new KendraIntelligentRankingException(errorMessage); - } - searchHit.score(rescoreResultItem.getScore()); - maxScore = Math.max(maxScore, rescoreResultItem.getScore()); - newSearchHits.add(searchHit); + } + searchHit.score(rescoreResultItem.getScore()); + maxScore = Math.max(maxScore, rescoreResultItem.getScore()); + newSearchHits.add(searchHit); } // Add remaining hits to response, which are already sorted by OpenSearch score for (int i = numberOfHitsToRerank; i < originalHits.size(); ++i) { - newSearchHits.add(originalHits.get(i)); + newSearchHits.add(originalHits.get(i)); } return new SearchHits(newSearchHits.toArray(new SearchHit[newSearchHits.size()]), hits.getTotalHits(), maxScore); } catch (Exception ex) { @@ -200,9 +210,8 @@ public SearchHits transform(final SearchHits hits, } } - private List> getTopPassages(final String queryText, final List splitPassages) { + private List> getTopPassages(final String queryText, final List> passages) { List query = textTokenizer.tokenize(queryText); - List> passages = textTokenizer.tokenize(splitPassages); BM25Scorer bm25Scorer = new BM25Scorer(BM25_B_VALUE, BM25_K1_VALUE, passages); PriorityQueue pq = new PriorityQueue<>(Comparator.comparingDouble(x -> x.getScore())); diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/PassageGenerator.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/PassageGenerator.java new file mode 100644 index 0000000..3742d57 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/PassageGenerator.java @@ -0,0 +1,182 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +public class PassageGenerator { + private static final int MAX_SENTENCE_LENGTH_IN_TOKENS = 35; + private static final int MIN_PASSAGE_LENGTH_IN_TOKENS = 100; + private static final int MAX_PASSAGE_COUNT = 10; + + private SentenceSplitter sentenceSplitter; + private TextTokenizer textTokenizer; + + public PassageGenerator() { + this.sentenceSplitter = new SentenceSplitter(); + this.textTokenizer = new TextTokenizer(); + } + + public List> generatePassages(final String document, final int maxSentenceLengthInTokens, + final int minPassageLengthInTokens, final int maxPassageCount) { + if (document == null || document.isBlank()) { + return new ArrayList<>(); + } + + List> tokenizedSentences = generateTokenizedSentences(document, maxSentenceLengthInTokens); + + // To generate N passages with overlap, generate N/2 + 1 passages first, then overlap and exclude last passage + List>> passages = combineSentencesIntoPassages(tokenizedSentences, + minPassageLengthInTokens, (maxPassageCount / 2 + 1)); + + return generatePassagesWithOverlap(passages); + } + + + List> generatePassagesWithOverlap(final List>> passages) { + final int passageCount = passages.size(); + final List passageSentenceCounts = passages.stream() + .map(p -> p.size()) + .collect(Collectors.toList()); + + // Generate list of passages, with each passage being a list of tokens, by combining sentences in each passage + List> passagesWithOverlap = new ArrayList<>(); + + if (passageCount == 0) { + return passagesWithOverlap; + } + + if (passageCount == 1) { + passagesWithOverlap.add(combineSentencesIntoSinglePassage(passages.get(0))); + return passagesWithOverlap; + } + + for (int i = 0; i < (passageCount - 1); ++i) { + // Start at the middle sentence of the first passage + final int firstPassageMidSentenceIndex = (int) Math.floor(passageSentenceCounts.get(i) / 2.0); + + // Stop at the middle sentence of the next passage. If there is only one sentence, take it + final int nextPassageMidSentenceIndex = (int) Math.max(1, Math.floor(passageSentenceCounts.get(i + 1) / 2.0)); + + // Add first passage to overall list, combining tokenized sentences into a single list of tokens + passagesWithOverlap.add(combineSentencesIntoSinglePassage(passages.get(i))); + + // Generate the passage with overlap + final List newPassage = new ArrayList<>(); + // Use final integer values for stream operation + newPassage.addAll(combineSentencesIntoSinglePassage( + passages.get(i).subList(firstPassageMidSentenceIndex, passageSentenceCounts.get(i)))); + newPassage.addAll(combineSentencesIntoSinglePassage( + passages.get(i + 1).subList(0, nextPassageMidSentenceIndex))); + + // Add passage with overlap to overall list + passagesWithOverlap.add(newPassage); + } + + // Do not add the last passage, in order to limit the overall + return passagesWithOverlap; + } + + List combineSentencesIntoSinglePassage(final List> tokenizedSentences) { + return tokenizedSentences.stream() + .flatMap(List::stream) + .collect(Collectors.toList()); + } + + /** + * Combine sentences into passages with minimum length {@code MIN_PASSAGE_LENGTH_IN_TOKENS}, + * without splitting up sentences, upto a maximum number of passages + * @param tokenizedSentences list of tokenized sentences + * @return List of passages, where each passage is a list of tokenized sentences + */ + List>> combineSentencesIntoPassages(final List> tokenizedSentences, + int minPassageLengthInTokens, + int maxPassageCount) { + final int sentenceCount = tokenizedSentences.size(); + // Maintain list of lengths of each sentence + final List sentenceLengthsInTokens = tokenizedSentences.stream() + .map(s -> s.size()) + .collect(Collectors.toList()); + + int currentPassageLengthInTokens = 0; + // Each passage is a list of tokenized sentences, tokens from all sentences are not collapsed + // into a single string because sentences are required for further processing + List> currentPassage = new ArrayList<>(); + List>> passages = new ArrayList<>(); + + for (int i = 0; i < sentenceCount; ++i) { + // Add the sentence to the current passage + currentPassage.add(tokenizedSentences.get(i)); + currentPassageLengthInTokens += sentenceLengthsInTokens.get(i); + + // If the token count from all remaining sentences is less than half the minimum passage size, + // append all remaining sentence to current passage and end + if (i < (sentenceCount - 1)) { + final int tokenCountFromRemainingSentences = sentenceLengthsInTokens.subList(i + 1, sentenceCount).stream() + .reduce(0, Integer::sum); + if (tokenCountFromRemainingSentences <= (minPassageLengthInTokens / 2)) { + currentPassage.addAll(tokenizedSentences.subList(i + 1, sentenceCount)); + passages.add(currentPassage); + break; + } + } + + // If min passage length is reached, or this is the last sentence, add current passage to list of passages + if (currentPassageLengthInTokens >= minPassageLengthInTokens || i == (sentenceCount - 1)) { + passages.add(currentPassage); + // Reset current passage and it length + currentPassage = new ArrayList<>(); + currentPassageLengthInTokens = 0; + } + + // If max number of passages is reached, end + if (passages.size() == maxPassageCount) { + break; + } + } + + return passages; + } + + /** + * Split a text document into tokenized sentences, while breaking up large sentences + * @param document input document + * @return List, where each member of the list is a list of tokens + */ + List> generateTokenizedSentences(final String document, final int maxSentenceLengthInTokens) { + List> tokenizedSentences = new ArrayList<>(); + + List sentences = sentenceSplitter.split(document); + for (String sentence: sentences) { + List currentSentence = textTokenizer.tokenize(sentence); + if (currentSentence.isEmpty()) { + continue; + } + // Break up long sentences + if (currentSentence.size() <= maxSentenceLengthInTokens) { + tokenizedSentences.add(currentSentence); + } else { + final int sentenceLengthInTokens = currentSentence.size(); + for (int i = 0; i < sentenceLengthInTokens; i += maxSentenceLengthInTokens) { + final int tokensRemainingInSentence = + sentenceLengthInTokens - (i + maxSentenceLengthInTokens); + // If the remaining text is too short, add it to the current sentence and end + if (tokensRemainingInSentence <= (maxSentenceLengthInTokens / 2)) { + tokenizedSentences.add(currentSentence.subList(i, sentenceLengthInTokens)); + break; + } + tokenizedSentences.add(currentSentence.subList(i, + Math.min(sentenceLengthInTokens, i + maxSentenceLengthInTokens))); + } + } + } + return tokenizedSentences; + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SentenceSplitter.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SentenceSplitter.java new file mode 100644 index 0000000..47c50be --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SentenceSplitter.java @@ -0,0 +1,47 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess; + +import com.ibm.icu.text.BreakIterator; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +public class SentenceSplitter { + + /** + * Split the input text into sentences + * @param text input text + * @return list of strings, each a sentence + */ + public List split(final String text) { + if (text == null) { + return new ArrayList<>(); + } + + final BreakIterator breakIterator = BreakIterator.getSentenceInstance(Locale.ENGLISH); + breakIterator.setText(text); + + List sentences = new ArrayList(); + int start = breakIterator.first(); + String currentSentence; + + for (int end = breakIterator.next(); end != BreakIterator.DONE; start = end, end = breakIterator.next()) { + currentSentence = text.substring(start, end).stripTrailing(); + if (!currentSentence.isEmpty()) { + sentences.add(currentSentence); + } + } + return sentences; + } +} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SlidingWindowTextSplitter.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SlidingWindowTextSplitter.java deleted file mode 100644 index 6ff8455..0000000 --- a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SlidingWindowTextSplitter.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess; - -import com.ibm.icu.text.BreakIterator; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.LinkedList; -import java.util.List; -import java.util.Locale; -import org.apache.commons.lang3.tuple.Pair; - -/** - * Applies a sliding window to split input text into passages. However, splitting is aware of - * sentence boundaries and hence the size of splits may be larger than the configured window size. - */ -public class SlidingWindowTextSplitter { - - /** - * Minimum size of window of text to be selected in a split. Window size can be larger - * in practice because of respecting sentence boundaries. - */ - private int windowSize; - - /** - * The minimum step size by which to move the sliding window after a split. Step size can be larger - * in practice because of respecting sentence boundaries. - */ - private int stepSize; - - /** - * The maximum number of passages to be extracted from the input text - */ - private int maximumPassages; - - //local parameters to control sentence/token iteration - private Pair lastSentenceBoundary; - private boolean useTokenIterator; - - public SlidingWindowTextSplitter(int windowSize, int stepSize, int maximumPassages) { - setSlidingWindow(windowSize, stepSize); - this.maximumPassages = maximumPassages; - this.useTokenIterator = Boolean.FALSE; - this.lastSentenceBoundary = null; - } - - /** - * Set parameters of sliding window text splitter - * @param updatedWindowSize text window size - * @param updatedStepSize step size to move the window - */ - public void setSlidingWindow(int updatedWindowSize, int updatedStepSize) { - if (updatedStepSize > updatedWindowSize) { - throw new IllegalArgumentException("Step size " + updatedStepSize + " is larger than window size " + updatedWindowSize); - } - - this.windowSize = updatedWindowSize; - this.stepSize = updatedStepSize; - } - - /** - * Obtains a list of split passages, aware of sentence boundaries, from the input text. - * @param text Text to be split - * @return List of passages extracted from input text - */ - public List split(String text) { - if (text.isEmpty()) { - return new ArrayList<>(); - } - - if (text.length() <= windowSize) { - return Arrays.asList(text); - } - - List splitText = new LinkedList<>(); - - BreakIterator sentenceIterator = BreakIterator.getSentenceInstance(Locale.ENGLISH); - BreakIterator tokenIterator = BreakIterator.getLineInstance(Locale.ENGLISH); - sentenceIterator.setText(text); - - int passageCounter = 0; - int startBoundaryIndex = 0; - int endBoundaryIndex = 0; - - int nextStartBoundaryIndex = 0; - boolean nextStartBoundaryIndexUpdated = false; - - while ((passageCounter < this.maximumPassages) && (endBoundaryIndex != BreakIterator.DONE)) { - // If current passage length is already larger than the step size, - // use the end index as the start index for next window - if (!nextStartBoundaryIndexUpdated && (endBoundaryIndex - startBoundaryIndex + 1) >= stepSize) { - nextStartBoundaryIndexUpdated = true; - nextStartBoundaryIndex = endBoundaryIndex; - } - - if (endBoundaryIndex == text.length() ) { - // End of the input text. Add the passage, irrespective of its length. - splitText.add(text.substring(startBoundaryIndex, endBoundaryIndex)); - ++passageCounter; - } else if ((endBoundaryIndex - startBoundaryIndex + 1) >= windowSize && endBoundaryIndex > nextStartBoundaryIndex ) { - // If current passage length is greater than both step and window size, extend the window - // such that all passages have some overlap - splitText.add(text.substring(startBoundaryIndex, endBoundaryIndex)); - ++passageCounter; - - startBoundaryIndex = nextStartBoundaryIndex; - - // Check whether current end boundary index can be used as next start boundary index - if ((endBoundaryIndex - startBoundaryIndex + 1) >= stepSize) { - nextStartBoundaryIndex = endBoundaryIndex; - nextStartBoundaryIndexUpdated = true; - } else { - nextStartBoundaryIndexUpdated = false; - } - } - - endBoundaryIndex = getNextEndBoundaryIndex(sentenceIterator, tokenIterator, endBoundaryIndex, text); - } - - return splitText; - } - - /** - * Get the end boundary index of the next split. In common cases, move the iterator by a sentence. - * If a sentence is too long, move the iterator by tokens rather than sentences. - * @param sentenceIterator iterator respecting sentence boundaries - * @param tokenIterator iterator respecting tokens - * @param previousEndBoundaryIndex end boundary index of previous split - * @param text input text to split - */ - private int getNextEndBoundaryIndex(BreakIterator sentenceIterator, BreakIterator tokenIterator, int previousEndBoundaryIndex, String text) { - if (!useTokenIterator) { - int nextSentenceBoundary = sentenceIterator.next(); - int sentenceLength = nextSentenceBoundary - previousEndBoundaryIndex; - if (nextSentenceBoundary == BreakIterator.DONE || sentenceLength < windowSize) { - return nextSentenceBoundary; - } else { - // Sentence is too long, use tokenIterator - lastSentenceBoundary = Pair.of(previousEndBoundaryIndex, nextSentenceBoundary); - useTokenIterator = Boolean.TRUE; - tokenIterator.setText(text.substring(previousEndBoundaryIndex, nextSentenceBoundary)); - } - } - - // Always add the offset of previous sentence - int nextTokenBoundary = lastSentenceBoundary.getLeft() + tokenIterator.next(); - if (nextTokenBoundary == lastSentenceBoundary.getRight()) { - // Finished token iterations, unset booleans to start from next sentence. - useTokenIterator = Boolean.FALSE; - lastSentenceBoundary = null; - } - return nextTokenBoundary; - } -} \ No newline at end of file diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/TextTokenizer.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/TextTokenizer.java index daf89f3..ca7f2e2 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/TextTokenizer.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/TextTokenizer.java @@ -7,16 +7,20 @@ */ package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess; +import com.ibm.icu.text.BreakIterator; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Set; +import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; public class TextTokenizer { + private static final int MINIMUM_WORD_LENGTH = 2; + private static final int MAXIMUM_WORD_LENGTH = 25; private static final Set STOP_WORDS = new HashSet<>( Arrays.asList("i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you", "your", "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself", "it", "its", "itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", @@ -26,26 +30,83 @@ public class TextTokenizer { "on", "off", "over", "under", "again", "further", "then", "once", "here", "there", "when", "where", "why", "how", "all", "any", "both", "each", "few", "more", "most", "other", "some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t", "can", "will", "just", "don", "should", "now")); - private static final Pattern SPLIT_PATTERN = Pattern.compile("[\\p{Punct}\\s]+"); + private static final Pattern ALL_PUNCTUATIONS_REGEX = Pattern.compile("^\\p{Pc}+$|^\\p{Pd}+$|^\\p{Pe}+$|^\\p{Pf}+$|^\\p{Pi}+$|^\\p{Po}+$|^\\p{Ps}+$"); + private static final Pattern PUNCTUATIONS_REGEX_PATTERN = Pattern.compile("\\p{Pc}|\\p{Pd}|\\p{Pe}|\\p{Pf}|\\p{Pi}|\\p{Po}|\\p{Ps}"); public List> tokenize(List texts) { + if (texts == null) { + return new ArrayList<>(); + } + return texts.stream() .map(text -> tokenize(text)) .collect(Collectors.toList()); } + /** + * Split the input text into tokens, with post-processing to remove stop words, punctuation, etc. + * @param text input text + * @return list of tokens + */ public List tokenize(String text) { - String[] tokens = text.split(SPLIT_PATTERN.pattern()); - List validTokens = new ArrayList<>(); - for (String token : tokens) { - if (token.length() == 0) { + if (text == null) { + return new ArrayList<>(); + } + + final BreakIterator breakIterator = BreakIterator.getWordInstance(Locale.ENGLISH); + breakIterator.setText(text); + + List tokens = new ArrayList(); + int start = breakIterator.first(); + String currentWord; + for (int end = breakIterator.next(); end != BreakIterator.DONE; start = end, end = breakIterator.next()) { + currentWord = text.substring(start, end).stripTrailing().toLowerCase(Locale.ENGLISH); + if (currentWord.isEmpty()) { continue; } - String lowerCased = token.toLowerCase(Locale.ENGLISH); - if (!STOP_WORDS.contains(lowerCased)) { - validTokens.add(lowerCased); + // Split long words + List shortenedTokens = new ArrayList<>(); + if (currentWord.length() <= MAXIMUM_WORD_LENGTH) { + shortenedTokens.add(currentWord); + } else { + for (int i = 0; i < currentWord.length(); i += MAXIMUM_WORD_LENGTH) { + shortenedTokens.add(currentWord.substring(i, Math.min(currentWord.length(), i + MAXIMUM_WORD_LENGTH))); + } + } + // Filter out punctuation, short words, numbers + for (String shortenedToken : shortenedTokens) { + if (!isWordAllPunctuation(shortenedToken) && !STOP_WORDS.contains(shortenedToken) && + shortenedToken.length() >= MINIMUM_WORD_LENGTH && !isNumeric(shortenedToken)) { + String tokenWithInWordPunctuationRemoved = removeInWordPunctuation(shortenedToken); + if (!tokenWithInWordPunctuationRemoved.isEmpty()) { + tokens.add(tokenWithInWordPunctuationRemoved); + } + } } } - return validTokens; + return tokens; + } + + boolean isWordAllPunctuation(final String token) { + return (token != null) && ALL_PUNCTUATIONS_REGEX.matcher(token).matches(); + } + + boolean isNumeric(final String token) { + if (token == null) { + return false; + } + try { + Double.parseDouble(token); + } catch (NumberFormatException e) { + return false; + } + return true; + } + + String removeInWordPunctuation(String token) { + if (token == null) { + return null; + } + return PUNCTUATIONS_REGEX_PATTERN.matcher(token).replaceAll(""); } } diff --git a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/PassageGeneratorTests.java b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/PassageGeneratorTests.java new file mode 100644 index 0000000..eea81e4 --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/PassageGeneratorTests.java @@ -0,0 +1,200 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.opensearch.test.OpenSearchTestCase; + +public class PassageGeneratorTests extends OpenSearchTestCase { + private static final int MAX_SENTENCE_LENGTH_IN_TOKENS = 4; + private static final int MIN_PASSAGE_LENGTH_IN_TOKENS = 14; + + private static final List SUFFIX_FOR_LONGER_SENTENCE_TOKENS = Arrays.asList("longer", "collection", "suffix"); + private static final String SUFFIX_FOR_LONGER_SENTENCE = String.join(" ", SUFFIX_FOR_LONGER_SENTENCE_TOKENS) + "."; + private static final List PASSAGE_1_SENTENCE_1_TOKENS = Arrays.asList("Words", "comprising", "passage1", "sentence1"); + private static final String PASSAGE_1_SENTENCE_1 = String.join(" ", PASSAGE_1_SENTENCE_1_TOKENS) + "."; + private static final List PASSAGE_1_SENTENCE_2_TOKENS = Arrays.asList("Words", "comprising", "passage1", "sentence2"); + private static final String PASSAGE_1_SENTENCE_2 = String.join(" ", PASSAGE_1_SENTENCE_2_TOKENS) + "."; + private static final List PASSAGE_1_SENTENCE_3_TOKENS = Arrays.asList("Words", "comprising", "passage1", "sentence3"); + private static final List PASSAGE_1_LONG_SENTENCE_TOKENS = Stream.concat(PASSAGE_1_SENTENCE_3_TOKENS.stream(), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream()).collect(Collectors.toList()); + private static final String PASSAGE_1_LONG_SENTENCE = String.join(" ", PASSAGE_1_LONG_SENTENCE_TOKENS) + "."; + private static final List PASSAGE_2_SENTENCE_1_TOKENS = Arrays.asList("Words", "comprising", "passage2", "sentence1"); + private static final String PASSAGE_2_SENTENCE_1 = String.join(" ", PASSAGE_2_SENTENCE_1_TOKENS) + "."; + private static final List PASSAGE_2_SENTENCE_2_TOKENS = Arrays.asList("Words", "comprising", "passage2", "sentence2"); + private static final String PASSAGE_2_SENTENCE_2 = String.join(" ", PASSAGE_2_SENTENCE_2_TOKENS) + "."; + private static final List PASSAGE_2_SENTENCE_3_TOKENS = Arrays.asList("Words", "comprising", "passage2", "sentence3"); + private static final List PASSAGE_2_LONG_SENTENCE_TOKENS = Stream.concat(PASSAGE_2_SENTENCE_3_TOKENS.stream(), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream()).collect(Collectors.toList()); + private static final String PASSAGE_2_LONG_SENTENCE = String.join(" ", PASSAGE_2_LONG_SENTENCE_TOKENS) + "."; + + private PassageGenerator passageGenerator = new PassageGenerator(); + + public void testGeneratePassages_BlankDocument() { + assertEquals(Collections.emptyList(), passageGenerator.generatePassages( + null, MAX_SENTENCE_LENGTH_IN_TOKENS, MIN_PASSAGE_LENGTH_IN_TOKENS, 1)); + assertEquals(Collections.emptyList(), passageGenerator.generatePassages( + "", MAX_SENTENCE_LENGTH_IN_TOKENS, MIN_PASSAGE_LENGTH_IN_TOKENS, 1)); + assertEquals(Collections.emptyList(), passageGenerator.generatePassages( + " ", MAX_SENTENCE_LENGTH_IN_TOKENS, MIN_PASSAGE_LENGTH_IN_TOKENS, 1)); + } + + public void testGeneratePassages_ValidDocument() { + final String document = String.join(" ", PASSAGE_1_SENTENCE_1, PASSAGE_1_SENTENCE_2, + PASSAGE_1_LONG_SENTENCE, PASSAGE_2_SENTENCE_1, PASSAGE_2_SENTENCE_2, PASSAGE_2_LONG_SENTENCE); + + List expectedPassage1 = new ArrayList<>(); + expectedPassage1.addAll(PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedPassage1.addAll(PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedPassage1.addAll(PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedPassage1.addAll(SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + + List expectedOverlapPassage = new ArrayList<>(); + expectedOverlapPassage.addAll(PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedOverlapPassage.addAll(SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedOverlapPassage.addAll(PASSAGE_2_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedOverlapPassage.addAll(PASSAGE_2_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + + List> expectedOutput = Arrays.asList(expectedPassage1, expectedOverlapPassage); + + List> actualOutput = passageGenerator.generatePassages( + document, MAX_SENTENCE_LENGTH_IN_TOKENS, MIN_PASSAGE_LENGTH_IN_TOKENS, 2); + assertEquals(expectedOutput, actualOutput); + + } + + public void testGeneratePassagesWithOverlap_EmptyInput() { + assertEquals(Collections.emptyList(), passageGenerator.generatePassagesWithOverlap(new ArrayList<>())); + } + + public void testGeneratePassagesWithOverlap_SinglePassage() { + List>> passages = Arrays.asList( + Arrays.asList( + PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()) + ) + ); + + List expectedPassage = new ArrayList<>(); + expectedPassage.addAll(PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedPassage.addAll(PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedPassage.addAll(PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedPassage.addAll(SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + List> expectedOutput = Arrays.asList(expectedPassage); + + List> actualOutput = passageGenerator.generatePassagesWithOverlap(passages); + assertEquals(expectedOutput, actualOutput); + } + + public void testGeneratePassagesWithOverlap_TwoPassages() { + List>> passages = Arrays.asList( + Arrays.asList( + PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()) + ), + Arrays.asList( + PASSAGE_2_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_2_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_2_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()) + ) + ); + + List expectedPassage1 = new ArrayList<>(); + expectedPassage1.addAll(PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedPassage1.addAll(PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedPassage1.addAll(PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedPassage1.addAll(SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + + List expectedOverlapPassage = new ArrayList<>(); + expectedOverlapPassage.addAll(PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedOverlapPassage.addAll(SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedOverlapPassage.addAll(PASSAGE_2_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expectedOverlapPassage.addAll(PASSAGE_2_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + + List> expectedOutput = Arrays.asList(expectedPassage1, expectedOverlapPassage); + + List> actualOutput = passageGenerator.generatePassagesWithOverlap(passages); + assertEquals(expectedOutput, actualOutput); + } + + public void testCombineSentencesIntoSinglePassage() { + List> tokenizedSentences = Arrays.asList( + PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()) + ); + + List expected = new ArrayList<>(); + expected.addAll(PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expected.addAll(PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expected.addAll(PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + expected.addAll(SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList())); + + List actual = passageGenerator.combineSentencesIntoSinglePassage(tokenizedSentences); + assertEquals(expected, actual); + } + + public void testCombineSentencesIntoPassages() { + List> tokenizedSentences = Arrays.asList( + PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_2_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_2_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_2_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()) + ); + + List>> expected = Arrays.asList( + Arrays.asList( + PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()) + ) + ); + + List>> actual = passageGenerator.combineSentencesIntoPassages( + tokenizedSentences, MIN_PASSAGE_LENGTH_IN_TOKENS, 1); + assertEquals(expected, actual); + } + + public void testGenerateTokenizedSentences() { + final String document = String.join(" ", PASSAGE_1_SENTENCE_1, PASSAGE_1_SENTENCE_2, + PASSAGE_1_LONG_SENTENCE, PASSAGE_2_SENTENCE_1, PASSAGE_2_SENTENCE_2, PASSAGE_2_LONG_SENTENCE); + + List> expected = Arrays.asList( + PASSAGE_1_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_1_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_2_SENTENCE_1_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_2_SENTENCE_2_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + PASSAGE_2_SENTENCE_3_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()), + SUFFIX_FOR_LONGER_SENTENCE_TOKENS.stream().map(t -> t.toLowerCase(Locale.ENGLISH)).collect(Collectors.toList()) + ); + + List> actual = passageGenerator.generateTokenizedSentences(document, MAX_SENTENCE_LENGTH_IN_TOKENS); + assertEquals(expected, actual); + } + + +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SentenceSplitterTests.java b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SentenceSplitterTests.java new file mode 100644 index 0000000..21644a8 --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SentenceSplitterTests.java @@ -0,0 +1,39 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.opensearch.test.OpenSearchTestCase; + +public class SentenceSplitterTests extends OpenSearchTestCase { + + private static final String TEXT_1 = "What is the capital of the United States?"; + private static final String TEXT_2 = "OPENSEARCH IS OPEN SOURCE SEARCH AND ANALYTICS SUITE."; + private static final String TEXT_3 = + "You can install OpenSearch by following instructions at https://opensearch.org/docs/latest/opensearch/install/index/."; + private static final String TEXT_4 = "Testing lots of spaces ! and a long word Pneumonoultramicroscopicsilicovolcanoconiosis"; + + private SentenceSplitter sentenceSplitter = new SentenceSplitter(); + + public void testSplit_BlankInput() { + assertEquals(Collections.emptyList(), sentenceSplitter.split(null)); + assertEquals(Collections.emptyList(), sentenceSplitter.split("")); + assertEquals(Collections.emptyList(), sentenceSplitter.split(" ")); + } + + public void testSplit() { + final String text = String.join(" ", TEXT_1 + " ", TEXT_2, TEXT_3, TEXT_4); + List splitSentences = sentenceSplitter.split(text); + + assertEquals(Arrays.asList(TEXT_1, TEXT_2, TEXT_3, + "Testing lots of spaces !", + "and a long word Pneumonoultramicroscopicsilicovolcanoconiosis"), splitSentences); + } +} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SlidingWindowTextSplitterTests.java b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SlidingWindowTextSplitterTests.java deleted file mode 100644 index cf630de..0000000 --- a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/SlidingWindowTextSplitterTests.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess; - -import static org.opensearch.common.io.PathUtils.get; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Locale; -import java.util.Scanner; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import org.opensearch.test.OpenSearchTestCase; - -public class SlidingWindowTextSplitterTests extends OpenSearchTestCase { - private static final int MAXIMUM_PASSAGES = 10; - private static final String TEST_FILE_PATH = "splitter/input.txt"; - - public void testConstructWithInvalidInputs() { - // Step size cannot be larger than window size - assertThrows(IllegalArgumentException.class, () -> new SlidingWindowTextSplitter(3, 4, MAXIMUM_PASSAGES)); - } - - public void testSplitWithEmptyInput() { - SlidingWindowTextSplitter splitter = new SlidingWindowTextSplitter(10, 4, MAXIMUM_PASSAGES); - assertEquals(Collections.emptyList(), splitter.split("")); - } - - public void testSetSlidingWindowWithInvalidStepSize() { - SlidingWindowTextSplitter splitter = new SlidingWindowTextSplitter(10, 4, MAXIMUM_PASSAGES); - assertThrows(IllegalArgumentException.class, () -> splitter.setSlidingWindow(2, 4)); - } - - public void testSetSlidingWindowUpdatesSplitter() { - SlidingWindowTextSplitter splitter = new SlidingWindowTextSplitter(30, 10, MAXIMUM_PASSAGES); - final String testInput = generateTestInput(1, 5, false); - List splitText = splitter.split(testInput); - - assertEquals(4, splitText.size()); - assertEquals(Arrays.asList( - generateTestInput(1, 2, true), - generateTestInput(2, 3, true), - generateTestInput(3, 4, true), - generateTestInput(4, 5, false)), - splitText); - - // double window size, expect each passage to be loonger - splitter.setSlidingWindow(60, 10); - splitText = splitter.split(testInput); - - assertEquals(2, splitText.size()); - assertEquals(Arrays.asList( - generateTestInput(1, 4, true), - generateTestInput(2, 5, false)), - splitText); - - // double - splitter.setSlidingWindow(60, 30); - splitText = splitter.split(testInput); - - assertEquals(2, splitText.size()); - assertEquals(Arrays.asList( - generateTestInput(1, 4, true), - generateTestInput(3, 5, false)), - splitText); - } - - public void testSplitObeysMaximumPassagesLimit() { - final int maximumPassages = 3; - final String inputText = generateTestInput(1, 5, false); - - SlidingWindowTextSplitter splitter = new SlidingWindowTextSplitter(30, 10, maximumPassages); - List splitText = splitter.split(inputText); - - assertEquals(maximumPassages, splitText.size()); - assertEquals(Arrays.asList( - generateTestInput(1, 2, true), - generateTestInput(2, 3, true), - generateTestInput(3, 4, true)), - splitText); - } - - public void testSplitWhenLastSplitIsShorterThanWindow() { - final String shortText = "Short text"; - final String inputText = String.join(" ", generateTestInput(1, 5, false), shortText); - final String expectedFinalSplit = String.join(" ", "This is a test 5.", shortText); - - SlidingWindowTextSplitter splitter = new SlidingWindowTextSplitter(30, 10, MAXIMUM_PASSAGES); - List splitText = splitter.split(inputText); - - assertEquals(5, splitText.size()); - assertEquals(Arrays.asList( - generateTestInput(1, 2, true), - generateTestInput(2, 3, true), - generateTestInput(3, 4, true), - generateTestInput(4, 5, true), - expectedFinalSplit), - splitText); - } - - public void testSplitWithOverlap() throws IOException { - final int windowSize = 1500; - final int stepSize = 1300; - // Because of respecting sentence boundaries, actual overlap might be smaller. - // Provide a buffer of 20 characters - final int expectedOverlap = windowSize - stepSize - 20; - - SlidingWindowTextSplitter splitter = new SlidingWindowTextSplitter(windowSize, stepSize, MAXIMUM_PASSAGES); - final String input = loadTestInputFromFile(); - - List splitText = splitter.split(input); - - // Overlap means that we get more splits - assertTrue(splitText.size() >= input.length() / windowSize); - - // Verify overlap for every pair of splits - for(int i = 0; i < splitText.size() - 1; ++i) { - assertTrue(splitText.get(i).contains(splitText.get(i + 1).substring(0, expectedOverlap))); - } - } - - public void testSplitUsesLineBreakWhenSentenceIsTooLong() throws IOException { - final int windowSize = 1500; - final int stepSize = 1300; - - SlidingWindowTextSplitter splitter = new SlidingWindowTextSplitter(windowSize, stepSize, MAXIMUM_PASSAGES); - final String input = loadTestInputFromFile(); - String inputWithoutSentenceBoundaries = input.replace(".", " "); - - List splitText = splitter.split(inputWithoutSentenceBoundaries); - - // Ensure that text is split - assertTrue(splitText.size() > 1); - } - - private String generateTestInput(int start, int end, boolean addTerminalSpace) { - final String testInput = IntStream.range(start, end + 1).boxed().map( - i -> String.format(Locale.ENGLISH, "This is a test %s.", i) - ).collect(Collectors.joining(" ")); - return addTerminalSpace ? testInput + " " : testInput; - } - - private String loadTestInputFromFile() throws IOException { - final Scanner scanner = new Scanner(SlidingWindowTextSplitterTests.class.getClassLoader() - .getResourceAsStream(TEST_FILE_PATH), - StandardCharsets.UTF_8.name()).useDelimiter("\\A"); - return scanner.hasNext() ? scanner.next() : ""; - } -} diff --git a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/TextTokenizerTests.java b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/TextTokenizerTests.java index e8c0a21..b40e261 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/TextTokenizerTests.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/preprocess/TextTokenizerTests.java @@ -8,6 +8,7 @@ package org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.junit.Assert; @@ -15,24 +16,54 @@ public class TextTokenizerTests extends OpenSearchTestCase { - public static final String TEXT_1 = "What is the capital of the United States?"; - public static final List EXPECTED_1 = Arrays.asList("capital", "united", "states"); - public static final String TEXT_2 = "OPENSEARCH IS OPEN SOURCE SEARCH AND ANALYTICS SUITE."; - public static final List EXPECTED_2 = Arrays.asList("opensearch", "open", "source", "search", "analytics", "suite"); + private static final String TEXT_1 = "What is the capital of the United States? "; + private static final List EXPECTED_1 = Arrays.asList("capital", "united", "states"); + private static final String TEXT_2 = "OPENSEARCH IS OPEN SOURCE SEARCH AND ANALYTICS SUITE NUMBER1."; + private static final List EXPECTED_2 = Arrays.asList("opensearch", "open", "source", "search", "analytics", "suite", "number1"); + private static final String TEXT_3 = + "You can install OpenSearch by following instructions at https://opensearch.org/docs/latest/opensearch/install/index/"; + private static final List EXPECTED_3 = Arrays.asList("install", "opensearch", "following", + "instructions", "https", "opensearchorg", "docs", "latest", "opensearch", "install", "index"); + private static final String TEXT_4 = "Testing lots of spaces and a long word Pneumonoultramicroscopicsilicovolcanoconiosis"; + private static final List EXPECTED_4 = Arrays.asList("testing", "lots", "spaces", "long", "word", + "pneumonoultramicroscopics", "ilicovolcanoconiosis"); + private TextTokenizer textTokenizer = new TextTokenizer(); - public void testTokenize1() { - List actual = textTokenizer.tokenize(TEXT_1); - Assert.assertEquals(EXPECTED_1, actual); - } - - public void testTokenize2() { - List actual = textTokenizer.tokenize(TEXT_2); - Assert.assertEquals(EXPECTED_2, actual); + public void testTokenize() { + List testCases = Arrays.asList(null, "", TEXT_1, TEXT_2, TEXT_3, TEXT_4); + List> expectedResults = Arrays.asList(Collections.emptyList(), Collections.emptyList(), EXPECTED_1, EXPECTED_2, EXPECTED_3, EXPECTED_4); + for (int i = 0; i < testCases.size(); ++i) { + assertEquals("Test case " + testCases.get(i) + " failed", expectedResults.get(i), textTokenizer.tokenize(testCases.get(i))); + } } public void testTokenizeMultiple() { - List> actual = textTokenizer.tokenize(Arrays.asList(TEXT_1, TEXT_2)); - Assert.assertEquals(Arrays.asList(EXPECTED_1, EXPECTED_2), actual); + List> actual = textTokenizer.tokenize(Arrays.asList(TEXT_1, TEXT_2, TEXT_3, TEXT_4)); + assertEquals(Arrays.asList(EXPECTED_1, EXPECTED_2, EXPECTED_3, EXPECTED_4), actual); + } + + public void testIsWordAllPunctuation() { + List testCases = Arrays.asList(null, "", " ", "!@./?"); + List expectedResults = Arrays.asList(false, false, false, true); + for (int i = 0; i < testCases.size(); ++i) { + assertEquals("Test case " + testCases.get(i) + " failed", expectedResults.get(i), textTokenizer.isWordAllPunctuation(testCases.get(i))); + } + } + + public void testIsNumeric() { + List testCases = Arrays.asList(null, "", " ", "!@./?", "abc", " 22 ", "22", "5.028", "-20.0d"); + List expectedResults = Arrays.asList(false, false, false, false, false, true, true, true, true); + for (int i = 0; i < testCases.size(); ++i) { + assertEquals("Test case " + testCases.get(i) + " failed", expectedResults.get(i), textTokenizer.isNumeric(testCases.get(i))); + } + } + + public void testRemoveInWordPunctuation() { + List testCases = Arrays.asList(null, "", " ", "!@./?", "ab!!c", "a b!c,22"); + List expectedResults = Arrays.asList(null, "", " ", "", "abc", "a bc22"); + for (int i = 0; i < testCases.size(); ++i) { + assertEquals("Test case " + testCases.get(i) + " failed", expectedResults.get(i), textTokenizer.removeInWordPunctuation(testCases.get(i))); + } } } diff --git a/src/test/resources/splitter/input.txt b/src/test/resources/splitter/input.txt deleted file mode 100644 index 025aadd..0000000 --- a/src/test/resources/splitter/input.txt +++ /dev/null @@ -1,15 +0,0 @@ -The history of natural language processing (NLP) generally started in the 1950s, although work can be found from earlier periods. In 1950, Alan Turing published an article titled "Computing Machinery and Intelligence" which proposed what is now called the Turing test as a criterion of intelligence[clarification needed]. - -The Georgetown experiment in 1954 involved fully automatic translation of more than sixty Russian sentences into English. The authors claimed that within three or five years, machine translation would be a solved problem.[2] However, real progress was much slower, and after the ALPAC report in 1966, which found that ten-year-long research had failed to fulfill the expectations, funding for machine translation was dramatically reduced. Little further research in machine translation was conducted until the late 1980s, when the first statistical machine translation systems were developed. - -Some notably successful natural language processing systems developed in the 1960s were SHRDLU, a natural language system working in restricted "blocks worlds" with restricted vocabularies, and ELIZA, a simulation of a Rogerian psychotherapist, written by Joseph Weizenbaum between 1964 and 1966. Using almost no information about human thought or emotion, ELIZA sometimes provided a startlingly human-like interaction. When the "patient" exceeded the very small knowledge base, ELIZA might provide a generic response, for example, responding to "My head hurts" with "Why do you say your head hurts?". - -During the 1970s, many programmers began to write "conceptual ontologies", which structured real-world information into computer-understandable data. Examples are MARGIE (Schank, 1975), SAM (Cullingford, 1978), PAM (Wilensky, 1978), TaleSpin (Meehan, 1976), QUALM (Lehnert, 1977), Politics (Carbonell, 1979), and Plot Units (Lehnert 1981). During this time, many chatterbots were written including PARRY, Racter, and Jabberwacky. - -Up to the 1980s, most natural language processing systems were based on complex sets of hand-written rules. Starting in the late 1980s, however, there was a revolution in natural language processing with the introduction of machine learning algorithms for language processing. This was due to both the steady increase in computational power (see Moore's law) and the gradual lessening of the dominance of Chomskyan theories of linguistics (e.g. transformational grammar), whose theoretical underpinnings discouraged the sort of corpus linguistics that underlies the machine-learning approach to language processing.[3] Some of the earliest-used machine learning algorithms, such as decision trees, produced systems of hard if-then rules similar to existing hand-written rules. However, part-of-speech tagging introduced the use of hidden Markov models to natural language processing, and increasingly, research has focused on statistical models, which make soft, probabilistic decisions based on attaching real-valued weights to the features making up the input data. The cache language models upon which many speech recognition systems now rely are examples of such statistical models. Such models are generally more robust when given unfamiliar input, especially input that contains errors (as is very common for real-world data), and produce more reliable results when integrated into a larger system comprising multiple subtasks. - -Many of the notable early successes occurred in the field of machine translation, due especially to work at IBM Research, where successively more complicated statistical models were developed. These systems were able to take advantage of existing multilingual textual corpora that had been produced by the Parliament of Canada and the European Union as a result of laws calling for the translation of all governmental proceedings into all official languages of the corresponding systems of government. However, most other systems depended on corpora specifically developed for the tasks implemented by these systems, which was (and often continues to be) a major limitation in the success of these systems. As a result, a great deal of research has gone into methods of more effectively learning from limited amounts of data. - -Recent research has increasingly focused on unsupervised and semi-supervised learning algorithms. Such algorithms are able to learn from data that has not been hand-annotated with the desired answers, or using a combination of annotated and non-annotated data. Generally, this task is much more difficult than supervised learning, and typically produces less accurate results for a given amount of input data. However, there is an enormous amount of non-annotated data available (including, among other things, the entire content of the World Wide Web), which can often make up for the inferior results if the algorithm used has a low enough time complexity to be practical. - -In the 2010s, representation learning and deep neural network-style machine learning methods became widespread in natural language processing, due in part to a flurry of results showing that such techniques[4][5] can achieve state-of-the-art results in many natural language tasks, for example in language modeling,[6] parsing,[7][8] and many others. Popular techniques include the use of word embeddings to capture semantic properties of words, and an increase in end-to-end learning of a higher-level task (e.g., question answering) instead of relying on a pipeline of separate intermediate tasks (e.g., part-of-speech tagging and dependency parsing). In some areas, this shift has entailed substantial changes in how NLP systems are designed, such that deep neural network-based approaches may be viewed as a new paradigm distinct from statistical natural language processing. For instance, the term neural machine translation (NMT) emphasizes the fact that deep learning-based approaches to machine translation directly learn sequence-to-sequence transformations, obviating the need for intermediate steps such as word alignment and language modeling that were used in statistical machine translation (SMT). \ No newline at end of file