-
Notifications
You must be signed in to change notification settings - Fork 10
/
in_memory.py
86 lines (66 loc) · 2.25 KB
/
in_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, Iterable, List
from haystack import default_from_dict, default_to_dict, logging
from haystack.dataclasses import ChatMessage
from haystack_experimental.chat_message_stores.types import ChatMessageStore
logger = logging.getLogger(__name__)
class InMemoryChatMessageStore(ChatMessageStore):
"""
Stores chat messages in-memory.
"""
def __init__(
self,
):
"""
Initializes the InMemoryChatMessageStore.
"""
self.messages = []
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "InMemoryChatMessageStore":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
return default_from_dict(cls, data)
def count_messages(self) -> int:
"""
Returns the number of chat messages stored.
:returns: The number of messages.
"""
return len(self.messages)
def write_messages(self, messages: List[ChatMessage]) -> int:
"""
Writes chat messages to the ChatMessageStore.
:param messages: A list of ChatMessages to write.
:returns: The number of messages written.
:raises ValueError: If messages is not a list of ChatMessages.
"""
if not isinstance(messages, Iterable) or any(not isinstance(message, ChatMessage) for message in messages):
raise ValueError("Please provide a list of ChatMessages.")
self.messages.extend(messages)
return len(messages)
def delete_messages(self) -> None:
"""
Deletes all stored chat messages.
"""
self.messages = []
def retrieve(self) -> List[ChatMessage]:
"""
Retrieves all stored chat messages.
:returns: A list of chat messages.
"""
return self.messages