diff --git a/WindowsAppRuntime.sln b/WindowsAppRuntime.sln index 434d55a54f..6e50801299 100644 --- a/WindowsAppRuntime.sln +++ b/WindowsAppRuntime.sln @@ -529,6 +529,11 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PackageManager.Test.M.White EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PackageManager.Test.M.White.msix", "test\PackageManager\data\PackageManager.Test.M.White.msix\PackageManager.Test.M.White.msix.vcxproj", "{28DCF9CE-D9F4-4A7D-8AD1-F2EFC0D3B4DF}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "OAuth", "OAuth", "{1A6F936D-7350-4177-8195-146BDDFACF3B}" +EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "OAuth", "dev\OAuth\OAuth.vcxitems", "{3E7FD510-8B66-40E7-A80B-780CB8972F83}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Security.Authentication.OAuth.Projection", "dev\Projections\CS\Microsoft.Security.Authentication.OAuth\Microsoft.Security.Authentication.OAuth.Projection.csproj", "{1D24CC70-85B1-4864-B847-3328F40AF01E}" Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Interop", "Interop", "{3B706C5C-55E0-4B76-BF59-89E20FE46795}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "CameraCaptureUI", "CameraCaptureUI", "{0833D8EF-6E11-4133-B0EE-9B7625CD615E}" @@ -1883,6 +1888,22 @@ Global {28DCF9CE-D9F4-4A7D-8AD1-F2EFC0D3B4DF}.Release|x64.Build.0 = Release|x64 {28DCF9CE-D9F4-4A7D-8AD1-F2EFC0D3B4DF}.Release|x86.ActiveCfg = Release|Win32 {28DCF9CE-D9F4-4A7D-8AD1-F2EFC0D3B4DF}.Release|x86.Build.0 = Release|Win32 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Debug|Any CPU.ActiveCfg = Debug|x64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Debug|Any CPU.Build.0 = Debug|x64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Debug|ARM64.ActiveCfg = Debug|arm64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Debug|ARM64.Build.0 = Debug|arm64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Debug|x64.ActiveCfg = Debug|x64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Debug|x64.Build.0 = Debug|x64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Debug|x86.ActiveCfg = Debug|x86 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Debug|x86.Build.0 = Debug|x86 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Release|Any CPU.ActiveCfg = Release|x64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Release|Any CPU.Build.0 = Release|x64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Release|ARM64.ActiveCfg = Release|arm64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Release|ARM64.Build.0 = Release|arm64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Release|x64.ActiveCfg = Release|x64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Release|x64.Build.0 = Release|x64 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Release|x86.ActiveCfg = Release|x86 + {1D24CC70-85B1-4864-B847-3328F40AF01E}.Release|x86.Build.0 = Release|x86 {1DAA2342-CF55-48E5-B49C-982FA5C07014}.Debug|Any CPU.ActiveCfg = Debug|x64 {1DAA2342-CF55-48E5-B49C-982FA5C07014}.Debug|Any CPU.Build.0 = Debug|x64 {1DAA2342-CF55-48E5-B49C-982FA5C07014}.Debug|ARM64.ActiveCfg = Debug|x64 @@ -2093,6 +2114,9 @@ Global {7C240089-0F22-4247-9C91-51255C8DC18B} = {6213B1A3-E854-498F-AAFA-4CFC1E71023E} {AC79B8FF-4C27-4326-AD20-BBC70059FF51} = {7C240089-0F22-4247-9C91-51255C8DC18B} {28DCF9CE-D9F4-4A7D-8AD1-F2EFC0D3B4DF} = {6759ECC6-9381-4172-89E6-853F81A03D28} + {1A6F936D-7350-4177-8195-146BDDFACF3B} = {448ED2E5-0B37-4D97-9E6B-8C10A507976A} + {3E7FD510-8B66-40E7-A80B-780CB8972F83} = {1A6F936D-7350-4177-8195-146BDDFACF3B} + {1D24CC70-85B1-4864-B847-3328F40AF01E} = {716C26A0-E6B0-4981-8412-D14A4D410531} {3B706C5C-55E0-4B76-BF59-89E20FE46795} = {448ED2E5-0B37-4D97-9E6B-8C10A507976A} {0833D8EF-6E11-4133-B0EE-9B7625CD615E} = {3B706C5C-55E0-4B76-BF59-89E20FE46795} {95409D1E-843F-4316-8D8E-471B3E203F94} = {0833D8EF-6E11-4133-B0EE-9B7625CD615E} @@ -2106,6 +2130,7 @@ Global test\inc\inc.vcxitems*{08bc78e0-63c6-49a7-81b3-6afc3deac4de}*SharedItemsImports = 4 dev\PushNotifications\PushNotifications.vcxitems*{103c0c23-7ba8-4d44-a63c-83488e2e3a81}*SharedItemsImports = 9 dev\EnvironmentManager\API\Microsoft.Process.Environment.vcxitems*{2f3fad1b-d3df-4866-a3a3-c2c777d55638}*SharedItemsImports = 9 + dev\OAuth\OAuth.vcxitems*{3e7fd510-8b66-40e7-a80b-780cb8972f83}*SharedItemsImports = 9 test\inc\inc.vcxitems*{412d023e-8635-4ad2-a0ea-e19e08d36915}*SharedItemsImports = 4 test\inc\inc.vcxitems*{4b30c685-8490-440f-9879-a75d45daa361}*SharedItemsImports = 4 dev\UndockedRegFreeWinRT\UndockedRegFreeWinRT.vcxitems*{56371ca6-144b-4989-a4e9-391ad4fa7651}*SharedItemsImports = 9 diff --git a/build/CopyFilesToStagingDir.ps1 b/build/CopyFilesToStagingDir.ps1 index e2dd07084f..2de517a091 100644 --- a/build/CopyFilesToStagingDir.ps1 +++ b/build/CopyFilesToStagingDir.ps1 @@ -51,6 +51,7 @@ PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windo PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.Media.Capture.winmd $FullPublishDir\Microsoft.WindowsAppRuntime\ PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.PushNotifications.winmd $FullPublishDir\Microsoft.WindowsAppRuntime\ PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.Security.AccessControl.winmd $FullPublishDir\Microsoft.WindowsAppRuntime\ +PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Security.Authentication.OAuth.winmd $FullPublishDir\Microsoft.WindowsAppRuntime\ PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.Storage.winmd $FullPublishDir\Microsoft.WindowsAppRuntime\ PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.System.Power.winmd $FullPublishDir\Microsoft.WindowsAppRuntime\ PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.System.winmd $FullPublishDir\Microsoft.WindowsAppRuntime\ @@ -121,6 +122,8 @@ PublishFile $FullBuildOutput\Microsoft.Windows.PushNotifications.Projection\Micr PublishFile $FullBuildOutput\Microsoft.Windows.PushNotifications.Projection\Microsoft.Windows.PushNotifications.Projection.pdb $NugetDir\lib\net6.0-windows10.0.17763.0 PublishFile $FullBuildOutput\Microsoft.Windows.Security.AccessControl.Projection\Microsoft.Windows.Security.AccessControl.Projection.dll $NugetDir\lib\net6.0-windows10.0.17763.0 PublishFile $FullBuildOutput\Microsoft.Windows.Security.AccessControl.Projection\Microsoft.Windows.Security.AccessControl.Projection.pdb $NugetDir\lib\net6.0-windows10.0.17763.0 +PublishFile $FullBuildOutput\Microsoft.Security.Authentication.OAuth.Projection\Microsoft.Security.Authentication.OAuth.Projection.dll $NugetDir\lib\net6.0-windows10.0.17763.0 +PublishFile $FullBuildOutput\Microsoft.Security.Authentication.OAuth.Projection\Microsoft.Security.Authentication.OAuth.Projection.pdb $NugetDir\lib\net6.0-windows10.0.17763.0 PublishFile $FullBuildOutput\Microsoft.Windows.Storage.Projection\Microsoft.Windows.Storage.Projection.dll $NugetDir\lib\net6.0-windows10.0.17763.0 PublishFile $FullBuildOutput\Microsoft.Windows.Storage.Projection\Microsoft.Windows.Storage.Projection.pdb $NugetDir\lib\net6.0-windows10.0.17763.0 PublishFile $FullBuildOutput\Microsoft.Windows.System.Power.Projection\Microsoft.Windows.System.Power.Projection.dll $NugetDir\lib\net6.0-windows10.0.17763.0 @@ -190,6 +193,7 @@ PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windo PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.Media.Capture.winmd $NugetDir\lib\uap10.0 PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.PushNotifications.winmd $NugetDir\lib\uap10.0 PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.Security.AccessControl.winmd $NugetDir\lib\uap10.0 +PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Security.Authentication.OAuth.winmd $NugetDir\lib\uap10.0 PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.Storage.winmd $NugetDir\lib\uap10.0 PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.System.Power.winmd $NugetDir\lib\uap10.0 PublishFile $FullBuildOutput\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Windows.System.winmd $NugetDir\lib\uap10.0 diff --git a/build/NuSpecs/AppxManifest.xml b/build/NuSpecs/AppxManifest.xml index 78e52452ba..954952fd6e 100644 --- a/build/NuSpecs/AppxManifest.xml +++ b/build/NuSpecs/AppxManifest.xml @@ -30,6 +30,12 @@ Microsoft.WindowsAppRuntime.dll + + + + + + diff --git a/build/NuSpecs/WindowsAppSDK-Nuget-Native.WinRt.props b/build/NuSpecs/WindowsAppSDK-Nuget-Native.WinRt.props index 960dabbc33..068d3d3c49 100644 --- a/build/NuSpecs/WindowsAppSDK-Nuget-Native.WinRt.props +++ b/build/NuSpecs/WindowsAppSDK-Nuget-Native.WinRt.props @@ -53,6 +53,12 @@ $(MSBuildThisFileDirectory)..\..\runtimes\win10-$(_WindowsAppSDKFoundationPlatform)\native\Microsoft.WindowsAppRuntime.dll true + + $(MSBuildThisFileDirectory)..\..\lib\uap10.0\Microsoft.Security.Authentication.OAuth.winmd + $(MSBuildThisFileDirectory)..\..\runtimes\win10-$(_WindowsAppSDKFoundationPlatform)\native\Microsoft.WindowsAppRuntime.dll + true + $(MSBuildThisFileDirectory)..\..\lib\uap10.0\Microsoft.Windows.Media.Capture.winmd diff --git a/build/NuSpecs/WindowsAppSDK-Nuget-Native.targets b/build/NuSpecs/WindowsAppSDK-Nuget-Native.targets index dff01b0bf6..6a1a66bb1f 100644 --- a/build/NuSpecs/WindowsAppSDK-Nuget-Native.targets +++ b/build/NuSpecs/WindowsAppSDK-Nuget-Native.targets @@ -74,6 +74,13 @@ Microsoft.WindowsAppRuntime.dll + + + false + Microsoft.WindowsAppRuntime.dll + + + + + + + + + + + Feature_OAuth + OAuth for the WindowsAppRuntime + AlwaysEnabled + + diff --git a/dev/OAuth/AuthFailure.cpp b/dev/OAuth/AuthFailure.cpp new file mode 100644 index 0000000000..6cb4242adf --- /dev/null +++ b/dev/OAuth/AuthFailure.cpp @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "AuthFailure.h" + +#include + +using namespace std::literals; +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + AuthFailure::AuthFailure(const Uri& responseUri) + { + std::map additionalParams; + + auto parseComponents = [&](const winrt::hstring& str) { + if (str.empty()) + { + return; // Avoid unnecessary construction/activation + } + + for (auto&& entry : WwwFormUrlDecoder(str)) + { + auto name = entry.Name(); + if (name == L"error"sv) + { + m_error = entry.Value(); + } + else if (name == L"error_description"sv) + { + m_errorDescription = entry.Value(); + } + else if (name == L"error_uri"sv) + { + m_errorUri = Uri(entry.Value()); + } + else if (name == L"state"sv) + { + m_state = entry.Value(); + } + else + { + additionalParams.emplace(std::move(name), entry.Value()); + } + } + }; + + parseComponents(responseUri.Query()); + parseComponents(fragment_component(responseUri)); + + m_additionalParams = winrt::single_threaded_map(std::move(additionalParams)).GetView(); + } + + winrt::hstring AuthFailure::Error() + { + return m_error; + } + + winrt::hstring AuthFailure::ErrorDescription() + { + return m_errorDescription; + } + + Uri AuthFailure::ErrorUri() + { + return m_errorUri; + } + + winrt::hstring AuthFailure::State() + { + return m_state; + } + + IMapView AuthFailure::AdditionalParams() + { + return m_additionalParams; + } +} diff --git a/dev/OAuth/AuthFailure.h b/dev/OAuth/AuthFailure.h new file mode 100644 index 0000000000..cd78ff45cd --- /dev/null +++ b/dev/OAuth/AuthFailure.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct AuthFailure : AuthFailureT + { + AuthFailure(const foundation::Uri& responseUri); + + winrt::hstring Error(); + winrt::hstring ErrorDescription(); + foundation::Uri ErrorUri(); + winrt::hstring State(); + collections::IMapView AdditionalParams(); + + private: + winrt::hstring m_error; + winrt::hstring m_errorDescription; + foundation::Uri m_errorUri{ nullptr }; + winrt::hstring m_state; + collections::IMapView m_additionalParams; + }; +} diff --git a/dev/OAuth/AuthRequestAsyncOperation.cpp b/dev/OAuth/AuthRequestAsyncOperation.cpp new file mode 100644 index 0000000000..60571a97e7 --- /dev/null +++ b/dev/OAuth/AuthRequestAsyncOperation.cpp @@ -0,0 +1,582 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "OAuth2Manager.h" +#include "AuthRequestAsyncOperation.h" +#include "AuthRequestResult.h" + +using namespace std::literals; +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; +using namespace winrt::Windows::Security::Cryptography; + +AuthRequestAsyncOperation::AuthRequestAsyncOperation(winrt::hstring& state) +{ + try + { + if (state.empty()) + { + while (true) + { + state = random_base64urlencoded_string(32); + if (try_create_pipe(state)) + { + break; + } + + // 'FILE_FLAG_FIRST_PIPE_INSTANCE' is documented as failing with 'ERROR_ACCESS_DENIED' if a pipe + // with the same name has already been created. + if (auto err = ::GetLastError(); err != ERROR_ACCESS_DENIED) + { + throw winrt::hresult_error(HRESULT_FROM_WIN32(err), + L"Generation of a unique state value unexpectedly failed"); + } + } + } + else if (!try_create_pipe(state)) + { + auto err = ::GetLastError(); + auto msg = + (err == ERROR_ACCESS_DENIED) ? L"Provided state value is not unique" : L"Failed to create named pipe"; + throw winrt::hresult_error(HRESULT_FROM_WIN32(err), msg); + } + + m_overlapped.hEvent = ::CreateEventW(nullptr, true, false, nullptr); + if (!m_overlapped.hEvent) + { + throw winrt::hresult_error(HRESULT_FROM_WIN32(::GetLastError()), L"Failed to create an event"); + } + + m_ptp.reset(::CreateThreadpoolWait(async_callback, this, nullptr)); // Use reset() to initialize + if (!m_ptp) + { + throw winrt::hresult_error(HRESULT_FROM_WIN32(::GetLastError()), L"Failed to create threadpool wait"); + } + connect_to_new_client(); + } + catch (...) + { + // Throwing in a constructor will cause the destructor not to run... + destroy(); + throw; + } +} + +AuthRequestAsyncOperation::AuthRequestAsyncOperation(implementation::AuthRequestParams* params) : + m_params(params->get_strong()) +{ + try + { + // Calling 'finalize' will (1) prevent subsequent changes from being made to the params, (2) validate + // consistency in the parameters that are set, and (3) throw an exception if 'finalize' was previously called by + // someone else. If no exception is thrown, it signals that this object effectively owns the request parameters + // and is able to read and set necessary properties without fear of them being modified by another call + m_params->finalize(); + + if ((m_params->CodeChallengeMethod() != CodeChallengeMethodKind::None) && m_params->CodeChallenge().empty()) + { + m_params->set_code_challenge(winrt::hstring{ random_base64urlencoded_string(32) }); + } + + if (m_params->State().empty()) + { + while (true) + { + winrt::hstring state{ random_base64urlencoded_string(32) }; + if (try_create_pipe(state)) + { + m_params->set_state(state); + break; + } + + // 'FILE_FLAG_FIRST_PIPE_INSTANCE' is documented as failing with 'ERROR_ACCESS_DENIED' if a pipe + // with the same name has already been created. + if (auto err = ::GetLastError(); err != ERROR_ACCESS_DENIED) + { + throw winrt::hresult_error(HRESULT_FROM_WIN32(err), + L"Generation of a unique state value unexpectedly failed"); + } + } + } + else if (!try_create_pipe(m_params->State())) + { + auto err = ::GetLastError(); + auto msg = + (err == ERROR_ACCESS_DENIED) ? L"Provided state value is not unique" : L"Failed to create named pipe"; + throw winrt::hresult_error(HRESULT_FROM_WIN32(err), msg); + } + + m_overlapped.hEvent = ::CreateEventW(nullptr, true, false, nullptr); + if (!m_overlapped.hEvent) + { + throw winrt::hresult_error(HRESULT_FROM_WIN32(::GetLastError()), L"Failed to create an event"); + } + + m_ptp.reset(::CreateThreadpoolWait(async_callback, this, nullptr)); // Use reset() to initialize + if (!m_ptp) + { + throw winrt::hresult_error(HRESULT_FROM_WIN32(::GetLastError()), L"Failed to create threadpool wait"); + } + connect_to_new_client(); + } + catch (...) + { + // Throwing in a constructor will cause the destructor not to run... + destroy(); + throw; + } +} + +AuthRequestAsyncOperation::~AuthRequestAsyncOperation() +{ + destroy(); +} + +void AuthRequestAsyncOperation::destroy() +{ + { + // Expects lock to be held and is required since we haven't ensured all callbacks have completed + std::unique_lock guard{ m_mutex }; + close_pipe(); + } + + // Note that we don't hold the lock here for two reasons. The big reason is that 'WaitForThreadpoolWaitCallbacks may + // wait on a callback trying to acquire the lock. The second reason - and the reason we get away with this - is that + // this code path only gets called on destruction, meaning nothing except callbacks (which we wait for) will access + // or modify object state + if (m_ptp) + { + if (!::SetThreadpoolWaitEx(m_ptp.get(), nullptr, nullptr, nullptr)) + { + // False here means that there's a callback in progress. This would realistically only happen if there was + // a race between the client calling 'Cancel' and someone connecting to the pipe + ::WaitForThreadpoolWaitCallbacks(m_ptp.get(), true); + } + + m_ptp.reset(); + } + + if (m_overlapped.hEvent) + { + ::CloseHandle(m_overlapped.hEvent); + m_overlapped.hEvent = nullptr; + } +} + +void AuthRequestAsyncOperation::close_pipe() +{ + auto lastState = std::exchange(m_state, request_state::closed); + if (lastState == request_state::closed) + { + return; + } + + if (m_pipe) + { + ::CancelIoEx(m_pipe.get(), &m_overlapped); + m_pipe.reset(); + } +} + +winrt::hresult AuthRequestAsyncOperation::ErrorCode() +{ + std::shared_lock guard{ m_mutex }; + return m_error; +} + +uint32_t AuthRequestAsyncOperation::Id() +{ + return 1; // NOTE: This is copying the C++/WinRT implementation +} + +winrt::Windows::Foundation::AsyncStatus AuthRequestAsyncOperation::Status() +{ + std::shared_lock guard{ m_mutex }; + return m_status; +} + +void AuthRequestAsyncOperation::Cancel() +{ + winrt::make_self()->cancel(this); +} + +void AuthRequestAsyncOperation::Close() +{ + // TODO? C++/WinRT does a noop here +} + +AsyncOperationCompletedHandler AuthRequestAsyncOperation::Completed() +{ + std::shared_lock guard{ m_mutex }; + return m_handler; +} + +void AuthRequestAsyncOperation::Completed(const AsyncOperationCompletedHandler& handler) +{ + bool shouldInvoke = false; + { + std::lock_guard guard{ m_mutex }; + if (m_handlerSet) + { + throw winrt::hresult_illegal_delegate_assignment(); + } + + m_handlerSet = true; + if (!handler) + { + WINRT_ASSERT(!m_handler); + return; + } + + if (m_status != winrt::Windows::Foundation::AsyncStatus::Started) + { + shouldInvoke = true; + } + else + { + m_handler = handler; + } + } + + if (shouldInvoke) + { + invoke_handler(handler); + } +} + +AuthRequestResult AuthRequestAsyncOperation::GetResults() +{ + std::shared_lock guard{ m_mutex }; + if (m_status == winrt::Windows::Foundation::AsyncStatus::Completed) + { + return m_result; + } + else if (m_error < 0) + { + throw winrt::hresult_error(m_error); + } + + WINRT_ASSERT(m_status == winrt::Windows::Foundation::AsyncStatus::Started); + throw winrt::hresult_illegal_method_call(); +} + +void AuthRequestAsyncOperation::complete(const Uri& responseUri) +{ + transition_state(winrt::Windows::Foundation::AsyncStatus::Completed, responseUri); +} + +void AuthRequestAsyncOperation::cancel() +{ + transition_state(winrt::Windows::Foundation::AsyncStatus::Canceled, nullptr, HRESULT_FROM_WIN32(ERROR_CANCELLED)); +} + +void AuthRequestAsyncOperation::error(winrt::hresult hr) +{ + transition_state(winrt::Windows::Foundation::AsyncStatus::Error, nullptr, hr); +} + +void AuthRequestAsyncOperation::transition_state(winrt::Windows::Foundation::AsyncStatus status, const Uri& responseUri, winrt::hresult hr) +{ + AsyncOperationCompletedHandler handler; + { + std::lock_guard guard{ m_mutex }; + close_pipe(); + + // State change is initiated by OAuth2Manager and should never happen twice + WINRT_ASSERT(m_status == winrt::Windows::Foundation::AsyncStatus::Started); + m_status = status; + m_error = hr; + + if (responseUri) + { + WINRT_ASSERT(hr >= 0); + m_result = winrt::make(m_params.get(), responseUri); + } + else + { + WINRT_ASSERT(hr < 0); + } + + handler = m_handler; + } + + if (handler) + { + invoke_handler(handler); + } +} + +void CALLBACK AuthRequestAsyncOperation::async_callback(PTP_CALLBACK_INSTANCE, PVOID context, PTP_WAIT, + TP_WAIT_RESULT waitResult) +{ + auto pThis = static_cast(context); + pThis->callback(waitResult); +} + +void AuthRequestAsyncOperation::callback(TP_WAIT_RESULT waitResult) +{ + try + { + request_state currentState; + DWORD bytes = 0; + DWORD overlappedError = ERROR_SUCCESS; + { + std::shared_lock guard{ m_mutex }; + currentState = m_state; + if (currentState == request_state::closed) + { + // Nothing productive we can do if the pipe was closed. This also likely means the result was an error + return; + } + + if (waitResult == WAIT_OBJECT_0) + { + if (!::GetOverlappedResult(m_pipe.get(), &m_overlapped, &bytes, false)) + { + overlappedError = ::GetLastError(); + } + } + } + + switch (currentState) + { + case request_state::connecting: + { + WINRT_ASSERT(waitResult == WAIT_OBJECT_0); // TODO: Is this valid? Maybe when we cancelled? Error? + if (waitResult != WAIT_OBJECT_0) + { + WINRT_ASSERT(waitResult == WAIT_TIMEOUT); + throw winrt::hresult_error(HRESULT_FROM_WIN32(ERROR_TIMEOUT), + L"Timed out waiting for a client to connect to the pipe"); + } + else if (overlappedError != ERROR_SUCCESS) + { + // If ConnectNamedClient failed, assume we hit an unrecoverable failure + throw winrt::hresult_error(HRESULT_FROM_WIN32(overlappedError), + L"Failed waiting for a client to connect to the pipe"); + } + + initiate_read(); + } + break; + + case request_state::reading: + { + if (overlappedError == ERROR_MORE_DATA) + { + // NOTE: Pipe server is effectively single threaded, hence no synchronization needed here + m_pipeReadData.insert(m_pipeReadData.end(), m_pipeReadBuffer, + m_pipeReadBuffer + m_overlapped.InternalHigh); + initiate_read(); // Need more data before we can complete + } + else if ((waitResult != WAIT_OBJECT_0) || (overlappedError != ERROR_SUCCESS)) + { + // Ideally we could assume that read timeouts/failures are fatal, however we don't know if the client is + // trustworthy and we don't want some arbitrary process to bait us into terminating the request + connect_to_new_client(true); + } + else + { + on_read_complete(); + } + } + break; + + default: + WINRT_ASSERT(false); + throw winrt::hresult_error(E_UNEXPECTED, L"Unexpected failure waiting for AuthRequest result"); + break; + } + } + catch (...) + { + winrt::make_self()->error(this, winrt::to_hresult()); + } +} + +bool AuthRequestAsyncOperation::try_create_pipe(const winrt::hstring& state) +{ + // NOTE: Called on construction where no synchronization is needed + auto name = request_pipe_name(state); + m_pipe.reset(::CreateNamedPipeW(name.c_str(), PIPE_ACCESS_INBOUND | FILE_FLAG_FIRST_PIPE_INSTANCE | FILE_FLAG_OVERLAPPED, + PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS, 1, 1024, 1024, 0, nullptr)); + + if (m_pipe) + { + m_pipeName = std::move(name); + return true; + } + + return false; +} + +void AuthRequestAsyncOperation::connect_to_new_client(bool disconnect) +{ + m_pipeReadData.clear(); + + DWORD lastError; + { + std::shared_lock guard{ m_mutex }; + if (m_state == request_state::closed) + { + return; + } + + if (disconnect) + { + [[maybe_unused]] auto disconnectResult = ::DisconnectNamedPipe(m_pipe.get()); + WINRT_ASSERT(disconnectResult); // TODO: Correct if the client disconnected from us? + } + + [[maybe_unused]] auto connectResult = ::ConnectNamedPipe(m_pipe.get(), &m_overlapped); + WINRT_ASSERT(!connectResult); // Only non-zero in asynchronous mode, even if already connected + lastError = ::GetLastError(); + } + + if (lastError == ERROR_PIPE_CONNECTED) + { + // Client already connected + initiate_read(); + } + else if (lastError != ERROR_IO_PENDING) + { + throw winrt::hresult_error(HRESULT_FROM_WIN32(lastError), L"Failed to listen for clients on the pipe"); + } + else + { + { + std::lock_guard guard{ m_mutex }; + if (m_state == request_state::closed) + { + // Don't set the threadpool wait again as we may have just cleared it! + return; + } + + m_state = request_state::connecting; + } + ::SetThreadpoolWait(m_ptp.get(), m_overlapped.hEvent, nullptr); + } +} + +void AuthRequestAsyncOperation::initiate_read() +{ + while (true) + { + BOOL readResult; + { + std::shared_lock guard{ m_mutex }; + if (m_state == request_state::closed) + { + // No pipe to read from + return; + } + + readResult = ::ReadFile(m_pipe.get(), m_pipeReadBuffer, sizeof(m_pipeReadBuffer), nullptr, &m_overlapped); + } + + if (readResult) + { + // Immediate success. No need to wait + on_read_complete(); + break; + } + + auto err = ::GetLastError(); + if (err == ERROR_MORE_DATA) + { + // Partial read successful; save data and continue loop to try and read more data + m_pipeReadData.insert(m_pipeReadData.end(), m_pipeReadBuffer, m_pipeReadBuffer + m_overlapped.InternalHigh); + } + else if (err == ERROR_IO_PENDING) + { + // Reading asynchronously + std::lock_guard guard{ m_mutex }; + if (m_state == request_state::closed) + { + // Simultaneously closed; don't set the threadpool wait as we may have just cleared it! + return; + } + + m_state = request_state::reading; + std::int64_t timeout = std::chrono::duration_cast(-50ms).count(); // 50ms timeout + ::SetThreadpoolWait(m_ptp.get(), m_overlapped.hEvent, reinterpret_cast(&timeout)); + break; + } + else + { + connect_to_new_client(true); + break; + } + } +} + +void AuthRequestAsyncOperation::on_read_complete() +{ + m_pipeReadData.insert(m_pipeReadData.end(), m_pipeReadBuffer, m_pipeReadBuffer + m_overlapped.InternalHigh); + + bool shouldReconnect = true; + try + { + auto expectedState = m_params->State(); + auto encryptedBuffer = CryptographicBuffer::CreateFromByteArray(m_pipeReadData); + auto uriString = decrypt(encryptedBuffer, expectedState); + + // An exception is unlikely (we needed the state from the URI to open the pipe in the first place), but could + // happen if someone is connecting and sending garbage data. We'll catch below, so all is okay + Uri responseUri(uriString); + winrt::hstring state; + auto tryFindState = [&](const winrt::hstring& str) + { + if (str.empty()) + { + return; // Avoid unnecessary construction/activation + } + + for (auto&& entry : WwwFormUrlDecoder(str)) + { + if (entry.Name() == L"state") + { + state = entry.Value(); + break; + } + } + }; + + tryFindState(responseUri.Query()); + if (state.empty()) + { + tryFindState(fragment_component(responseUri)); + } + + if (state == expectedState) + { + shouldReconnect = + winrt::make_self()->try_complete_local(state, responseUri); + } + } + catch (...) + { + // Likely handed bad data; just disconnect and attempt a reconnect + } + + if (shouldReconnect) + { + connect_to_new_client(true); + } + // Otherwise the 'try_complete_local' call should have closed the pipe +} + +void AuthRequestAsyncOperation::invoke_handler(const AsyncOperationCompletedHandler& handler) +{ + try + { + handler(*this, m_status); + } + catch (...) + { + // Just eat exceptions as they're not relevant to the caller at all + } +} diff --git a/dev/OAuth/AuthRequestAsyncOperation.h b/dev/OAuth/AuthRequestAsyncOperation.h new file mode 100644 index 0000000000..07fa7f2ecf --- /dev/null +++ b/dev/OAuth/AuthRequestAsyncOperation.h @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +#include "AuthRequestParams.h" + +struct AuthRequestAsyncOperation : + winrt::implements, + foundation::IAsyncInfo> +{ + AuthRequestAsyncOperation(winrt::hstring& state); + AuthRequestAsyncOperation(oauth::implementation::AuthRequestParams* params); + ~AuthRequestAsyncOperation(); + + // IAsyncInfo + winrt::hresult ErrorCode(); + uint32_t Id(); + foundation::AsyncStatus Status(); + void Cancel(); + void Close(); + + // IAsyncOperation + foundation::AsyncOperationCompletedHandler Completed(); + void Completed(const foundation::AsyncOperationCompletedHandler& handler); + oauth::AuthRequestResult GetResults(); + + // Internal functions + void complete(const foundation::Uri& responseUri); + void cancel(); + void error(winrt::hresult hr); + +private: + enum class request_state + { + closed, + connecting, + reading, + }; + + static void CALLBACK async_callback(PTP_CALLBACK_INSTANCE, PVOID context, PTP_WAIT, TP_WAIT_RESULT waitResult); + void callback(TP_WAIT_RESULT waitResult); + + bool try_create_pipe(const winrt::hstring& state); + void close_pipe(); + void connect_to_new_client(bool disconnect = false); + void initiate_read(); + void on_read_complete(); + + void transition_state(foundation::AsyncStatus status, const foundation::Uri& responseUri = nullptr, + winrt::hresult hr = {}); + void invoke_handler(const foundation::AsyncOperationCompletedHandler& handler); + + void destroy(); + + std::shared_mutex m_mutex; + + winrt::com_ptr m_params; + std::wstring m_pipeName; + wil::unique_handle m_pipe; + request_state m_state = request_state::connecting; + OVERLAPPED m_overlapped = {}; + wil::unique_threadpool_wait m_ptp; + std::vector m_pipeReadData; + std::uint8_t m_pipeReadBuffer[128]; + + // IAsyncOperation state + oauth::AuthRequestResult m_result{ nullptr }; + bool m_handlerSet = false; + foundation::AsyncOperationCompletedHandler m_handler; + foundation::AsyncStatus m_status = foundation::AsyncStatus::Started; + winrt::hresult m_error = {}; +}; diff --git a/dev/OAuth/AuthRequestParams.cpp b/dev/OAuth/AuthRequestParams.cpp new file mode 100644 index 0000000000..ad83df4d8a --- /dev/null +++ b/dev/OAuth/AuthRequestParams.cpp @@ -0,0 +1,254 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "AuthRequestParams.h" + +#include + +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; +using namespace winrt::Windows::Security::Cryptography; +using namespace winrt::Windows::Security::Cryptography::Core; + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + AuthRequestParams::AuthRequestParams(const winrt::hstring& responseType, const winrt::hstring& clientId) : + m_responseType(responseType), + m_clientId(clientId) + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + } + + AuthRequestParams::AuthRequestParams(const winrt::hstring& responseType, const winrt::hstring& clientId, + const Uri& redirectUri) : + m_responseType(responseType), + m_clientId(clientId), + m_redirectUri(redirectUri) + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + } + + oauth::AuthRequestParams AuthRequestParams::CreateForAuthorizationCodeRequest(const winrt::hstring& clientId) + { + return CreateForAuthorizationCodeRequest(clientId, nullptr); + } + + oauth::AuthRequestParams AuthRequestParams::CreateForAuthorizationCodeRequest(const winrt::hstring& clientId, + const Uri& redirectUri) + { + auto result = winrt::make_self(L"code", clientId, redirectUri); + result->m_codeChallengeMethod = CodeChallengeMethodKind::S256; + return *result; + } + + oauth::AuthRequestParams AuthRequestParams::CreateForImplicitRequest(const winrt::hstring& clientId) + { + return winrt::make(L"token", clientId); + } + + oauth::AuthRequestParams AuthRequestParams::CreateForImplicitRequest(const winrt::hstring& clientId, + const Uri& redirectUri) + { + return winrt::make(L"token", clientId, redirectUri); + } + + winrt::hstring AuthRequestParams::ResponseType() + { + std::shared_lock guard{ m_mutex }; + return m_responseType; + } + + void AuthRequestParams::ResponseType(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_responseType = value; + } + + winrt::hstring AuthRequestParams::ClientId() + { + std::shared_lock guard{ m_mutex }; + return m_clientId; + } + + void AuthRequestParams::ClientId(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_clientId = value; + } + + Uri AuthRequestParams::RedirectUri() + { + std::shared_lock guard{ m_mutex }; + return m_redirectUri; + } + + void AuthRequestParams::RedirectUri(const Uri& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_redirectUri = value; + } + + winrt::hstring AuthRequestParams::State() + { + std::shared_lock guard{ m_mutex }; + return m_state; + } + + void AuthRequestParams::State(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_state = value; + } + + winrt::hstring AuthRequestParams::Scope() + { + std::shared_lock guard{ m_mutex }; + return m_scope; + } + + void AuthRequestParams::Scope(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_scope = value; + } + + winrt::hstring AuthRequestParams::CodeChallenge() + { + std::shared_lock guard{ m_mutex }; + return m_codeChallenge; + } + + void AuthRequestParams::CodeChallenge(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_codeChallenge = value; + } + + CodeChallengeMethodKind AuthRequestParams::CodeChallengeMethod() + { + std::shared_lock guard{ m_mutex }; + return m_codeChallengeMethod; + } + + void AuthRequestParams::CodeChallengeMethod(CodeChallengeMethodKind value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_codeChallengeMethod = value; + } + + IMap AuthRequestParams::AdditionalParams() + { + std::shared_lock guard{ m_mutex }; + return *m_additionalParams; + } + + void AuthRequestParams::finalize() + { + std::lock_guard guard{ m_mutex }; + if (m_finalized) + { + throw winrt::hresult_illegal_method_call(L"AuthRequestParams can only be used for a single request call"); + } + + m_finalized = true; + m_additionalParams->lock(); + + if (!m_codeChallenge.empty() && (m_codeChallengeMethod == CodeChallengeMethodKind::None)) + { + throw winrt::hresult_illegal_method_call( + L"'CodeChallenge' cannot be set when 'CodeChallengeMethod' is set to 'None'"); + } + } + + void AuthRequestParams::set_state(winrt::hstring value) + { + std::lock_guard guard{ m_mutex }; + WINRT_ASSERT(m_state.empty()); + m_state = std::move(value); + } + + void AuthRequestParams::set_code_challenge(winrt::hstring value) + { + std::lock_guard guard{ m_mutex }; + WINRT_ASSERT(m_codeChallenge.empty()); + WINRT_ASSERT(m_codeChallengeMethod != CodeChallengeMethodKind::None); + m_codeChallenge = std::move(value); + } + + std::wstring AuthRequestParams::create_url(const Uri& authEndpoint) + { + std::shared_lock guard{ m_mutex }; + WINRT_ASSERT(m_finalized); + + // Per RFC 6749 section 3.1, the auth endpoint URI *MAY* contain a query string, which must be retained + std::wstring result{ authEndpoint.RawUri() }; + if (authEndpoint.Query().empty()) + { + result += L"?state="; + } + else + { + result += L"&state="; + } + + result += Uri::EscapeComponent(m_state); + + if (!m_responseType.empty()) + { + result += L"&response_type="; + result += Uri::EscapeComponent(m_responseType); + } + + if (!m_clientId.empty()) + { + result += L"&client_id="; + result += Uri::EscapeComponent(m_clientId); + } + + if (m_redirectUri) + { + result += L"&redirect_uri="; + result += Uri::EscapeComponent(m_redirectUri.RawUri()); + } + + if (!m_scope.empty()) + { + result += L"&scope="; + result += Uri::EscapeComponent(m_scope); + } + + if (m_codeChallengeMethod == CodeChallengeMethodKind::S256) + { + result += L"&code_challenge_method=S256&code_challenge="; + result += base64urlencode(sha256(m_codeChallenge)); + } + else if (m_codeChallengeMethod == CodeChallengeMethodKind::Plain) + { + result += L"&code_challenge_method=plain&code_challenge="; + result += Uri::EscapeComponent(m_codeChallenge); + } + + if (m_additionalParams) + { + for (auto&& pair : IMap{ *m_additionalParams }) + { + result += L"&"; + result += Uri::EscapeComponent(pair.Key()); + result += L"="; + result += Uri::EscapeComponent(pair.Value()); + } + } + + return result; + } +} diff --git a/dev/OAuth/AuthRequestParams.h b/dev/OAuth/AuthRequestParams.h new file mode 100644 index 0000000000..6cb3ba67a4 --- /dev/null +++ b/dev/OAuth/AuthRequestParams.h @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +#include + +#include "LockableMap.h" + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct AuthRequestParams : AuthRequestParamsT + { + AuthRequestParams(const winrt::hstring& responseType, const winrt::hstring& clientId); + AuthRequestParams(const winrt::hstring& responseType, const winrt::hstring& clientId, + const foundation::Uri& redirectUri); + + static oauth::AuthRequestParams CreateForAuthorizationCodeRequest(const winrt::hstring& clientId); + static oauth::AuthRequestParams CreateForAuthorizationCodeRequest(const winrt::hstring& clientId, + const foundation::Uri& redirectUri); + static oauth::AuthRequestParams CreateForImplicitRequest(const winrt::hstring& clientId); + static oauth::AuthRequestParams CreateForImplicitRequest(const winrt::hstring& clientId, + const foundation::Uri& redirectUri); + + // Interface functions + winrt::hstring ResponseType(); + void ResponseType(const winrt::hstring& value); + winrt::hstring ClientId(); + void ClientId(const winrt::hstring& value); + foundation::Uri RedirectUri(); + void RedirectUri(const foundation::Uri& value); + winrt::hstring State(); + void State(const winrt::hstring& value); + winrt::hstring Scope(); + void Scope(const winrt::hstring& value); + winrt::hstring CodeChallenge(); + void CodeChallenge(const winrt::hstring& value); + oauth::CodeChallengeMethodKind CodeChallengeMethod(); + void CodeChallengeMethod(oauth::CodeChallengeMethodKind value); + collections::IMap AdditionalParams(); + + // Implementation functions + void finalize(); + void set_state(winrt::hstring value); + void set_code_challenge(winrt::hstring value); + std::wstring create_url(const foundation::Uri& authEndpoint); + + private: + void check_not_finalized() + { + // NOTE: Lock should be held when calling + if (m_finalized) + { + throw winrt::hresult_illegal_method_call( + L"AuthRequestParams object cannot be modified after being used to initiate a request"); + } + } + + std::shared_mutex m_mutex; + bool m_finalized = false; + winrt::hstring m_responseType; + winrt::hstring m_clientId; + foundation::Uri m_redirectUri{ nullptr }; + winrt::hstring m_state; + winrt::hstring m_scope; + winrt::hstring m_codeChallenge; + oauth::CodeChallengeMethodKind m_codeChallengeMethod = oauth::CodeChallengeMethodKind::None; + winrt::com_ptr> m_additionalParams = + winrt::make_self>(); + }; +} + +namespace winrt::Microsoft::Security::Authentication::OAuth::factory_implementation +{ + struct AuthRequestParams : AuthRequestParamsT + { + }; +} diff --git a/dev/OAuth/AuthRequestResult.cpp b/dev/OAuth/AuthRequestResult.cpp new file mode 100644 index 0000000000..075f8f7a56 --- /dev/null +++ b/dev/OAuth/AuthRequestResult.cpp @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "AuthFailure.h" +#include "AuthRequestResult.h" +#include "AuthResponse.h" + +#include + +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + AuthRequestResult::AuthRequestResult(AuthRequestParams* params, const Uri& responseUri) : m_responseUri(responseUri) + { + // We first need to figure out if this is a success or failure response + bool isError = false; + bool isSuccess = false; + auto checkComponent = [&](const winrt::hstring& str) { + if (str.empty()) + { + return; // Avoid unnecessary construction/activation + } + + for (auto&& entry : WwwFormUrlDecoder(str)) + { + auto name = entry.Name(); + if ((name == L"code") || (name == L"access_token")) + { + isSuccess = true; + break; + } + else if (name == L"error") + { + isError = true; + break; + } + } + }; + + checkComponent(responseUri.Query()); + if (!isError && !isSuccess) + { + checkComponent(fragment_component(responseUri)); + } + + // If we don't recognize the response as an error, interpret it as success. The application may be using an + // extension that we don't recognize + if (isError) + { + m_failure = winrt::make(m_responseUri); + } + else + { + m_response = winrt::make(params, m_responseUri); + } + } + + Uri AuthRequestResult::ResponseUri() + { + return m_responseUri; + } + + oauth::AuthResponse AuthRequestResult::Response() + { + return m_response; + } + + oauth::AuthFailure AuthRequestResult::Failure() + { + return m_failure; + } +} diff --git a/dev/OAuth/AuthRequestResult.h b/dev/OAuth/AuthRequestResult.h new file mode 100644 index 0000000000..f0a42eca60 --- /dev/null +++ b/dev/OAuth/AuthRequestResult.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +#include "AuthRequestParams.h" + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct AuthRequestResult : AuthRequestResultT + { + AuthRequestResult(AuthRequestParams* params, const foundation::Uri& responseUri); + + foundation::Uri ResponseUri(); + oauth::AuthResponse Response(); + oauth::AuthFailure Failure(); + + private: + foundation::Uri m_responseUri; + oauth::AuthResponse m_response{ nullptr }; + oauth::AuthFailure m_failure{ nullptr }; + }; +} diff --git a/dev/OAuth/AuthResponse.cpp b/dev/OAuth/AuthResponse.cpp new file mode 100644 index 0000000000..d338ee7cd0 --- /dev/null +++ b/dev/OAuth/AuthResponse.cpp @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "AuthResponse.h" + +#include + +using namespace std::literals; +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + AuthResponse::AuthResponse(AuthRequestParams* requestParams, const Uri& responseUri) : + m_requestParams(requestParams->get_strong()) + { + std::map additionalParams; + auto parseComponents = [&](const winrt::hstring& str) { + if (str.empty()) + { + return; // Avoid unnecessary construction/activation + } + + for (auto&& entry : WwwFormUrlDecoder(str)) + { + auto name = entry.Name(); + if (name == L"state"sv) + { + m_state = entry.Value(); + } + else if (name == L"code"sv) + { + m_code = entry.Value(); + } + else if (name == L"access_token"sv) + { + m_accessToken = entry.Value(); + } + else if (name == L"token_type"sv) + { + m_tokenType = entry.Value(); + } + else if (name == L"expires_in"sv) + { + m_expiresIn = entry.Value(); + } + else if (name == L"scope"sv) + { + m_scope = entry.Value(); + } + else + { + additionalParams.emplace(std::move(name), entry.Value()); + } + } + }; + + parseComponents(responseUri.Query()); + parseComponents(fragment_component(responseUri)); + + m_additionalParams = winrt::single_threaded_map(std::move(additionalParams)).GetView(); + } + + winrt::hstring AuthResponse::State() + { + return m_state; + } + + winrt::hstring AuthResponse::Code() + { + return m_code; + } + + winrt::hstring AuthResponse::AccessToken() + { + return m_accessToken; + } + + winrt::hstring AuthResponse::TokenType() + { + return m_tokenType; + } + + winrt::hstring AuthResponse::ExpiresIn() + { + return m_expiresIn; + } + + winrt::hstring AuthResponse::Scope() + { + return m_scope; + } + + IMapView AuthResponse::AdditionalParams() + { + return m_additionalParams; + } +} diff --git a/dev/OAuth/AuthResponse.h b/dev/OAuth/AuthResponse.h new file mode 100644 index 0000000000..8c53755cc5 --- /dev/null +++ b/dev/OAuth/AuthResponse.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +#include "AuthRequestParams.h" + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct AuthResponse : AuthResponseT + { + AuthResponse(AuthRequestParams* params, const foundation::Uri& responseUri); + + winrt::hstring State(); + winrt::hstring Code(); + winrt::hstring AccessToken(); + winrt::hstring TokenType(); + winrt::hstring ExpiresIn(); + winrt::hstring Scope(); + collections::IMapView AdditionalParams(); + + // Implementation functions + const winrt::com_ptr& request_params() const noexcept + { + return m_requestParams; + } + + private: + winrt::com_ptr m_requestParams; + + winrt::hstring m_state; + winrt::hstring m_code; + winrt::hstring m_accessToken; + winrt::hstring m_tokenType; + winrt::hstring m_expiresIn; + winrt::hstring m_scope; + collections::IMapView m_additionalParams; + }; +} diff --git a/dev/OAuth/ClientAuthentication.cpp b/dev/OAuth/ClientAuthentication.cpp new file mode 100644 index 0000000000..9e1dddd74e --- /dev/null +++ b/dev/OAuth/ClientAuthentication.cpp @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "ClientAuthentication.h" + +#include + +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; +using namespace winrt::Windows::Security::Cryptography; +using namespace winrt::Windows::Web::Http::Headers; + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + ClientAuthentication::ClientAuthentication() + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + } + + ClientAuthentication::ClientAuthentication(const HttpCredentialsHeaderValue& authorization) : + m_authorization(authorization) + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + } + + oauth::ClientAuthentication ClientAuthentication::CreateForBasicAuthorization(const winrt::hstring& clientId, + const winrt::hstring& clientSecret) + { + auto authString = clientId + L":" + clientSecret; + auto buffer = CryptographicBuffer::ConvertStringToBinary(authString, BinaryStringEncoding::Utf8); + auto base64Token = CryptographicBuffer::EncodeToBase64String(buffer); + HttpCredentialsHeaderValue header(L"Basic", base64Token); + return winrt::make(header); + } + + HttpCredentialsHeaderValue ClientAuthentication::Authorization() + { + std::shared_lock guard{ m_mutex }; + return m_authorization; + } + + void ClientAuthentication::Authorization(const HttpCredentialsHeaderValue& value) + { + std::lock_guard guard{ m_mutex }; + m_authorization = value; + } + + HttpCredentialsHeaderValue ClientAuthentication::ProxyAuthorization() + { + std::shared_lock guard{ m_mutex }; + return m_proxyAuthorization; + } + + void ClientAuthentication::ProxyAuthorization(const HttpCredentialsHeaderValue& value) + { + std::lock_guard guard{ m_mutex }; + m_proxyAuthorization = value; + } + + winrt::Windows::Foundation::Collections::IMap ClientAuthentication::AdditionalHeaders() + { + std::shared_lock guard{ m_mutex }; + return m_additionalHeaders; + } +} diff --git a/dev/OAuth/ClientAuthentication.h b/dev/OAuth/ClientAuthentication.h new file mode 100644 index 0000000000..0c40e1f40e --- /dev/null +++ b/dev/OAuth/ClientAuthentication.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +#include + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct ClientAuthentication : ClientAuthenticationT + { + ClientAuthentication(); + ClientAuthentication(http::Headers::HttpCredentialsHeaderValue const& authorization); + + static oauth::ClientAuthentication CreateForBasicAuthorization(const winrt::hstring& clientId, + const winrt::hstring& clientSecret); + + http::Headers::HttpCredentialsHeaderValue Authorization(); + void Authorization(http::Headers::HttpCredentialsHeaderValue const& value); + http::Headers::HttpCredentialsHeaderValue ProxyAuthorization(); + void ProxyAuthorization(http::Headers::HttpCredentialsHeaderValue const& value); + collections::IMap AdditionalHeaders(); + + private: + std::shared_mutex m_mutex; + http::Headers::HttpCredentialsHeaderValue m_authorization{ nullptr }; + http::Headers::HttpCredentialsHeaderValue m_proxyAuthorization{ nullptr }; + collections::IMap m_additionalHeaders = + winrt::multi_threaded_map(); + }; +} +namespace winrt::Microsoft::Security::Authentication::OAuth::factory_implementation +{ + struct ClientAuthentication : ClientAuthenticationT + { + }; +} diff --git a/dev/OAuth/Crypto.h b/dev/OAuth/Crypto.h new file mode 100644 index 0000000000..4f83a2e4f4 --- /dev/null +++ b/dev/OAuth/Crypto.h @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +// Helpers using the cryptographic APIs +#pragma once + +#include + +// This function is for encoding binary data into a format that can be safely included in URLs and other web contexts +inline std::wstring base64urlencode(const streams::IBuffer& buffer) +{ + using namespace winrt::Windows::Security::Cryptography; + + std::wstring result = CryptographicBuffer::EncodeToBase64String(buffer).c_str(); + std::replace(result.begin(), result.end(), '+', '-'); + std::replace(result.begin(), result.end(), '/', '_'); + result.erase(std::remove(result.begin(), result.end(), '='), result.end()); + + return result; +} + +inline std::wstring random_base64urlencoded_string(std::uint32_t octets) +{ + using namespace winrt::Windows::Security::Cryptography; + auto buffer = CryptographicBuffer::GenerateRandom(octets); + return base64urlencode(buffer); +} + +// This function computes the SHA-256 hash of a given text string +inline streams::IBuffer sha256(const winrt::hstring& text, + crypto::BinaryStringEncoding encoding = crypto::BinaryStringEncoding::Utf8) +{ + using namespace winrt::Windows::Security::Cryptography; + using namespace winrt::Windows::Security::Cryptography::Core; + + auto algo = HashAlgorithmProvider::OpenAlgorithm(HashAlgorithmNames::Sha256()); + return algo.HashData(CryptographicBuffer::ConvertStringToBinary(text, encoding)); +} + +// This function computes the SHA-256 hash of a given text string and then encodes the result as a base64 string +inline winrt::hstring sha256_base64encoded(const winrt::hstring& text) +{ + auto buffer = sha256(text); + return crypto::CryptographicBuffer::EncodeToBase64String(buffer); +} + +inline std::wstring request_pipe_name(const winrt::hstring& state) +{ + // In order to try and protect the state and auth code, we use a hash of the state value for the pipe name + std::wstring result = LR"^-^(\\.\pipe\oauth\)^-^"; + result += sha256_base64encoded(state); + return result; +} + +// The create_key function generates a symmetric cryptographic key from a given string.This key is used for both encryption and decryption. +inline crypto::Core::CryptographicKey create_key(const winrt::hstring& keyString) +{ + using namespace winrt::Windows::Security::Cryptography; + using namespace winrt::Windows::Security::Cryptography::Core; + using namespace winrt::Windows::Storage::Streams; + + WINRT_ASSERT(!keyString.empty()); + + // AES key must be 128, 192, or 256 bits (16, 24, or 32 bytes). Note that the key doesn't have to make a valid + // string. If we end up slicing a UTF-8 character, that's okay + auto keyBuffer = CryptographicBuffer::ConvertStringToBinary(keyString, BinaryStringEncoding::Utf8); + auto keyBufferBegin = keyBuffer.data(); + auto keyBufferEnd = keyBufferBegin + keyBuffer.Length(); + + // Repeat the key string as necessary to achieve the desired length + std::vector buffer(keyBufferBegin, keyBufferEnd); + auto desiredSize = (buffer.size() <= 16) ? 16 : (buffer.size() <= 24) ? 24 : 32; + while (buffer.size() < desiredSize) + { + buffer.insert(buffer.end(), keyBufferBegin, keyBufferEnd); + } + buffer.resize(desiredSize); + + auto algo = SymmetricKeyAlgorithmProvider::OpenAlgorithm(SymmetricAlgorithmNames::AesEcbPkcs7()); + return algo.CreateSymmetricKey(CryptographicBuffer::CreateFromByteArray(buffer)); +} + +// The encrypt function encrypts a given message using a specified key string. +inline streams::IBuffer encrypt(const winrt::hstring& message, const winrt::hstring& keyString) +{ + using namespace winrt::Windows::Security::Cryptography; + using namespace winrt::Windows::Security::Cryptography::Core; + + auto msgBuffer = CryptographicBuffer::ConvertStringToBinary(message, BinaryStringEncoding::Utf8); + auto key = create_key(keyString); + return CryptographicEngine::Encrypt(key, msgBuffer, nullptr); +} + +// The decrypt function decrypts a given encrypted buffer using a specified key string. +inline winrt::hstring decrypt(const streams::IBuffer& encryptedBuffer, const winrt::hstring& keyString) +{ + using namespace winrt::Windows::Security::Cryptography; + using namespace winrt::Windows::Security::Cryptography::Core; + + auto key = create_key(keyString); + auto decryptedBuffer = CryptographicEngine::Decrypt(key, encryptedBuffer, nullptr); + return CryptographicBuffer::ConvertBinaryToString(BinaryStringEncoding::Utf8, decryptedBuffer); +} diff --git a/dev/OAuth/LockableMap.h b/dev/OAuth/LockableMap.h new file mode 100644 index 0000000000..e8b15315f9 --- /dev/null +++ b/dev/OAuth/LockableMap.h @@ -0,0 +1,303 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once + +#include +#include + +namespace impl +{ + template + inline T default_value() + { + if constexpr (std::is_constructible_v) + { + // Handles classes where we'd otherwise get activation + return T{ nullptr }; + } + else + { + return T{}; + } + } + + template + struct KeyValuePair : winrt::implements, collections::IKeyValuePair> + { + KeyValuePair(KeyT key, ValueT value) : + m_key(std::move(key)), + m_value(std::move(value)) + { + } + + KeyT Key() + { + return m_key; + } + + ValueT Value() + { + return m_value; + } + + private: + + KeyT m_key; + ValueT m_value; + }; + + template + struct LockableMapIterator : winrt::implements, + collections::IIterator>> + { + LockableMapIterator(winrt::com_ptr map, std::size_t version) : m_map(std::move(map)), m_version(version) + { + m_itr = m_map->m_map.begin(); + } + + // IIterator + collections::IKeyValuePair Current() + { + std::shared_lock guard{ m_map->m_mutex }; + check_version(); + if (m_itr == m_map->m_map.end()) + { + throw winrt::hresult_out_of_bounds(); + } + + return winrt::make>(m_itr->first, m_itr->second); + } + + bool HasCurrent() + { + std::shared_lock guard{ m_map->m_mutex }; + check_version(); + return m_itr != m_map->m_map.end(); + } + + std::uint32_t GetMany(winrt::array_view> items) + { + std::shared_lock guard{ m_map->m_mutex }; + check_version(); + + auto end = m_map->m_map.end(); + std::uint32_t result = 0; + for (; (m_itr != end) && (result < items.size()); ++m_itr) + { + items[result++] = winrt::make>(m_itr->first, m_itr->second); + } + + return result; + } + + bool MoveNext() + { + std::shared_lock guard{ m_map->m_mutex }; + check_version(); + + auto end = m_map->m_map.end(); + if (m_itr != end) + { + ++m_itr; + } + + return m_itr != end; + } + + private: + void check_version() + { + if (m_version != m_map->m_version) + { + throw winrt::hresult_changed_state(); + } + } + + winrt::com_ptr m_map; + std::size_t m_version; + typename std::map::const_iterator m_itr; + }; + + template + struct LockableMapView : winrt::implements, + collections::IMapView, + collections::IIterable>> + { + LockableMapView(winrt::com_ptr map, std::size_t version) : m_map(std::move(map)), m_version(version) + { + } + + // IMapView + std::uint32_t Size() + { + std::shared_lock guard{ m_map->m_mutex }; + check_version(); + return static_cast(m_map->m_map.size()); + } + + bool HasKey(const KeyT& key) + { + std::shared_lock guard{ m_map->m_mutex }; + check_version(); + return m_map->m_map.find(key) != m_map->m_map.end(); + } + + ValueT Lookup(const KeyT& key) + { + std::shared_lock guard{ m_map->m_mutex }; + check_version(); + auto itr = m_map->m_map.find(key); + if (itr == m_map->m_map.end()) + { + throw winrt::hresult_out_of_bounds(); + } + + return itr->second; + } + + void Split(collections::IMapView& lhs, collections::IMapView& rhs) + { + // NOTE: Follows C++/WinRT implementation + lhs = nullptr; + rhs = nullptr; + } + + // IIterable + collections::IIterator> First() + { + std::shared_lock guard{ m_map->m_mutex }; + return winrt::make>(m_map, m_version); + } + + private: + void check_version() + { + if (m_version != m_map->m_version) + { + throw winrt::hresult_changed_state(); + } + } + + winrt::com_ptr m_map; + std::size_t m_version; + }; +} + +// Here, "lock" means "prevent further modification." Of course, the objects contained within the map can modified, but +// nothing can be added/removed from the map +template +struct LockableMap : winrt::implements, + collections::IMap, + collections::IIterable>> +{ + friend struct impl::LockableMapIterator; + friend struct impl::LockableMapView; + + // IMap + std::uint32_t Size() + { + std::shared_lock guard{ m_mutex }; + return static_cast(m_map.size()); + } + + void Clear() + { + std::map oldValues; // Release outside of lock + + std::lock_guard guard{ m_mutex }; + check_not_locked(); + m_map.swap(oldValues); + ++m_version; + } + + collections::IMapView GetView() + { + std::shared_lock guard{ m_mutex }; + return winrt::make>(this->get_strong(), m_version); + } + + bool HasKey(const KeyT& value) + { + std::shared_lock guard{ m_mutex }; + return m_map.find(value) != m_map.end(); + } + + bool Insert(const KeyT& key, const ValueT& value) + { + auto removedValue = impl::default_value(); + + std::lock_guard guard{ m_mutex }; + check_not_locked(); + auto [itr, added] = m_map.emplace(key, value); + if (!added) + { + std::swap(removedValue, itr->second); + itr->second = value; + } + ++m_version; + + return !added; + } + + ValueT Lookup(const KeyT& key) + { + std::shared_lock guard{ m_mutex }; + auto itr = m_map.find(key); + if (itr == m_map.end()) + { + throw winrt::hresult_out_of_bounds(); + } + return itr->second; + } + + void Remove(const KeyT& key) + { + typename std::map::node_type node; // Destroy outside of lock + + { + std::lock_guard guard{ m_mutex }; + check_not_locked(); + node = m_map.extract(key); + ++m_version; + } + + if (!node) + { + throw winrt::hresult_out_of_bounds(); + } + } + + // IIterable + collections::IIterator> First() + { + std::shared_lock guard{ m_mutex }; + return winrt::make>(this->get_strong(), m_version); + } + + // Implementation Functions + void lock() + { + std::lock_guard guard{ m_mutex }; + if (m_locked) + { + throw winrt::hresult_illegal_method_call(L"Map has already been locked from modification"); + } + + m_locked = true; + } + +private: + void check_not_locked() + { + // NOTE: Lock should be held when calling + if (m_locked) + { + throw winrt::hresult_illegal_method_call(L"Map has been locked from modification"); + } + } + + std::shared_mutex m_mutex; + std::size_t m_version = 0; + bool m_locked = false; + std::map m_map; +}; diff --git a/dev/OAuth/OAuth.idl b/dev/OAuth/OAuth.idl new file mode 100644 index 0000000000..c73522b7c4 --- /dev/null +++ b/dev/OAuth/OAuth.idl @@ -0,0 +1,480 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. + +#include + +namespace Microsoft.Security.Authentication.OAuth +{ + [contractversion(1)] + apicontract OAuthContract {}; + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + runtimeclass ClientAuthentication + { + ClientAuthentication(); + ClientAuthentication(Windows.Web.Http.Headers.HttpCredentialsHeaderValue authorization); + + static ClientAuthentication CreateForBasicAuthorization(String clientId, String clientSecret); + + // Specifies the 'Authorization' header of the HTTP POST request when requesting a token + Windows.Web.Http.Headers.HttpCredentialsHeaderValue Authorization { get; set; }; + + // Specifies the 'Proxy-Authorization' header of the HTTP POST request when requesting a token + Windows.Web.Http.Headers.HttpCredentialsHeaderValue ProxyAuthorization { get; set; }; + + // Specifies additional header values of the HTTP POST request when requesting a token + Windows.Foundation.Collections.IMap AdditionalHeaders { get; }; + } + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + static runtimeclass OAuth2Manager + { + // Initiates an authorization request in the user's default browser as described by RFC 6749 section 3.1. The + // returned 'IAsyncOperation' will remain in the 'Started' state until it is either cancelled or completed by a + // call to 'CompleteAuthRequest'. This performs authorization of response_type="token". + static Windows.Foundation.IAsyncOperation RequestAuthAsync( + Microsoft.UI.WindowId parentWindowId, + Windows.Foundation.Uri completeAuthEndpoint, + Windows.Foundation.Uri redirectUri); + + // Initiates an authorization request in the user's default browser as described by RFC 6749 section 3.1. The + // returned 'IAsyncOperation' will remain in the 'Started' state until it is either cancelled or completed by a + // call to 'CompleteAuthRequest'.This performs authorization of response_type="token". + static Windows.Foundation.IAsyncOperation RequestAuthAsync( + Microsoft.UI.WindowId parentWindowId, + Windows.Foundation.Uri completeAuthEndpoint); + + // Initiates an authorization request in the user's default browser as described by RFC 6749 section 3.1. The + // returned 'IAsyncOperation' will remain in the 'Started' state until it is either cancelled or completed by a + // call to 'CompleteAuthRequest'. + static Windows.Foundation.IAsyncOperation RequestAuthWithParamsAsync( + Microsoft.UI.WindowId parentWindowId, + Windows.Foundation.Uri authEndpoint, + AuthRequestParams params); + + // Called by the application when the user agent completes an auth request via a redirect Uri. Return value is + // true if an appropriate request could be found and completed. Otherwise returns false indicating that the + // response went unhandled and the application may respond as appropriate. + static Boolean CompleteAuthRequest(Windows.Foundation.Uri responseUri); + + // Initiates an access token request as described by RFC 6749 section 3.2. + static Windows.Foundation.IAsyncOperation RequestTokenAsync( + Windows.Foundation.Uri tokenEndpoint, + TokenRequestParams params); + + // Initiates an access token request as described by RFC 6749 section 3.2. + static Windows.Foundation.IAsyncOperation RequestTokenAsync( + Windows.Foundation.Uri tokenEndpoint, + TokenRequestParams params, + ClientAuthentication clientAuth); + } + + // Correlates to the 'code_challenge_method' as described by section 4.3 of RFC 7636: Proof Key for Code Exchange by + // OAuth Public Clients (https://www.rfc-editor.org/rfc/rfc7636.html#section-4.3) + [contract(OAuthContract, 1), feature(Feature_OAuth)] + enum CodeChallengeMethodKind + { + // Suppresses the use of a code verifier. An error will be thrown if a code challenge string is set when this + // option is used + None = 0, + // Challenge method of "S256" (i.e. SHA256). This is the default unless explicitly set + S256 = 1, + // Challenge method of "plain" (i.e. send as plain text) + Plain = 2, + }; + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + runtimeclass AuthRequestParams + { + // Construct with required parameters + AuthRequestParams(String responseType, String clientId); + // Construct with required parameters as well as a redirect URI, which is frequently specified + AuthRequestParams(String responseType, String clientId, Windows.Foundation.Uri redirectUri); + + // Helper method to create for an authorization code grant request ("code" response type) with required + // parameters, per RFC 6749 section 4.1.1. + static AuthRequestParams CreateForAuthorizationCodeRequest(String clientId); + // Helper method to create for an authorization code grant request ("code" response type) with required + // parameters as well as a redirect URI, which is frequently specified. + static AuthRequestParams CreateForAuthorizationCodeRequest(String clientId, Windows.Foundation.Uri redirectUri); + + // Helper method to create for an implicit grant request ("token" response type) with required parameters, per + // RFC 6749 section 4.2.1. + static AuthRequestParams CreateForImplicitRequest(String clientId); + // Helper method to create for an implicit grant request ("token" response type) with required parameters as + // well as a redirect URI, which is frequently specified. + static AuthRequestParams CreateForImplicitRequest(String clientId, Windows.Foundation.Uri redirectUri); + + // Specifies the required "response_type" parameter of the authorization request. This property is initialized + // by the creation function used ("code" for 'CreateForAuthorizationCodeRequest' and "token" for + // 'CreateForImplicitRequest'). + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.1 and 4.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.1 + String ResponseType { get; set; }; + + // Specifies the required "client_id" parameter of the authorization request. This property is initialized by + // the value provided in the creation function call. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.1 and 4.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.1 + String ClientId { get; set; }; + + // Specifies the optional "redirect_uri" parameter of the authorization request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.1 and 4.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.1 + Windows.Foundation.Uri RedirectUri { get; set; }; + + // Specifies the recommended "state" parameter of the authorization request. Note that although this is not + // required by the OAuth standard, a state value will always be set to correlate requests and responses. This + // parameter can be manually specified, in which case it must be globally unique across the entire system, + // otherwise an error will be thrown. It is therefore recommended to let the API select a value for you as it + // will guarantee that a unique value will be used. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.1 and 4.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.1 + String State { get; set; }; + + // Specifies the optional "scope" parameter of the authorization request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.1 and 4.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.1 + String Scope { get; set; }; + + // Used as the PKCE code verifier. Either this value or a hash of this value will be used to specify the + // "code_challenge" parameter of the authorization request, depending on the value of 'CodeChallengeMethod'. If + // this value is not specified and 'CodeChallengeMethod' is not 'None', a random value will be generated for + // this property. The code verifier will persist all the way through to the token request. + // + // Defined by RFC 7636: Proof Key for Code Exchange by OAuth Public Clients, section 4.1 + // https://www.rfc-editor.org/rfc/rfc7636#section-4.1 + String CodeChallenge{ get; set; }; + + // Specifies the optional "code_challenge_method" parameter of the authorization request. For authorization code + // requests, this value defaults to 'S256'. For implicit requests, this value defaults to 'None' and cannot be + // changed. + // + // Defined by RFC 7636: Proof Key for Code Exchange by OAuth Public Clients, section 4.3 + // https://www.rfc-editor.org/rfc/rfc7636#section-4.3 + CodeChallengeMethodKind CodeChallengeMethod { get; set; }; + + // Additional parameters passed along in the query string of the request URL. + Windows.Foundation.Collections.IMap AdditionalParams { get; }; + } + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + runtimeclass AuthResponse + { + // From the "state" parameter of the authorization response. This property will always be set because a state + // value is always sent with the request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.2 and 4.2.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2 + String State { get; }; + + // From the "code" parameter of the authorization response. Set only if the request was an authorization code + // request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.1.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.2 + String Code { get; }; + + // From the "access_token" parameter of the authorization response. Set only if the request was an implicit + // request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.2.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2 + String AccessToken { get; }; + + // From the "token_type" parameter of the authorization response. Set only if the request was an implicit + // request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.2.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2 + String TokenType { get; }; + + // From the "expires_in" parameter of the authorization response. An optional parameter that may be set only if + // the request was an implicit request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.2.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2 + String ExpiresIn { get; }; // TODO: DateTime? + + // From the "scope" parameter of the authorization response. An optional parameter that may be set only if the + // request was an implicit request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.2.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2 + String Scope { get; }; + + // Additional parameters set by the authorization server in the response URI. + Windows.Foundation.Collections.IMapView AdditionalParams { get; }; + } + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + runtimeclass AuthFailure + { + // From the "error" parameter of the error response. The value of this property will map to a well known string + // specified in RFC 6749 sections 4.1.2.1 and 4.2.2.1, or approved extensions. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.2.1 and 4.2.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1 + String Error { get; }; + + // From the "error_description" parameter of the error response. An optional parameter that, when set, provides + // additional human-readable information intended to assist the developer in understanding the error. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.2.1 and 4.2.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1 + String ErrorDescription { get; }; + + // From the "error_uri" parameter of the error response. An optional parameter that, when set, specifies a URI + // identifying a human-readable webpage intended to assist the developer in understanding the error. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.2.1 and 4.2.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1 + Windows.Foundation.Uri ErrorUri { get; }; + + // From the "state" parameter of the error response. This property will always be set because a state value is + // always sent with the request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.2.1 and 4.2.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.2.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1 + String State { get; }; + + // Additional parameters set by the authorization server in the response URI. + Windows.Foundation.Collections.IMapView AdditionalParams { get; }; + } + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + runtimeclass AuthRequestResult + { + // The raw URI that was used to complete the request. + Windows.Foundation.Uri ResponseUri { get; }; + + // Non-null if the server's response indicates success, otherwise null + AuthResponse Response { get; }; + + // Non-null if the server's response indicates failure, otherwise null + AuthFailure Failure { get; }; + } + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + runtimeclass TokenRequestParams + { + // Construct with required parameters + TokenRequestParams(String grantType); + + // Helper method to create for an authorization code grant request ("authorization_code" grant type), + // initialized with the required parameters extracted from the authorization response, per RFC 6749 section + // 4.1.3. + static TokenRequestParams CreateForAuthorizationCodeRequest(AuthResponse authResponse); + + // Helper method to create for a resource owner password credentials grant request ("password" grant type), + // initialized with the required parameters, per RFC 6749 section 4.3.2. + static TokenRequestParams CreateForResourceOwnerPasswordCredentials(String username, String password); + + // Helper method to create for a client credentials grant request ("client_credentials" grant type), initialized + // with the required parameters, per RFC 6749 section 4.4.2. + static TokenRequestParams CreateForClientCredentials(); + + // Helper method to create for an extension grant request, using the provided URI for the grant type, per RFC + // 6749 section 4.5. + static TokenRequestParams CreateForExtension(Windows.Foundation.Uri extensionUri); + + // Helper method to create for an access token refresh request ("refresh_token" grant type), initialized with + // the required parameters, per RFC 6749 section 6. + static TokenRequestParams CreateForRefreshToken(String refreshToken); + + // Specifies the required "grant_type" parameter of the token request. This property is initialized by the + // creation function used ("authorization_code" for 'CreateForAuthorizationCodeRequest', "password" for + // 'CreateForResourceOwnerPasswordCredentials', "client_credentials" for 'CreateForClientCredentials', + // "refresh_token" for 'CreateForRefreshToken', or the specified URI for 'CreateForExtension'). + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.1.3, 4.3.2, 4.4.2, 4.5, and 6 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.3 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.3.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.4.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.5 + // https://www.rfc-editor.org/rfc/rfc6749#section-6 + String GrantType { get; set; }; + + // Specifies the "code" parameter of the token request. This property is required when the grant type is + // "authorization_code" and is initialized by 'CreateForAuthorizationCodeRequest'. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.1.3 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.3 + String Code { get; set; }; + + // Specifies the "redirect_uri" parameter of the token request. This property is required when the grant type is + // "authorization_code" and a redirect URI was included in the authorization request. This property is + // initialized by 'CreateForAuthorizationCodeRequest'. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.1.3 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.3 + Windows.Foundation.Uri RedirectUri { get; set; }; + + // Specifies the "code_verifier" parameter of the token request. This property is required when the grant type + // is "authorization_code" and a code challenge was included in the authorization request. This property is + // initialized by 'CreateForAuthorizationCodeRequest'. + // + // Defined by RFC 7636: Proof Key for Code Exchange by OAuth Public Clients, section 4.5 + // https://www.rfc-editor.org/rfc/rfc7636#section-4.5 + String CodeVerifier { get; set; }; + + // Specifies the "client_id" parameter of the token request. This property is required when the grant type is + // "authorization_code" and no alternative client authentication is specified. This property is initiated by + // 'CreateForAuthorizationCodeRequest'. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.1.3 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.1.3 + String ClientId { get; set; }; + + // Specifies the "username" parameter of the token request. This property is required when the grant type is + // "password" and is initialized by 'CreateForResourceOwnerPasswordCredentials'. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.3.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.3.2 + String Username { get; set; }; + + // Specifies the "password" parameter of the token request. This property is required when the grant type is + // "password" and is initialized by 'CreateForResourceOwnerPasswordCredentials'. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 4.3.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.3.2 + String Password { get; set; }; + + // Specifies the "scope" parameter of the token request. This property is valid only when the grant type is + // "password", "client_credentials", or "refresh_token". + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, sections 4.3.2, 4.4.2, and 6 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.3.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-4.4.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-6 + String Scope { get; set; }; + + // Specifies the "refresh_token" parameter of the token request. This property is required when the grant type + // is "refresh_token" and is initialized by 'CreateForRefreshToken'. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 6 + // https://www.rfc-editor.org/rfc/rfc6749#section-6 + String RefreshToken { get; set; }; + + // Additional parameters passed along in the HTTP request entity-body. + Windows.Foundation.Collections.IMap AdditionalParams { get; }; + } + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + runtimeclass TokenResponse + { + // From the "access_token" parameter of the token response. A required property that should always be set. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 5.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-5.1 + String AccessToken { get; }; + + // From the "token_type" parameter of the token response. A required property that should always be set. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 5.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-5.1 + String TokenType { get; }; + + // From the "expires_in" parameter of the token response. An optional property that, when set, specifies the + // lifetime of the access token in seconds. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 5.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-5.1 + Double ExpiresIn { get; }; // TODO: DateTime? + + // From the "refresh_token" parameter of the token response. An optional property that, when set, can be used to + // obtain new access tokens using the same authorization grant provided during the request. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 5.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-5.1 + String RefreshToken { get; }; + + // From the "scope" parameter of the token response. An optional property that, when set, describes the scope of + // the access token issued by the authorization server. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 5.1 + // https://www.rfc-editor.org/rfc/rfc6749#section-5.1 + String Scope { get; }; + + // Additional parameters set by the authorization server in the token response. + Windows.Foundation.Collections.IMapView AdditionalParams { get; }; + } + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + enum TokenFailureKind + { + // The server responded with an error response as described by RFC 6749 section 5.2. This means that the failure + // object has an 'Error' string and possibly other specified properties. + ErrorResponse = 0, + + // The HTTP POST request failed. See the 'ErrorCode' property for more details as to why. + HttpFailure = 1, + + // The server responded, but its response was improperly formatted. This could be that the server did not send + // the response as JSON, the response JSON string was improperly formatted, or the response JSON contained + // unexpected object types (e.g. a number when a string is expected, etc.). + InvalidResponse = 2, + }; + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + runtimeclass TokenFailure + { + // Indicates the type of failure that this object describes, which will indicate which properties might be set. + TokenFailureKind Kind { get; }; + + // If 'Kind' was anything other than 'ErrorResponse', + HRESULT ErrorCode { get; }; + + // From the "error" parameter of the error response. The value of this property will map to a well known string + // specified in RFC 6749 section 5.2, or approved extensions. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 5.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + String Error { get; }; + + // From the "error_description" parameter of the error response. An optional parameter that, when set, provides + // additional human-readable information intended to assist the developer in understanding the error. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 5.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + String ErrorDescription { get; }; + + // From the "error_uri" parameter of the error response. An optional parameter that, when set, specifies a URI + // identifying a human-readable webpage intended to assist the developer in understanding the error. + // + // Defined by RFC 6749: The OAuth 2.0 Authorization Framework, section 5.2 + // https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + Windows.Foundation.Uri ErrorUri { get; }; + + // Additional parameters set by the authorization server in the token response. + Windows.Foundation.Collections.IMapView AdditionalParams { get; }; + } + + [contract(OAuthContract, 1), feature(Feature_OAuth)] + runtimeclass TokenRequestResult + { + // The raw HTTP response that was used to complete the request + Windows.Web.Http.HttpResponseMessage ResponseMessage { get; }; + + // Non-null if the server's response indicates success, otherwise null + TokenResponse Response { get; }; + + // Non-null if the server's response indicates failure, otherwise null + TokenFailure Failure { get; }; + } +} diff --git a/dev/OAuth/OAuth.vcxitems b/dev/OAuth/OAuth.vcxitems new file mode 100644 index 0000000000..de6bd8a280 --- /dev/null +++ b/dev/OAuth/OAuth.vcxitems @@ -0,0 +1,89 @@ + + + + $(MSBuildAllProjects);$(MSBuildThisFileFullPath) + true + {3E7FD510-8B66-40E7-A80B-780CB8972F83} + + + + %(AdditionalIncludeDirectories);$(MSBuildThisFileDirectory) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/dev/OAuth/OAuth2Manager.cpp b/dev/OAuth/OAuth2Manager.cpp new file mode 100644 index 0000000000..4ae3e16018 --- /dev/null +++ b/dev/OAuth/OAuth2Manager.cpp @@ -0,0 +1,455 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" +#include "windows.h" +#include + +#include "OAuth2Manager.h" +#include "AuthRequestParams.h" +#include "TokenFailure.h" +#include "TokenRequestParams.h" +#include "TokenRequestResult.h" +#include "TokenResponse.h" +#include "OAuth2ManagerTelemetry.h" + +#include + +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Data::Json; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; +using namespace winrt::Windows::Web::Http; + +namespace winrt::Microsoft::Security::Authentication::OAuth::factory_implementation +{ + IAsyncOperation OAuth2Manager::RequestAuthAsync(winrt::Microsoft::UI::WindowId const& parentWindowId, + const Uri& completeAuthEndpoint, + const Uri& redirectUri) + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + + bool isAppPackaged = m_telemetryHelper.IsPackagedApp(); + PCWSTR appName = m_telemetryHelper.GetAppName().c_str(); + OAuth2ManagerTelemetry::RequestAuthAsyncTriggered(isAppPackaged, appName, true); + + winrt::hstring state; + auto asyncOp = winrt::make_self(state); + + { + std::lock_guard guard{ m_mutex }; + m_pendingAuthRequests.push_back(AuthRequestState{ state, asyncOp }); + } + + try + { + // Pipe server has been successfully set up. Initiate the launch + auto url = create_implicit_url(completeAuthEndpoint, state, redirectUri); + + // Launch browser + execute_shell(parentWindowId, url); + } + catch (...) + { + try_remove(asyncOp.get()); + throw; + } + + return *asyncOp; + } + + IAsyncOperation OAuth2Manager::RequestAuthAsync(winrt::Microsoft::UI::WindowId const& parentWindowId, + const Uri& completeAuthEndpoint) + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + + + bool isAppPackaged = m_telemetryHelper.IsPackagedApp(); + PCWSTR appName = m_telemetryHelper.GetAppName().c_str(); + OAuth2ManagerTelemetry::RequestAuthAsyncTriggered(isAppPackaged, appName, false); + + winrt::hstring state; + auto asyncOp = winrt::make_self(state); + + { + std::lock_guard guard{ m_mutex }; + m_pendingAuthRequests.push_back(AuthRequestState{ state, asyncOp }); + } + + try + { + // Pipe server has been successfully set up. Initiate the launch + auto url = create_implicit_url(completeAuthEndpoint, state, nullptr); + + // Launch browser + execute_shell(parentWindowId, url); + } + catch (...) + { + try_remove(asyncOp.get()); + throw; + } + return *asyncOp; + } + + IAsyncOperation OAuth2Manager::RequestAuthWithParamsAsync(winrt::Microsoft::UI::WindowId const& parentWindowId, + const Uri& authEndpoint, + const oauth::AuthRequestParams& params) + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + + + bool isAppPackaged = m_telemetryHelper.IsPackagedApp(); + PCWSTR appName = m_telemetryHelper.GetAppName().c_str(); + + auto paramsImpl = winrt::get_self(params); + auto asyncOp = winrt::make_self(paramsImpl); + OAuth2ManagerTelemetry::RequestAuthWithParamsAsyncTriggered(isAppPackaged, appName, paramsImpl->ResponseType().c_str()); + + + { + std::lock_guard guard{ m_mutex }; + m_pendingAuthRequests.push_back(AuthRequestState{ params.State(), asyncOp }); + } + + try + { + // Pipe server has been successfully set up. Initiate the launch + auto url = paramsImpl->create_url(authEndpoint); + + // Launch browser + execute_shell(parentWindowId, url); + } + catch (...) + { + try_remove(asyncOp.get()); + throw; + } + + return *asyncOp; + } + + bool OAuth2Manager::CompleteAuthRequest(const Uri& responseUri) + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + + bool isAppPackaged = m_telemetryHelper.IsPackagedApp(); + PCWSTR appName = m_telemetryHelper.GetAppName().c_str(); + OAuth2ManagerTelemetry::CompleteAuthRequestTriggered(isAppPackaged, appName); + // We need to extract the state in order to find the original request + winrt::hstring state; + auto tryFindState = [&](const winrt::hstring& str) + { + if (str.empty()) + { + return; // Avoid unnecessary construction/activation + } + + for (auto&& entry : WwwFormUrlDecoder(str)) + { + if (entry.Name() == L"state") + { + state = entry.Value(); + break; + } + } + }; + + tryFindState(responseUri.Query()); + if (state.empty()) + { + tryFindState(fragment_component(responseUri)); + + // Don't throw an error. It could be the case that the application just blindly calls this function first + if (state.empty()) + { + return false; + } + } + + // First check in our local pending list + if (try_complete_local(state, responseUri)) + { + return true; + } + + // Not found locally; we need to check to see if the request originated in another process + auto pipeName = request_pipe_name(state); + + // We encrypt the URI using the state as the key. This accomplishes a couple things: (1) it helps protect the + // server from another process attaching and sending bogus data, and (2) it helps protect against sending the + // authorization grant information to the wrong client. Both of these points of course become moot if the bad + // party intercepts the state value, and because the state value is somewhat exposed through the browser launch/ + // URL, these steps are intended more as a defense in depth. Other features such as PKCE should be used to + // ensure that codes/tokens are safe in the event that the state is compromised. + auto encryptedUri = encrypt(responseUri.RawUri(), state); + + // When we create the named pipe, we only allow a single pipe instance. This should be fine under normal + // circumstances, however it might be the case that another process attaches to the pipe. This may be + // innocuous - e.g. the browser did multiple redirects - or it could be a bad actor - e.g. a process sending + // random garbage to any pipe it can open or another process specifically targeting oauth. Therefore we make + // multiple attempts to connect to the pipe + wil::unique_handle pipe; + while (true) // TODO: Bound this? Need to remember to return false if we do + { + pipe.reset(::CreateFileW(pipeName.c_str(), GENERIC_WRITE, 0, nullptr, OPEN_EXISTING, 0, nullptr)); + if (pipe) break; + + if (auto err = ::GetLastError(); err != ERROR_PIPE_BUSY) + { + // The pipe no longer exist; e.g. flow already completed, client cancelled, etc. + return false; + } + + if (!::WaitNamedPipeW(pipeName.c_str(), 100)) + { + // 100ms should be enough time to wrap up any business. So either the system is bogged down (perhaps too + // many requests to open the pipe), the pipe was closed, or the pipe was closed and opened by another + // process who isn't being responsive. + return false; + } + } + + ULONG serverPid = 0; + if (::GetNamedPipeServerProcessId(pipe.get(), &serverPid)) + { + ::AllowSetForegroundWindow(serverPid); + + // TODO: We can also possibly verify other things about the server process (exe path, etc.) + } + + DWORD bytesToWrite = encryptedUri.Length(); + DWORD bytesWritten = 0; + if (!::WriteFile(pipe.get(), encryptedUri.data(), bytesToWrite, &bytesWritten, nullptr) || + (bytesWritten != bytesToWrite)) + { + // TODO: Actual error? This could be because the server timed us out... + return false; + } + + // The client should have the URI and the operation should be considered handled + return true; + } + + IAsyncOperation OAuth2Manager::RequestTokenAsync(Uri tokenEndpoint, + oauth::TokenRequestParams params) + { + return RequestTokenAsync(std::move(tokenEndpoint), std::move(params), nullptr); + } + + IAsyncOperation OAuth2Manager::RequestTokenAsync(Uri tokenEndpoint, + oauth::TokenRequestParams params, oauth::ClientAuthentication clientAuth) + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + + bool isAppPackaged = m_telemetryHelper.IsPackagedApp(); + PCWSTR appName = m_telemetryHelper.GetAppName().c_str(); + + auto paramsImpl = winrt::get_self(params); + paramsImpl->finalize(); + OAuth2ManagerTelemetry::RequestTokenAsyncTriggered(isAppPackaged, appName, paramsImpl->GrantType().c_str(), clientAuth ? true : false); + HttpResponseMessage response{ nullptr }; + winrt::hstring responseString; + try + { + HttpClient httpClient; + HttpFormUrlEncodedContent content(winrt::single_threaded_map(paramsImpl->params())); + HttpRequestMessage request(HttpMethod::Post(), tokenEndpoint); + request.Content(content); + + auto headers = request.Headers(); + headers.Accept().ParseAdd(L"application/json"); + + if (clientAuth) + { + if (auto auth = clientAuth.Authorization()) + { + headers.Authorization(auth); + } + + if (auto proxyAuth = clientAuth.ProxyAuthorization()) + { + headers.ProxyAuthorization(proxyAuth); + } + + if (auto map = clientAuth.AdditionalHeaders()) + { + for (auto&& pair : map) + { + if (!headers.TryAppendWithoutValidation(pair.Key(), pair.Value())) + { + // TODO? Why might this fail? Throw? + } + } + } + } + + auto lifetime{ get_strong() }; + auto cancellation = co_await winrt::get_cancellation_token(); + cancellation.enable_propagation(); + + response = co_await httpClient.SendRequestAsync(request); + // TODO: Check status code? + if (!response.IsSuccessStatusCode()) + { + auto status = response.StatusCode(); + HRESULT hr = MAKE_HRESULT(SEVERITY_ERROR, FACILITY_HTTP, static_cast(status)); + co_return implementation::TokenRequestResult::MakeFailure(std::move(response), + TokenFailureKind::HttpFailure, hr); + } + + auto responseContentType = response.Content().Headers().ContentType().MediaType(); + if (responseContentType != L"application/json") + { + co_return implementation::TokenRequestResult::MakeFailure(std::move(response), + TokenFailureKind::InvalidResponse, WEB_E_UNSUPPORTED_FORMAT); + } + + responseString = co_await response.Content().ReadAsStringAsync(); + } + catch (...) + { + LOG_CAUGHT_EXCEPTION(); + co_return implementation::TokenRequestResult::MakeFailure(std::move(response), + TokenFailureKind::HttpFailure, winrt::to_hresult()); + } + + JsonObject jsonObject{ nullptr }; + if (!JsonObject::TryParse(responseString, jsonObject)) + { + co_return implementation::TokenRequestResult::MakeFailure(std::move(response), + TokenFailureKind::InvalidResponse, WEB_E_INVALID_JSON_STRING); + } + else + { + try + { + // Determine if it's a success or error response based on the presence of 'error' + if (jsonObject.HasKey(L"error")) + { + auto failure = winrt::make(jsonObject); + co_return winrt::make(std::move(response), nullptr, + std::move(failure)); + } + else + { + auto success = winrt::make(jsonObject); + co_return winrt::make(std::move(response), std::move(success), + nullptr); + } + } + catch (...) + { + LOG_CAUGHT_EXCEPTION(); + co_return implementation::TokenRequestResult::MakeFailure(std::move(response), + TokenFailureKind::InvalidResponse, winrt::to_hresult()); + } + } + } + + bool OAuth2Manager::try_complete_local(const winrt::hstring& state, const foundation::Uri& responseUri) + { + AuthRequestState requestState; + { + std::lock_guard guard{ m_mutex }; + auto itr = std::find_if(m_pendingAuthRequests.begin(), m_pendingAuthRequests.end(), + [&](auto&& entry) { return entry.state == state; }); + + if (itr != m_pendingAuthRequests.end()) + { + requestState = std::move(*itr); + *itr = std::move(m_pendingAuthRequests.back()); + m_pendingAuthRequests.pop_back(); + } + } + + if (requestState.async_op) + { + // Found locally + requestState.async_op->complete(responseUri); + return true; + } + + return false; + } + + void OAuth2Manager::cancel(AuthRequestAsyncOperation* op) + { + auto requestState = try_remove(op); + if (requestState.async_op) + { + requestState.async_op->cancel(); + } + } + + void OAuth2Manager::error(AuthRequestAsyncOperation* op, winrt::hresult hr) + { + auto requestState = try_remove(op); + if (requestState.async_op) + { + requestState.async_op->error(hr); + } + } + + AuthRequestState OAuth2Manager::try_remove(AuthRequestAsyncOperation* op) + { + std::lock_guard guard{ m_mutex }; + auto itr = std::find_if(m_pendingAuthRequests.begin(), m_pendingAuthRequests.end(), + [&](auto&& entry) { return entry.async_op.get() == op; }); + + AuthRequestState result; + if (itr != m_pendingAuthRequests.end()) + { + result = std::move(*itr); + *itr = std::move(m_pendingAuthRequests.back()); + m_pendingAuthRequests.pop_back(); + } + + return result; + } + + std::wstring OAuth2Manager::create_implicit_url(const foundation::Uri& completeAuthEndpoint, const winrt::hstring& state, const foundation::Uri& redirectUri) + { + std::lock_guard guard{ m_mutex }; + // Per RFC 6749 section 3.1, the auth endpoint URI *MAY* contain a query string, which must be retained + std::wstring result{ completeAuthEndpoint.RawUri() }; + if (completeAuthEndpoint.Query().empty()) + { + result += L"?state="; + } + else + { + result += L"&state="; + } + result += Uri::EscapeComponent(state); + result += L"&response_type=token"; + + if (redirectUri) + { + result += L"&redirect_uri="; + result += Uri::EscapeComponent(redirectUri.RawUri()); + } + return result; + } + + void OAuth2Manager::execute_shell(winrt::Microsoft::UI::WindowId const& parentWindowId, const std::wstring& url) + { + // Convert parentWindowId to HWND + HWND hwndParent = reinterpret_cast(parentWindowId.Value); + + SHELLEXECUTEINFO sei = { sizeof(sei) }; + sei.fMask = SEE_MASK_NOCLOSEPROCESS; + sei.hwnd = hwndParent; + sei.lpVerb = L"open"; + sei.lpFile = url.c_str(); + sei.lpParameters = nullptr; + sei.lpDirectory = nullptr; + sei.nShow = SW_SHOWDEFAULT; + sei.hInstApp = nullptr; + + if (!ShellExecuteExW(&sei)) + { + throw winrt::hresult_error(HRESULT_FROM_WIN32(::GetLastError()), L"Failed to launch browser"); + } + } +} diff --git a/dev/OAuth/OAuth2Manager.h b/dev/OAuth/OAuth2Manager.h new file mode 100644 index 0000000000..fc1f7012ac --- /dev/null +++ b/dev/OAuth/OAuth2Manager.h @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +#include "AuthRequestAsyncOperation.h" +#include "TelemetryHelper.h" + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct OAuth2Manager; +} + +namespace winrt::Microsoft::Security::Authentication::OAuth::factory_implementation +{ + struct AuthRequestState + { + winrt::hstring state; + winrt::com_ptr async_op; + }; + + struct OAuth2Manager : OAuth2ManagerT + { + foundation::IAsyncOperation RequestAuthAsync( + winrt::Microsoft::UI::WindowId const& parentWindowId, const foundation::Uri& completeAuthEndpoint, const foundation::Uri& redirectUri); + foundation::IAsyncOperation RequestAuthAsync( + winrt::Microsoft::UI::WindowId const& parentWindowId, const foundation::Uri& completeAuthEndpoint); + foundation::IAsyncOperation RequestAuthWithParamsAsync( + winrt::Microsoft::UI::WindowId const& parentWindowId, const foundation::Uri& authEndpoint, const oauth::AuthRequestParams& params); + bool CompleteAuthRequest(const foundation::Uri& responseUri); + foundation::IAsyncOperation RequestTokenAsync(foundation::Uri tokenEndpoint, + oauth::TokenRequestParams params); + foundation::IAsyncOperation RequestTokenAsync(foundation::Uri tokenEndpoint, + oauth::TokenRequestParams params, oauth::ClientAuthentication clientAuth); + + // Implementation functions + bool try_complete_local(const winrt::hstring& state, const foundation::Uri& responseUri); + void cancel(AuthRequestAsyncOperation* op); + void error(AuthRequestAsyncOperation* op, winrt::hresult hr); + + private: + AuthRequestState try_remove(AuthRequestAsyncOperation* op); + + std::wstring create_implicit_url(const foundation::Uri& completeAuthEndpoint, const winrt::hstring& state, const foundation::Uri& redirectUri); + void execute_shell(winrt::Microsoft::UI::WindowId const& parentWindowId, const std::wstring& url); + std::shared_mutex m_mutex; + TelemetryHelper m_telemetryHelper; + std::vector m_pendingAuthRequests; + }; +} + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct OAuth2Manager + { + static foundation::IAsyncOperation RequestAuthAsync( + winrt::Microsoft::UI::WindowId const& parentWindowId, + foundation::Uri completeAuthEndpoint, foundation::Uri redirectUri) + { + return winrt::make_self()->RequestAuthAsync(parentWindowId, + completeAuthEndpoint, + redirectUri); + } + + static foundation::IAsyncOperation RequestAuthAsync( + winrt::Microsoft::UI::WindowId const& parentWindowId, + foundation::Uri completeAuthEndpoint) + { + return winrt::make_self()->RequestAuthAsync(parentWindowId, + completeAuthEndpoint); + } + + static foundation::IAsyncOperation RequestAuthWithParamsAsync( + winrt::Microsoft::UI::WindowId const& parentWindowId, + foundation::Uri authEndpoint, oauth::AuthRequestParams params) + { + return winrt::make_self()->RequestAuthWithParamsAsync(parentWindowId, + authEndpoint, + params); + } + + static bool CompleteAuthRequest(const foundation::Uri& responseUri) + { + return winrt::make_self()->CompleteAuthRequest(responseUri); + } + + static foundation::IAsyncOperation RequestTokenAsync(foundation::Uri tokenEndpoint, + oauth::TokenRequestParams params) + { + return winrt::make_self()->RequestTokenAsync(std::move(tokenEndpoint), + std::move(params)); + } + + static foundation::IAsyncOperation RequestTokenAsync(foundation::Uri tokenEndpoint, + oauth::TokenRequestParams params, oauth::ClientAuthentication clientAuth) + { + return winrt::make_self()->RequestTokenAsync(std::move(tokenEndpoint), + std::move(params), std::move(clientAuth)); + } + }; +} diff --git a/dev/OAuth/OAuth2ManagerTelemetry.h b/dev/OAuth/OAuth2ManagerTelemetry.h new file mode 100644 index 0000000000..b2e3d3a5cc --- /dev/null +++ b/dev/OAuth/OAuth2ManagerTelemetry.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#pragma once +#include "..\WindowsAppRuntime_Insights\WindowsAppRuntimeInsights.h" + +class OAuth2ManagerTelemetry : public wil::TraceLoggingProvider +{ + IMPLEMENT_TRACELOGGING_CLASS_WITH_MICROSOFT_TELEMETRY(OAuth2ManagerTelemetry, "Microsoft.WindowsAppSDK.OAuth2ManagerTelemetry", + (0x27d8ee3f, 0xd704, 0x45d6, 0xb6, 0x6c, 0x1d, 0xad, 0x95, 0x79, 0x5c, 0xe5)); + //{27d8ee3f-d704-45d6-b66c-1dad95795ce5} +public: + DEFINE_COMPLIANT_MEASURES_EVENT_PARAM3(RequestAuthAsyncTriggered, PDT_ProductAndServicePerformance, + bool, IsAppPackaged, PCWSTR, AppName, bool, IsRedirectURIPassed); + + DEFINE_COMPLIANT_MEASURES_EVENT_PARAM3(RequestAuthWithParamsAsyncTriggered, PDT_ProductAndServicePerformance, + bool, IsAppPackaged, PCWSTR, AppName, PCWSTR, ResponseType); + + DEFINE_COMPLIANT_MEASURES_EVENT_PARAM2(CompleteAuthRequestTriggered, PDT_ProductAndServiceUsage, + bool, IsAppPackaged, PCWSTR, AppName); + DEFINE_COMPLIANT_MEASURES_EVENT_PARAM4(RequestTokenAsyncTriggered, PDT_ProductAndServiceUsage, + bool, IsAppPackaged, PCWSTR, AppName, PCWSTR, GrantType, bool, IsClientAuthPassed); +}; diff --git a/dev/OAuth/TokenFailure.cpp b/dev/OAuth/TokenFailure.cpp new file mode 100644 index 0000000000..dc5665885c --- /dev/null +++ b/dev/OAuth/TokenFailure.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "TokenFailure.h" + +#include + +using namespace std::literals; +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Data::Json; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + TokenFailure::TokenFailure(TokenFailureKind kind, winrt::hresult code) : m_kind(kind), m_errorCode(code) {} + + TokenFailure::TokenFailure(const JsonObject& jsonObject) : + m_kind(TokenFailureKind::ErrorResponse), + m_errorCode(E_FAIL) + { + std::map additionalParams; + + // NOTE: Functions like 'GetString' will throw if the value is not the requested type, so the calling code must + // be ready to handle such failures + for (auto&& pair : jsonObject) + { + auto name = pair.Key(); + if (name == L"error"sv) + { + m_error = pair.Value().GetString(); + // TODO: Use the error string to set a more accurate HRESULT? + } + else if (name == L"error_description"sv) + { + m_errorDescription = pair.Value().GetString(); + } + else if (name == L"error_uri"sv) + { + m_errorUri = Uri(pair.Value().GetString()); + } + else + { + additionalParams.emplace(std::move(name), pair.Value()); + } + } + + m_additionalParams = winrt::single_threaded_map(std::move(additionalParams)).GetView(); + } + + TokenFailureKind TokenFailure::Kind() + { + return m_kind; + } + + winrt::hresult TokenFailure::ErrorCode() + { + return m_errorCode; + } + + winrt::hstring TokenFailure::Error() + { + return m_error; + } + + winrt::hstring TokenFailure::ErrorDescription() + { + return m_errorDescription; + } + + Uri TokenFailure::ErrorUri() + { + return m_errorUri; + } + + IMapView TokenFailure::AdditionalParams() + { + return m_additionalParams; + } +} diff --git a/dev/OAuth/TokenFailure.h b/dev/OAuth/TokenFailure.h new file mode 100644 index 0000000000..0c26fb873e --- /dev/null +++ b/dev/OAuth/TokenFailure.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct TokenFailure : TokenFailureT + { + TokenFailure(TokenFailureKind kind, winrt::hresult code); + TokenFailure(const json::JsonObject& jsonObject); + + TokenFailureKind Kind(); + winrt::hresult ErrorCode(); + winrt::hstring Error(); + winrt::hstring ErrorDescription(); + foundation::Uri ErrorUri(); + collections::IMapView AdditionalParams(); + + private: + TokenFailureKind m_kind; + winrt::hresult m_errorCode; + winrt::hstring m_error; + winrt::hstring m_errorDescription; + foundation::Uri m_errorUri{ nullptr }; + collections::IMapView m_additionalParams; + }; +} diff --git a/dev/OAuth/TokenRequestParams.cpp b/dev/OAuth/TokenRequestParams.cpp new file mode 100644 index 0000000000..bfb7682f94 --- /dev/null +++ b/dev/OAuth/TokenRequestParams.cpp @@ -0,0 +1,240 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "AuthResponse.h" +#include "TokenRequestParams.h" + +#include + +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Foundation; +using namespace Collections; + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + TokenRequestParams::TokenRequestParams(const winrt::hstring& grantType) : m_grantType(grantType) + { + THROW_HR_IF(E_NOTIMPL, !::Microsoft::Security::Authentication::OAuth::Feature_OAuth::IsEnabled()); + } + + oauth::TokenRequestParams TokenRequestParams::CreateForAuthorizationCodeRequest( + const oauth::AuthResponse& authResponse) + { + auto result = winrt::make_self(L"authorization_code"); + result->m_code = authResponse.Code(); + + auto implResponse = winrt::get_self(authResponse); + auto requestParams = implResponse->request_params(); + if (auto redirectUri = requestParams->RedirectUri()) + { + result->m_redirectUri = std::move(redirectUri); + } + + if (auto clientId = requestParams->ClientId(); !clientId.empty()) + { + result->m_clientId = std::move(clientId); + } + + if (auto codeVerifier = requestParams->CodeChallenge(); !codeVerifier.empty()) + { + result->m_codeVerifier = std::move(codeVerifier); + } + + return *result; + } + + oauth::TokenRequestParams TokenRequestParams::CreateForResourceOwnerPasswordCredentials( + const winrt::hstring& username, const winrt::hstring& password) + { + auto result = winrt::make_self(L"password"); + result->m_username = username; + result->m_password = password; + + return *result; + } + + oauth::TokenRequestParams TokenRequestParams::CreateForClientCredentials() + { + return winrt::make(L"client_credentials"); + } + + oauth::TokenRequestParams TokenRequestParams::CreateForExtension(const Uri& extensionUri) + { + return winrt::make(extensionUri.RawUri()); + } + + oauth::TokenRequestParams TokenRequestParams::CreateForRefreshToken(const winrt::hstring& refreshToken) + { + auto result = winrt::make_self(L"refresh_token"); + result->m_refreshToken = refreshToken; + + return *result; + } + + winrt::hstring TokenRequestParams::GrantType() + { + std::shared_lock guard{ m_mutex }; + return m_grantType; + } + + void TokenRequestParams::GrantType(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_grantType = value; + } + + winrt::hstring TokenRequestParams::Code() + { + std::shared_lock guard{ m_mutex }; + return m_code; + } + + void TokenRequestParams::Code(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_code = value; + } + + Uri TokenRequestParams::RedirectUri() + { + std::shared_lock guard{ m_mutex }; + return m_redirectUri; + } + + void TokenRequestParams::RedirectUri(const Uri& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_redirectUri = value; + } + + winrt::hstring TokenRequestParams::CodeVerifier() + { + std::shared_lock guard{ m_mutex }; + return m_codeVerifier; + } + + void TokenRequestParams::CodeVerifier(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_codeVerifier = value; + } + + winrt::hstring TokenRequestParams::ClientId() + { + std::shared_lock guard{ m_mutex }; + return m_clientId; + } + + void TokenRequestParams::ClientId(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_clientId = value; + } + + winrt::hstring TokenRequestParams::Username() + { + std::shared_lock guard{ m_mutex }; + return m_username; + } + + void TokenRequestParams::Username(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_username = value; + } + + winrt::hstring TokenRequestParams::Password() + { + std::shared_lock guard{ m_mutex }; + return m_password; + } + + void TokenRequestParams::Password(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_password = value; + } + + winrt::hstring TokenRequestParams::Scope() + { + std::shared_lock guard{ m_mutex }; + return m_scope; + } + + void TokenRequestParams::Scope(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_scope = value; + } + + winrt::hstring TokenRequestParams::RefreshToken() + { + std::shared_lock guard{ m_mutex }; + return m_refreshToken; + } + + void TokenRequestParams::RefreshToken(const winrt::hstring& value) + { + std::lock_guard guard{ m_mutex }; + check_not_finalized(); + m_refreshToken = value; + } + + IMap TokenRequestParams::AdditionalParams() + { + std::shared_lock guard{ m_mutex }; + return *m_additionalParams; + } + + void TokenRequestParams::finalize() + { + std::lock_guard guard{ m_mutex }; + if (m_finalized) + { + throw winrt::hresult_illegal_method_call(L"TokenRequestParams can only be used for a single request call"); + } + + m_finalized = true; + m_additionalParams->lock(); + } + + std::map TokenRequestParams::params() + { + // HttpFormUrlEncodedContent requires an IIterable> as input. In theory we can + // make the TokenRequestParams implement this type to save on some work, however this may be a little tricky + std::map result; + auto addIfSet = [&](std::wstring_view key, const winrt::hstring& value) { + if (!value.empty()) + { + result.emplace(key, value); + } + }; + + std::shared_lock guard{ m_mutex }; + addIfSet(L"grant_type", m_grantType); + addIfSet(L"code", m_code); + if (m_redirectUri) result.emplace(L"redirect_uri", m_redirectUri.RawUri()); + addIfSet(L"code_verifier", m_codeVerifier); + addIfSet(L"client_id", m_clientId); + addIfSet(L"username", m_username); + addIfSet(L"password", m_password); + addIfSet(L"scope", m_scope); + addIfSet(L"refresh_token", m_refreshToken); + for (auto&& pair : IMap{ *m_additionalParams }) + { + result.emplace(pair.Key(), pair.Value()); + } + + return result; + } +} diff --git a/dev/OAuth/TokenRequestParams.h b/dev/OAuth/TokenRequestParams.h new file mode 100644 index 0000000000..a8b434b9c0 --- /dev/null +++ b/dev/OAuth/TokenRequestParams.h @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +#include + +#include "LockableMap.h" + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct TokenRequestParams : TokenRequestParamsT + { + TokenRequestParams() = default; + TokenRequestParams(const winrt::hstring& grantType); + + static oauth::TokenRequestParams CreateForAuthorizationCodeRequest(const oauth::AuthResponse& authResponse); + static oauth::TokenRequestParams CreateForResourceOwnerPasswordCredentials(const winrt::hstring& username, + const winrt::hstring& password); + static oauth::TokenRequestParams CreateForClientCredentials(); + static oauth::TokenRequestParams CreateForExtension(const foundation::Uri& extensionUri); + static oauth::TokenRequestParams CreateForRefreshToken(const winrt::hstring& refreshToken); + + winrt::hstring GrantType(); + void GrantType(const winrt::hstring& value); + winrt::hstring Code(); + void Code(const winrt::hstring& value); + foundation::Uri RedirectUri(); + void RedirectUri(const foundation::Uri& value); + winrt::hstring CodeVerifier(); + void CodeVerifier(const winrt::hstring& value); + winrt::hstring ClientId(); + void ClientId(const winrt::hstring& value); + winrt::hstring Username(); + void Username(const winrt::hstring& value); + winrt::hstring Password(); + void Password(const winrt::hstring& value); + winrt::hstring Scope(); + void Scope(const winrt::hstring& value); + winrt::hstring RefreshToken(); + void RefreshToken(const winrt::hstring& value); + collections::IMap AdditionalParams(); + + // Implementation functions + void finalize(); + std::map params(); + + private: + void check_not_finalized() + { + // NOTE: Lock should be held when calling + if (m_finalized) + { + throw winrt::hresult_illegal_method_call( + L"TokenRequestParams object cannot be modified after being used to initiate a request"); + } + } + + std::shared_mutex m_mutex; + bool m_finalized = false; + winrt::hstring m_grantType; + winrt::hstring m_code; + foundation::Uri m_redirectUri{ nullptr }; + winrt::hstring m_codeVerifier; + winrt::hstring m_clientId; + winrt::hstring m_username; + winrt::hstring m_password; + winrt::hstring m_scope; + winrt::hstring m_refreshToken; + winrt::com_ptr> m_additionalParams = + winrt::make_self>(); + }; +} + +namespace winrt::Microsoft::Security::Authentication::OAuth::factory_implementation +{ + struct TokenRequestParams : TokenRequestParamsT + { + }; +} diff --git a/dev/OAuth/TokenRequestResult.cpp b/dev/OAuth/TokenRequestResult.cpp new file mode 100644 index 0000000000..8fc61d1ed8 --- /dev/null +++ b/dev/OAuth/TokenRequestResult.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "TokenFailure.h" +#include "TokenRequestResult.h" +#include "TokenResponse.h" + +#include + +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Data::Json; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; +using namespace winrt::Windows::Web::Http; + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + TokenRequestResult::TokenRequestResult(HttpResponseMessage responseMessage, oauth::TokenResponse response, + oauth::TokenFailure failure) : + m_responseMessage(std::move(responseMessage)), + m_response(std::move(response)), + m_failure(std::move(failure)) + { + } + + oauth::TokenRequestResult TokenRequestResult::MakeFailure(HttpResponseMessage response, + TokenFailureKind failureKind, winrt::hresult failureCode) + { + return winrt::make(std::move(response), nullptr, + winrt::make(failureKind, failureCode)); + } + + HttpResponseMessage TokenRequestResult::ResponseMessage() + { + return m_responseMessage; + } + + oauth::TokenResponse TokenRequestResult::Response() + { + return m_response; + } + + oauth::TokenFailure TokenRequestResult::Failure() + { + return m_failure; + } +} diff --git a/dev/OAuth/TokenRequestResult.h b/dev/OAuth/TokenRequestResult.h new file mode 100644 index 0000000000..9edb8fcd15 --- /dev/null +++ b/dev/OAuth/TokenRequestResult.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct TokenRequestResult : TokenRequestResultT + { + TokenRequestResult(http::HttpResponseMessage responseMessage, oauth::TokenResponse resposne, + oauth::TokenFailure failure); + + static oauth::TokenRequestResult MakeFailure(http::HttpResponseMessage response, TokenFailureKind failureKind, + winrt::hresult failureCode); + + http::HttpResponseMessage ResponseMessage(); + oauth::TokenResponse Response(); + oauth::TokenFailure Failure(); + + private: + http::HttpResponseMessage m_responseMessage; + oauth::TokenResponse m_response{ nullptr }; + oauth::TokenFailure m_failure{ nullptr }; + }; +} diff --git a/dev/OAuth/TokenResponse.cpp b/dev/OAuth/TokenResponse.cpp new file mode 100644 index 0000000000..471a6d03c5 --- /dev/null +++ b/dev/OAuth/TokenResponse.cpp @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#include +#include "common.h" + +#include "TokenResponse.h" + +#include + +using namespace winrt::Microsoft::Security::Authentication::OAuth; +using namespace winrt::Windows::Data::Json; +using namespace winrt::Windows::Foundation; +using namespace winrt::Windows::Foundation::Collections; + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + TokenResponse::TokenResponse(const json::JsonObject& jsonObject) + { + std::map additionalParams; + + // NOTE: Functions like 'GetString' will throw if the value is not the requested type. It might be worth + // revisiting this in the future + for (auto&& pair : jsonObject) + { + auto name = pair.Key(); + if (name == L"access_token") + { + m_accessToken = pair.Value().GetString(); + } + else if (name == L"token_type") + { + m_tokenType = pair.Value().GetString(); + } + else if (name == L"expires_in") + { + m_expiresIn = pair.Value().GetNumber(); + } + else if (name == L"refresh_token") + { + m_refreshToken = pair.Value().GetString(); + } + else if (name == L"scope") + { + m_scope = pair.Value().GetString(); + } + else + { + additionalParams.emplace(std::move(name), pair.Value()); + } + } + + m_additionalParams = winrt::single_threaded_map(std::move(additionalParams)).GetView(); + } + + winrt::hstring TokenResponse::AccessToken() + { + return m_accessToken; + } + + winrt::hstring TokenResponse::TokenType() + { + return m_tokenType; + } + + double TokenResponse::ExpiresIn() + { + return m_expiresIn; + } + + winrt::hstring TokenResponse::RefreshToken() + { + return m_refreshToken; + } + + winrt::hstring TokenResponse::Scope() + { + return m_scope; + } + + IMapView TokenResponse::AdditionalParams() + { + return m_additionalParams; + } +} diff --git a/dev/OAuth/TokenResponse.h b/dev/OAuth/TokenResponse.h new file mode 100644 index 0000000000..f5cff92ae8 --- /dev/null +++ b/dev/OAuth/TokenResponse.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once +#include + +namespace winrt::Microsoft::Security::Authentication::OAuth::implementation +{ + struct TokenResponse : TokenResponseT + { + TokenResponse(const json::JsonObject& jsonObject); + + winrt::hstring AccessToken(); + winrt::hstring TokenType(); + double ExpiresIn(); + winrt::hstring RefreshToken(); + winrt::hstring Scope(); + collections::IMapView AdditionalParams(); + + private: + winrt::hstring m_accessToken; + winrt::hstring m_tokenType; + double m_expiresIn; + winrt::hstring m_refreshToken; + winrt::hstring m_scope; + collections::IMapView m_additionalParams; + }; +} diff --git a/dev/OAuth/common.h b/dev/OAuth/common.h new file mode 100644 index 0000000000..882516f4e0 --- /dev/null +++ b/dev/OAuth/common.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See LICENSE in the project root for license information. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace collections = winrt::Windows::Foundation::Collections; +namespace crypto = winrt::Windows::Security::Cryptography; +namespace foundation = winrt::Windows::Foundation; +namespace http = winrt::Windows::Web::Http; +namespace json = winrt::Windows::Data::Json; +namespace oauth = winrt::Microsoft::Security::Authentication::OAuth; +namespace streams = winrt::Windows::Storage::Streams; + +#include "Crypto.h" + +inline winrt::hstring fragment_component(const foundation::Uri& uri) +{ + auto fragment = uri.Fragment(); + std::wstring_view fragmentStr = fragment; + if (!fragmentStr.empty()) + { + WINRT_ASSERT(fragmentStr.front() == '#'); + fragmentStr = fragmentStr.substr(1); + } + + return winrt::hstring(fragmentStr); +} diff --git a/dev/Projections/CS/Microsoft.Security.Authentication.OAuth/Microsoft.Security.Authentication.OAuth.Projection.csproj b/dev/Projections/CS/Microsoft.Security.Authentication.OAuth/Microsoft.Security.Authentication.OAuth.Projection.csproj new file mode 100644 index 0000000000..3d96ec2b93 --- /dev/null +++ b/dev/Projections/CS/Microsoft.Security.Authentication.OAuth/Microsoft.Security.Authentication.OAuth.Projection.csproj @@ -0,0 +1,55 @@ + + + net6.0-windows10.0.17763.0 + 10.0.17763.0 + x64;x86;arm64 + AnyCPU + false + + + + true + true + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + Microsoft.Security.Authentication.OAuth + 10.0.17763.0 + false + + + + + pdbonly + true + + + + + + + + + + + + + $(OutDir)..\WindowsAppRuntime_DLL\StrippedWinMD\Microsoft.Security.Authentication.OAuth.winmd + true + + + + \ No newline at end of file diff --git a/dev/WindowsAppRuntime_DLL/WindowsAppRuntime_DLL.vcxproj b/dev/WindowsAppRuntime_DLL/WindowsAppRuntime_DLL.vcxproj index d9c2fbb6a0..7d852a3d6c 100644 --- a/dev/WindowsAppRuntime_DLL/WindowsAppRuntime_DLL.vcxproj +++ b/dev/WindowsAppRuntime_DLL/WindowsAppRuntime_DLL.vcxproj @@ -97,6 +97,7 @@ + @@ -324,4 +325,4 @@ - \ No newline at end of file +