Skip to content

Commit

Permalink
feat: AbstractModelをModelに継承させるように
Browse files Browse the repository at this point in the history
feat: UserActions.searchのdetailを使用した際の型を正確に
  • Loading branch information
yupix committed Sep 14, 2023
1 parent fecdb01 commit 2fe31ee
Show file tree
Hide file tree
Showing 31 changed files with 181 additions and 99 deletions.
10 changes: 8 additions & 2 deletions mipac/abstract/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from mipac.manager.client import ClientManager


__all__ = ("AbstractModel",)


class AbstractModel(ABC):
@property
@abstractmethod
def action(self):
def __init__(self, data: Any, *, client: ClientManager) -> None:
pass
16 changes: 8 additions & 8 deletions mipac/actions/admins/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from mipac.abstract.action import AbstractAction
from mipac.config import config
from mipac.errors.base import NotSupportVersion, ParameterError
from mipac.errors.base import NotSupportVersion, NotSupportVersionText, ParameterError
from mipac.http import HTTPClient, Route
from mipac.models.admin import IndexStat, ModerationLog, ServerInfo, UserIP
from mipac.models.meta import AdminMeta
Expand Down Expand Up @@ -45,7 +45,7 @@ async def vacuum(self, full: bool = False, analyze: bool = False) -> bool:

