diff --git a/src/Common/Commands.Common.Authentication.Test/AzureSessionTests.cs b/src/Common/Commands.Common.Authentication.Test/AzureSessionTests.cs new file mode 100644 index 000000000000..4b819456204b --- /dev/null +++ b/src/Common/Commands.Common.Authentication.Test/AzureSessionTests.cs @@ -0,0 +1,74 @@ +// ---------------------------------------------------------------------------------- +// +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +using Microsoft.Azure.Commands.Common.Authentication; +using Microsoft.Azure.Commands.Common.Authentication.Factories; +using Microsoft.Azure.Commands.Common.Authentication.Models; +using Microsoft.WindowsAzure.Commands.Test.Utilities.Common; +using System; +using System.Collections.Generic; +using Microsoft.WindowsAzure.Commands.ScenarioTest; +using Xunit; +using Microsoft.Azure.Commands.Common.Authentication.Abstractions; +using System.IO; + +namespace Common.Authentication.Test +{ + public class AzureSessionTests + { + [Fact] + [Trait(Category.AcceptanceType, Category.CheckIn)] + public void InitializerCreatesTokenCacheFile() + { + IAzureSession oldSession = null; + try + { + oldSession = AzureSession.Instance; + } + catch { } + try + { + var store = new MemoryDataStore(); + AzureSessionInitializer.CreateOrReplaceSession(store); + var session = AzureSession.Instance; + var tokenCacheFile = Path.Combine(session.ProfileDirectory, session.TokenCacheFile); + Assert.True(store.FileExists(tokenCacheFile)); + + } + finally + { + AzureSession.Initialize(() => oldSession, true); + } + } + + [Fact] + [Trait(Category.AcceptanceType, Category.CheckIn)] + public void TokenCacheIgnoresInvalidData() + { + var store = new AzureTokenCache { CacheData = new byte[] { 3, 0, 0, 0, 0, 0, 0, 0 } }; + var cache = new AuthenticationStoreTokenCache(store); + Assert.NotEqual(cache.CacheData, store.CacheData); + } + + [Fact] + [Trait(Category.AcceptanceType, Category.CheckIn)] + public void TokenCacheUsesValidData() + { + var store = new AzureTokenCache { CacheData = new byte[] { 2, 0, 0, 0, 0, 0, 0, 0 } }; + var cache = new AuthenticationStoreTokenCache(store); + Assert.Equal(cache.CacheData, store.CacheData); + } + + } +} diff --git a/src/Common/Commands.Common.Authentication.Test/Commands.Common.Authentication.Test.csproj b/src/Common/Commands.Common.Authentication.Test/Commands.Common.Authentication.Test.csproj index 869a9e1e1668..52e15dee4217 100644 --- a/src/Common/Commands.Common.Authentication.Test/Commands.Common.Authentication.Test.csproj +++ b/src/Common/Commands.Common.Authentication.Test/Commands.Common.Authentication.Test.csproj @@ -120,6 +120,7 @@ + diff --git a/src/Common/Commands.Common.Authentication/Authentication/AuthenticationStoreTokenCache.cs b/src/Common/Commands.Common.Authentication/Authentication/AuthenticationStoreTokenCache.cs index e8eff6c684c8..0281ca631f89 100644 --- a/src/Common/Commands.Common.Authentication/Authentication/AuthenticationStoreTokenCache.cs +++ b/src/Common/Commands.Common.Authentication/Authentication/AuthenticationStoreTokenCache.cs @@ -22,19 +22,17 @@ namespace Microsoft.Azure.Commands.Common.Authentication [Serializable] public class AuthenticationStoreTokenCache : TokenCache, IAzureTokenCache, IDisposable { - AzureTokenCache _tokenStore; - + IAzureTokenCache _store = new AzureTokenCache(); public byte[] CacheData { get { - return _tokenStore.CacheData; + return Serialize(); } set { - this.Clear(); - _tokenStore.CacheData = value; + this.Deserialize(value); } } @@ -45,10 +43,9 @@ public AuthenticationStoreTokenCache(AzureTokenCache store) : base() throw new ArgumentNullException("store"); } - _tokenStore = store; - if (_tokenStore != null && _tokenStore.CacheData != null && _tokenStore.CacheData.Length > 0) + if (store.CacheData != null && store.CacheData.Length > 0) { - base.Deserialize(_tokenStore.CacheData); + CacheData = store.CacheData; } AfterAccess += HandleAfterAccess; @@ -59,30 +56,22 @@ public AuthenticationStoreTokenCache(AzureTokenCache store) : base() /// /// The cache to copy /// The store to use for persisting state - public AuthenticationStoreTokenCache(TokenCache cache, AzureTokenCache store) : this(store) + public AuthenticationStoreTokenCache(TokenCache cache) : base() { if (null == cache) { throw new ArgumentNullException("Cache"); } - Deserialize(cache.Serialize()); - } - - /// - /// Create a token cache, copying any data from the given token cache - /// - /// The cache to copy - public AuthenticationStoreTokenCache(TokenCache cache) : this(cache, new AzureTokenCache()) - { + CacheData = cache.Serialize(); + AfterAccess += HandleAfterAccess; } - public void HandleAfterAccess(TokenCacheNotificationArgs args) { if (HasStateChanged) { - _tokenStore.CacheData = Serialize(); + _store.CacheData = Serialize(); } } @@ -90,10 +79,10 @@ protected virtual void Dispose(bool disposing) { if (disposing) { - var cache = Interlocked.Exchange(ref _tokenStore, null); + var cache = Interlocked.Exchange(ref _store, null); if (cache != null) { - cache.CacheData = base.Serialize(); + cache.CacheData = Serialize(); } } } diff --git a/src/Common/Commands.Common.Authentication/Authentication/ProtectedFileTokenCache.cs b/src/Common/Commands.Common.Authentication/Authentication/ProtectedFileTokenCache.cs index a8a3103764d5..c190bcf49e2f 100644 --- a/src/Common/Commands.Common.Authentication/Authentication/ProtectedFileTokenCache.cs +++ b/src/Common/Commands.Common.Authentication/Authentication/ProtectedFileTokenCache.cs @@ -37,6 +37,8 @@ public class ProtectedFileTokenCache : TokenCache, IAzureTokenCache private static readonly Lazy instance = new Lazy(() => new ProtectedFileTokenCache()); + IDataStore _store; + public byte[] CacheData { get @@ -54,25 +56,28 @@ public byte[] CacheData // If the file is already present, it loads its content in the ADAL cache private ProtectedFileTokenCache() { + _store = AzureSession.Instance.DataStore; Initialize(CacheFileName); } - public ProtectedFileTokenCache(byte[] inputData) + public ProtectedFileTokenCache(byte[] inputData, IDataStore store = null) : this(CacheFileName, store) { - AfterAccess = AfterAccessNotification; - BeforeAccess = BeforeAccessNotification; CacheData = inputData; } + public ProtectedFileTokenCache(string cacheFile, IDataStore store = null) + { + _store = store ?? AzureSession.Instance.DataStore; + Initialize(cacheFile); + } + private void Initialize(string fileName) { - AfterAccess = AfterAccessNotification; - BeforeAccess = BeforeAccessNotification; lock (fileLock) { - if (AzureSession.Instance.DataStore.FileExists(fileName)) + if (_store.FileExists(fileName)) { - var existingData = AzureSession.Instance.DataStore.ReadFileAsBytes(fileName); + var existingData = _store.ReadFileAsBytes(fileName); if (existingData != null) { try @@ -81,25 +86,26 @@ private void Initialize(string fileName) } catch (CryptographicException) { - AzureSession.Instance.DataStore.DeleteFile(fileName); + _store.DeleteFile(fileName); } } } + + // Create the file to start with + _store.WriteFile(CacheFileName, ProtectedData.Protect(Serialize(), null, DataProtectionScope.CurrentUser)); } - } - public ProtectedFileTokenCache(string cacheFile) - { - Initialize(cacheFile); + AfterAccess = AfterAccessNotification; + BeforeAccess = BeforeAccessNotification; } // Empties the persistent store. public override void Clear() { base.Clear(); - if (AzureSession.Instance.DataStore.FileExists(CacheFileName)) + if (_store.FileExists(CacheFileName)) { - AzureSession.Instance.DataStore.DeleteFile(CacheFileName); + _store.DeleteFile(CacheFileName); } } @@ -109,9 +115,9 @@ void BeforeAccessNotification(TokenCacheNotificationArgs args) { lock (fileLock) { - if (AzureSession.Instance.DataStore.FileExists(CacheFileName)) + if (_store.FileExists(CacheFileName)) { - var existingData = AzureSession.Instance.DataStore.ReadFileAsBytes(CacheFileName); + var existingData = _store.ReadFileAsBytes(CacheFileName); if (existingData != null) { try @@ -120,7 +126,7 @@ void BeforeAccessNotification(TokenCacheNotificationArgs args) } catch (CryptographicException) { - AzureSession.Instance.DataStore.DeleteFile(CacheFileName); + _store.DeleteFile(CacheFileName); } } } @@ -131,13 +137,18 @@ void BeforeAccessNotification(TokenCacheNotificationArgs args) void AfterAccessNotification(TokenCacheNotificationArgs args) { // if the access operation resulted in a cache update - if (HasStateChanged) + EnsureStateSaved(); + } + + void EnsureStateSaved() + { + lock (fileLock) { - lock (fileLock) + if (HasStateChanged) { // reflect changes in the persistent store - AzureSession.Instance.DataStore.WriteFile(CacheFileName, - ProtectedData.Protect(Serialize(), null, DataProtectionScope.CurrentUser)); + _store.WriteFile(CacheFileName, + ProtectedData.Protect(Serialize(), null, DataProtectionScope.CurrentUser)); // once the write operation took place, restore the HasStateChanged bit to false HasStateChanged = false; } diff --git a/src/Common/Commands.Common.Authentication/AzureSessionInitializer.cs b/src/Common/Commands.Common.Authentication/AzureSessionInitializer.cs index f8a0de638f53..90b64bceb854 100644 --- a/src/Common/Commands.Common.Authentication/AzureSessionInitializer.cs +++ b/src/Common/Commands.Common.Authentication/AzureSessionInitializer.cs @@ -35,16 +35,32 @@ public static class AzureSessionInitializer /// public static void InitializeAzureSession() { - AzureSession.Initialize(CreateInstance); + AzureSession.Initialize(() => CreateInstance()); } - static IAzureSession CreateInstance() + /// + /// Create a new session and replace any existing session + /// + public static void CreateOrReplaceSession() + { + CreateOrReplaceSession(new DiskDataStore()); + } + + /// + /// Create a new session and replace any existing session + /// + public static void CreateOrReplaceSession(IDataStore dataStore) + { + AzureSession.Initialize(() => CreateInstance(dataStore), true); + } + + static IAzureSession CreateInstance(IDataStore dataStore = null) { var session = new AdalSession { ClientFactory = new ClientFactory(), AuthenticationFactory = new AuthenticationFactory(), - DataStore = new DiskDataStore(), + DataStore = dataStore?? new DiskDataStore(), OldProfileFile = "WindowsAzureProfile.xml", OldProfileFileBackup = "WindowsAzureProfile.xml.bak", ProfileDirectory = Path.Combine( @@ -59,17 +75,7 @@ static IAzureSession CreateInstance() FileUtilities.EnsureDirectoryExists(session.ProfileDirectory); var cacheFile = Path.Combine(session.ProfileDirectory, session.TokenCacheFile); var contents = new byte[0]; - if (session.DataStore.FileExists(cacheFile)) - { - contents = session.DataStore.ReadFileAsBytes(cacheFile); - } - - if (contents != null && contents.Length > 0) - { - contents = ProtectedData.Unprotect(contents, null, DataProtectionScope.CurrentUser); - } - - session.TokenCache = new ProtectedFileTokenCache(contents); + session.TokenCache = new ProtectedFileTokenCache(cacheFile, dataStore); } catch { diff --git a/src/ResourceManager/Profile/Commands.Profile.Test/AzureRMProfileTests.cs b/src/ResourceManager/Profile/Commands.Profile.Test/AzureRMProfileTests.cs index 533fff135bba..7369956095f9 100644 --- a/src/ResourceManager/Profile/Commands.Profile.Test/AzureRMProfileTests.cs +++ b/src/ResourceManager/Profile/Commands.Profile.Test/AzureRMProfileTests.cs @@ -751,7 +751,7 @@ public void SavingProfileWorks() }, ""VersionProfile"": null, ""TokenCache"": { - ""CacheData"": ""AQIDBAUGCAkA"" + ""CacheData"": ""AgAAAAAAAAA="" }, ""ExtendedProperties"": {} }