Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support all_users & all_sender_channels for segment #164

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion stream_chat/async_chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
QuerySegmentTargetsOptions,
SegmentData,
SegmentType,
SegmentUpdatableFields,
)

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -591,7 +592,7 @@ async def query_segments(
return await self.post("segments/query", data=payload)

async def update_segment(
self, segment_id: str, data: SegmentData
self, segment_id: str, data: SegmentUpdatableFields
) -> StreamResponse:
return await self.put(f"segments/{segment_id}", data=data)

Expand Down
15 changes: 13 additions & 2 deletions stream_chat/async_chat/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from stream_chat.base.segment import SegmentInterface
from stream_chat.types.base import SortParam
from stream_chat.types.segment import QuerySegmentTargetsOptions, SegmentData
from stream_chat.types.segment import (
QuerySegmentTargetsOptions,
SegmentData,
SegmentUpdatableFields,
)
from stream_chat.types.stream_response import StreamResponse


Expand All @@ -24,24 +28,29 @@ async def create(
return state

async def get(self) -> StreamResponse:
super().verify_segment_id()
return await self.client.get_segment(segment_id=self.segment_id) # type: ignore

async def update(self, data: SegmentData) -> StreamResponse:
async def update(self, data: SegmentUpdatableFields) -> StreamResponse:
super().verify_segment_id()
return await self.client.update_segment( # type: ignore
segment_id=self.segment_id, data=data
)

async def delete(self) -> StreamResponse:
super().verify_segment_id()
return await self.client.delete_segment( # type: ignore
segment_id=self.segment_id
)

async def target_exists(self, target_id: str) -> StreamResponse:
super().verify_segment_id()
return await self.client.segment_target_exists( # type: ignore
segment_id=self.segment_id, target_id=target_id
)

async def add_targets(self, target_ids: list) -> StreamResponse:
super().verify_segment_id()
return await self.client.add_segment_targets( # type: ignore
segment_id=self.segment_id, target_ids=target_ids
)
Expand All @@ -52,6 +61,7 @@ async def query_targets(
sort: Optional[List[SortParam]] = None,
options: Optional[QuerySegmentTargetsOptions] = None,
) -> StreamResponse:
super().verify_segment_id()
return await self.client.query_segment_targets( # type: ignore
segment_id=self.segment_id,
filter_conditions=filter_conditions,
Expand All @@ -60,6 +70,7 @@ async def query_targets(
)

async def remove_targets(self, target_ids: list) -> StreamResponse:
super().verify_segment_id()
return await self.client.remove_segment_targets( # type: ignore
segment_id=self.segment_id, target_ids=target_ids
)
3 changes: 2 additions & 1 deletion stream_chat/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
QuerySegmentTargetsOptions,
SegmentData,
SegmentType,
SegmentUpdatableFields,
)

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -982,7 +983,7 @@ def query_segments(

@abc.abstractmethod
def update_segment(
self, segment_id: str, data: SegmentData
self, segment_id: str, data: SegmentUpdatableFields
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
"""
Update a segment by id
Expand Down
10 changes: 9 additions & 1 deletion stream_chat/base/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
QuerySegmentTargetsOptions,
SegmentData,
SegmentType,
SegmentUpdatableFields,
)
from stream_chat.types.stream_response import StreamResponse

Expand Down Expand Up @@ -36,7 +37,7 @@ def get(self) -> Union[StreamResponse, Awaitable[StreamResponse]]:

@abc.abstractmethod
def update(
self, data: SegmentData
self, data: SegmentUpdatableFields
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
pass

Expand Down Expand Up @@ -70,3 +71,10 @@ def remove_targets(
self, target_ids: List[str]
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
pass

def verify_segment_id(self) -> None:
if not self.segment_id:
raise ValueError(
"Segment id is missing. Either create the segment using segment.create() "
"or set the id during instantiation - segment = Segment(segment_id=segment_id)"
)
5 changes: 4 additions & 1 deletion stream_chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
QuerySegmentTargetsOptions,
SegmentData,
SegmentType,
SegmentUpdatableFields,
)

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -569,7 +570,9 @@ def query_segments(
payload.update(cast(dict, options))
return self.post("segments/query", data=payload)

def update_segment(self, segment_id: str, data: SegmentData) -> StreamResponse:
def update_segment(
self, segment_id: str, data: SegmentUpdatableFields
) -> StreamResponse:
return self.put(f"segments/{segment_id}", data=data)

def delete_segment(self, segment_id: str) -> StreamResponse:
Expand Down
15 changes: 13 additions & 2 deletions stream_chat/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from stream_chat.base.segment import SegmentInterface
from stream_chat.types.base import SortParam
from stream_chat.types.segment import QuerySegmentTargetsOptions, SegmentData
from stream_chat.types.segment import (
QuerySegmentTargetsOptions,
SegmentData,
SegmentUpdatableFields,
)
from stream_chat.types.stream_response import StreamResponse


Expand All @@ -24,22 +28,27 @@ def create(
return state # type: ignore

def get(self) -> StreamResponse:
super().verify_segment_id()
return self.client.get_segment(segment_id=self.segment_id) # type: ignore

def update(self, data: SegmentData) -> StreamResponse:
def update(self, data: SegmentUpdatableFields) -> StreamResponse:
super().verify_segment_id()
return self.client.update_segment( # type: ignore
segment_id=self.segment_id, data=data
)

def delete(self) -> StreamResponse:
super().verify_segment_id()
return self.client.delete_segment(segment_id=self.segment_id) # type: ignore

def target_exists(self, target_id: str) -> StreamResponse:
super().verify_segment_id()
return self.client.segment_target_exists( # type: ignore
segment_id=self.segment_id, target_id=target_id
)

def add_targets(self, target_ids: list) -> StreamResponse:
super().verify_segment_id()
return self.client.add_segment_targets( # type: ignore
segment_id=self.segment_id, target_ids=target_ids
)
Expand All @@ -50,6 +59,7 @@ def query_targets(
sort: Optional[List[SortParam]] = None,
options: Optional[QuerySegmentTargetsOptions] = None,
) -> StreamResponse:
super().verify_segment_id()
return self.client.query_segment_targets( # type: ignore
segment_id=self.segment_id,
sort=sort,
Expand All @@ -58,6 +68,7 @@ def query_targets(
)

def remove_targets(self, target_ids: list) -> StreamResponse:
super().verify_segment_id()
return self.client.remove_segment_targets( # type: ignore
segment_id=self.segment_id, target_ids=target_ids
)
17 changes: 15 additions & 2 deletions stream_chat/types/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class SegmentType(Enum):
USER = "user"


class SegmentData(TypedDict, total=False):
class SegmentUpdatableFields(TypedDict, total=False):
"""
Represents the data structure for a segment.
Represents the updatable data structure for a segment.

Parameters:
name: The name of the segment.
Expand All @@ -38,6 +38,19 @@ class SegmentData(TypedDict, total=False):
filter: Optional[Dict]


class SegmentData(SegmentUpdatableFields, total=False):
"""
Represents the data structure for a segment.

Parameters:
all_users: Whether to target all users.
all_sender_channels: Whether to target all sender channels.
"""

all_users: Optional[bool]
all_sender_channels: Optional[bool]


class QuerySegmentsOptions(Pager, total=False):
pass

Expand Down
Loading