Skip to content

Commit

Permalink
Use IPv8 MockRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
qstokkink committed Dec 16, 2024
1 parent 15ad76a commit 4da9ced
Show file tree
Hide file tree
Showing 17 changed files with 416 additions and 857 deletions.
103 changes: 0 additions & 103 deletions src/tribler/test_unit/base_restapi.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from ipv8.peerdiscovery.network import Network
from ipv8.test.base import TestBase
from ipv8.test.mocking.endpoint import AutoMockEndpoint
from ipv8.test.REST.rest_base import MockRequest, response_to_json

from tribler.core.content_discovery.community import ContentDiscoveryCommunity
from tribler.core.content_discovery.restapi.search_endpoint import SearchEndpoint
from tribler.core.restapi.rest_endpoint import HTTP_BAD_REQUEST
from tribler.test_unit.base_restapi import MockRequest, response_to_json


class MockContentDiscoveryCommunity(ContentDiscoveryCommunity):
Expand All @@ -33,19 +33,6 @@ def send_search_request(self, **kwargs) -> tuple[UUID, list[Peer]]:
return UUID(int=1), [self.my_peer]


class SearchRequest(MockRequest):
"""
A MockRequest that mimics SearchRequests.
"""

def __init__(self, query: dict) -> None:
"""
Create a new SearchRequest.
"""
super().__init__(query, "PUT", "/search/remote")
self.context = [MockContentDiscoveryCommunity()]


class TestSearchEndpoint(TestBase):
"""
Tests for the SearchEndpoint REST endpoint.
Expand All @@ -56,8 +43,10 @@ async def test_remote_search_bad_request(self) -> None:
Test if a bad request returns the bad request status.
"""
endpoint = SearchEndpoint()
request = MockRequest("/api/search/remote", "PUT", {"channel_pk": "GG"})
request.context = [MockContentDiscoveryCommunity()]

response = await endpoint.remote_search(SearchRequest({"channel_pk": "GG"}))
response = await endpoint.remote_search(request)

self.assertEqual(HTTP_BAD_REQUEST, response.status)

Expand All @@ -66,8 +55,10 @@ async def test_remote_search(self) -> None:
Test if a good search request returns a dict with the UUID and serving peers.
"""
endpoint = SearchEndpoint()
request = MockRequest("/api/search/remote", "PUT", {"channel_pk": "AA", "fts_text": ""})
request.context = [MockContentDiscoveryCommunity()]

response = await endpoint.remote_search(SearchRequest({"channel_pk": "AA", "fts_text": ""}))
response = await endpoint.remote_search(request)
response_body_json = await response_to_json(response)

self.assertEqual(200, response.status)
Expand Down
115 changes: 36 additions & 79 deletions src/tribler/test_unit/core/database/restapi/test_database_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,17 @@
from __future__ import annotations

from asyncio import sleep
from typing import TYPE_CHECKING, Callable
from typing import Callable
from unittest.mock import AsyncMock, Mock, call

from aiohttp.web_urldispatcher import UrlMappingMatchInfo
from ipv8.test.base import TestBase
from ipv8.test.REST.rest_base import MockRequest, response_to_json
from multidict import MultiDict, MultiDictProxy

from tribler.core.database.layers.knowledge import ResourceType, SimpleStatement
from tribler.core.database.restapi.database_endpoint import DatabaseEndpoint, parse_bool
from tribler.core.database.serialization import REGULAR_TORRENT
from tribler.core.restapi.rest_endpoint import HTTP_BAD_REQUEST
from tribler.test_unit.base_restapi import MockRequest, response_to_json

if TYPE_CHECKING:
from tribler.core.database.store import MetadataStore


class TorrentHealthRequest(MockRequest):
"""
A MockRequest that mimics TorrentHealthRequests.
"""

def __init__(self, query: dict, infohash: str, mds: MetadataStore | None) -> None:
"""
Create a new TorrentHealthRequest.
"""
super().__init__(query, "GET", f"/metadata/torrents/{infohash}/health")
self._infohash = infohash
self.context = (mds,)

@property
def match_info(self) -> UrlMappingMatchInfo:
"""
Get the match info (the infohash in the url).
"""
return UrlMappingMatchInfo({"infohash": self._infohash}, Mock())


class PopularTorrentsRequest(MockRequest):
"""
A MockRequest that mimics PopularTorrentsRequests.
"""

def __init__(self, query: dict, mds: MetadataStore | None) -> None:
"""
Create a new PopularTorrentsRequest.
"""
super().__init__(query, "GET", "/metadata/torrents/popular")
self.context = (mds,)


class SearchLocalRequest(MockRequest):
"""
A MockRequest that mimics SearchLocalRequests.
"""

def __init__(self, query: dict, mds: MetadataStore | None) -> None:
"""
Create a new SearchLocalRequest.
"""
default_query = {"fts_text": ""}
default_query.update(query)
super().__init__(default_query, "GET", "/metadata/search/local")
self.context = (mds,)


class SearchCompletionsRequest(MockRequest):
"""
A MockRequest that mimics SearchCompletionsRequests.
"""

def __init__(self, query: dict, mds: MetadataStore | None) -> None:
"""
Create a new SearchCompletionsRequest.
"""
super().__init__(query, "GET", "/metadata/search/completions")
self.context = (mds, )


class TestDatabaseEndpoint(TestBase):
Expand Down Expand Up @@ -144,8 +78,10 @@ async def test_get_torrent_health_bad_timeout(self) -> None:
Test if a bad timeout value in get_torrent_health leads to a HTTP_BAD_REQUEST status.
"""
endpoint = DatabaseEndpoint()
request = MockRequest("/metadata/torrents/AA/health", query={"timeout": "AA"}, match_info={"infohash": "AA"})
request.context = [endpoint.mds]

