Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cursor pagination of get_all_users in /admin/users route #1441

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions memgpt/client/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ def __init__(self, base_url: str, token: str):
self.token = token
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}

def get_users(self):
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers)
def get_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50):
payload = {"cursor": str(cursor) if cursor else None, "limit": limit}
response = requests.get(f"{self.base_url}/admin/users", headers=self.headers, json=payload)
if response.status_code != 200:
raise HTTPError(response.json())
return GetAllUsersResponse(**response.json())
Expand Down
17 changes: 13 additions & 4 deletions memgpt/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
String,
TypeDecorator,
create_engine,
desc,
func,
)
from sqlalchemy.dialects.postgresql import UUID
Expand Down Expand Up @@ -655,11 +656,19 @@ def get_user(self, user_id: uuid.UUID) -> Optional[User]:
return results[0].to_record()

@enforce_types
def get_all_users(self) -> List[User]:
# TODO make paginated
def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]):
with self.session_maker() as session:
results = session.query(UserModel).all()
return [r.to_record() for r in results]
query = session.query(UserModel).order_by(desc(UserModel.id))
if cursor:
query = query.filter(UserModel.id < cursor)
results = query.limit(limit).all()
if not results:
return None, []
user_records = [r.to_record() for r in results]
next_cursor = user_records[-1].id
assert isinstance(next_cursor, uuid.UUID)

return next_cursor, user_records

@enforce_types
def get_source(
Expand Down
12 changes: 9 additions & 3 deletions memgpt/server/rest_api/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
router = APIRouter()


class GetAllUsersRequest(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor to which to start the paginated request.")
limit: Optional[int] = Field(50, description="Maximum number of users to retrieve per page.")


class GetAllUsersResponse(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor for the next page in the response.")
user_list: List[dict] = Field(..., description="A list of users.")


Expand Down Expand Up @@ -54,18 +60,18 @@ class DeleteUserResponse(BaseModel):

def setup_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
def get_all_users():
def get_all_users(request: GetAllUsersRequest = Body(...)):
"""
Get a list of all users in the database
"""
try:
users = server.ms.get_all_users()
next_cursor, users = server.ms.get_all_users(request.cursor, request.limit)
processed_users = [{"user_id": user.id} for user in users]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return GetAllUsersResponse(user_list=processed_users)
return GetAllUsersResponse(cursor=next_cursor, user_list=processed_users)

@router.post("/users", tags=["admin"], response_model=CreateUserResponse)
def create_user(request: Optional[CreateUserRequest] = Body(None)):
Expand Down
51 changes: 51 additions & 0 deletions tests/test_admin_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,54 @@ def test_admin_client(admin_client):
# list users
users = admin_client.get_users()
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"


def test_get_users_pagination(admin_client):
_reset_config()

page_size = 5
num_users = 7
expected_users_remainder = num_users - page_size

# create users
all_user_ids = []
for i in range(num_users):

user_id = uuid.uuid4()
all_user_ids.append(user_id)
key_name = "test_key" + f"{i}"

create_user_response = admin_client.create_user(user_id)
admin_client.create_key(create_user_response.user_id, key_name)

# list users in page 1
get_all_users_response1 = admin_client.get_users(limit=page_size)
cursor1 = get_all_users_response1.cursor
user_list1 = get_all_users_response1.user_list
assert len(user_list1) == page_size

# list users in page 2 using cursor
get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
cursor2 = get_all_users_response2.cursor
user_list2 = get_all_users_response2.user_list

assert len(user_list2) == expected_users_remainder
assert cursor1 != cursor2

# delete users
clean_up_users_and_keys(all_user_ids)

# list users to check pagination with no users
users = admin_client.get_users()
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"


def clean_up_users_and_keys(user_id_list):
admin_client = Admin(test_base_url, test_server_token)

# clean up all keys and users
for user_id in user_id_list:
keys_list = admin_client.get_keys(user_id)
for key in keys_list:
admin_client.delete_key(key)
admin_client.delete_user(user_id)
Loading