Skip to content

Commit

Permalink
Merge branch 'ossobv-fix-timezone-support'
Browse files Browse the repository at this point in the history
  • Loading branch information
mhindery committed Jul 31, 2020
2 parents 9a2073d + 8e8476d commit e123199
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 10 deletions.
6 changes: 3 additions & 3 deletions djangosaml2idp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _should_refresh(self) -> bool:
def _refresh_from_remote(self) -> bool:
try:
self.local_metadata = validate_metadata(fetch_metadata(self.remote_metadata_url))
self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata).replace(tzinfo=None)
self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata)
# Return True if it is now valid, False (+ log an error) otherwise
if now() > self.metadata_expiration_dt:
raise ValidationError(f'Remote metadata for SP {self.entity_id} was refreshed, but contains an expired validity datetime.')
Expand All @@ -92,7 +92,7 @@ def _refresh_from_remote(self) -> bool:
def _refresh_from_local(self) -> bool:
try:
# Try to extract a valid expiration datetime from the local metadata
self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata).replace(tzinfo=None)
self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata)
# Return True if it is now valid, False (+ log an error) otherwise
if now() > self.metadata_expiration_dt:
raise ValidationError(f'Local metadata for SP {self.entity_id} contains an expired validity datetime or none at all, no remote metadata found to refresh.')
Expand Down Expand Up @@ -157,7 +157,7 @@ def __str__(self):

def save(self, *args, **kwargs):
if not self.metadata_expiration_dt:
self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata).replace(tzinfo=None)
self.metadata_expiration_dt = extract_validuntil_from_metadata(self.local_metadata)
super().save(*args, **kwargs)
IDP.load(force_refresh=True)

Expand Down
3 changes: 3 additions & 0 deletions djangosaml2idp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import xml.etree.ElementTree as ET
import zlib
from xml.parsers.expat import ExpatError
from django.conf import settings
from django.utils.translation import gettext as _
import arrow
import requests
Expand Down Expand Up @@ -68,4 +69,6 @@ def extract_validuntil_from_metadata(metadata: str) -> datetime.datetime:
except Exception as e:
raise ValidationError(f'Could not extra ValidUntil timestamp from metadata: {e}')

if not settings.USE_TZ:
return metadata_expiration_dt.replace(tzinfo=None)
return metadata_expiration_dt
13 changes: 9 additions & 4 deletions tests/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from django.contrib.auth import get_user_model
from django.utils import timezone
from djangosaml2idp.forms import ServiceProviderAdminForm

User = get_user_model()
Expand All @@ -19,7 +20,9 @@ def test_nometadata_given(self):
assert 'Either a remote metadata URL, or a local metadata xml needs to be provided.' in form.errors['__all__']

@pytest.mark.django_db
def test_valid_local_metadata(self, sp_metadata_xml):
@pytest.mark.parametrize('use_tz, tzinfo', [(True, timezone.utc), (False, None)])
def test_valid_local_metadata(self, settings, sp_metadata_xml, use_tz, tzinfo):
settings.USE_TZ = use_tz
form = ServiceProviderAdminForm({
'entity_id': 'entity-id',
'_processor': 'djangosaml2idp.processors.BaseProcessor',
Expand All @@ -36,7 +39,7 @@ def test_valid_local_metadata(self, sp_metadata_xml):
instance = form.save()
assert instance.remote_metadata_url == ''
assert instance.local_metadata == sp_metadata_xml
assert instance.metadata_expiration_dt == datetime.datetime(2021, 2, 14, 17, 43, 34)
assert instance.metadata_expiration_dt == datetime.datetime(2021, 2, 14, 17, 43, 34, tzinfo=tzinfo)

@pytest.mark.django_db
def test_invalid_local_metadata(self):
Expand All @@ -57,7 +60,9 @@ def test_invalid_local_metadata(self):

@pytest.mark.django_db
@mock.patch('requests.get')
def test_valid_remote_metadata_url(self, mock_get, sp_metadata_xml):
@pytest.mark.parametrize('use_tz, tzinfo', [(True, timezone.utc), (False, None)])
def test_valid_remote_metadata_url(self, mock_get, settings, sp_metadata_xml, use_tz, tzinfo):
settings.USE_TZ = use_tz
mock_get.return_value = mock.Mock(status_code=200, text=sp_metadata_xml)
form = ServiceProviderAdminForm({
'entity_id': 'entity-id',
Expand All @@ -75,7 +80,7 @@ def test_valid_remote_metadata_url(self, mock_get, sp_metadata_xml):
instance = form.save()
assert instance.remote_metadata_url == 'https://ok'
assert instance.local_metadata == sp_metadata_xml
assert instance.metadata_expiration_dt == datetime.datetime(2021, 2, 14, 17, 43, 34)
assert instance.metadata_expiration_dt == datetime.datetime(2021, 2, 14, 17, 43, 34, tzinfo=tzinfo)

@pytest.mark.django_db
@mock.patch('requests.get')
Expand Down
9 changes: 6 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime
import xml
from unittest import mock

import arrow
import pytest
from django.core.exceptions import ValidationError
from django.utils import timezone

from djangosaml2idp.utils import (encode_saml,
extract_validuntil_from_metadata,
Expand Down Expand Up @@ -52,9 +53,11 @@ def test_extract_validuntil_from_metadata_invalid(self):
with pytest.raises(ValidationError):
extract_validuntil_from_metadata('')

def test_extract_validuntil_from_metadata_valid(self, sp_metadata_xml):
@pytest.mark.parametrize('use_tz, tzinfo', [(True, timezone.utc), (False, None)])
def test_extract_validuntil_from_metadata_valid(self, settings, sp_metadata_xml, use_tz, tzinfo):
settings.USE_TZ = use_tz
valid_until_dt_extracted = extract_validuntil_from_metadata(sp_metadata_xml)
assert valid_until_dt_extracted == arrow.get("2021-02-14T17:43:34Z")
assert valid_until_dt_extracted == datetime.datetime(2021, 2, 14, 17, 43, 34, tzinfo=tzinfo)


class TestUtils:
Expand Down

0 comments on commit e123199

Please sign in to comment.