response = await endpoint.get_torrent_health(TorrentHealthRequest({"timeout": "AA"}, "AA", endpoint.mds))
response = await endpoint.get_torrent_health(request)

self.assertEqual(HTTP_BAD_REQUEST, response.status)

Expand All @@ -154,8 +90,10 @@ async def test_get_torrent_health_no_checker(self) -> None:
Test if calling get_torrent_health without a torrent checker leads to a false checking status.
"""
endpoint = DatabaseEndpoint()
request = MockRequest("/metadata/torrents/AA/health", match_info={"infohash": "AA"})
request.context = [endpoint.mds]

response = await endpoint.get_torrent_health(TorrentHealthRequest({}, "AA", endpoint.mds))
response = await endpoint.get_torrent_health(request)
response_body_json = await response_to_json(response)

self.assertEqual(200, response.status)
Expand All @@ -168,8 +106,10 @@ async def test_get_torrent_health(self) -> None:
endpoint = DatabaseEndpoint()
check_torrent_health = AsyncMock()
endpoint.torrent_checker = Mock(check_torrent_health=check_torrent_health)
request = MockRequest("/metadata/torrents/AA/health", match_info={"infohash": "AA"})
request.context = [endpoint.mds]

response = await endpoint.get_torrent_health(TorrentHealthRequest({}, "AA", endpoint.mds))
response = await endpoint.get_torrent_health(request)
response_body_json = await response_to_json(response)

self.assertEqual(200, response.status)
Expand Down Expand Up @@ -228,8 +168,10 @@ async def test_get_popular_torrents(self) -> None:
tdef=Mock(infohash="AA"))
endpoint.download_manager = Mock(get_download=Mock(return_value=download), metainfo_requests=[])
endpoint.mds = Mock(get_entries=Mock(return_value=[Mock(to_simple_dict=Mock(return_value=metadata))]))
request = MockRequest("/api/metadata/torrents/popular")
request.context = [endpoint.mds]

response = await endpoint.get_popular_torrents(PopularTorrentsRequest(metadata, endpoint.mds))
response = await endpoint.get_popular_torrents(request)
response_body_json = await response_to_json(response)
response_results = response_body_json["results"][0]

Expand All @@ -249,8 +191,10 @@ async def test_local_search_bad_query(self) -> None:
Test if a bad value leads to a bad request status.
"""
endpoint = DatabaseEndpoint()
request = MockRequest("/api/metadata/search/local", query={"fts_text": "", "first": "bla"})
request.context = [endpoint.mds]

