Skip to content

Commit

Permalink
refactor models.get_original* queries to cache more and query less
Browse files Browse the repository at this point in the history
for #1149
  • Loading branch information
snarfed committed Sep 20, 2024
1 parent 0792ac1 commit 5330dca
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 91 deletions.
1 change: 0 additions & 1 deletion flask_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import common

logger = logging.getLogger(__name__)
# logging.getLogger('lexrpc').setLevel(logging.INFO)
logging.getLogger('negotiator').setLevel(logging.WARNING)

app_dir = Path(__file__).parent
Expand Down
10 changes: 5 additions & 5 deletions ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def translate_user_id(*, id, from_, to):
if user:
if copy := user.get_copy(to):
return copy
if orig := models.get_original(id):
if isinstance(orig, to):
return orig.key.id()
if orig := models.get_original_user_key(id):
if orig.kind() == to._get_kind():
return orig.id()

match from_.LABEL, to.LABEL:
case _, 'atproto' | 'nostr':
Expand Down Expand Up @@ -327,8 +327,8 @@ def translate_object_id(*, id, from_, to):
if obj := from_.load(id, remote=False):
if copy := obj.get_copy(to):
return copy
if orig := models.get_original(id):
return orig.key.id()
if orig := models.get_original_object_key(id):
return orig.id()

match from_.LABEL, to.LABEL:
case _, 'atproto' | 'nostr':
Expand Down
109 changes: 42 additions & 67 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def _run():
return user

else:
if orig := get_original(id):
if orig_key := get_original_user_key(id):
orig = orig_key.get()
if orig.status and not allow_opt_out:
return None
orig.existing = False
Expand Down Expand Up @@ -1212,61 +1213,47 @@ def resolve_ids(self):
return

inner_obj = outer_obj['object'] = as1.get_object(outer_obj)
fields = ['actor', 'author', 'inReplyTo']

# collect relevant ids
ids = [inner_obj.get('id')]
for obj in outer_obj, inner_obj:
for tag in as1.get_objects(obj, 'tags'):
if tag.get('objectType') == 'mention':
ids.append(tag.get('url'))
for field in fields:
for val in as1.get_objects(obj, field):
ids.append(val.get('id'))

ids = util.trim_nulls(ids)
if not ids:
return

# batch lookup matching users
origs = {} # maps str copy URI to str original URI
for obj in get_originals(tuple(ids)):
for copy in obj.copies:
if copy.protocol in (self_proto.LABEL, self_proto.ABBREV):
origs[copy.uri] = obj.key.id()

logger.debug(f'Resolving {self_proto.LABEL} ids; originals: {origs}')
replaced = False

def replace(val):
def replace(val, orig_fn):
id = val.get('id') if isinstance(val, dict) else val
orig = origs.get(id)
if not id:
return id

orig = orig_fn(id)
if not orig:
return val

nonlocal replaced
replaced = True
if isinstance(val, dict) and val.keys() > {'id'}:
val['id'] = orig
logger.debug(f'Resolved copy id {val} to original {orig.id()}')

if isinstance(val, dict) and util.trim_nulls(val).keys() > {'id'}:
val['id'] = orig.id()
return val
else:
return orig
return orig.id()

# actually replace ids
#
# object field could be either object (eg repost) or actor (eg follow)
outer_obj['object'] = replace(inner_obj, get_original_object_key)
if not replaced:
outer_obj['object'] = replace(inner_obj, get_original_user_key)

for obj in outer_obj, inner_obj:
for tag in as1.get_objects(obj, 'tags'):
if tag.get('objectType') == 'mention':
tag['url'] = replace(tag.get('url'))
for field in fields:
obj[field] = [replace(val) for val in util.get_list(obj, field)]
tag['url'] = replace(tag.get('url'), get_original_user_key)
for field, fn in (
('actor', get_original_user_key),
('author', get_original_user_key),
('inReplyTo', get_original_object_key),
):
obj[field] = [replace(val, fn) for val in util.get_list(obj, field)]
if len(obj[field]) == 1:
obj[field] = obj[field][0]

outer_obj['object'] = replace(inner_obj)

if util.trim_nulls(outer_obj['object']).keys() == {'id'}:
outer_obj['object'] = outer_obj['object']['id']

if replaced:
self.our_as1 = util.trim_nulls(outer_obj)

Expand Down Expand Up @@ -1639,47 +1626,35 @@ def get_paging_param(param):
return results, new_before, new_after


def get_original(copy_id, keys_only=None):
"""Fetches a user or object with a given id in copies.
Thin wrapper around :func:`get_copies` that returns the first
matching result.
Also see :Object:`get_copy` and :User:`get_copy`.
@lru_cache(maxsize=100000)
def get_original_object_key(copy_id):
"""Finds the :class:`Object` with a given copy id, if any.
Args:
copy_id (str)
keys_only (bool): passed through to :class:`google.cloud.ndb.Query`
Returns:
User or Object:
google.cloud.ndb.Key or None
"""
got = get_originals((copy_id,), keys_only=keys_only)
if got:
return got[0]
assert copy_id

return Object.query(Object.copies.uri == copy_id).get(keys_only=True)

@lru_cache(maxsize=10000)
def get_originals(copy_ids, keys_only=None):
"""Fetches users (across all protocols) for a given set of copies.

