Skip to content

Commit f4f5a0a

Browse files
committed
Backport search API changes from v2
1 parent b85d96d commit f4f5a0a

12 files changed

+73
-94
lines changed

src/index/index.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ void Index::applyUpdates(const OpBatch &batch) {
187187

188188
}
189189

190-
std::vector<SearchResult> Index::search(const std::vector<uint32_t> &terms, int64_t timeoutInMSecs) {
190+
std::vector<SearchResult> Index::search(const std::vector<uint32_t> &hashes, int64_t timeoutInMSecs) {
191191
auto reader = openReader();
192-
return reader->search(terms.data(), terms.size(), timeoutInMSecs);
192+
return reader->search(hashes, timeoutInMSecs);
193193
}

src/index/index_reader.cpp

+22-20
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,33 @@ SegmentDataReader* IndexReader::segmentDataReader(const SegmentInfo& segment)
3838
return new SegmentDataReader(m_dir->openFile(segment.dataFileName()), BLOCK_SIZE);
3939
}
4040

41-
void IndexReader::search(const uint32_t* fingerprint, size_t length, Collector* collector, int64_t timeoutInMSecs)
41+
std::vector<SearchResult> IndexReader::search(const std::vector<uint32_t> &hashesIn, int64_t timeoutInMSecs)
4242
{
4343
auto deadline = timeoutInMSecs > 0 ? (QDateTime::currentMSecsSinceEpoch() + timeoutInMSecs) : 0;
44-
std::vector<uint32_t> fp(fingerprint, fingerprint + length);
45-
std::sort(fp.begin(), fp.end());
46-
const SegmentInfoList& segments = m_info.segments();
47-
for (int i = 0; i < segments.size(); i++) {
48-
if (deadline > 0) {
49-
if (QDateTime::currentMSecsSinceEpoch() > deadline) {
50-
throw TimeoutExceeded();
51-
}
52-
}
53-
const SegmentInfo& s = segments.at(i);
54-
SegmentSearcher searcher(s.index(), segmentDataReader(s), s.lastKey());
55-
searcher.search(fp.data(), fp.size(), collector);
44+
45+
std::vector<uint32_t> hashes(hashesIn);
46+
std::sort(hashes.begin(), hashes.end());
47+
48+
std::unordered_map<uint32_t, int> hits;
49+
50+
const SegmentInfoList& segments = m_info.segments();
51+
for (auto segment : segments) {
52+
if (deadline > 0) {
53+
if (QDateTime::currentMSecsSinceEpoch() > deadline) {
54+
throw TimeoutExceeded();
55+
}
5656
}
57-
}
57+
SegmentSearcher searcher(segment.index(), segmentDataReader(segment), segment.lastKey());
58+
searcher.search(hashes, hits);
59+
}
5860

59-
std::vector<SearchResult> IndexReader::search(const uint32_t* fingerprint, size_t length, int64_t timeoutInMSecs)
60-
{
61-
TopHitsCollector collector(1000);
62-
search(fingerprint, length, &collector, timeoutInMSecs);
6361
std::vector<SearchResult> results;
64-
for (const auto result : collector.topResults()) {
65-
results.emplace_back(result.id(), result.score());
62+
results.reserve(hits.size());
63+
for (const auto &hit : hits) {
64+
results.emplace_back(hit.first, hit.second);
6665
}
66+
67+
sortSearchResults(results);
68+
6769
return results;
6870
}

src/index/index_reader.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ class IndexReader
2929
return m_index;
3030
}
3131

32-
void search(const uint32_t *fingerprint, size_t length, Collector *collector, int64_t timeoutInMSecs = 0);
33-
std::vector<SearchResult> search(const uint32_t *fingerprint, size_t length, int64_t timeoutInMSecs = 0);
32+
std::vector<SearchResult> search(const std::vector<uint32_t> &hashes, int64_t timeoutInMSecs = 0);
3433

3534
SegmentDataReader* segmentDataReader(const SegmentInfo& segment);
3635

src/index/index_reader_test.cpp

+10-14
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,25 @@ TEST(IndexReaderTest, Search)
1818
DirectorySharedPtr dir(new RAMDirectory());
1919
IndexSharedPtr index(new Index(dir, true));
2020

21-
uint32_t fp1[] = { 7, 9, 12 };
22-
auto fp1len = 3;
23-
24-
uint32_t fp2[] = { 7, 9, 11 };
25-
auto fp2len = 3;
21+
std::vector<uint32_t> fp1 = { 7, 9, 12 };
22+
std::vector<uint32_t> fp2 = { 7, 9, 11 };
2623

2724
{
2825
auto writer = index->openWriter();
29-
writer->addDocument(1, fp1, fp1len);
26+
writer->addDocument(1, fp1.data(), fp1.size());
3027
writer->commit();
31-
writer->addDocument(2, fp2, fp2len);
28+
writer->addDocument(2, fp2.data(), fp2.size());
3229
writer->commit();
3330
}
3431

3532
{
3633
IndexReader reader(index);
37-
TopHitsCollector collector(100);
38-
reader.search(fp1, fp1len, &collector);
39-
ASSERT_EQ(2, collector.topResults().size());
40-
ASSERT_EQ(1, collector.topResults().at(0).id());
41-
ASSERT_EQ(3, collector.topResults().at(0).score());
42-
ASSERT_EQ(2, collector.topResults().at(1).id());
43-
ASSERT_EQ(2, collector.topResults().at(1).score());
34+
auto results = reader.search(fp1);
35+
ASSERT_EQ(2, results.size());
36+
ASSERT_EQ(1, results.at(0).docId());
37+
ASSERT_EQ(3, results.at(0).score());
38+
ASSERT_EQ(2, results.at(1).docId());
39+
ASSERT_EQ(2, results.at(1).score());
4440
}
4541
}
4642

