From 7c3ef3c823a6ef541af78c3ffd05ccfbeaa486ee Mon Sep 17 00:00:00 2001 From: Frans van Dorsselaer <17404029+dorssel@users.noreply.github.com> Date: Mon, 1 Apr 2024 03:23:53 +0200 Subject: [PATCH] Add rules --- Installer/Server.wxs | 4 + UnitTests/BusId_Tests.cs | 28 +++--- UnitTests/Parse_usbipd_Tests.cs | 4 +- Usbipd.Automation/BusId.cs | 10 +- Usbipd/CommandHandlersCli.cs | 60 +++++++++++- Usbipd/ICommandHandlers.cs | 3 + Usbipd/ParseResultExtensions.cs | 14 +++ Usbipd/Program.cs | 161 ++++++++++++++++++++++++++++++-- Usbipd/RegistryUtils.cs | 103 ++++++++++++++++++++ Usbipd/Rule.cs | 16 ++++ Usbipd/RuleBind.cs | 63 +++++++++++++ Usbipd/RulePermission.cs | 11 +++ Usbipd/RuleTYpe.cs | 10 ++ 13 files changed, 458 insertions(+), 29 deletions(-) create mode 100644 Usbipd/ParseResultExtensions.cs create mode 100644 Usbipd/Rule.cs create mode 100644 Usbipd/RuleBind.cs create mode 100644 Usbipd/RulePermission.cs create mode 100644 Usbipd/RuleTYpe.cs diff --git a/Installer/Server.wxs b/Installer/Server.wxs index 548fa457..dd65d263 100644 --- a/Installer/Server.wxs +++ b/Installer/Server.wxs @@ -73,6 +73,10 @@ SPDX-License-Identifier: GPL-3.0-only Key="Devices" ForceCreateOnInstall="yes" /> + Invalid "IncompatibleHub", "1-1", "1-2", - "1-65534", - "1-65535", + "1-98", + "1-99", "2-1", "2-2", - "2-65534", - "2-65535", - "65534-1", - "65534-2", - "65534-65534", - "65534-65535", - "65535-1", - "65535-2", - "65535-65534", - "65535-65535", + "2-98", + "2-99", + "98-1", + "98-2", + "98-98", + "98-99", + "99-1", + "99-2", + "99-98", + "99-99", ]; public static IEnumerable Valid => from value in _Valid select new string[] { value }; diff --git a/UnitTests/Parse_usbipd_Tests.cs b/UnitTests/Parse_usbipd_Tests.cs index 14d4b085..b2585c8c 100644 --- a/UnitTests/Parse_usbipd_Tests.cs +++ b/UnitTests/Parse_usbipd_Tests.cs @@ -9,9 +9,9 @@ sealed class Parse_usbipd_Tests : ParseTestBase { [TestMethod] - public void Success() + public void NoCommand() { - Test(ExitCode.Success); + Test(ExitCode.ParseError); } [TestMethod] diff --git a/Usbipd.Automation/BusId.cs b/Usbipd.Automation/BusId.cs index 0dda6f76..c4134992 100644 --- a/Usbipd.Automation/BusId.cs +++ b/Usbipd.Automation/BusId.cs @@ -22,11 +22,13 @@ public BusId(ushort bus, ushort port) { // Do not allow the explicit creation of the special IncompatibleHub value. // Instead, use the static IncompatibleHub field (preferrable) or "default". - if (bus == 0) + // USB supports up to 127 devices, but that would require multiple hubs; the "per hub" port will never be >99. + // And if you have more than 99 hubs on one system, then you win a prize! (but we're not going to support it...) + if (bus == 0 || bus > 99) { throw new ArgumentOutOfRangeException(nameof(bus)); } - if (port == 0) + if (port == 0 || port > 99) { throw new ArgumentOutOfRangeException(nameof(port)); } @@ -63,8 +65,8 @@ public static bool TryParse(string input, out BusId busId) } var match = Regex.Match(input, "^([1-9][0-9]*)-([1-9][0-9]*)$"); if (match.Success - && ushort.TryParse(match.Groups[1].Value, out var bus) - && ushort.TryParse(match.Groups[2].Value, out var port)) + && ushort.TryParse(match.Groups[1].Value, out var bus) && bus <= 99 + && ushort.TryParse(match.Groups[2].Value, out var port) && port <= 99) { busId = new(bus, port); return true; diff --git a/Usbipd/CommandHandlersCli.cs b/Usbipd/CommandHandlersCli.cs index 81eee713..f5c1cd7d 100644 --- a/Usbipd/CommandHandlersCli.cs +++ b/Usbipd/CommandHandlersCli.cs @@ -142,7 +142,6 @@ Task ICommandHandlers.List(bool usbids, IConsole console, Cancellation { state = "Not shared"; } - // NOTE: Strictly speaking, both Bus and Port can be > 99. If you have one of those, you win a prize! console.Write($"{(device.BusId.Value.IsIncompatibleHub ? string.Empty : device.BusId.Value),-5} "); console.Write($"{device.HardwareId,-9} "); console.WriteTruncated(GetDescription(device, usbids), 60, true); @@ -474,4 +473,63 @@ Task ICommandHandlers.State(IConsole console, CancellationToken cancel Console.Write(json); return Task.FromResult(ExitCode.Success); } + + Task ICommandHandlers.RuleAdd(Rule rule, IConsole console, CancellationToken cancellationToken) + { + if (RegistryUtils.GetRules().FirstOrDefault(r => r.Value == rule) is var existingRule && existingRule.Key != default) + { + console.ReportError($"Rule already exists with guid '{existingRule.Key:D}'."); + return Task.FromResult(ExitCode.Failure); + } + + if (!CheckWriteAccess(console)) + { + return Task.FromResult(ExitCode.AccessDenied); + } + + var guid = RegistryUtils.AddRule(rule); + console.ReportInfo($"Rule created with guid '{guid:D}'."); + return Task.FromResult(ExitCode.Success); + } + + Task ICommandHandlers.RuleList(IConsole console, CancellationToken cancellationToken) + { + var allRules = RegistryUtils.GetRules(); + console.WriteLine("Rules:"); + console.WriteLine($"{"GUID",-36} {"TYPE",-4} {"ACCESS",-6} {"BUSID",-5} {"VID:PID",-9}"); + foreach (var rule in allRules) + { + console.Write($"{rule.Key,-36} "); + console.Write($"{rule.Value.Type,-4} "); + console.Write($"{(rule.Value.Allow ? "Allow" : "Deny"),-6} "); + switch (rule.Value.Type) + { + case RuleType.Bind: + var ruleBind = (RuleBind)rule.Value; + console.Write($"{(ruleBind.BusId.HasValue ? ruleBind.BusId.Value : string.Empty),-5} "); + console.Write($"{(ruleBind.HardwareId.HasValue ? ruleBind.HardwareId.Value : string.Empty),-9}"); + break; + } + console.WriteLine(string.Empty); + } + console.WriteLine(string.Empty); + return Task.FromResult(ExitCode.Success); + } + + Task ICommandHandlers.RuleRemove(Guid guid, IConsole console, CancellationToken cancellationToken) + { + if (!RegistryUtils.GetRules().ContainsKey(guid)) + { + console.ReportError($"There is no rule with guid '{guid:D}'."); + return Task.FromResult(ExitCode.Failure); + } + + if (!CheckWriteAccess(console)) + { + return Task.FromResult(ExitCode.AccessDenied); + } + + RegistryUtils.RemoveRule(guid); + return Task.FromResult(ExitCode.Success); + } } diff --git a/Usbipd/ICommandHandlers.cs b/Usbipd/ICommandHandlers.cs index 8fdc7287..0766a731 100644 --- a/Usbipd/ICommandHandlers.cs +++ b/Usbipd/ICommandHandlers.cs @@ -26,4 +26,7 @@ interface ICommandHandlers public Task State(IConsole console, CancellationToken cancellationToken); public Task Install(IConsole console, CancellationToken cancellationToken); public Task Uninstall(IConsole console, CancellationToken cancellationToken); + public Task RuleAdd(Rule rule, IConsole console, CancellationToken cancellationToken); + public Task RuleList(IConsole console, CancellationToken cancellationToken); + public Task RuleRemove(Guid guid, IConsole console, CancellationToken cancellationToken); } diff --git a/Usbipd/ParseResultExtensions.cs b/Usbipd/ParseResultExtensions.cs new file mode 100644 index 00000000..9b15b065 --- /dev/null +++ b/Usbipd/ParseResultExtensions.cs @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: 2024 Frans van Dorsselaer +// +// SPDX-License-Identifier: GPL-3.0-only + +using System.CommandLine; +using System.CommandLine.Parsing; + +namespace Usbipd; + +static class ParseResultExtensions +{ + public static T? GetValueForOptionOrNull(this ParseResult parseResult, Option option) where T : struct + => parseResult.HasOption(option) ? parseResult.GetValueForOption(option) : null; +} diff --git a/Usbipd/Program.cs b/Usbipd/Program.cs index e8c805f5..25f42a3e 100644 --- a/Usbipd/Program.cs +++ b/Usbipd/Program.cs @@ -7,7 +7,6 @@ using System.CommandLine.Builder; using System.CommandLine.Completions; using System.CommandLine.Help; -using System.CommandLine.IO; using System.CommandLine.Parsing; using System.Diagnostics; using System.Reflection; @@ -63,6 +62,17 @@ static string OneOfRequiredText(params Option[] options) return $"Exactly one of the options {list} is required."; } + static string AtLeastOneOfRequiredText(params Option[] options) + { + Debug.Assert(options.Length >= 2); + + var names = options.Select(o => $"'--{o.Name}'").ToArray(); + var list = names.Length == 2 + ? $"{names[0]} or {names[1]}" + : string.Join(", ", names[0..(names.Length - 1)]) + ", or " + names[^1]; + return $"At least one of the options {list} is required."; + } + static void ValidateOneOf(CommandResult commandResult, params Option[] options) { Debug.Assert(options.Length >= 2); @@ -73,6 +83,16 @@ static void ValidateOneOf(CommandResult commandResult, params Option[] options) } } + static void ValidateAtLeastOneOf(CommandResult commandResult, params Option[] options) + { + Debug.Assert(options.Length >= 2); + + if (!options.Any(option => commandResult.FindResultFor(option) is not null)) + { + commandResult.ErrorMessage = AtLeastOneOfRequiredText(options); + } + } + internal static IEnumerable CompletionGuard(CompletionContext completionContext, Func?> complete) { try @@ -115,11 +135,6 @@ internal static int Main(params string[] args) internal static ExitCode Run(IConsole? optionalTestConsole, ICommandHandlers commandHandlers, params string[] args) { var rootCommand = new RootCommand("Shares locally connected USB devices to other machines, including Hyper-V guests and WSL 2."); - rootCommand.SetHandler(invocationContext => - { - invocationContext.HelpBuilder.Write(rootCommand, invocationContext.Console.Out.CreateTextWriter()); - }); - { // // attach [--auto-attach] @@ -169,7 +184,7 @@ internal static ExitCode Run(IConsole? optionalTestConsole, ICommandHandlers com }.AddCompletions(completionContext => CompletionGuard(completionContext, () => UsbDevice.GetAll().Where(d => d.BusId.HasValue).GroupBy(d => d.HardwareId).Select(g => g.Key.ToString()))); // - // wsl attach + // attach // var attachCommand = new Command("attach", "Attach a USB device to a client\0" + "Attaches a USB device to a client.\n" @@ -317,7 +332,7 @@ await commandHandlers.Bind(invocationContext.ParseResult.GetValueForOption(hardw }.AddCompletions(completionContext => CompletionGuard(completionContext, () => UsbDevice.GetAll().Where(d => d.BusId.HasValue).GroupBy(d => d.HardwareId).Select(g => g.Key.ToString()))); // - // wsl detach + // detach // var detachCommand = new Command("detach", "Detach a USB device from a client\0" + "Detaches one or more USB devices. The client sees this as a surprise " @@ -396,6 +411,136 @@ await commandHandlers.List(invocationContext.ParseResult.HasOption(usbidsOption) }); rootCommand.AddCommand(listCommand); } + { + // + // rule + // + var ruleCommand = new Command("rule", "Manage rules\0" + + "Rules allow or deny specific functionality, such as bind and attach.\n"); + { + // + // rule add --access + // + var accessOption = new Option( + aliases: ["--access", "-a"] + ) + { + ArgumentHelpName = "ACCESS", + Description = "Allow or Deny", + IsRequired = true, + }; + // + // rule add --type + // + var typeOption = new Option( + aliases: ["--type", "-t"] + ) + { + ArgumentHelpName = "TYPE", + Description = "Add a rule of type ", + IsRequired = true, + }; + // + // rule add [--busid ] + // + var busIdOption = new Option( + aliases: ["--busid", "-b"], + parseArgument: ParseCompatibleBusId + ) + { + ArgumentHelpName = "BUSID", + Description = "Share device having ", + }.AddCompletions(CompatibleBusIdCompletions); + // + // rule add [--hardware-id :] + // + var hardwareIdOption = new Option( + // NOTE: the alias '-h' is already for '--help' + aliases: ["--hardware-id", "-i"], + parseArgument: ParseVidPid + ) + { + ArgumentHelpName = "VID:PID", + Description = "Attach device having :", + }.AddCompletions(completionContext => CompletionGuard(completionContext, () => + UsbDevice.GetAll().GroupBy(d => d.HardwareId).Select(g => g.Key.ToString()))); + // + // rule add + // + var addCommand = new Command("add", "Add a rule\0" + + "Add a new rule. The resulting rule set will be effective immediately.\n" + + "\n" + + AtLeastOneOfRequiredText(busIdOption, hardwareIdOption)) + { + accessOption, + typeOption, + busIdOption, + hardwareIdOption, + }; + addCommand.AddValidator(commandResult => + { + ValidateAtLeastOneOf(commandResult, busIdOption, hardwareIdOption); + }); + addCommand.SetHandler(async invocationContext => + { + var ruleType = invocationContext.ParseResult.GetValueForOption(typeOption); + invocationContext.ExitCode = (int)await (ruleType switch + { + RuleType.Bind => + commandHandlers.RuleAdd(new RuleBind(invocationContext.ParseResult.GetValueForOption(accessOption) == RuleAccess.Allow, + invocationContext.ParseResult.GetValueForOptionOrNull(busIdOption), + invocationContext.ParseResult.GetValueForOptionOrNull(hardwareIdOption)), + invocationContext.Console, invocationContext.GetCancellationToken()), + _ => throw new UnexpectedResultException($"Unexpected rule type '{ruleType}'."), + }); + }); + ruleCommand.AddCommand(addCommand); + } + { + // + // rule list + // + var listCommand = new Command("list", "List rules\0" + + "List all rules.\n"); + listCommand.SetHandler(async invocationContext => + { + invocationContext.ExitCode = (int) + await commandHandlers.RuleList(invocationContext.Console, invocationContext.GetCancellationToken()); + }); + ruleCommand.AddCommand(listCommand); + } + { + // + // rule remove --guid + // + var guidOption = new Option( + aliases: ["--guid", "-g"], + parseArgument: ParseGuid + ) + { + ArgumentHelpName = "GUID", + Description = "Stop sharing persisted device having ", + IsRequired = true, + }.AddCompletions(completionContext => CompletionGuard(completionContext, () => + RegistryUtils.GetRules().Select(r => r.Key.ToString("D")))); + // + // rule remove + // + var removeCommand = new Command("remove", "Remove a rule\0" + + "Remove an existing rule. The resulting rule set will be effective immediately.\n") + { + guidOption, + }; + removeCommand.SetHandler(async invocationContext => + { + invocationContext.ExitCode = (int) + await commandHandlers.RuleRemove(invocationContext.ParseResult.GetValueForOption(guidOption), + invocationContext.Console, invocationContext.GetCancellationToken()); + }); + ruleCommand.AddCommand(removeCommand); + } + rootCommand.AddCommand(ruleCommand); + } { // // server [...] diff --git a/Usbipd/RegistryUtils.cs b/Usbipd/RegistryUtils.cs index 5217cf23..2cc849d8 100644 --- a/Usbipd/RegistryUtils.cs +++ b/Usbipd/RegistryUtils.cs @@ -34,6 +34,9 @@ static RegistryKey OpenBaseKey(bool writable) const string AttachedName = "Attached"; const string BusIdName = "BusId"; const string IPAddressName = "IPAddress"; + const string RulesName = "Rules"; + const string TypeName = "Type"; + const string AllowName = "Allow"; /// /// if not installed @@ -52,6 +55,18 @@ static RegistryKey GetDevicesKey(bool writable) return devicesKey.OpenSubKey(guid.ToString("B"), writable); } + static RegistryKey GetRulesKey(bool writable) + { + return BaseKey(writable).OpenSubKey(RulesName, writable) + ?? throw new UnexpectedResultException("Registry key not found; try reinstalling the software."); + } + + static RegistryKey? GetRuleKey(Guid guid, bool writable) + { + using var devicesKey = GetRulesKey(writable); + return devicesKey.OpenSubKey(guid.ToString("B"), writable); + } + public static void Persist(string instanceId, string description) { var guid = Guid.NewGuid(); @@ -224,4 +239,92 @@ public static bool HasWriteAccess() return false; } } + + public static Guid AddRule(Rule rule) + { + if (!rule.IsValid()) + { + throw new ArgumentException("Invalid rule", nameof(rule)); + } + if (GetRules().ContainsValue(rule)) + { + throw new ArgumentException("Duplicate rule", nameof(rule)); + } + var guid = Guid.NewGuid(); + using var ruleKey = GetRulesKey(true).CreateSubKey($"{guid:B}"); + ruleKey.SetValue(AllowName, rule.Allow ? 1 : 0); + ruleKey.SetValue(TypeName, rule.Type.ToString()); + rule.Save(ruleKey); + return guid; + } + + public static void RemoveRule(Guid guid) + { + using var rulesKey = GetRulesKey(true); + rulesKey.DeleteSubKeyTree(guid.ToString("B"), false); + } + + /// + /// Enumerates all rules. + /// + /// This retrieves the entire (valid) registry state; it ignores invalid rules, as well as any duplicates. + /// + /// + public static SortedDictionary GetRules() + { + var guids = new SortedSet(); + using var rulesKey = GetRulesKey(false); + foreach (var subKeyName in rulesKey.GetSubKeyNames()) + { + if (Guid.TryParseExact(subKeyName, "B", out var guid)) + { + // Sanitize uniqueness. + guids.Add(guid); + } + } + + var rules = new SortedDictionary(); + foreach (var guid in guids) + { + using var ruleKey = GetRuleKey(guid, false); + if (ruleKey is null) + { + continue; + } + if (ruleKey.GetValue(AllowName) is not int allowValue) + { + // Must exist and be a DWORD. + continue; + } + var allow = (allowValue != 0); + if (!Enum.TryParse(ruleKey.GetValue(TypeName) as string, true, out var ruleType)) + { + // Must exist and be a valid enum string. + continue; + } + Rule rule; + switch (ruleType) + { + case RuleType.Bind: + rule = RuleBind.Load(allow, ruleKey); + break; + default: + // Invalid, ignore. + continue; + } + if (!rule.IsValid()) + { + // Invalid, ignore. + continue; + } + if (rules.ContainsValue(rule)) + { + // Duplicate, ignore. + continue; + } + rules.Add(guid, rule); + } + // All unique and valid. + return rules; + } } diff --git a/Usbipd/Rule.cs b/Usbipd/Rule.cs new file mode 100644 index 00000000..1682b62b --- /dev/null +++ b/Usbipd/Rule.cs @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: 2024 Frans van Dorsselaer +// +// SPDX-License-Identifier: GPL-3.0-only + +using Microsoft.Win32; + +namespace Usbipd; + +abstract record Rule(bool Allow, RuleType Type) +{ + public abstract bool IsValid(); + + public abstract bool Matches(UsbDevice usbDevice); + + public abstract void Save(RegistryKey registryKey); +} diff --git a/Usbipd/RuleBind.cs b/Usbipd/RuleBind.cs new file mode 100644 index 00000000..e8dccbd1 --- /dev/null +++ b/Usbipd/RuleBind.cs @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: 2024 Frans van Dorsselaer +// +// SPDX-License-Identifier: GPL-3.0-only + +using Microsoft.Win32; +using Usbipd.Automation; + +namespace Usbipd; + +sealed record RuleBind(bool allow, BusId? BusId, VidPid? HardwareId) + : Rule(allow, RuleType.Bind) +{ + const string BusIdName = "BusId"; + const string HardwareIdName = "HardwareId"; + + public override bool IsValid() => BusId.HasValue || HardwareId.HasValue; + + public override bool Matches(UsbDevice usbDevice) + { + if (!IsValid()) + { + throw new InvalidOperationException("Invalid rule"); + } + + if (BusId.HasValue && BusId.Value != usbDevice.BusId) + { + return false; + } + if (HardwareId.HasValue && HardwareId.Value != usbDevice.HardwareId) + { + return false; + } + + return true; + } + + public override void Save(RegistryKey registryKey) + { + if (BusId.HasValue) + { + registryKey.SetValue(BusIdName, BusId.Value.ToString()); + } + if (HardwareId.HasValue) + { + registryKey.SetValue(HardwareIdName, HardwareId.Value.ToString()); + } + } + + public static RuleBind Load(bool allow, RegistryKey registryKey) + { + BusId? busId = null; + if (Automation.BusId.TryParse(registryKey.GetValue(BusIdName) as string ?? string.Empty, out var parsedBusId)) + { + busId = parsedBusId; + } + VidPid? hardwareId = null; + if (VidPid.TryParse(registryKey.GetValue(HardwareIdName) as string ?? string.Empty, out var parsedHardwareId)) + { + hardwareId = parsedHardwareId; + } + return new(allow, busId, hardwareId); + } +} diff --git a/Usbipd/RulePermission.cs b/Usbipd/RulePermission.cs new file mode 100644 index 00000000..83893cda --- /dev/null +++ b/Usbipd/RulePermission.cs @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: 2024 Frans van Dorsselaer +// +// SPDX-License-Identifier: GPL-3.0-only + +namespace Usbipd; + +public enum RuleAccess +{ + Allow, + Deny, +} diff --git a/Usbipd/RuleTYpe.cs b/Usbipd/RuleTYpe.cs new file mode 100644 index 00000000..cdde196f --- /dev/null +++ b/Usbipd/RuleTYpe.cs @@ -0,0 +1,10 @@ +// SPDX-FileCopyrightText: 2024 Frans van Dorsselaer +// +// SPDX-License-Identifier: GPL-3.0-only + +namespace Usbipd; + +public enum RuleType +{ + Bind, +}