18
18
from attr import dataclass
19
19
import asyncpg
20
20
21
- from mautrix .types import RoomID , ContentURI
21
+ from mautrix .types import RoomID , ContentURI , UserID
22
22
from mautrix .util .async_db import Database
23
23
24
24
fake_db = Database ("" ) if TYPE_CHECKING else None
@@ -37,22 +37,23 @@ class Portal:
37
37
encrypted : bool
38
38
name_set : bool
39
39
avatar_set : bool
40
+ relay_user_id : Optional [UserID ]
40
41
41
42
async def insert (self ) -> None :
42
43
q = ("INSERT INTO portal (thread_id, receiver, other_user_pk, mxid, name, avatar_url, "
43
- " encrypted, name_set, avatar_set) "
44
- "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" )
44
+ " encrypted, name_set, avatar_set, relay_user_id ) "
45
+ "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10 )" )
45
46
await self .db .execute (q , self .thread_id , self .receiver , self .other_user_pk ,
46
47
self .mxid , self .name , self .avatar_url , self .encrypted ,
47
- self .name_set , self .avatar_set )
48
+ self .name_set , self .avatar_set , self . relay_user_id )
48
49
49
50
async def update (self ) -> None :
50
51
q = ("UPDATE portal SET other_user_pk=$3, mxid=$4, name=$5, avatar_url=$6, encrypted=$7,"
51
- " name_set=$8, avatar_set=$9 "
52
+ " name_set=$8, avatar_set=$9, relay_user_id=$10 "
52
53
"WHERE thread_id=$1 AND receiver=$2" )
53
54
await self .db .execute (q , self .thread_id , self .receiver , self .other_user_pk ,
54
55
self .mxid , self .name , self .avatar_url , self .encrypted ,
55
- self .name_set , self .avatar_set )
56
+ self .name_set , self .avatar_set , self . relay_user_id )
56
57
57
58
@classmethod
58
59
def _from_row (cls , row : asyncpg .Record ) -> 'Portal' :
@@ -61,7 +62,7 @@ def _from_row(cls, row: asyncpg.Record) -> 'Portal':
61
62
@classmethod
62
63
async def get_by_mxid (cls , mxid : RoomID ) -> Optional ['Portal' ]:
63
64
q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
64
- " name_set, avatar_set "
65
+ " name_set, avatar_set, relay_user_id "
65
66
"FROM portal WHERE mxid=$1" )
66
67
row = await cls .db .fetchrow (q , mxid )
67
68
if not row :
@@ -72,7 +73,7 @@ async def get_by_mxid(cls, mxid: RoomID) -> Optional['Portal']:
72
73
async def get_by_thread_id (cls , thread_id : str , receiver : int ,
73
74
rec_must_match : bool = True ) -> Optional ['Portal' ]:
74
75
q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
75
- " name_set, avatar_set "
76
+ " name_set, avatar_set, relay_user_id "
76
77
"FROM portal WHERE thread_id=$1 AND receiver=$2" )
77
78
if not rec_must_match :
78
79
q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
@@ -86,23 +87,23 @@ async def get_by_thread_id(cls, thread_id: str, receiver: int,
86
87
@classmethod
87
88
async def find_private_chats_of (cls , receiver : int ) -> List ['Portal' ]:
88
89
q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
89
- " name_set, avatar_set "
90
+ " name_set, avatar_set, relay_user_id "
90
91
"FROM portal WHERE receiver=$1 AND other_user_pk IS NOT NULL" )
91
92
rows = await cls .db .fetch (q , receiver )
92
93
return [cls ._from_row (row ) for row in rows ]
93
94
94
95
@classmethod
95
96
async def find_private_chats_with (cls , other_user : int ) -> List ['Portal' ]:
96
97
q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
97
- " name_set, avatar_set "
98
+ " name_set, avatar_set, relay_user_id "
98
99
"FROM portal WHERE other_user_pk=$1" )
99
100
rows = await cls .db .fetch (q , other_user )
100
101
return [cls ._from_row (row ) for row in rows ]
101
102
102
103
@classmethod
103
104
async def all_with_room (cls ) -> List ['Portal' ]:
104
105
q = ("SELECT thread_id, receiver, other_user_pk, mxid, name, avatar_url, encrypted, "
105
- " name_set, avatar_set "
106
+ " name_set, avatar_set, relay_user_id "
106
107
"FROM portal WHERE mxid IS NOT NULL" )
107
108
rows = await cls .db .fetch (q )
108
109
return [cls ._from_row (row ) for row in rows ]
0 commit comments