Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions firebase_admin/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ def set_custom_user_claims(uid, custom_claims, app=None):
FirebaseError: If an error occurs while updating the user account.
"""
user_manager = _get_auth_service(app).user_manager
if custom_claims is None:
custom_claims = DELETE_ATTRIBUTE
user_manager.update_user(uid, custom_claims=custom_claims)


Expand Down
6 changes: 3 additions & 3 deletions integration/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,9 @@ def test_update_custom_user_claims(new_user):
def test_disable_user(new_user_with_params):
user = auth.update_user(
new_user_with_params.uid,
display_name=None,
photo_url=None,
phone_number=None,
display_name=auth.DELETE_ATTRIBUTE,
photo_url=auth.DELETE_ATTRIBUTE,
phone_number=auth.DELETE_ATTRIBUTE,
disabled=True)
assert user.uid == new_user_with_params.uid
assert user.email == new_user_with_params.email
Expand Down
5 changes: 3 additions & 2 deletions tests/test_user_mgt.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,10 @@ def test_set_custom_user_claims_str(self, user_mgt_app):
request = json.loads(recorder[0].body.decode())
assert request == {'localId' : 'testuser', 'customAttributes' : claims}

def test_set_custom_user_claims_remove(self, user_mgt_app):
@pytest.mark.parametrize('claims', [None, auth.DELETE_ATTRIBUTE])
def test_set_custom_user_claims_remove(self, user_mgt_app, claims):
_, recorder = _instrument_user_manager(user_mgt_app, 200, '{"localId":"testuser"}')
auth.set_custom_user_claims('testuser', auth.DELETE_ATTRIBUTE, app=user_mgt_app)
auth.set_custom_user_claims('testuser', claims, app=user_mgt_app)
request = json.loads(recorder[0].body.decode())
assert request == {'localId' : 'testuser', 'customAttributes' : json.dumps({})}

Expand Down