Skip to content

Commit

Permalink
Client: when we have binary input and we refresh token, reset the inp…
Browse files Browse the repository at this point in the history
…ut stream

for snarfed/bridgy#1670 maybe?
  • Loading branch information
snarfed committed Apr 30, 2024
1 parent 7b55ca9 commit 395555a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ Here's how to package, test, and ship a new release.
* Fix websocket subscription server hang with blocking server XRPC methods due to exhausting worker thread pool ([#8](https://github.com/snarfed/lexrpc/issues/8)).
* Add `truncate` kwarg to `Client` and `Server` constructors to automatically truncate (ellipsize) string values that are longer than their ``maxGraphemes`` or ``maxLength`` in their lexicon. Defaults to `False`.
* `Client`:
* Bug fix: when input is a binary stream and we refresh the token, reset the input stream so it still works for the retry of the original call.
### 0.6 - 2024-03-16
Expand Down
8 changes: 7 additions & 1 deletion lexrpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* asyncio support for subscription websockets
"""
import copy
from io import BytesIO
from io import BufferedRandom, BytesIO
import json
import logging
from urllib.parse import urljoin
Expand Down Expand Up @@ -179,6 +179,10 @@ def loggable(val):
# query or procedure
fn = requests.get if type == 'query' else requests.post
logger.debug(f'Running requests.{fn} {url} {loggable(input)} {params_str} {log_headers}')

if input and not isinstance(input, (dict, str, bytes)):
input = BufferedRandom(input)

resp = fn(
url,
json=input if input and isinstance(input, dict) else None,
Expand Down Expand Up @@ -208,6 +212,8 @@ def loggable(val):
elif not resp.ok: # token expired, try to refresh it
if output and output.get('error') in TOKEN_ERRORS:
self.call(REFRESH_NSID)
if isinstance(input, BufferedRandom):
input.seek(0)
return self.call(nsid, input=input, headers=req_headers, **params) # retry

resp.raise_for_status()
Expand Down
61 changes: 59 additions & 2 deletions lexrpc/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Unit tests for client.py."""
from io import BytesIO
import json
from unittest import skip, TestCase
from unittest.mock import call, patch
from unittest.mock import ANY, call, patch
import urllib.parse

import dag_cbor
Expand All @@ -10,7 +11,7 @@
import simple_websocket

from .lexicons import LEXICONS
from .. import client, Client
from .. import base, client, Client

HEADERS = {
**client.DEFAULT_HEADERS,
Expand Down Expand Up @@ -445,3 +446,59 @@ def test_binary_data(self, mock_post):
**client.DEFAULT_HEADERS,
'Content-Type': 'foo/bar',
})

@patch('requests.post')
def test_binary_input_stream_with_refresh_token(self, mock_post):
session = {
'accessJwt': 'new-towkin',
'refreshJwt': 'reephrush',
'handle': 'handull',
'did': 'dyd',
}
post_resps = [
response(status=400, body={ # io.example.encodings, fails auth
'error': 'ExpiredToken',
'message': 'Token has expired'
}),
response(session), # refreshSession
response({'ok': 'ok'}), # io.example.encodings retry, succeeds
]

cur = -1
def check_posts(url, data=None, **kwargs):
if url.endswith('/io.example.encodings'):
# consume data stream to test that we reset it for the retry
self.assertEqual(b'foo bar', data.read())
nonlocal cur
cur += 1
return post_resps[cur]

mock_post.side_effect = check_posts

cli = Client('http://ser.ver', lexicons=LEXICONS + base._bundled_lexicons,
access_token='towkin', refresh_token='reephrush')
input = BytesIO(b'foo bar')
resp = cli.io.example.encodings(input, headers={'Content-Type': 'foo/bar'})
self.assertEqual({'ok': 'ok'}, resp)
self.assertEqual(session, cli.session)

mock_post.assert_has_calls([
call('http://ser.ver/xrpc/io.example.encodings',
json=None, data=ANY, headers={
**client.DEFAULT_HEADERS,
'Content-Type': 'foo/bar',
'Authorization': 'Bearer towkin'
}),
call('http://ser.ver/xrpc/com.atproto.server.refreshSession',
json=None, data=None, headers={
**client.DEFAULT_HEADERS,
'Content-Type': 'application/json',
'Authorization': 'Bearer reephrush',
}),
call('http://ser.ver/xrpc/io.example.encodings',
json=None, data=ANY, headers={
**client.DEFAULT_HEADERS,
'Content-Type': 'foo/bar',
'Authorization': 'Bearer new-towkin'
}),
], any_order=True)

0 comments on commit 395555a

Please sign in to comment.