Skip to content

Commit

Permalink
Merge pull request #51 from hkulekci/recommendation
Browse files Browse the repository at this point in the history
Improvement for Recommendation Endpoint
  • Loading branch information
hkulekci authored Jun 12, 2024
2 parents 72b2a4b + 62469cd commit 0ba31dc
Show file tree
Hide file tree
Showing 11 changed files with 529 additions and 28 deletions.
21 changes: 6 additions & 15 deletions src/Endpoints/Collections/Points.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

use Qdrant\Endpoints\AbstractEndpoint;
use Qdrant\Endpoints\Collections\Points\Payload;
use Qdrant\Endpoints\Collections\Points\Recommend;
use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\PointsStruct;
use Qdrant\Models\Request\PointsBatch;
use Qdrant\Models\Request\RecommendRequest;
use Qdrant\Models\Request\ScrollRequest;
use Qdrant\Models\Request\SearchRequest;
use Qdrant\Response;
Expand All @@ -28,6 +28,11 @@ public function payload(): Payload
return (new Payload($this->client))->setCollectionName($this->collectionName);
}

public function recommend(): Recommend
{
return (new Recommend($this->client))->setCollectionName($this->collectionName);
}

/**
* @throws InvalidArgumentException
*/
Expand Down Expand Up @@ -180,18 +185,4 @@ public function batch(PointsBatch $points, array $queryParams = []): Response
)
);
}

/**
* @throws InvalidArgumentException
*/
public function recommend(RecommendRequest $recommendParams): Response
{
return $this->client->execute(
$this->createRequest(
'POST',
'collections/' . $this->collectionName . '/points/recommend',
$recommendParams->toArray()
)
);
}
}
62 changes: 62 additions & 0 deletions src/Endpoints/Collections/Points/Recommend.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
<?php
/**
* Payload
*
* @since Mar 2023
* @author Haydar KULEKCI <haydarkulekci@gmail.com>
*/

namespace Qdrant\Endpoints\Collections\Points;

use Qdrant\Endpoints\AbstractEndpoint;
use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Request\Points\BatchRecommendRequest;
use Qdrant\Models\Request\Points\GroupRecommendRequest;
use Qdrant\Models\Request\Points\RecommendRequest;
use Qdrant\Response;

class Recommend extends AbstractEndpoint
{
/**
* Retrieves points that are closer to stored positive examples and further from negative examples.
*
* @throws InvalidArgumentException
*/
public function recommend(RecommendRequest $request, array $queryParams = []): Response
{
return $this->client->execute(
$this->createRequest(
'POST',
'/collections/' . $this->getCollectionName() . '/points/recommend' . $this->queryBuild($queryParams),
$request->toArray()
)
);
}

/**
* Retrieves points in batches that are closer to stored positive examples and further from negative examples.
*
* @param BatchRecommendRequest $request
* @param array $queryParams
* @return Response
*/
public function batch(BatchRecommendRequest $request, array $queryParams = []): Response
{

return $this->client->execute(
$this->createRequest(
'POST',
'/collections/' . $this->getCollectionName() . '/points/recommend/batch' . $this->queryBuild($queryParams),
$request->toArray()
)
);
}

/**
* @throws InvalidArgumentException
*/
public function groups($request, array $queryParams = []): Response
{
throw new \RuntimeException('Not implemented on client!');
}
}
42 changes: 42 additions & 0 deletions src/Models/Filter/Condition/GeoPolygon.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
<?php
/**
* @since May 2023
* @author Haydar KULEKCI <haydarkulekci@gmail.com>
*/

namespace Qdrant\Models\Filter\Condition;

use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Domain\Assert;

class GeoPolygon extends AbstractCondition implements ConditionInterface
{
public function __construct(string $key, protected array $exterior, protected ?array $interiors = null)
{
parent::__construct($key);

if (empty($this->exterior)) {
throw new InvalidArgumentException('Exteriors required!');
}

foreach ($this->exterior as $point) {
Assert::keysExists($point, ['lat', 'lon'], 'Each point of polygon needs lat and lon parameters');
}
if ($interiors) {
foreach ($this->interiors as $point) {
Assert::keysExists($point, ['lat', 'lon'], 'Each point of polygon needs lat and lon parameters');
}
}
}

public function toArray(): array
{
return [
'key' => $this->key,
'geo_polygon' => [
'exterior' => $this->exterior,
'interiors' => $this->interiors ?? []
]
];
}
}
26 changes: 26 additions & 0 deletions src/Models/Filter/Filter.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class Filter implements ConditionInterface
protected array $must = [];
protected array $must_not = [];
protected array $should = [];
protected array $minShould = [];
protected ?int $minShouldCount;

