11
11
import hashlib
12
12
import hmac
13
13
import json
14
+ import aiohttp
15
+ import urllib .parse
14
16
15
17
from aiohttp import ClientConnectionError
16
18
@@ -37,6 +39,27 @@ def __init__(self):
37
39
super ().__init__ ("You may only replace your puppet with your own Matrix account." )
38
40
39
41
42
+ class CouldNotDetermineHomeServerURL (CustomPuppetError ):
43
+ """
44
+ Will be raised when any are true:
45
+ - .well-known/matrix/client returns 200 with mangled JSON body
46
+ - .well-known's JSON key [""m.homeserver"]["base_url"] does not exist
47
+ - .well-known's JSON key [""m.homeserver"]["base_url"] is not a valid URL
48
+ - .well-known's supplied homeserver URL, or the base domain URL, errors when validating it's version endpoint
49
+
50
+ This is in accordance with: https://matrix.org/docs/spec/client_server/r0.6.1#id178
51
+ """
52
+
53
+ def __init__ (self , domain : str ):
54
+ super ().__init__ (f"Could not discover a valid homeserver URL from domain { domain } " )
55
+
56
+
57
+ class OnlyLoginLocalDomain (CustomPuppetError ):
58
+ """Will be raised when CustomPuppetMixin.allow_external_custom_puppets is set to False"""
59
+ def __init__ (self , domain : str ):
60
+ super ().__init__ (f"You may only replace your puppet with an account from { domain } " )
61
+
62
+
40
63
class CustomPuppetMixin (ABC ):
41
64
"""
42
65
Mixin for the Puppet class to enable Matrix puppeting.
@@ -63,6 +86,7 @@ class CustomPuppetMixin(ABC):
63
86
"""
64
87
65
88
sync_with_custom_puppets : bool = True
89
+ allow_external_custom_puppets : bool = False
66
90
only_handle_own_synced_events : bool = True
67
91
login_shared_secret : Optional [bytes ] = None
68
92
login_device_name : Optional [str ] = None
@@ -78,6 +102,7 @@ class CustomPuppetMixin(ABC):
78
102
default_mxid_intent : IntentAPI
79
103
custom_mxid : Optional [UserID ]
80
104
access_token : Optional [str ]
105
+ base_url : Optional [str ]
81
106
next_batch : Optional [SyncToken ]
82
107
83
108
intent : IntentAPI
@@ -99,9 +124,63 @@ def is_real_user(self) -> bool:
99
124
return bool (self .custom_mxid and self .access_token )
100
125
101
126
def _fresh_intent (self ) -> IntentAPI :
102
- return (self .az .intent .user (self .custom_mxid , self .access_token )
127
+ return (self .az .intent .user (self .custom_mxid , self .access_token , self . base_url )
103
128
if self .is_real_user else self .default_mxid_intent )
104
129
130
+ async def _discover_homeserver_endpoint (self , domain : str ) -> str :
131
+ domain_is_valid = False
132
+
133
+ async def validate_versions_api (base_url : str ) -> bool :
134
+
135
+ async with self .az .http_session .get (urllib .parse .urljoin (base_url , "_matrix/client/versions" )) as response :
136
+ if response .status != 200 :
137
+ return False
138
+
139
+ try :
140
+ obj = await response .json (content_type = None )
141
+ if len (obj ["versions" ]) > 1 :
142
+ return True
143
+ except (KeyError , json .JSONDecodeError ):
144
+ return False
145
+
146
+ async def get_well_known_homeserver_base_url (probable_domain : str ) -> Optional [str ]:
147
+ async with self .az .http_session .get (f"https://{ probable_domain } /.well-known/matrix/client" ) as response :
148
+ if response .status != 200 :
149
+ return None
150
+
151
+ try :
152
+ obj = await response .json (content_type = None )
153
+ return obj ["m.homeserver" ]["base_url" ]
154
+ except (KeyError , json .JSONDecodeError ) as e :
155
+ raise CouldNotDetermineHomeServerURL (domain ) from e
156
+
157
+ try :
158
+ if await validate_versions_api (f"https://{ domain } " ):
159
+ # Flag front domain as valid, but keep looking
160
+ domain_is_valid = True
161
+ except aiohttp .ClientError :
162
+ pass
163
+
164
+ try :
165
+ base_url = await get_well_known_homeserver_base_url (domain )
166
+
167
+ if base_url is None :
168
+ if domain_is_valid :
169
+ # If we found a valid domain already, we just return that
170
+ return f"https://{ domain } "
171
+ else :
172
+ raise CouldNotDetermineHomeServerURL (domain )
173
+
174
+ if await validate_versions_api (base_url ):
175
+ return base_url
176
+ elif await validate_versions_api (base_url + "/" ):
177
+ return base_url + "/"
178
+ except aiohttp .ClientError as e :
179
+ if domain_is_valid :
180
+ # Earlier we already found a valid domain, so we ignore the error and return the base domain instead
181
+ return f"https://{ domain } "
182
+ raise CouldNotDetermineHomeServerURL (domain ) from e
183
+
105
184
@classmethod
106
185
def can_auto_login (cls , mxid : UserID ) -> bool :
107
186
if not cls .login_shared_secret :
@@ -131,7 +210,8 @@ async def _login_with_shared_secret(cls, mxid: UserID) -> Optional[str]:
131
210
data = await resp .json ()
132
211
return data ["access_token" ]
133
212
134
- async def switch_mxid (self , access_token : Optional [str ], mxid : Optional [UserID ]) -> None :
213
+ async def switch_mxid (self , access_token : Optional [str ], mxid : Optional [UserID ],
214
+ base_url : Optional [str ] = None ) -> None :
135
215
"""
136
216
Switch to a real Matrix user or away from one.
137
217
@@ -140,15 +220,28 @@ async def switch_mxid(self, access_token: Optional[str], mxid: Optional[UserID])
140
220
the appservice-owned ID.
141
221
mxid: The expected Matrix user ID of the custom account, or ``None`` when
142
222
``access_token`` is None.
223
+ base_url: An optional base URL to direct API calls to. If ``None``, and ``mxid`` is not ``None``,
224
+ and ``mxid`` ``server_part`` is the not the appservice domain, autodiscovery is tried.
143
225
"""
144
226
if access_token == "auto" :
145
227
access_token = await self ._login_with_shared_secret (mxid )
146
228
if not access_token :
147
229
raise ValueError ("Failed to log in with shared secret" )
148
230
self .log .debug (f"Logged in for { mxid } using shared secret" )
231
+
232
+ if mxid is not None :
233
+ mxid_domain = self .az .intent .parse_user_id (mxid )[1 ]
234
+ if mxid_domain != self .az .domain :
235
+ if not self .allow_external_custom_puppets :
236
+ raise OnlyLoginLocalDomain (self .az .domain )
237
+ elif base_url is None :
238
+ # This can throw CouldNotDetermineHomeServerURL
239
+ base_url = await self ._discover_homeserver_endpoint (mxid_domain )
240
+
149
241
prev_mxid = self .custom_mxid
150
242
self .custom_mxid = mxid
151
243
self .access_token = access_token
244
+ self .base_url = base_url
152
245
self .intent = self ._fresh_intent ()
153
246
154
247
await self .start ()
0 commit comments