Skip to content

Commit

Permalink
Added storing state on session for mfa on radius
Browse files Browse the repository at this point in the history
  • Loading branch information
dkmstr committed Oct 12, 2023
1 parent d21c371 commit de33a96
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
5 changes: 4 additions & 1 deletion server/src/uds/auths/Radius/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,10 @@ def authenticate(
) -> auths.AuthenticationResult:
try:
connection = self.radiusClient()
groups, mfaCode = connection.authenticate(username=username, password=credentials, mfaField=self.mfaAttr.value.strip())
groups, mfaCode, state = connection.authenticate(username=username, password=credentials, mfaField=self.mfaAttr.value.strip())
# If state, store in session
if state:
request.session[client.STATE_VAR_NAME] = state.decode()
# store the user mfa attribute if it is set
if mfaCode:
self.storage.putPickle(
Expand Down
8 changes: 5 additions & 3 deletions server/src/uds/auths/Radius/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
NOT_CHECKED, INCORRECT, CORRECT = -1, 0, 1 # for pwd and otp
NOT_NEEDED, NEEDED = INCORRECT, CORRECT # for otp_needed

STATE_VAR_NAME = 'radius_state'


class RadiusAuthenticationError(Exception):
pass
Expand Down Expand Up @@ -129,7 +131,7 @@ def sendAccessRequest(self, username: str, password: str, **kwargs) -> pyrad.pac
# Second element of return value is the mfa code from field
def authenticate(
self, username: str, password: str, mfaField: str = ''
) -> typing.Tuple[typing.List[str], str]:
) -> typing.Tuple[typing.List[str], str, bytes]:
reply = self.sendAccessRequest(username, password)

if reply.code not in (pyrad.packet.AccessAccept, pyrad.packet.AccessChallenge):
Expand All @@ -147,7 +149,7 @@ def authenticate(
]
else:
logger.info('No "Class (25)" attribute found')
return ([], '')
return ([], '', b'')

# ...and mfa code
mfaCode = ''
Expand All @@ -157,7 +159,7 @@ def authenticate(
for i in typing.cast(typing.Iterable[bytes], reply['Class'])
if i.startswith(groupClassPrefix)
)
return (groups, mfaCode)
return (groups, mfaCode, typing.cast(typing.List[bytes], reply.get('State') or [b''])[0])

def authenticate_only(self, username: str, password: str) -> RadiusResult:
reply = self.sendAccessRequest(username, password)
Expand Down
6 changes: 3 additions & 3 deletions server/src/uds/mfas/Radius/mfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def process(
return self.checkResult(self.allowLoginWithoutMFA.value, request)

# Store state for later use, related to this user
request.session['radius_state'] = auth_reply.state or b''
request.session[client.STATE_VAR_NAME] = auth_reply.state or b''

# correct password and otp_needed
return mfas.MFA.RESULT.OK
Expand Down Expand Up @@ -254,10 +254,10 @@ def validate(
web_pwd = webPassword(request)
try:
connection = self.radiusClient()
state = request.session.get('radius_state', b'')
state = request.session.get(client.STATE_VAR_NAME, b'')
if state:
# Remove state from session
del request.session['radius_state']
del request.session[client.STATE_VAR_NAME]
# Use state to validate
auth_reply = connection.authenticate_challenge(username, otp=code, state=state)
else: # No state, so full authentication
Expand Down
4 changes: 2 additions & 2 deletions server/src/uds/web/views/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def mfa(request: ExtendedHttpRequest) -> HttpResponse: # pylint: disable=too-ma
logger.warning('MFA: No user or user is already authorized')
return HttpResponseRedirect(reverse('page.index')) # No user, no MFA

mfaProvider: typing.Optional['models.MFA'] = request.user.manager.mfa
mfaProvider = typing.cast('None|models.MFA', request.user.manager.mfa)
if not mfaProvider:
logger.warning('MFA: No MFA provider for user')
return HttpResponseRedirect(reverse('page.index'))
Expand All @@ -187,7 +187,7 @@ def mfa(request: ExtendedHttpRequest) -> HttpResponse: # pylint: disable=too-ma

# Obtain MFA data
authInstance = request.user.manager.getInstance()
mfaInstance: 'mfas.MFA' = mfaProvider.getInstance()
mfaInstance = typing.cast('mfas.MFA', mfaProvider.getInstance())

# Get validity duration
validity = mfaProvider.validity * 60
Expand Down

0 comments on commit de33a96

Please sign in to comment.