async def update_user_note(self, user_id: str, text: str) -> bool:
if config.use_version < 12:
raise NotSupportVersion("ご利用のインスタンスのバージョンではサポートされていない機能です")
raise NotSupportVersion(NotSupportVersionText)
body = {"userId": user_id, "text": text}
return bool(
await self.__session.request(
Expand Down Expand Up @@ -154,7 +154,7 @@ async def get_moderation_logs(
get_all: bool = False,
) -> AsyncGenerator[ModerationLog, None]:
if config.use_version < 12:
raise NotSupportVersion("ご利用のインスタンスのバージョンではサポートされていない機能です")
raise NotSupportVersion(NotSupportVersionText)

if limit > 100:
raise ParameterError("limit must be less than 100")
Expand Down Expand Up @@ -192,7 +192,7 @@ async def send_email(self, to: str, subject: str, text: str) -> bool:

async def resolve_abuse_user_report(self, report_id: str, forward: bool = False) -> bool:
if config.use_version < 12:
raise NotSupportVersion("ご利用のインスタンスのバージョンではサポートされていない機能です")
raise NotSupportVersion(NotSupportVersionText)

body = {"reportId": report_id, "forward": forward}
return bool(
Expand All @@ -202,17 +202,17 @@ async def resolve_abuse_user_report(self, report_id: str, forward: bool = False)
)

async def reset_password(self, user_id: str) -> str:
"""指定したIDのユーザーのパスワードをリセットします
"""target user's password reset
Parameters
----------
user_id : str
パスワードをリセットする対象のユーザーID
target user's id
Returns
-------
str
新しいパスワード
new password
"""
return await self.__session.request(
Route("POST", "/api/admin/reset-password"), auth=True, json={"userId": user_id}
Expand All @@ -229,7 +229,7 @@ async def get_index_stats(self) -> list[IndexStat]:

async def get_user_ips(self, user_id: str) -> list[UserIP]:
if config.use_version < 12:
raise NotSupportVersion("ご利用のインスタンスのバージョンではサポートされていない機能です")
raise NotSupportVersion(NotSupportVersionText)

res: list[IUserIP] = await self.__session.request(
Route("POST", "/api/admin/get-user-ips"),
Expand Down
111 changes: 80 additions & 31 deletions mipac/actions/user.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
from __future__ import annotations

from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional
from typing import TYPE_CHECKING, AsyncGenerator, Literal, Optional, TypeVar, Union, overload

from mipac.config import config
from mipac.errors.base import NotExistRequiredData, NotSupportVersion, ParameterError
from mipac.errors.base import (
NotExistRequiredData,
NotSupportVersion,
NotSupportVersionText,
ParameterError,
)
from mipac.http import HTTPClient, Route
from mipac.models.clip import Clip
from mipac.models.note import Note
from mipac.models.user import Achievement, LiteUser, UserDetailed
from mipac.types.clip import IClip
from mipac.types.note import INote
from mipac.types.user import ILiteUser, IUserDetailed
from mipac.utils.cache import cache
from mipac.utils.format import remove_dict_empty
from mipac.utils.pagination import Pagination
from mipac.utils.pagination import Pagination, pagination_iterator
from mipac.utils.util import check_multi_arg

if TYPE_CHECKING:
from mipac.manager.client import ClientManager

__all__ = ["UserActions"]

T = TypeVar("T", bound=Union[LiteUser, UserDetailed])


class UserActions:
def __init__(
Expand Down Expand Up @@ -70,22 +78,23 @@ async def get(
**kwargs,
) -> UserDetailed:
"""
ユーザーのプロフィールを取得します。一度のみサーバーにアクセスしキャッシュをその後は使います。
fetch_userを使った場合はキャッシュが廃棄され再度サーバーにアクセスします。
Retrieve user information from the user ID using the cache.
If there is no cache, `fetch` is automatically used.
The `fetch` method is recommended if you want up-to-date user information.
Parameters
----------
user_id : str
取得したいユーザーのユーザーID
target user id
username : str
取得したいユーザーのユーザー名
target username
host : str, default=None
取得したいユーザーがいるインスタンスのhost
Hosts with target users
Returns
-------
UserDetailed
ユーザー情報
user information
"""

field = remove_dict_empty({"userId": user_id, "username": username, "host": host})
Expand All @@ -101,16 +110,19 @@ async def fetch(
host: str | None = None,
) -> UserDetailed:
"""
サーバーにアクセスし、ユーザーのプロフィールを取得します。基本的には get_userをお使いください。
Retrieve the latest user information using the target user ID or username.
If you do not need the latest information, you should basically use the `get` method.
This method accesses the server each time,
which may increase the number of server accesses.
Parameters
----------
user_id : str
取得したいユーザーのユーザーID
target user id
username : str
取得したいユーザーのユーザー名
target username
host : str, default=None
取得したいユーザーがいるインスタンスのhost
Hosts with target users
Returns
-------
Expand All @@ -135,7 +147,7 @@ async def get_notes(
get_all: bool = False,
) -> AsyncGenerator[Note, None]:
if check_multi_arg(user_id, self.__user) is False:
raise ParameterError("user_idがありません", user_id, self.__user)
raise ParameterError("missing required argument: user_id", user_id, self.__user)

user_id = user_id or self.__user and self.__user.id
data = {
Expand Down Expand Up @@ -174,7 +186,7 @@ def get_mention(self, user: Optional[LiteUser] = None) -> str:
Parameters
----------
user : Optional[User], default=None
メンションを取得したいユーザーのオブジェクト
The object of the user whose mentions you want to retrieve
Returns
-------
Expand All @@ -188,13 +200,39 @@ def get_mention(self, user: Optional[LiteUser] = None) -> str:
raise NotExistRequiredData("Required parameters: user")
return f"@{user.username}@{user.host}" if user.instance else f"@{user.username}"

@overload
async def search(
self,
query: str,
limit: int = 100,
offset: int = 0,
origin: Literal["local", "remote", "combined"] = "combined",
detail: bool = True,
detail: Literal[False] = ...,
*,
get_all: bool = False,
) -> AsyncGenerator[LiteUser, None]:
...

@overload
async def search(
self,
query: str,
limit: int = 100,
offset: int = 0,
origin: Literal["local", "remote", "combined"] = "combined",
detail: Literal[True] = True,
*,
get_all: bool = False,
) -> AsyncGenerator[UserDetailed, None]:
...

async def search(
self,
query: str,
limit: int = 100,
offset: int = 0,
origin: Literal["local", "remote", "combined"] = "combined",
detail: Literal[True, False] = True,
*,
get_all: bool = False,
) -> AsyncGenerator[UserDetailed | LiteUser, None]:
Expand All @@ -211,14 +249,14 @@ async def search(
The number of users to skip.
origin : Literal['local', 'remote', 'combined'], default='combined'
The origin of users to search.
detail : bool, default=True
detail : Literal[True, False], default=True
Whether to return detailed user information.
get_all : bool, default=False
Whether to return all users.
Returns
-------
AsyncGenerator[UserDetailed | LiteUser, None]
AsyncGenerator[Union[LiteUser, UserDetailed], None]
A AsyncGenerator of users.
"""

Expand All @@ -232,18 +270,29 @@ async def search(
{"query": query, "limit": limit, "offset": offset, "origin": origin, "detail": detail}
)

pagination = Pagination[UserDetailed | LiteUser](
self.__session, Route("POST", "/api/users/search"), json=body, pagination_type="count"
)

while True:
res_users = await pagination.next()
for user in res_users:
if detail:
yield UserDetailed(user, client=self.__client)
yield LiteUser(user, client=self.__client)
if get_all is False or pagination.is_final:
break
if detail:
pagination = Pagination[IUserDetailed](
self.__session,
Route("POST", "/api/users/search"),
json=body,
pagination_type="count",
)
iterator = pagination_iterator(
pagination, get_all, model=UserDetailed, client=self.__client
)
else:
pagination = Pagination[ILiteUser](
self.__session,
Route("POST", "/api/users/search"),
json=body,
pagination_type="count",
)

iterator = pagination_iterator(
pagination, get_all=get_all, model=LiteUser, client=self.__client
)
async for user in iterator:
yield user

async def search_by_username_and_host(
self,
Expand Down Expand Up @@ -295,7 +344,7 @@ async def get_achievements(self, user_id: str | None = None) -> list[Achievement
"""Get achievements of user."""

if config.use_version < 13:
raise NotSupportVersion("ご利用のインスタンスのバージョンではサポートされていない機能です")
raise NotSupportVersion(NotSupportVersionText)

user_id = user_id or self.__user and self.__user.id

Expand Down
4 changes: 2 additions & 2 deletions mipac/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Any, Literal, TypeVar

import aiohttp

from mipac import __version__
from mipac.config import config
from mipac.errors.base import APIError
from mipac.types.endpoints import ENDPOINTS
Expand All @@ -15,8 +17,6 @@
from mipac.utils.format import remove_dict_empty, upper_to_lower
from mipac.utils.util import COLORS, _from_json

from mipac import __version__

_log = logging.getLogger(__name__)


Expand Down
3 changes: 2 additions & 1 deletion mipac/manager/admins/invite.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from mipac.abstract.manager import AbstractManager
from mipac.http import HTTPClient
from mipac.actions.admins.invite import AdminInviteActions
from mipac.http import HTTPClient

if TYPE_CHECKING:
from mipac.manager.client import ClientManager
Expand Down
3 changes: 2 additions & 1 deletion mipac/models/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Literal

from mipac.abstract.model import AbstractModel
from mipac.types.ads import IAd
from mipac.utils.format import str_to_datetime

Expand All @@ -11,7 +12,7 @@
from mipac.manager.client import ClientManager


class Ad:
class Ad(AbstractModel):
def __init__(self, ad_data: IAd, *, client: ClientManager) -> None:
self.__ad_data: IAd = ad_data
self.__client: ClientManager = client
Expand Down
Loading

0 comments on commit 2fe31ee

Please sign in to comment.