response = await endpoint.local_search(SearchLocalRequest({"first": "bla"}, endpoint.mds))
response = await endpoint.local_search(request)

self.assertEqual(HTTP_BAD_REQUEST, response.status)

Expand All @@ -262,8 +206,10 @@ async def test_local_search_errored_search(self) -> None:
"""
endpoint = DatabaseEndpoint()
endpoint.tribler_db = Mock()
request = MockRequest("/api/metadata/search/local", query={"fts_text": ""})
request.context = [endpoint.mds]

response = await endpoint.local_search(SearchLocalRequest({}, endpoint.mds))
response = await endpoint.local_search(request)

self.assertEqual(HTTP_BAD_REQUEST, response.status)

Expand All @@ -276,8 +222,10 @@ async def test_local_search_no_knowledge(self) -> None:
endpoint.mds = Mock(run_threaded=self.mds_run_now, get_total_count=Mock(), get_max_rowid=Mock(),
get_entries=Mock(return_value=[Mock(to_simple_dict=Mock(return_value={"test": "test",
"type": -1}))]))
request = MockRequest("/api/metadata/search/local", query={"fts_text": ""})
request.context = [endpoint.mds]

response = await endpoint.local_search(SearchLocalRequest({}, endpoint.mds))
response = await endpoint.local_search(request)
response_body_json = await response_to_json(response)

self.assertEqual(200, response.status)
Expand All @@ -297,8 +245,11 @@ async def test_local_search_no_knowledge_include_total(self) -> None:
get_max_rowid=Mock(return_value=7),
get_entries=Mock(return_value=[Mock(to_simple_dict=Mock(return_value={"test": "test",
"type": -1}))]))
request = MockRequest("/api/metadata/search/local", query={"fts_text": "",
"include_total": "I would like this"})
request.context = [endpoint.mds]

response = await endpoint.local_search(SearchLocalRequest({"include_total": "I would like this"}, endpoint.mds))
response = await endpoint.local_search(request)
response_body_json = await response_to_json(response)

self.assertEqual(200, response.status)
Expand All @@ -315,8 +266,10 @@ async def test_completions_bad_query(self) -> None:
Test if a missing query leads to a bad request status.
"""
endpoint = DatabaseEndpoint()
request = MockRequest("/api/metadata/search/completions")
request.context = [endpoint.mds]

response = await endpoint.completions(SearchCompletionsRequest({}, endpoint.mds))
response = await endpoint.completions(request)

self.assertEqual(HTTP_BAD_REQUEST, response.status)

Expand All @@ -326,8 +279,10 @@ async def test_completions_lowercase_search(self) -> None:
"""
endpoint = DatabaseEndpoint()
endpoint.mds = Mock(get_auto_complete_terms=Mock(return_value=["test1", "test2"]))
request = MockRequest("/api/metadata/search/completions", query={"q": "test"})
request.context = [endpoint.mds]

response = await endpoint.completions(SearchCompletionsRequest({"q": "test"}, endpoint.mds))
response = await endpoint.completions(request)
response_body_json = await response_to_json(response)

self.assertEqual(200, response.status)
Expand All @@ -340,8 +295,10 @@ async def test_completions_mixed_case_search(self) -> None:
"""
endpoint = DatabaseEndpoint()
endpoint.mds = Mock(get_auto_complete_terms=Mock(return_value=["test1", "test2"]))
request = MockRequest("/api/metadata/search/completions", query={"q": "TeSt"})
request.context = [endpoint.mds]

response = await endpoint.completions(SearchCompletionsRequest({"q": "TeSt"}, endpoint.mds))
response = await endpoint.completions(request)
response_body_json = await response_to_json(response)

self.assertEqual(200, response.status)
Expand Down
Loading

0 comments on commit 4da9ced

Please sign in to comment.