public function addMust(ConditionInterface $condition): Filter
{
Expand All @@ -35,6 +37,20 @@ public function addShould(ConditionInterface $condition): Filter
return $this;
}

public function addMinShould(ConditionInterface $condition): Filter
{
$this->minShould[] = $condition;

return $this;
}

public function setMinShouldCount(int $count): Filter
{
$this->minShouldCount = $count;

return $this;
}

public function toArray(): array
{
$filter = [];
Expand All @@ -59,6 +75,16 @@ public function toArray(): array
$filter['should'][] = $should->toArray();
}
}
if ($this->minShould && $this->minShouldCount) {
$filter['min_should'] = [
'conditions' => [],
'min_count' => $this->minShouldCount
];
foreach ($this->minShould as $should) {
/** ConditionInterface $must */
$filter['min_should']['conditions'][] = $should->toArray();
}
}

return $filter;
}
Expand Down
50 changes: 50 additions & 0 deletions src/Models/Request/Points/BatchRecommendRequest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
<?php
/**
* RecommendRequest
*
* @since Jun 2023
* @author Greg Priday <greg@siteorigin.com>
*/
namespace Qdrant\Models\Request\Points;

use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\Traits\ProtectedPropertyAccessor;

class BatchRecommendRequest
{
use ProtectedPropertyAccessor;

/** @var RecommendRequest[] $searches */
protected array $searches = [];

/**
* @param RecommendRequest[] $searches
*/
public function __construct(array $searches)
{
foreach ($searches as $search) {
$this->addSearch($search);
}
}

public function addSearch(RecommendRequest $request): static
{
$this->searches[] = $request;

return $this;
}

public function toArray(): array
{
$searches = [];

foreach ($this->searches as $search) {
$searches[] = $search->toArray();
}

return [
'searches' => $searches
];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,31 @@
* @since Jun 2023
* @author Greg Priday <greg@siteorigin.com>
*/
namespace Qdrant\Models\Request;
namespace Qdrant\Models\Request\Points;

use Qdrant\Exception\InvalidArgumentException;
use Qdrant\Models\Filter\Filter;
use Qdrant\Models\Traits\ProtectedPropertyAccessor;

class RecommendRequest
{
use ProtectedPropertyAccessor;

/**
* average_vector - Average positive and negative vectors and create a single query with the formula
* query = avg_pos + avg_pos - avg_neg. Then performs normal search.
*/
const STRATEGY_AVERAGE_VECTOR = 'average_vector';

/**
* best_score - Uses custom search objective. Each candidate is compared against all examples, its
* score is then chosen from the max(max_pos_score, max_neg_score). If the max_neg_score is chosen
* then it is squared and negated, otherwise it is just the max_pos_score.
*/
const STRATEGY_BEST_SCORE = 'best_score';

protected ?string $shardKey = null;
protected ?string $strategy = null;
protected ?Filter $filter = null;
protected ?string $using = null;
protected ?int $limit = null;
Expand All @@ -31,6 +47,27 @@ public function setFilter(Filter $filter): static
return $this;
}

public function setShardKey(string $shardKey): static
{
$this->shardKey = $shardKey;

return $this;
}

public function setStrategy(string $strategy): static
{
$strategies = [
self::STRATEGY_AVERAGE_VECTOR,
self::STRATEGY_BEST_SCORE,
];
if (!in_array($strategy, $strategies)) {
throw new InvalidArgumentException('Invalid strategy for recommendation.');
}
$this->strategy = $strategy;

return $this;
}

public function setScoreThreshold(float $scoreThreshold): static
{
$this->scoreThreshold = $scoreThreshold;
Expand Down Expand Up @@ -66,19 +103,25 @@ public function toArray(): array
'negative' => $this->negative,
];

if ($this->shardKey !== null) {
$body['shard_key'] = $this->shardKey;
}
if ($this->filter !== null && $this->filter->toArray()) {
$body['filter'] = $this->filter->toArray();
}
if($this->scoreThreshold) {
if($this->scoreThreshold !== null) {
$body['score_threshold'] = $this->scoreThreshold;
}
if ($this->using) {
if ($this->using !== null) {
$body['using'] = $this->using;
}
if ($this->limit) {
if ($this->limit !== null) {
$body['limit'] = $this->limit;
}
if ($this->offset) {
if ($this->strategy !== null) {
$body['strategy'] = $this->strategy;
}
if ($this->offset !== null) {
$body['offset'] = $this->offset;
}

Expand Down
Loading

0 comments on commit 0ba31dc

Please sign in to comment.