Skip to content

Commit

Permalink
Make LTI http service take an LTIRegistration
Browse files Browse the repository at this point in the history
Until now all call to LTIA API happen on the context of a launch so it
made sense to always scope them to the registration of that launch.

In preparation for start fetching course rosters, likely on a celery
task, outside the context of a launch, take an explicit registration for
LTIA calls.
  • Loading branch information
marcospri committed Aug 27, 2024
1 parent 4de0d14 commit e9976f3
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 77 deletions.
9 changes: 8 additions & 1 deletion lms/services/lti_grading/_v13.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime, timezone
from urllib.parse import urlparse

from lms.models import LTIRegistration
from lms.product.family import Family
from lms.product.plugin.misc import MiscPlugin
from lms.services.exceptions import ExternalRequestError, StudentNotInCourse
Expand All @@ -23,18 +24,20 @@ class LTI13GradingService(LTIGradingService):
"https://purl.imsglobal.org/spec/lti-ags/scope/score",
]

def __init__( # noqa: PLR0913
def __init__( # noqa: PLR0913, PLR0917
self,
line_item_url,
line_item_container_url,
ltia_service: LTIAHTTPService,
product_family: Family,
misc_plugin: MiscPlugin,
lti_registration: LTIRegistration,
):
super().__init__(line_item_url, line_item_container_url)
self._ltia_service = ltia_service
self._product_family = product_family
self._misc_plugin = misc_plugin
self._lti_registration = lti_registration

def read_result(self, grading_id) -> GradingResult:
result = GradingResult(score=None, comment=None)
Expand All @@ -47,6 +50,7 @@ def read_result(self, grading_id) -> GradingResult:

try:
response = self._ltia_service.request(
self._lti_registration,
"GET",
self._service_url(self.line_item_url, "/results"),
scopes=self.LTIA_SCOPES,
Expand Down Expand Up @@ -98,6 +102,7 @@ def record_result(self, grading_id, score=None, pre_record_hook=None, comment=No

try:
return self._ltia_service.request(
self._lti_registration,
"POST",
self._service_url(self.line_item_url, "/scores"),
scopes=self.LTIA_SCOPES,
Expand Down Expand Up @@ -138,6 +143,7 @@ def create_line_item(self, resource_link_id, label, score_maximum=100):
"resourceLinkId": resource_link_id,
}
return self._ltia_service.request(
self._lti_registration,
"POST",
self.line_item_container_url,
scopes=self.LTIA_SCOPES,
Expand All @@ -155,6 +161,7 @@ def _read_grading_configuration(self, resource_link_id) -> dict:
containers = []
try:
containers = self._ltia_service.request(
self._lti_registration,
"GET",
self.line_item_container_url,
scopes=self.LTIA_SCOPES,
Expand Down
1 change: 1 addition & 0 deletions lms/services/lti_grading/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def service_factory(_context, request):
ltia_service=request.find_service(LTIAHTTPService),
product_family=request.product.family,
misc_plugin=request.product.plugin.misc,
lti_registration=request.lti_user.application_instance.lti_registration,
)

return LTI11GradingService(
Expand Down
18 changes: 8 additions & 10 deletions lms/services/lti_names_roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from typing import TypedDict

from lms.models import LTIRegistration
from lms.services.ltia_http import LTIAHTTPService


Expand All @@ -31,20 +32,22 @@ class LTINamesRolesService:
"https://purl.imsglobal.org/spec/lti-nrps/scope/contextmembership.readonly"
]

def __init__(self, service_url: str, ltia_http_service: LTIAHTTPService):
self._service_url = service_url
def __init__(self, ltia_http_service: LTIAHTTPService):
self._ltia_service = ltia_http_service

def get_context_memberships(self) -> list[Member]:
def get_context_memberships(
self, lti_registration: LTIRegistration, service_url: str
) -> list[Member]:
"""
Get all the memberships of a context (a course).
The course is defined by the service URL which will obtain
from a LTI launch parameter and is always linked to an specific context.
"""
response = self._ltia_service.request(
lti_registration,
"GET",
self._service_url,
service_url,
scopes=self.LTIA_SCOPES,
headers={
"Accept": "application/vnd.ims.lti-nrps.v2.membershipcontainer+json"
Expand All @@ -55,9 +58,4 @@ def get_context_memberships(self) -> list[Member]:


def factory(_context, request):
return LTINamesRolesService(
service_url=request.lti_jwt.get(
"https://purl.imsglobal.org/spec/lti-nrps/claim/namesroleservice", {}
).get("context_memberships_url"),
ltia_http_service=request.find_service(LTIAHTTPService),
)
return LTINamesRolesService(ltia_http_service=request.find_service(LTIAHTTPService))
39 changes: 24 additions & 15 deletions lms/services/ltia_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,52 @@
class LTIAHTTPService:
"""Send LTI Advantage requests and return the responses."""

def __init__( # noqa: PLR0913
def __init__(
self,
lti_registration: LTIRegistration,
plugin: MiscPlugin,
jwt_service: JWTService,
http,
jwt_oauth2_token_service: JWTOAuth2TokenService,
):
self._lti_registration = lti_registration
self._jwt_service = jwt_service
self._http = http
self._plugin = plugin
self._jwt_oauth2_token_service = jwt_oauth2_token_service

def request(self, method, url, scopes, headers=None, **kwargs):
def request( # noqa: PLR0913
self,
lti_registration: LTIRegistration,
method,
url,
scopes,
headers=None,
**kwargs,
):
headers = headers or {}

assert "Authorization" not in headers

access_token = self._get_access_token(scopes)
access_token = self._get_access_token(lti_registration, scopes)
headers["Authorization"] = f"Bearer {access_token}"

return self._http.request(method, url, headers=headers, **kwargs)

def _get_access_token(self, scopes: list[str]) -> str:
def _get_access_token(
self, lti_registration: LTIRegistration, scopes: list[str]
) -> str:
"""Get a valid access token from the DB or get a new one from the LMS."""
token = self._jwt_oauth2_token_service.get_token(self._lti_registration, scopes)
token = self._jwt_oauth2_token_service.get_token(lti_registration, scopes)
if not token:
LOG.debug("Requesting new LTIA JWT token")
token = self._get_new_access_token(scopes)
token = self._get_new_access_token(lti_registration, scopes)
else:
LOG.debug("Using cached LTIA JWT token")

return token.access_token

def _get_new_access_token(self, scopes: list[str]) -> JWTOAuth2Token:
def _get_new_access_token(
self, lti_registration: LTIRegistration, scopes: list[str]
) -> JWTOAuth2Token:
"""
Get an access token from the LMS to use in LTA services.
Expand All @@ -62,15 +72,15 @@ def _get_new_access_token(self, scopes: list[str]) -> JWTOAuth2Token:
{
"exp": now + timedelta(hours=1),
"iat": now,
"iss": self._lti_registration.client_id,
"sub": self._lti_registration.client_id,
"aud": self._plugin.get_ltia_aud_claim(self._lti_registration),
"iss": lti_registration.client_id,
"sub": lti_registration.client_id,
"aud": self._plugin.get_ltia_aud_claim(lti_registration),
"jti": uuid.uuid4().hex,
}
)

response = self._http.post(
self._lti_registration.token_url,
lti_registration.token_url,
data={
"grant_type": "client_credentials",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
Expand All @@ -87,7 +97,7 @@ def _get_new_access_token(self, scopes: list[str]) -> JWTOAuth2Token:
raise

token = self._jwt_oauth2_token_service.save_token(
lti_registration=self._lti_registration,
lti_registration=lti_registration,
scopes=scopes,
access_token=token_data["access_token"],
expires_in=token_data["expires_in"],
Expand All @@ -97,7 +107,6 @@ def _get_new_access_token(self, scopes: list[str]) -> JWTOAuth2Token:

def factory(_context, request):
return LTIAHTTPService(
request.lti_user.application_instance.lti_registration,
request.product.plugin.misc,
request.find_service(JWTService),
request.find_service(name="http"),
Expand Down
31 changes: 23 additions & 8 deletions tests/unit/lms/services/lti_grading/_v13_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@

class TestLTI13GradingService:
@freeze_time("2022-04-04")
def test_read_lti_result(self, svc, response, ltia_http_service):
def test_read_lti_result(self, svc, response, ltia_http_service, lti_registration):
ltia_http_service.request.return_value.json.return_value = response
svc.line_item_url = "https://lms.com/lineitems?param=1"

result = svc.read_result(sentinel.user_id)

ltia_http_service.request.assert_called_once_with(
lti_registration,
"GET",
"https://lms.com/lineitems/results?param=1",
scopes=svc.LTIA_SCOPES,
Expand Down Expand Up @@ -73,14 +74,20 @@ def test_read_bad_response_lti_result(self, svc, ltia_http_service, bad_response
assert not result.comment

def test_read_result_blackboard(
self, blackboard_svc, ltia_http_service, blackboard_response, misc_plugin
self,
blackboard_svc,
ltia_http_service,
blackboard_response,
misc_plugin,
lti_registration,
):
ltia_http_service.request.return_value.json.return_value = blackboard_response
blackboard_svc.line_item_url = "https://lms.com/lineitems?param=1"

result = blackboard_svc.read_result(sentinel.user_id)

ltia_http_service.request.assert_called_once_with(
lti_registration,
"GET",
"https://lms.com/lineitems/results?param=1",
scopes=blackboard_svc.LTIA_SCOPES,
Expand All @@ -97,7 +104,7 @@ def test_read_result_blackboard(
)
assert result.comment == misc_plugin.clean_lms_grading_comment.return_value

def test_get_score_maximum(self, svc, ltia_http_service):
def test_get_score_maximum(self, svc, ltia_http_service, lti_registration):
ltia_http_service.request.return_value.json.return_value = [
{"scoreMaximum": sentinel.score_max, "id": svc.line_item_url},
{"scoreMaximum": 1, "id": sentinel.other_lineitem},
Expand All @@ -106,6 +113,7 @@ def test_get_score_maximum(self, svc, ltia_http_service):
score = svc.get_score_maximum(sentinel.resource_link_id)

ltia_http_service.request.assert_called_once_with(
lti_registration,
"GET",
"http://example.com/lineitems",
scopes=svc.LTIA_SCOPES,
Expand All @@ -130,7 +138,9 @@ def test_get_score_maximum_no_line_item(self, svc, ltia_http_service):

@freeze_time("2022-04-04")
@pytest.mark.parametrize("comment", [sentinel.comment, None])
def test_record_result(self, svc, ltia_http_service, comment, misc_plugin):
def test_record_result(
self, svc, ltia_http_service, comment, misc_plugin, lti_registration
):
svc.line_item_url = "https://lms.com/lineitems?param=1"

response = svc.record_result(sentinel.user_id, sentinel.score, comment=comment)
Expand All @@ -148,6 +158,7 @@ def test_record_result(self, svc, ltia_http_service, comment, misc_plugin):
payload["comment"] = misc_plugin.format_grading_comment_for_lms.return_value

ltia_http_service.request.assert_called_once_with(
lti_registration,
"POST",
"https://lms.com/lineitems/scores?param=1",
scopes=svc.LTIA_SCOPES,
Expand Down Expand Up @@ -182,14 +193,15 @@ def test_record_result_raises_StudentNotInCourse(
with pytest.raises(StudentNotInCourse):
svc.record_result(sentinel.user_id, sentinel.score)

def test_create_line_item(self, svc, ltia_http_service):
def test_create_line_item(self, svc, ltia_http_service, lti_registration):
response = svc.create_line_item(
sentinel.resource_link_id,
sentinel.label,
sentinel.score_maximum,
)

ltia_http_service.request.assert_called_once_with(
lti_registration,
"POST",
svc.line_item_container_url,
scopes=svc.LTIA_SCOPES,
Expand All @@ -202,13 +214,14 @@ def test_create_line_item(self, svc, ltia_http_service):
)
assert response == ltia_http_service.request.return_value.json.return_value

def test_record_result_calls_hook(self, svc, ltia_http_service):
def test_record_result_calls_hook(self, svc, ltia_http_service, lti_registration):
my_hook = Mock(return_value={"my_dict": 1})

svc.record_result(sentinel.user_id, score=1.5, pre_record_hook=my_hook)

my_hook.assert_called_once_with(request_body=Any.dict(), score=1.5)
ltia_http_service.request.assert_called_once_with(
lti_registration,
"POST",
"http://example.com/lineitem/scores",
scopes=svc.LTIA_SCOPES,
Expand Down Expand Up @@ -251,21 +264,23 @@ def blackboard_response(self):
]

@pytest.fixture
def svc(self, ltia_http_service, misc_plugin):
def svc(self, ltia_http_service, misc_plugin, lti_registration):
return LTI13GradingService(
"http://example.com/lineitem",
"http://example.com/lineitems",
ltia_http_service,
product_family=Family.CANVAS,
misc_plugin=misc_plugin,
lti_registration=lti_registration,
)

@pytest.fixture
def blackboard_svc(self, ltia_http_service, misc_plugin):
def blackboard_svc(self, ltia_http_service, misc_plugin, lti_registration):
return LTI13GradingService(
"http://example.com/lineitem",
"http://example.com/lineitems",
ltia_http_service,
product_family=Family.BLACKBOARD,
misc_plugin=misc_plugin,
lti_registration=lti_registration,
)
2 changes: 2 additions & 0 deletions tests/unit/lms/services/lti_grading/factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_v13(
ltia_http_service,
pyramid_request.product.family,
misc_plugin,
pyramid_request.lti_user.application_instance.lti_registration,
)
assert svc == LTI13GradingService.return_value

Expand All @@ -48,6 +49,7 @@ def test_v13_line_item_url_from_lti_params(
ltia_http_service,
pyramid_request.product.family,
misc_plugin,
pyramid_request.lti_user.application_instance.lti_registration,
)
assert svc == LTI13GradingService.return_value

Expand Down
Loading

0 comments on commit e9976f3

Please sign in to comment.