src/index/search_result.h

+1-7
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,7 @@ class SearchResult {
2929
inline void sortSearchResults(std::vector<SearchResult> &results)
3030
{
3131
std::sort(results.begin(), results.end(), [](const SearchResult &a, const SearchResult &b) {
32-
if (a.score() > b.score()) {
33-
return true;
34-
} else if (a.score() < b.score()) {
35-
return false;
36-
} else {
37-
return a.docId() < b.docId();
38-
}
32+
return a.score() > b.score() || (a.score() == b.score() && a.docId() < b.docId());
3933
});
4034
}
4135

src/index/segment_searcher.cpp

+11-12
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// Distributed under the MIT license, see the LICENSE file for details.
33

44
#include <algorithm>
5-
#include "collector.h"
65
#include "segment_data_reader.h"
76
#include "segment_searcher.h"
87

@@ -17,17 +16,17 @@ SegmentSearcher::~SegmentSearcher()
1716
{
1817
}
1918

20-
void SegmentSearcher::search(uint32_t *fingerprint, size_t length, Collector *collector)
19+
void SegmentSearcher::search(const std::vector<uint32_t> &hashes, std::unordered_map<uint32_t, int> &hits)
2120
{
2221
size_t i = 0, block = 0, lastBlock = SIZE_MAX;
23-
while (i < length) {
22+
while (i < hashes.size()) {
2423
if (block > lastBlock || lastBlock == SIZE_MAX) {
2524
size_t localFirstBlock, localLastBlock;
26-
if (fingerprint[i] > m_lastKey) {
25+
if (hashes[i] > m_lastKey) {
2726
// All following items are larger than the last segment's key.
2827
return;
2928
}
30-
if (m_index->search(fingerprint[i], &localFirstBlock, &localLastBlock)) {
29+
if (m_index->search(hashes[i], &localFirstBlock, &localLastBlock)) {
3130
if (block > localLastBlock) {
3231
// We already searched this block and the fingerprint item was not found.
3332
i++;
@@ -48,24 +47,24 @@ void SegmentSearcher::search(uint32_t *fingerprint, size_t length, Collector *co
4847
std::unique_ptr<BlockDataIterator> blockData(m_dataReader->readBlock(block, firstKey));
4948
while (blockData->next()) {
5049
uint32_t key = blockData->key();
51-
if (key >= fingerprint[i]) {
52-
while (key > fingerprint[i]) {
50+
if (key >= hashes[i]) {
51+
while (key > hashes[i]) {
5352
i++;
54-
if (i >= length) {
53+
if (i >= hashes.size()) {
5554
return;
5655
}
57-
else if (lastKey < fingerprint[i]) {
56+
else if (lastKey < hashes[i]) {
5857
// There are no longer any items in this block that we could match.
5958
goto nextBlock;
6059
}
6160
}
62-
if (key == fingerprint[i]) {
63-
collector->collect(blockData->value());
61+
if (key == hashes[i]) {
62+
auto docId = blockData->value();
63+
hits[docId]++;
6464
}
6565
}
6666
}
6767
nextBlock:
6868
block++;
6969
}
7070
}
71-

src/index/segment_searcher.h

+1-7
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,14 @@
1010
namespace Acoustid {
1111

1212
class SegmentDataReader;
13-
class Collector;
1413

1514
class SegmentSearcher
1615
{
1716
public:
1817
SegmentSearcher(SegmentIndexSharedPtr index, SegmentDataReader *dataReader, uint32_t lastKey = UINT32_MAX);
1918
virtual ~SegmentSearcher();
2019

21-
/**
22-
* Search for the fingerprint in one segment.
23-
*
24-
* The fingerprint must be sorted.
25-
*/
26-
void search(uint32_t *fingerprint, size_t length, Collector *collector);
20+
void search(const std::vector<uint32_t> &hashes, std::unordered_map<uint32_t, int> &hits);
2721

2822
private:
2923
SegmentIndexSharedPtr m_index;

src/server/http.cpp

+3-7
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,13 @@ static HttpResponse handleSearchRequest(const HttpRequest &request, const QShare
204204
limit = 100;
205205
}
206206

207-
auto collector = QSharedPointer<TopHitsCollector>::create(limit);
208-
{
209-
auto reader = index->openReader();
210-
reader->search(query.data(), query.size(), collector.data());
211-
}
212-
auto results = collector->topResults();
207+
auto results = index->search(query);
208+
filterSearchResults(results, limit);
213209

214210
QJsonArray resultsJson;
215211
for (auto &result : results) {
216212
resultsJson.append(QJsonObject{
217-
{"id", qint64(result.id())},
213+
{"id", qint64(result.docId())},
218214
{"score", result.score()},
219215
});
220216
}

src/server/protocol.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44

55
namespace Acoustid { namespace Server {
66

7-
QVector<uint32_t> parseFingerprint(const QString &input) {
8-
QStringList inputParts = input.split(',');
9-
QVector<uint32_t> output;
7+
std::vector<uint32_t> parseFingerprint(const QString &input) {
8+
QStringList inputParts = input.split(',');
9+
std::vector<uint32_t> output;
1010
output.reserve(inputParts.size());
1111
for (int i = 0; i < inputParts.size(); i++) {
1212
bool ok;
1313
auto value = inputParts.at(i).toInt(&ok);
1414
if (!ok) {
1515
throw HandlerException("invalid fingerprint");
1616
}
17-
output.append(value);
17+
output.push_back(value);
1818
}
19-
if (output.isEmpty()) {
19+
if (output.empty()) {
2020
throw HandlerException("empty fingerprint");
2121
}
2222
return output;
@@ -92,9 +92,9 @@ ScopedHandlerFunc buildHandler(const QString &command, const QStringList &args)
9292
auto results = session->search(hashes);
9393
QStringList output;
9494
output.reserve(results.size());
95-
for (int i = 0; i < results.size(); i++) {
96-
output.append(QString("%1:%2").arg(results[i].id()).arg(results[i].score()));
97-
}
95+
for (auto result : results) {
96+
output.append(QString("%1:%2").arg(result.docId()).arg(result.score()));
97+
}
9898
QString outputString = output.join(" ");
9999
session->clearTraceId();
100100
return outputString;

src/server/session.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "errors.h"
66
#include "index/index.h"
77
#include "index/index_writer.h"
8-
#include "index/top_hits_collector.h"
98

109
using namespace Acoustid;
1110
using namespace Acoustid::Server;
@@ -99,24 +98,24 @@ void Session::setAttribute(const QString &name, const QString &value) {
9998
m_indexWriter->setAttribute(name, value);
10099
}
101100

102-
void Session::insert(uint32_t id, const QVector<uint32_t> &hashes) {
101+
void Session::insert(uint32_t id, const std::vector<uint32_t> &hashes) {
103102
QMutexLocker locker(&m_mutex);
104103
if (m_indexWriter.isNull()) {
105104
throw NotInTransactionException();
106105
}
107106
m_indexWriter->addDocument(id, hashes.data(), hashes.size());
108107
}
109108

110-
QList<Result> Session::search(const QVector<uint32_t> &hashes) {
109+
std::vector<SearchResult> Session::search(const std::vector<uint32_t> &hashes) {
111110
QMutexLocker locker(&m_mutex);
112-
TopHitsCollector collector(m_maxResults, m_topScorePercent);
111+
std::vector<SearchResult> results;
113112
try {
114-
auto reader = m_index->openReader();
115-
reader->search(hashes.data(), hashes.size(), &collector, m_timeout);
113+
results = m_index->search(hashes, m_timeout);
116114
} catch (TimeoutExceeded &ex) {
117115
throw HandlerException("timeout exceeded");
118116
}
119-
return collector.topResults();
117+
filterSearchResults(results, m_maxResults, m_topScorePercent);
118+
return results;
120119
}
121120

122121
QString Session::getTraceId() {

src/server/session.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include <QMutex>
88
#include <QSharedPointer>
9-
#include "index/top_hits_collector.h"
9+
#include "index/search_result.h"
1010

1111
namespace Acoustid {
1212

@@ -28,8 +28,8 @@ class Session
2828
void rollback();
2929
void optimize();
3030
void cleanup();
31-
void insert(uint32_t id, const QVector<uint32_t> &hashes);
32-
QList<Result> search(const QVector<uint32_t> &hashes);
31+
void insert(uint32_t id, const std::vector<uint32_t> &hashes);
32+
std::vector<SearchResult> search(const std::vector<uint32_t> &hashes);
3333

3434
QString getAttribute(const QString &name);
3535
void setAttribute(const QString &name, const QString &value);

src/server/session_test.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,18 @@ TEST(SessionTest, InsertAndSearch)
5252
{
5353
auto results = session->search({ 1, 2, 3 });
5454
ASSERT_EQ(2, results.size());
55-
ASSERT_EQ(1, results[0].id());
55+
ASSERT_EQ(1, results[0].docId());
5656
ASSERT_EQ(3, results[0].score());
57-
ASSERT_EQ(2, results[1].id());
57+
ASSERT_EQ(2, results[1].docId());
5858
ASSERT_EQ(1, results[1].score());
5959
}
6060

6161
{
6262
auto results = session->search({ 1, 200, 300 });
6363
ASSERT_EQ(2, results.size());
64-
ASSERT_EQ(2, results[0].id());
64+
ASSERT_EQ(2, results[0].docId());
6565
ASSERT_EQ(3, results[0].score());
66-
ASSERT_EQ(1, results[1].id());
66+
ASSERT_EQ(1, results[1].docId());
6767
ASSERT_EQ(1, results[1].score());
6868
}
6969

@@ -72,7 +72,7 @@ TEST(SessionTest, InsertAndSearch)
7272
{
7373
auto results = session->search({ 1, 2, 3 });
7474
ASSERT_EQ(1, results.size());
75-
ASSERT_EQ(1, results[0].id());
75+
ASSERT_EQ(1, results[0].docId());
7676
ASSERT_EQ(3, results[0].score());
7777
}
7878
}

0 commit comments

Comments
 (0)