diff --git a/src/VisualStudio.Tests/VisualStudioInstanceExtensions.cs b/src/VisualStudio.Tests/VisualStudioInstanceExtensions.cs index fa930c6..57be6c1 100644 --- a/src/VisualStudio.Tests/VisualStudioInstanceExtensions.cs +++ b/src/VisualStudio.Tests/VisualStudioInstanceExtensions.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using VisualStudio; namespace vswhere @@ -30,7 +31,10 @@ public static VisualStudioInstance WithSku(this VisualStudioInstance vsInstance, public static VisualStudioInstance WithChannel(this VisualStudioInstance vsInstance, Channel channel) { - vsInstance.ChannelId = productIdByChannel[channel]; + vsInstance.ChannelId = productIdByChannel.TryGetValue(channel, out var channelId) ? + vsInstance.ChannelId = channelId : + throw new NotSupportedException("Cannot filter instances by the given channel."); + return vsInstance; } } diff --git a/src/VisualStudio.Tests/VisualStudioPredicateBuilderTests.cs b/src/VisualStudio.Tests/VisualStudioPredicateBuilderTests.cs index 1eb562a..860c1b8 100644 --- a/src/VisualStudio.Tests/VisualStudioPredicateBuilderTests.cs +++ b/src/VisualStudio.Tests/VisualStudioPredicateBuilderTests.cs @@ -1,4 +1,5 @@ -using System.Threading.Tasks; +using System; +using System.Threading.Tasks; using vswhere; using Xunit; diff --git a/src/VisualStudio/Channel.cs b/src/VisualStudio/Channel.cs index cb5ef51..60b74f7 100644 --- a/src/VisualStudio/Channel.cs +++ b/src/VisualStudio/Channel.cs @@ -5,6 +5,6 @@ public enum Channel Release, Preview, IntPreview, - Main + Main, } } diff --git a/src/VisualStudio/Commands/ModifyCommand.cs b/src/VisualStudio/Commands/ModifyCommand.cs index 5e640d7..911987f 100644 --- a/src/VisualStudio/Commands/ModifyCommand.cs +++ b/src/VisualStudio/Commands/ModifyCommand.cs @@ -39,7 +39,12 @@ public override async Task ExecuteAsync(TextWriter output) args.AddRange(Descriptor.ExtraArguments); - await installerService.ModifyAsync(instance.GetChannel(), instance.GetSku(), args, output); + // If the channel is not a built-in one, use the existing Uri for updates. + var channel = instance.GetChannel(); + if (channel != null) + await installerService.ModifyAsync(instance.GetChannel(), instance.GetSku(), args, output); + else + await installerService.ModifyAsync(instance.ChannelUri.Replace("/channel", ""), instance.GetSku(), args, output); } } } diff --git a/src/VisualStudio/Commands/UpdateCommand.cs b/src/VisualStudio/Commands/UpdateCommand.cs index efbabcd..a347127 100644 --- a/src/VisualStudio/Commands/UpdateCommand.cs +++ b/src/VisualStudio/Commands/UpdateCommand.cs @@ -39,7 +39,12 @@ public override async Task ExecuteAsync(TextWriter output) instance.InstallationPath }; - await installerService.UpdateAsync(instance.GetChannel(), instance.GetSku(), args, output); + // If the channel is not a built-in one, use the existing Uri for updates. + var channel = instance.GetChannel(); + if (channel != null) + await installerService.UpdateAsync(instance.GetChannel(), instance.GetSku(), args, output); + else + await installerService.UpdateAsync(instance.ChannelUri.Replace("/channel", ""), instance.GetSku(), args, output); } } } diff --git a/src/VisualStudio/InstallerService.cs b/src/VisualStudio/InstallerService.cs index fece206..1808e5e 100644 --- a/src/VisualStudio/InstallerService.cs +++ b/src/VisualStudio/InstallerService.cs @@ -4,31 +4,33 @@ using System.IO; using System.Linq; using System.Net.Http; -using System.Text; using System.Threading.Tasks; namespace VisualStudio { class InstallerService { - public Task InstallAsync(Channel? channel, Sku? sku, IEnumerable args, TextWriter output) => - RunAsync(string.Empty, channel, sku, args, output); + public Task InstallAsync(Channel? channel, Sku? sku, IEnumerable args, TextWriter output) + => RunAsync(string.Empty, channel, sku, args, output); - public Task UpdateAsync(Channel? channel, Sku? sku, IEnumerable args, TextWriter output) => - RunAsync("update", channel, sku, args, output); + public Task UpdateAsync(Channel? channel, Sku? sku, IEnumerable args, TextWriter output) + => RunAsync("update", channel, sku, args, output); - public Task ModifyAsync(Channel? channel, Sku? sku, IEnumerable args, TextWriter output) => - RunAsync("modify", channel, sku, args, output); + public Task ModifyAsync(Channel? channel, Sku? sku, IEnumerable args, TextWriter output) + => RunAsync("modify", channel, sku, args, output); - async Task RunAsync(string command, Channel? channel, Sku? sku, IEnumerable args, TextWriter output) - { - var uri = new StringBuilder("https://aka.ms/vs/16/"); - uri = uri.Append(MapChannel(channel)); - uri = uri.Append("/vs_"); - uri = uri.Append(MapSku(sku)); - uri = uri.Append(".exe"); + public Task UpdateAsync(string channelUri, Sku? sku, IEnumerable args, TextWriter output) + => RunAsync("update", channelUri, sku, args, output); + + public Task ModifyAsync(string channelUri, Sku? sku, IEnumerable args, TextWriter output) + => RunAsync("modify", channelUri, sku, args, output); - var bootstrapper = await DownloadAsync(uri.ToString(), output); + Task RunAsync(string command, Channel? channel, Sku? sku, IEnumerable args, TextWriter output) + => RunAsync(command, "https://aka.ms/vs/16/" + MapChannel(channel), sku, args, output); + + async Task RunAsync(string command, string channelUri, Sku? sku, IEnumerable args, TextWriter output) + { + var bootstrapper = await DownloadAsync($"{channelUri}/vs_{MapSku(sku)}.exe", output); var psi = new ProcessStartInfo(bootstrapper) { diff --git a/src/VisualStudio/VisualStudioInstanceExtensions.cs b/src/VisualStudio/VisualStudioInstanceExtensions.cs index dc2977c..85406cb 100644 --- a/src/VisualStudio/VisualStudioInstanceExtensions.cs +++ b/src/VisualStudio/VisualStudioInstanceExtensions.cs @@ -17,14 +17,14 @@ public static Sku GetSku(this VisualStudioInstance vsInstance) _ => throw new ArgumentException($"Invalid SKU {vsInstance.ProductId}. Must be one of {string.Join(", ", Enum.GetNames(typeof(Sku)).Select(x => x.ToLowerInvariant()))}.", "sku"), }; - public static Channel GetChannel(this VisualStudioInstance vsInstance) + public static Channel? GetChannel(this VisualStudioInstance vsInstance) => vsInstance.ChannelId switch { "VisualStudio.16.Release" => Channel.Release, "VisualStudio.16.Preview" => Channel.Preview, "VisualStudio.16.IntPreview" => Channel.IntPreview, "VisualStudio.16.int.main" => Channel.Main, - _ => throw new ArgumentException($"Invalid ChannelId {vsInstance.ChannelId}. Must be one of {string.Join(", ", Enum.GetNames(typeof(Channel)).Select(x => x.ToLowerInvariant()))}.", "sku"), + _ => null, }; } }