Skip to content

Commit

Permalink
test: add test cases for distance computations
Browse files Browse the repository at this point in the history
  • Loading branch information
bernard-ng committed Sep 5, 2024
1 parent aaa1f99 commit 02dc822
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/Model/Prompt/PromptTemplate.php
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
*/
final class PromptTemplate implements \Stringable
{
private readonly string $template;
private string $template;

private ?string $prompt = null;

Expand Down
2 changes: 1 addition & 1 deletion src/Retrieval/Loader/Directory/DirectoryLoader.php
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public function load(AbstractReader $reader = new FileReader()): iterable
public function loadAndSplit(SplitterInterface $splitter, AbstractReader $reader = new FileReader()): iterable
{
foreach ($this->load($reader) as $document) {
yield from $splitter->splitDocuments($document);
yield from $splitter->splitDocument($document);
}
}
}
3 changes: 2 additions & 1 deletion src/Retrieval/Splitter/SplitterInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ interface SplitterInterface
public function splitText(string $text): iterable;

/**
* @return iterable<Document>
* @param iterable<int, string> $splits
* @return iterable<int, Document>
*/
public function createDocuments(iterable $splits): iterable;

Expand Down
2 changes: 1 addition & 1 deletion src/Retrieval/Splitter/TextSplitter.php
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public function createDocuments(iterable $splits): iterable
{
foreach ($splits as $index => $chunk) {
yield new Document($chunk, metadata: new Metadata(
hash: md5((string) $chunk),
hash: md5($chunk),
chunkNumber: $index
));
}
Expand Down
61 changes: 61 additions & 0 deletions tests/Retrieval/VectorStore/DistanceTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
<?php

declare(strict_types=1);

namespace Devscast\Lugha\Tests\Retrieval\VectorStore;

use PHPUnit\Framework\TestCase;
use Devscast\Lugha\Retrieval\VectorStore\Distance;

/**
* Class DistanceTest.
*
* @author bernard-ng <bernard@devscast.tech>
*/
final class DistanceTest extends TestCase
{
public function testVectorWithDifferentDimensions(): void
{
$this->expectException(\InvalidArgumentException::class);
$this->expectExceptionMessage('Vectors must have the same dimension.');

$distance = Distance::COSINE;
$a = [2, 4, 1, 3];
$b = [3, 5, 2];

$distance->compute($a, $b);
}

public function testDifferentVectorDistances(): void
{
$a = [2, 4, 1, 3];
$b = [3, 5, 2, 1];

$this->assertEquals(0.0937, round(Distance::COSINE->compute($a, $b), 4));
$this->assertEquals(2.6458, round(Distance::L2->compute($a, $b), 4));
$this->assertEquals(5.0, round(Distance::L1->compute($a, $b), 4));
$this->assertEquals(31.0, round(Distance::INNER_PRODUCT->compute($a, $b), 4));
}

public function testSimilarVectorDistances(): void
{
$a = [2, 4, 1, 3];
$b = [2, 4, 1, 2];

$this->assertEquals(0.0141, round(Distance::COSINE->compute($a, $b), 4));
$this->assertEquals(1.0, round(Distance::L2->compute($a, $b), 4));
$this->assertEquals(1.0, round(Distance::L1->compute($a, $b), 4));
$this->assertEquals(27.0, round(Distance::INNER_PRODUCT->compute($a, $b), 4));
}

public function testSameVectorDistances(): void
{
$a = [2, 4, 1, 3];
$b = [2, 4, 1, 3];

$this->assertEquals(0.0, round(Distance::COSINE->compute($a, $b), 4));
$this->assertEquals(0.0, round(Distance::L2->compute($a, $b), 4));
$this->assertEquals(0.0, round(Distance::L1->compute($a, $b), 4));
$this->assertEquals(30.0, round(Distance::INNER_PRODUCT->compute($a, $b), 4));
}
}

0 comments on commit 02dc822

Please sign in to comment.