Also see :Object:`get_copy` and :User:`get_copy`.
@lru_cache(maxsize=100000)
def get_original_user_key(copy_id):
"""Finds the user with a given copy id, if any.
Args:
copy_ids (tuple (not list!) of str)
keys_only (bool): passed through to :class:`google.cloud.ndb.Query`
copy_id (str)
not_proto (Protocol): optional, don't query this protocol
Returns:
sequence of User and/or Object
google.cloud.ndb.Key or None
"""
assert copy_ids

classes = set(cls for cls in PROTOCOLS.values() if cls and cls.LABEL != 'ui')
classes.add(Object)
assert copy_id

return list(itertools.chain(*(
cls.query(cls.copies.uri.IN(copy_ids)).iter(keys_only=keys_only)
for cls in classes)))

# TODO: default to looking up copy_ids as key ids, across protocols? is
# that useful anywhere?
for proto in PROTOCOLS.values():
if proto and proto.LABEL != 'ui' and not proto.owns_id(copy_id):
if orig := proto.query(proto.copies.uri == copy_id).get(keys_only=True):
return orig
31 changes: 17 additions & 14 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,8 @@ def test_resolve_ids_copies_follow(self):
obj.resolve_ids()
self.assert_equals(follow, obj.our_as1)

models.get_originals.cache_clear()
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()

# matching copy users
self.make_user('other:alice', cls=OtherFake,
Expand Down Expand Up @@ -1008,7 +1009,8 @@ def test_resolve_ids_copies_reply(self):
obj.resolve_ids()
self.assert_equals(reply, obj.our_as1)

models.get_originals.cache_clear()
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()

# matching copies
self.make_user('other:alice', cls=OtherFake,
Expand Down Expand Up @@ -1048,7 +1050,8 @@ def test_resolve_ids_multiple_in_reply_to(self):
obj.resolve_ids()
self.assert_equals(note, obj.our_as1)

models.get_originals.cache_clear()
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()

# matching copies
self.store_object(id='other:a',
Expand Down Expand Up @@ -1153,19 +1156,19 @@ def test_normalize_ids_reply(self):
},
}, obj.our_as1)

def test_get_originals(self):
self.assertEqual([], models.get_originals(('foo', 'did:plc:bar')))
def test_get_original_user_key(self):
self.assertIsNone(models.get_original_user_key('other:user'))
models.get_original_user_key.cache_clear()
user = self.make_user('fake:user', cls=Fake,
copies=[Target(uri='other:user', protocol='other')])
self.assertEqual(user.key, models.get_original_user_key('other:user'))

def test_get_original_object_key(self):
self.assertIsNone(models.get_original_object_key('other:post'))
models.get_original_object_key.cache_clear()
obj = self.store_object(id='fake:post',
copies=[Target(uri='other:foo', protocol='other')])
user = self.make_user('other:user', cls=OtherFake,
copies=[Target(uri='fake:bar', protocol='fake')])

self.assert_entities_equal(
[obj, user], models.get_originals(('other:foo', 'fake:bar', 'baz')))

self.assert_entities_equal(
[obj, user], models.get_originals(('other:foo', 'fake:bar', 'baz')))
copies=[Target(uri='other:post', protocol='other')])
self.assertEqual(obj.key, models.get_original_object_key('other:post'))

def test_get_copy(self):
obj = Object(id='x')
Expand Down
9 changes: 6 additions & 3 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2555,7 +2555,8 @@ def test_resolve_ids_follow(self):
copies=[Target(uri='fake:bob', protocol='fake')])

common.memcache.clear()
models.get_originals.cache_clear()
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()

obj.new = True
OtherFake.fetchable = {
Expand Down Expand Up @@ -2588,7 +2589,8 @@ def test_resolve_ids_share(self):
copies=[Target(uri='fake:post', protocol='fake')])

common.memcache.clear()
models.get_originals.cache_clear()
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()
obj.new = True

_, code = Fake.receive(obj, authed_as='fake:alice')
Expand Down Expand Up @@ -2637,7 +2639,8 @@ def test_resolve_ids_reply_mentions(self):
id='fake:post', our_as1={'foo': 9}, source_protocol='fake',
copies=[Target(uri='other:post', protocol='other')])

models.get_originals.cache_clear()
models.get_original_user_key.cache_clear()
models.get_original_object_key.cache_clear()

obj.new = True
self.assertEqual(('OK', 202),
Expand Down
3 changes: 2 additions & 1 deletion tests/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def setUp(self):

memcache.clear()
global_cache.clear()
models.get_originals.cache_clear()
models.get_original_object_key.cache_clear()
models.get_original_user_key.cache_clear()
activitypub.WEB_OPT_OUT_DOMAINS = set()

# clear datastore
Expand Down

0 comments on commit 5330dca

Please sign in to comment.