From 2233fdec212261b0d4cc2b8ddec21bffc98315ca Mon Sep 17 00:00:00 2001 From: Graeme Gill Date: Wed, 25 Oct 2023 22:51:34 +1100 Subject: [PATCH] Fix for daynix/UsbDk #124, "UsbDk_StartRedirect hangs if no normal USB driver". Add filter driver to device with no normal USB driver that sets Raw mode. --- UsbDk/ControlDevice.cpp | 395 ++++++++++++++++++++++++++++++++++++ UsbDk/ControlDevice.h | 49 ++++- UsbDk/DeviceAccess.cpp | 74 ++++++- UsbDk/DeviceAccess.h | 6 +- UsbDk/FilterDevice.cpp | 85 +++++++- UsbDk/FilterDevice.h | 15 ++ UsbDk/HiderStrategy.h | 17 ++ UsbDk/RawFilterStrategy.cpp | 125 ++++++++++++ UsbDk/Registry.cpp | 40 ++++ UsbDk/Registry.h | 17 +- UsbDk/Trace.h | 1 + UsbDk/UsbDk.vcxproj | 1 + UsbDk/UsbDk.vcxproj.filters | 3 + UsbDk/UsbDkUtil.cpp | 112 ++++++++++ UsbDk/UsbDkUtil.h | 62 ++++++ 15 files changed, 987 insertions(+), 15 deletions(-) create mode 100644 UsbDk/RawFilterStrategy.cpp diff --git a/UsbDk/ControlDevice.cpp b/UsbDk/ControlDevice.cpp index a244a31..ce9169b 100644 --- a/UsbDk/ControlDevice.cpp +++ b/UsbDk/ControlDevice.cpp @@ -375,6 +375,170 @@ bool CUsbDkControlDevice::ShouldHide(const USB_DK_DEVICE_ID &DevId) return b; } +/* Ideally we would like to use IoGetDeviceProperty(DevicePropertyInstallState) to see if there */ +/* is a driver that will get installed, or IoOpenDeviceRegistryKey(Device) to look for a Device */ +/* value to see if the Device has a Driver assigned, but unfortunately nothing like this */ +/* is possible within the IRP_MN_QUERY_DEVICE_RELATIONS because the PDO isn't actually valid */ +/* for doing much at that point in time. Instead we create and maintain a list built */ +/* outside the IRP that has the associated registry key path in it for a VidPid/PortHub. This */ +/* isn't a perfect solution, since the registry state will lag when a device is plugged in for */ +/* the first time. The error is usually benign, and will be corrected */ +/* the second time the device is plugged in. */ +bool CUsbDkControlDevice::ShouldRawFiltDevice(CUsbDkChildDevice &Device, bool Is2ndCall) +{ + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, "%!FUNC! Checking against %S, %S Is2ndCall %d",Device.DeviceID(), Device.LocationID(), Is2ndCall); + + ULONG VidPid, PortHub; + + /* Format is "USB\VID_XXXX&PID_XXXX" */ + auto status = EightHexToInteger(Device.DeviceID(), 8, 17, &VidPid); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to Convert DeviceID string into ULONG (status %!STATUS!)", status); + return false; + } + + /* Format is "Port_#XXXX.Hub_#XXXX" */ + status = EightHexToInteger(Device.LocationID(), 6, 16, &PortHub); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to Convert Location string into ULONG (status %!STATUS!)", status); + return false; + } + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! ChildDevice Ident = 0x%08x 0x%08x'",VidPid, PortHub); + + /* Make sure that the "has a function driver" Set is initialized */ + if (!m_FDriverInited) + { + ReloadHasDriverList(); + } + + /* See if this Device has an entry in the Set */ + CUsbDkFDriverRule *Entry = nullptr; + bool hasentry = false; + bool hasfdriv = true; /* default to not adding RawFilter */ + const auto &FiltVisitor = [VidPid, PortHub, &Entry, &hasentry, &hasfdriv, this](CUsbDkFDriverRule *e) -> bool + { + if (e->Match(VidPid, PortHub)) + { + Entry = e; + + /* Check if there is a Driver value */ + CRegKey regkey; + + auto status = regkey.Open(*e->KeyName()); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to open Key '%wZ' registry key",e->KeyName()); + hasentry = false; /* Do ReloadHasDriverList() */ + return false; /* Terminate ForEach() */ + } + + hasentry = true; /* Don't ReloadHasDriverList() */ + + CStringHolder DriverNameHolder; + status = DriverNameHolder.Attach(TEXT("Driver")); + ASSERT(NT_SUCCESS(status)); + + CWdmMemoryBuffer Buffer; + status = regkey.QueryValueInfo(*DriverNameHolder, KeyValuePartialInformation, Buffer); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to read value '%wZ' (status %!STATUS!)", DriverNameHolder, status); + hasfdriv = false; + } else { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Was able to read value '%wZ' (status %!STATUS!)", DriverNameHolder, status); + hasfdriv = true; + } + return false; /* Terminate ForEach() */ + } + return true; /* Continue ForEach() */ + }; + const_cast(&m_FDriversRules)->ForEach(FiltVisitor); + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Check FDriverRulesSet returned hasentry %d, hasfdriv %d",hasentry, hasfdriv); + + /* If there's no Set entry, then try and create one and check again. */ + if (!hasentry) { + + if (!Is2ndCall) /* Try re-creating the list */ + { + ReloadHasDriverList(); + + /* Check again */ + hasentry = false; + hasfdriv = true; /* default to not adding RawFilter. */ + + const_cast(&m_FDriversRules)->ForEach(FiltVisitor); + } + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Check 2 FDriverRulesSet got Set count %d, hasentry %d, hasfdriv %d", + m_FDriversRules.GetCount(),hasentry, hasfdriv); + + /* If this fails (!hasentry) then either our assumptions in creating the Set */ + /* are wrong (i.e. Microsoft have changed the Registry layout for Enum\USB, in which */ + /* case GetCount() will be 0), or this is the first time the device has been plugged */ + /* in to the port, in which case we don't have a way of knowing if it has a driver. */ + /* We default to adding a Raw Filter so that the common case of plugging in a device */ + /* with no driver works with UsbDk, and take the (hopefully small) risk that this */ + /* won't disturb any device that has a driver. In the unlikely event this happens, */ + /* then re-plugging the device or redirecting via UsbDk should fix it. */ + if (m_FDriversRules.GetCount() > 0 && !hasentry) + { + hasfdriv = false; + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Check 2 FDriverRulesSet has no information so defaulting hasfdriv %d",hasfdriv); + } + } + + /* If a Function Driver for a device is installed after UsbDk and after a device is first */ + /* plugged in, then an Enum/USB entry will have been created for it, and PnP will ignore */ + /* the new Function Driver and assume that the device will continue to be driven in Raw mode. */ + /* This will create difficulty for the user, who then has to uninstall the device manually */ + /* using Device Manager to make it work with its Function Driver as well as UsbDk. We can */ + /* avoid this problem if we set the CONFIGFLAG_REINSTALL in the Enum/USB ConfigFlags, so */ + /* that PnP checks for a Function Driver on a Raw device each time it is plugged in. */ + if (Is2ndCall && Entry != nullptr && hasentry && !hasfdriv) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + "%!FUNC! Setting m_SetReinstall flag on Key '%wZ'",Entry->KeyName()); + status = Device.SetRawDeviceToReinstall(Entry->KeyName()); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to SetRawDeviceToReinstall (status %!STATUS!)",status); + } + } + return !hasfdriv; /* Add RawFilter if there is no function driver */ +} + +bool CUsbDkControlDevice::ShouldRawFilt(const USB_DK_DEVICE_ID &DevId) +{ + bool b = false; + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! About to call ShouldRawFiltDevice()"); + + EnumUsbDevicesByID(DevId, + [&b, this](CUsbDkChildDevice *Child) -> bool + { + b = ShouldRawFiltDevice(*Child, true); + return false; + }); + + return b; +} + bool CUsbDkControlDevice::EnumerateDevices(USB_DK_DEVICE_INFO *outBuff, size_t numberAllocatedDevices, size_t &numberExistingDevices) { numberExistingDevices = 0; @@ -1293,3 +1457,234 @@ NTSTATUS CUsbDkControlDevice::AddPersistentHideRule(const USB_DK_HIDE_RULE &UsbD } return STATUS_INVALID_PARAMETER; } + +/* Create or re-create the "Has Function Driver" registry key Set */ +NTSTATUS CUsbDkControlDevice::ReloadHasDriverList() +{ + /* See if we need to figure out the Registry root of ...\Enum\USB */ + if (!m_FDriverInited) + { + /* Find the first filter */ + CUsbDkFilterDevice *ffilter = nullptr; + m_FilterDevices.ForEach([&ffilter](CUsbDkFilterDevice *Filter) + { + ffilter = Filter; + return false; + }); + + if (ffilter == nullptr) { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_WDFDEVICE, "%!FUNC! No filters"); + STATUS_SUCCESS; /* Ignore error ? */ + } + + /* Get the Hub Driver HW key */ + WDFKEY hwkeyh; + auto status = WdfDeviceOpenRegistryKey(ffilter->WdfObject(), PLUGPLAY_REGKEY_DEVICE, KEY_READ, + WDF_NO_OBJECT_ATTRIBUTES, &hwkeyh); + if (!NT_SUCCESS(status)) { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_WDFDEVICE, "%!FUNC! Failed to get Control HWKey %!STATUS!",status); + return status; + } + + CRegKey regkey; + regkey.Acquire(WdfRegistryWdmGetHandle(hwkeyh)); + + // Could also use ObReferenceObjectByHandle(), ObQueryObjectName(), ObDereferenceObject() + CWdmMemoryBuffer InfoBuffer; + status = regkey.QueryKeyInfo(KeyNameInformation, InfoBuffer); + if (!NT_SUCCESS(status)) { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_WDFDEVICE, "%!FUNC! Failed to get Key path %!STATUS!",status); + WdfRegistryClose(hwkeyh); + return status; + } + + auto NameInfo = reinterpret_cast(InfoBuffer.Ptr()); + CStringHolder RootHolder; + RootHolder.Attach(NameInfo->Name, static_cast(NameInfo->NameLength)); + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_WDFDEVICE, "%!FUNC! Got Hub HW reg path '%wZ'",RootHolder); + + WdfRegistryClose(hwkeyh); + + /* Now we truncate m_RootName at the end of "\Enum\USB\" */ + if (!RootHolder.TruncateAfter(TEXT("\\Enum\\USB\\"))) { + return STATUS_FILE_NOT_AVAILABLE; /* Hmm. */ + } + + m_RootName.Create(RootHolder); + TraceEvents(TRACE_LEVEL_ERROR, TRACE_WDFDEVICE, "%!FUNC! Got Base HW reg path '%wZ'", + m_RootName); + m_FDriverInited = true; + } + + /* Open our root key */ + CRegKey rootkey; + auto status2 = rootkey.Open(*m_RootName); + if (!NT_SUCCESS(status2)) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to open RootKey '%wZ' registry key",m_RootName); + return status2; + } + + FDriverRulesSet Set; + + CStringHolder LocationNameHolder; + auto status = LocationNameHolder.Attach(TEXT("LocationInformation")); + ASSERT(NT_SUCCESS(status)); + + /* Search the sub keys for "VID_????&PID_????" */ + status = rootkey.ForEachSubKey([rootkey, &LocationNameHolder, &Set, this] + (CStringHolder &Sub1Name) + { + if (!Sub1Name.WCMatch(TEXT("VID_????&PID_????"))) + return; + + //TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + // "%!FUNC! Found matching Sub1Name '%wZ'",Sub1Name); + + /* Open the sub-key */ + CRegKey sub1key; + auto status = sub1key.Open(rootkey, *Sub1Name); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to open Sub1Key '%wZ' %!STATUS!",Sub1Name,status); + return; + } + + /* Search the instance sub keys */ + status = sub1key.ForEachSubKey([sub1key, &Sub1Name, &LocationNameHolder, &Set, this] + (CStringHolder &Sub2Name) + { + //TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + // "%!FUNC! Searching instance '%wZ'",Sub2Name); + + /* Open the instance sub-key */ + CRegKey sub2key; + auto status = sub2key.Open(sub1key, *Sub2Name); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to open instance '%wZ' %!STATUS!",Sub2Name,status); + return; + } + + /* Get the a 'LocationInformation' value */ + CWdmMemoryBuffer Buffer; + status = sub2key.QueryValueInfo(*LocationNameHolder, KeyValuePartialInformation, Buffer); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to read value %wZ (status %!STATUS!)", LocationNameHolder, status); + return; + } + //TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + // "%!FUNC! Found subkey 'LocationInfo'"); + + auto Info = reinterpret_cast(Buffer.Ptr()); + + //TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + // "%!FUNC! Info Type = %d, Length = %d", Info->Type,Info->DataLength); + if (Info->Type != REG_SZ + || Info->DataLength > (21 * sizeof(WCHAR))) + return; + + CStringHolder LocationValueHolder; + status = LocationValueHolder.Attach(reinterpret_cast(&Info->Data[0]), + static_cast(Info->DataLength)); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to Attach to Location value (status %!STATUS!)", status); + return; + } + + //TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + // "%!FUNC! Location = '%wZ'",LocationValueHolder); + + if (!LocationValueHolder.WCMatch(TEXT("Port_#????.Hub_#????"))) + return; + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Function Driver Ident = %wZ %wZ'",Sub1Name, LocationValueHolder); + + /* Form the overall device sub-key name */ + CString KeyName; + status = KeyName.Append(m_RootName); + if (NT_SUCCESS(status)) status = KeyName.Append(Sub1Name); + if (NT_SUCCESS(status)) status = KeyName.Append(TEXT("\\")); + if (NT_SUCCESS(status)) status = KeyName.Append(Sub2Name); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to create overall sub-key name (status %!STATUS!)", status); + return; + } + TraceEvents(TRACE_LEVEL_ERROR, TRACE_WDFDEVICE, "%!FUNC! Overall sub-key name '%wZ'", + KeyName); + + /* Get the device ID as integers */ + ULONG VidPid, PortHub; + + status = EightHexToInteger(Sub1Name, 4, 13, &VidPid); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to Convert VidPid string into ULONG (status %!STATUS!)", status); + return; + } + + status = EightHexToInteger(LocationValueHolder, 6, 16, &PortHub); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to Convert PortHub string into ULONG (status %!STATUS!)", status); + return; + } + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Function Driver Ident = 0x%08x 0x%08x'",VidPid, PortHub); + + /* (Note KeyName will be empty after the constructor due to swap.) */ + CObjHolder NewRule(new CUsbDkFDriverRule(VidPid, PortHub, KeyName)); + if (!NewRule) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to allocate new FDriver rule"); + return; + } + + if(!Set.Add(NewRule)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! failed. FDriver rule already present."); + return; + } + + NewRule.detach(); + }); + + if (status == STATUS_OBJECT_NAME_NOT_FOUND) + status = STATUS_SUCCESS; + }); + + if (status == STATUS_OBJECT_NAME_NOT_FOUND) + status = STATUS_SUCCESS; + + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to open Sub1Keys of '%wZ' registry key",PCUNICODE_STRING(m_RootName)); + return status; + } + + /* Overwrite m_FDriversRules with the temporary set */ + Set.MoveList(m_FDriversRules); + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! We now have %d entries in DFDriversRules",m_FDriversRules.GetCount()); + + return STATUS_SUCCESS; +} + diff --git a/UsbDk/ControlDevice.h b/UsbDk/ControlDevice.h index 9223036..08209f3 100644 --- a/UsbDk/ControlDevice.h +++ b/UsbDk/ControlDevice.h @@ -193,7 +193,7 @@ class CUsbDkRedirection : public CAllocatable, pub bool MatchProcess(ULONG pid); NTSTATUS WaitForAttachment() - { return m_RedirectionCreated.Wait(true, -SecondsTo100Nanoseconds(120)); } + { return m_RedirectionCreated.Wait(true, -SecondsTo100Nanoseconds(15)); } bool WaitForDetachment(); @@ -225,6 +225,41 @@ class CUsbDkRedirection : public CAllocatable, pub DECLARE_CWDMLIST_ENTRY(CUsbDkRedirection); }; +class CUsbDkFDriverRule : public CAllocatable < USBDK_NON_PAGED_POOL, 'FDRR' > +{ +public: + + CUsbDkFDriverRule(ULONG VidPid, ULONG PortHub, CString &KeyName) + : m_VidPid(VidPid) + , m_PortHub(PortHub) + { + m_KeyName.Swap(KeyName); + } + + bool Match(const ULONG VidPid, const ULONG PortHub) const + { + return m_VidPid == VidPid && m_PortHub == PortHub; + } + + bool operator ==(const CUsbDkFDriverRule &Other) const + { + return m_VidPid == Other.m_VidPid && + m_PortHub == Other.m_PortHub; + } + + void Dump(LONG traceLevel = m_defaultDumpLevel) const; + + CString& KeyName() { return m_KeyName; } + +private: + ULONG m_VidPid; + ULONG m_PortHub; + CString m_KeyName; /* HKLM/CCS/Enum/USB/VidPid/Location Registry key path */ + + static LONG m_defaultDumpLevel; + DECLARE_CWDMLIST_ENTRY(CUsbDkFDriverRule); +}; + class CDriverParamsRegistryPath final { public: @@ -301,6 +336,9 @@ class CUsbDkControlDevice : private CWdfControlDevice, public CAllocatable void NotifyRedirectionRemoved(const TDevID &Dev) const { @@ -312,6 +350,8 @@ class CUsbDkControlDevice : private CWdfControlDevice, public CAllocatable FDriverRulesSet; + typedef CWdmSet FDriverRulesSet; + FDriverRulesSet m_FDriversRules; + CString m_RootName; + bool m_FDriverInited = false; /* Set after m_RootName created. */ + template bool UsbDevicesForEachIf(TPredicate Predicate, TFunctor Functor) { return m_FilterDevices.ForEach([&](CUsbDkFilterDevice* Dev){ return Dev->EnumerateChildrenIf(Predicate, Functor); }); } diff --git a/UsbDk/DeviceAccess.cpp b/UsbDk/DeviceAccess.cpp index 6810a18..aa722d3 100644 --- a/UsbDk/DeviceAccess.cpp +++ b/UsbDk/DeviceAccess.cpp @@ -156,6 +156,41 @@ NTSTATUS CWdmDeviceAccess::QueryCapabilities(DEVICE_CAPABILITIES &Capabilities) return status; } +/* txType can be DeviceTextDescription or DeviceTextLocationInformation */ +PWCHAR CWdmDeviceAccess::QueryDeviceText(DEVICE_TEXT_TYPE txType) +{ + CIrp irp; + + auto status = irp.Create(m_DevObj); + + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Error %!STATUS! during IRP creation", status); + return nullptr; + } + + irp.Configure([txType] (PIO_STACK_LOCATION s) + { + s->MajorFunction = IRP_MJ_PNP; + s->MinorFunction = IRP_MN_QUERY_DEVICE_TEXT; + s->Parameters.QueryDeviceText.DeviceTextType = txType; + }); + + status = irp.SendSynchronously(); + + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Error %!STATUS! during %!devtx! query", status, txType); + return nullptr; + } + + PWCHAR txData; + irp.ReadResult([&txData](ULONG_PTR information) + { txData = reinterpret_cast(information); }); + + return (txData != nullptr) ? MakeNonPagedDuplicateSz(txData) : nullptr; +} + SIZE_T CWdmDeviceAccess::GetIdBufferLength(BUS_QUERY_ID_TYPE idType, PWCHAR idData) { switch (idType) @@ -186,6 +221,7 @@ bool CWdmDeviceAccess::QueryPowerData(CM_POWER_DATA& powerData) #endif } +#if 0 static void PowerRequestCompletion( _In_ PDEVICE_OBJECT DeviceObject, _In_ UCHAR MinorFunction, @@ -202,6 +238,7 @@ static void PowerRequestCompletion( TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_DEVACCESS, "%!FUNC! -> D%d", PowerState.DeviceState - 1); pev->Set(); } +#endif PWCHAR CWdmDeviceAccess::MakeNonPagedDuplicate(BUS_QUERY_ID_TYPE idType, PWCHAR idData) { @@ -221,6 +258,24 @@ PWCHAR CWdmDeviceAccess::MakeNonPagedDuplicate(BUS_QUERY_ID_TYPE idType, PWCHAR return static_cast(newIdData); } +PWCHAR CWdmDeviceAccess::MakeNonPagedDuplicateSz(PWCHAR txData) +{ + auto bufferLength = CRegSz::GetBufferLength(txData); + + auto newIdData = ExAllocatePoolWithTag(USBDK_NON_PAGED_POOL, bufferLength, 'IDHR'); + if (newIdData != nullptr) + { + RtlCopyMemory(newIdData, txData, bufferLength); + } + else + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! Failed to allocate non-paged buffer for Sz"); + } + + ExFreePool(txData); + return static_cast(newIdData); +} + NTSTATUS CWdmDeviceAccess::QueryForInterface(const GUID &guid, __out INTERFACE &intf, USHORT intfSize, USHORT intfVer, __in_opt PVOID intfCtx) { @@ -253,6 +308,7 @@ NTSTATUS CWdmDeviceAccess::QueryForInterface(const GUID &guid, __out INTERFACE & NTSTATUS CWdmUsbDeviceAccess::Reset(bool ForceD0) { CIoControlIrp Irp; +#if 0 // #115 reports that this can cause a WDF_VIOLATION (10d) error code with some devices. CM_POWER_DATA powerData; if (ForceD0 && QueryPowerData(powerData) && powerData.PD_MostRecentPowerState != PowerDeviceD0) { @@ -266,6 +322,11 @@ NTSTATUS CWdmUsbDeviceAccess::Reset(bool ForceD0) Event.Wait(); } } +#else + ForceD0; +#endif + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! About to sed IOCTL_INTERNAL_USB_CYCLE_PORT"); auto status = Irp.Create(m_DevObj, IOCTL_INTERNAL_USB_CYCLE_PORT); @@ -362,7 +423,8 @@ USB_DK_DEVICE_SPEED UsbDkWdmUsbDeviceGetSpeed(PDEVICE_OBJECT DevObj, PDRIVER_OBJ bool UsbDkGetWdmDeviceIdentity(const PDEVICE_OBJECT PDO, CObjHolder *DeviceID, - CObjHolder *InstanceID) + CObjHolder *InstanceID, + CObjHolder *LocationID) { CWdmDeviceAccess pdoAccess(PDO); @@ -386,6 +448,16 @@ bool UsbDkGetWdmDeviceIdentity(const PDEVICE_OBJECT PDO, } } + if (LocationID != nullptr) + { + *LocationID = pdoAccess.GetLocationID(); + if (!(*LocationID) || (*LocationID)->empty()) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_DEVACCESS, "%!FUNC! No Location ID read"); + return false; + } + } + return true; } diff --git a/UsbDk/DeviceAccess.h b/UsbDk/DeviceAccess.h index 0749102..06bf160 100644 --- a/UsbDk/DeviceAccess.h +++ b/UsbDk/DeviceAccess.h @@ -52,6 +52,7 @@ class CWdmDeviceAccess ULONG GetAddress(); CRegText *GetDeviceID() { return new CRegSz(QueryBusID(BusQueryDeviceID)); } CRegText *GetInstanceID() { return new CRegSz(QueryBusID(BusQueryInstanceID)); } + CRegText *GetLocationID() { return new CRegSz(QueryDeviceText(DeviceTextLocationInformation)); } bool QueryPowerData(CM_POWER_DATA& powerData); protected: PDEVICE_OBJECT m_DevObj; @@ -59,8 +60,10 @@ class CWdmDeviceAccess private: PWCHAR QueryBusID(BUS_QUERY_ID_TYPE idType); NTSTATUS QueryCapabilities(DEVICE_CAPABILITIES &Capabilities); + PWCHAR QueryDeviceText(DEVICE_TEXT_TYPE txType); static PWCHAR MakeNonPagedDuplicate(BUS_QUERY_ID_TYPE idType, PWCHAR idData); + static PWCHAR MakeNonPagedDuplicateSz(PWCHAR txData); static SIZE_T GetIdBufferLength(BUS_QUERY_ID_TYPE idType, PWCHAR idData); }; @@ -118,7 +121,8 @@ class CWdmUSBD bool UsbDkGetWdmDeviceIdentity(const PDEVICE_OBJECT PDO, CObjHolder *DeviceID, - CObjHolder *InstanceID = nullptr); + CObjHolder *InstanceID = nullptr, + CObjHolder *LocationID = nullptr); USB_DK_DEVICE_SPEED UsbDkWdmUsbDeviceGetSpeed(PDEVICE_OBJECT PDO, PDRIVER_OBJECT DriverObject); diff --git a/UsbDk/FilterDevice.cpp b/UsbDk/FilterDevice.cpp index 22d4383..41df826 100644 --- a/UsbDk/FilterDevice.cpp +++ b/UsbDk/FilterDevice.cpp @@ -26,6 +26,7 @@ #include "trace.h" #include "FilterDevice.tmh" #include "DeviceAccess.h" +#include "Registry.h" #include "ControlDevice.h" #include "UsbDkNames.h" @@ -288,8 +289,9 @@ void CUsbDkHubFilterStrategy::DropRemovedDevices(const CDeviceRelations &Relatio }); ToBeDeleted.ForEach([this](CUsbDkChildDevice *Device) -> bool { - m_ControlDevice->NotifyRedirectionRemoved(*Device); - return true; + Device->MarkRawDeviceToReinstall(); + m_ControlDevice->NotifyRedirectionRemoved(*Device); + return true; }); } @@ -327,7 +329,8 @@ void CUsbDkHubFilterStrategy::RegisterNewChild(PDEVICE_OBJECT PDO) CObjHolder DevID; CObjHolder InstanceID; - if (!UsbDkGetWdmDeviceIdentity(PDO, &DevID, &InstanceID)) + CObjHolder LocationID; + if (!UsbDkGetWdmDeviceIdentity(PDO, &DevID, &InstanceID, &LocationID)) { TraceEvents(TRACE_LEVEL_ERROR, TRACE_FILTERDEVICE, "%!FUNC! Cannot query device identity"); return; @@ -336,6 +339,7 @@ void CUsbDkHubFilterStrategy::RegisterNewChild(PDEVICE_OBJECT PDO) TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "%!FUNC! Registering new child (PDO: %p):", PDO); DevID->Dump(); InstanceID->Dump(); + LocationID->Dump(); // Not a USB device -> do not register if (!DevID->MatchPrefix(L"USB\\")) @@ -389,8 +393,8 @@ void CUsbDkHubFilterStrategy::RegisterNewChild(PDEVICE_OBJECT PDO) return; } - CUsbDkChildDevice *Device = new CUsbDkChildDevice(DevID, InstanceID, Port, Speed, DevDescriptor, - CfgDescriptors, *m_Owner, PDO); + CUsbDkChildDevice *Device = new CUsbDkChildDevice(DevID, InstanceID, LocationID, Port, Speed, + DevDescriptor, CfgDescriptors, *m_Owner, PDO); if (Device == nullptr) { @@ -400,6 +404,7 @@ void CUsbDkHubFilterStrategy::RegisterNewChild(PDEVICE_OBJECT PDO) DevID.detach(); InstanceID.detach(); + LocationID.detach(); Children().PushBack(Device); @@ -443,7 +448,8 @@ bool CUsbDkHubFilterStrategy::FetchConfigurationDescriptors(CWdmUsbDeviceAccess void CUsbDkHubFilterStrategy::ApplyRedirectionPolicy(CUsbDkChildDevice &Device) { if (m_ControlDevice->ShouldRedirect(Device) || - m_ControlDevice->ShouldHideDevice(Device)) + m_ControlDevice->ShouldHideDevice(Device) || + m_ControlDevice->ShouldRawFiltDevice(Device, false)) { if (Device.AttachToDeviceStack()) { @@ -589,7 +595,7 @@ bool CUsbDkFilterDevice::CStrategist::SelectStrategy(PDEVICE_OBJECT DevObj) // Get device ID CObjHolder DevID; - if (!UsbDkGetWdmDeviceIdentity(DevObj, &DevID, nullptr)) + if (!UsbDkGetWdmDeviceIdentity(DevObj, &DevID, nullptr, nullptr)) { TraceEvents(TRACE_LEVEL_ERROR, TRACE_FILTERDEVICE, "%!FUNC! Cannot query device ID"); return false; @@ -619,7 +625,7 @@ bool CUsbDkFilterDevice::CStrategist::SelectStrategy(PDEVICE_OBJECT DevObj) // Get instance ID CObjHolder InstanceID; - if (!UsbDkGetWdmDeviceIdentity(DevObj, nullptr, &InstanceID)) + if (!UsbDkGetWdmDeviceIdentity(DevObj, nullptr, &InstanceID, nullptr)) { TraceEvents(TRACE_LEVEL_ERROR, TRACE_FILTERDEVICE, "%!FUNC! Cannot query instance ID"); return false; @@ -675,6 +681,15 @@ bool CUsbDkFilterDevice::CStrategist::SelectStrategy(PDEVICE_OBJECT DevObj) return true; } + // Mark as Raw strategy to allow Redirect CYCLE_PORT to work. + if (m_Strategy->GetControlDevice()->ShouldRawFilt(ID)) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "%!FUNC! Assigning Raw filer device strategy"); + m_Strategy->Delete(); + m_Strategy = &m_RawFilterStrategy; + return true; + } + // No strategy TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "%!FUNC! Do not redirect or already redirected device, no strategy assigned"); @@ -799,3 +814,57 @@ void CUsbDkChildDevice::DetermineDeviceClasses() #endif TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_FILTERDEVICE, "Class mask %08X", m_ClassMaskForExtHider); } + +#ifndef CONFIGFLAG_REINSTALL /* This is in um/RegStr.h */ +# define CONFIGFLAG_REINSTALL 0x00000020 // Redo install +#endif + +void CUsbDkChildDevice::MarkRawDeviceToReinstall() +{ + if (!m_SetReinstall) + return; + + TraceEvents(TRACE_LEVEL_ERROR, TRACE_FILTERDEVICE, "%!FUNC! Found m_SetReinstall flag " + "on Child PDO 0x%p",PDO()); + CRegKey regkey; + + auto status = regkey.Open(*m_HwKeyPath); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to open Key '%wZ' registry key",m_HwKeyPath); + return; + } + CStringHolder ConfigFlagsNameHolder; + status = ConfigFlagsNameHolder.Attach(TEXT("ConfigFlags")); + ASSERT(NT_SUCCESS(status)); + + CWdmMemoryBuffer Buffer; + status = regkey.QueryValueInfo(*ConfigFlagsNameHolder, KeyValuePartialInformation, Buffer); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to read value '%wZ\\ConfigFlags' (status %!STATUS!)", + m_HwKeyPath, status); + return; + } + auto Info = reinterpret_cast(Buffer.Ptr()); + + if (Info->Type != REG_DWORD + || Info->DataLength != sizeof(DWORD32)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Wrong data type/length for value %wZ\\ConfigFlags", m_HwKeyPath); + return; + } + DWORD32 *flags = reinterpret_cast(Info->Data); + *flags |= CONFIGFLAG_REINSTALL; + + status = regkey.SetValueInfo(*ConfigFlagsNameHolder, Info); + if (!NT_SUCCESS(status)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_CONTROLDEVICE, + "%!FUNC! Failed to update %wZ\\ConfigFlags (status %!STATUS!)", + m_HwKeyPath, status); + } +} diff --git a/UsbDk/FilterDevice.h b/UsbDk/FilterDevice.h index 76b38ca..748c3e1 100644 --- a/UsbDk/FilterDevice.h +++ b/UsbDk/FilterDevice.h @@ -51,6 +51,7 @@ class CUsbDkChildDevice : public CAllocatable CUsbDkChildDevice(CRegText *DeviceID, CRegText *InstanceID, + CRegText *LocationID, ULONG Port, USB_DK_DEVICE_SPEED Speed, USB_DEVICE_DESCRIPTOR &DevDescriptor, @@ -59,6 +60,7 @@ class CUsbDkChildDevice : public CAllocatable PDEVICE_OBJECT PDO) : m_DeviceID(DeviceID) , m_InstanceID(InstanceID) + , m_LocationID(LocationID) , m_Port(Port) , m_Speed(Speed) , m_DevDescriptor(DevDescriptor) @@ -72,6 +74,7 @@ class CUsbDkChildDevice : public CAllocatable ULONG ParentID() const; PCWCHAR DeviceID() const { return *m_DeviceID->begin(); } PCWCHAR InstanceID() const { return *m_InstanceID->begin(); } + PCWCHAR LocationID() const { return *m_LocationID->begin(); } ULONG Port() const { return m_Port; } USB_DK_DEVICE_SPEED Speed() const @@ -102,9 +105,17 @@ class CUsbDkChildDevice : public CAllocatable void Dump(); + NTSTATUS SetRawDeviceToReinstall(CString &keypath) + { + m_SetReinstall = true; + return m_HwKeyPath.Create(keypath); + } + void MarkRawDeviceToReinstall(); + private: CObjHolder m_DeviceID; CObjHolder m_InstanceID; + CObjHolder m_LocationID; ULONG m_Port; USB_DK_DEVICE_SPEED m_Speed; USB_DEVICE_DESCRIPTOR m_DevDescriptor; @@ -115,6 +126,9 @@ class CUsbDkChildDevice : public CAllocatable CUsbDkChildDevice(const CUsbDkChildDevice&) = delete; CUsbDkChildDevice& operator= (const CUsbDkChildDevice&) = delete; + bool m_SetReinstall = false; /* Set CONFIGFLAG_REINSTALL on removal */ + CString m_HwKeyPath; + void DetermineDeviceClasses(); DECLARE_CWDMLIST_ENTRY(CUsbDkChildDevice); @@ -206,6 +220,7 @@ class CUsbDkFilterDevice : public CWdfDevice, CUsbDkNullFilterStrategy m_NullStrategy; CUsbDkHubFilterStrategy m_HubStrategy; CUsbDkHiderStrategy m_HiderStrategy; + CUsbDkRawFilterStrategy m_RawFilterStrategy; CUsbDkRedirectorStrategy m_DevStrategy; } m_Strategy; diff --git a/UsbDk/HiderStrategy.h b/UsbDk/HiderStrategy.h index bb6e59f..ddfa328 100644 --- a/UsbDk/HiderStrategy.h +++ b/UsbDk/HiderStrategy.h @@ -44,3 +44,20 @@ class CUsbDkHiderStrategy : public CUsbDkNullFilterStrategy CStopWatch m_RemovalStopWatch; }; + +class CUsbDkRawFilterStrategy : public CUsbDkNullFilterStrategy +{ +public: + virtual NTSTATUS Create(CUsbDkFilterDevice *Owner) override; + virtual void Delete() override; + + virtual NTSTATUS PNPPreProcess(PIRP Irp) override; + virtual NTSTATUS MakeAvailable() override + { + return STATUS_SUCCESS; + } + +private: + void PatchDeviceID(PIRP Irp); + NTSTATUS PatchDeviceText(PIRP Irp); +}; diff --git a/UsbDk/RawFilterStrategy.cpp b/UsbDk/RawFilterStrategy.cpp new file mode 100644 index 0000000..e73f659 --- /dev/null +++ b/UsbDk/RawFilterStrategy.cpp @@ -0,0 +1,125 @@ +/********************************************************************** +* Copyright (c) 2013-2014 Red Hat, Inc. +* +* Developed by Daynix Computing LTD. +* +* Authors: +* Dmitry Fleytman +* Pavel Gurvich +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**********************************************************************/ + +/* This filter strategy simply sets the RawDeviceOK flag. */ +/* As a bug work-around, it also supplies a DeviceText if */ +/* the device itself doesn't supply one, just like the Hide strategy. */ + +#include "stdafx.h" + +#include "HiderStrategy.h" +#include "trace.h" +#include "RawFilterStrategy.tmh" +#include "FilterDevice.h" +#include "ControlDevice.h" +#include "UsbDkNames.h" + +NTSTATUS CUsbDkRawFilterStrategy::Create(CUsbDkFilterDevice *Owner) +{ + auto status = CUsbDkNullFilterStrategy::Create(Owner); + if (NT_SUCCESS(status)) + { + m_ControlDevice->RegisterHiddenDevice(*Owner); + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_HIDER, "%!FUNC! Serial number for this device is %lu", Owner->GetSerialNumber()); + + } + + return status; +} + +void CUsbDkRawFilterStrategy::Delete() +{ + if (m_ControlDevice != nullptr) + { + m_ControlDevice->UnregisterHiddenDevice(*m_Owner); + } + + CUsbDkNullFilterStrategy::Delete(); +} + +NTSTATUS CUsbDkRawFilterStrategy::PatchDeviceText(PIRP Irp) +{ + static const WCHAR UsbDkDeviceText[] = USBDK_DRIVER_NAME L" device"; + + const WCHAR *Buffer = nullptr; + SIZE_T Size = 0; +#if 0 + TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_HIDER, "%!FUNC! Entry"); +#endif + PIO_STACK_LOCATION irpStack = IoGetCurrentIrpStackLocation(Irp); + switch (irpStack->Parameters.QueryDeviceText.DeviceTextType) + { + case DeviceTextDescription: + if (Irp->IoStatus.Information == 0) { /* leave original name if it exists */ + Buffer = &UsbDkDeviceText[0]; + Size = sizeof(UsbDkDeviceText); + } + break; + default: + break; + } + + if (Buffer != nullptr) + { + auto Result = DuplicateStaticBuffer(Buffer, Size); + if (Result != nullptr) + { + if (Irp->IoStatus.Information != 0) + ExFreePool(reinterpret_cast(Irp->IoStatus.Information)); + + Irp->IoStatus.Information = reinterpret_cast(Result); + Irp->IoStatus.Status = STATUS_SUCCESS; + } + } + return Irp->IoStatus.Status; +} + +NTSTATUS CUsbDkRawFilterStrategy::PNPPreProcess(PIRP Irp) +{ + PIO_STACK_LOCATION irpStack = IoGetCurrentIrpStackLocation(Irp); + +TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_HIDER, "%!FUNC! Got Irp 0x%x",irpStack->MinorFunction); + + switch (irpStack->MinorFunction) + { + case IRP_MN_QUERY_CAPABILITIES: + return PostProcessOnSuccess(Irp, + [](PIRP Irp) -> NTSTATUS + { + auto irpStack = IoGetCurrentIrpStackLocation(Irp); + irpStack->Parameters.DeviceCapabilities.Capabilities->RawDeviceOK = 1; + irpStack->Parameters.DeviceCapabilities.Capabilities->Removable = 0; +TraceEvents(TRACE_LEVEL_INFORMATION, TRACE_HIDER, "%!FUNC! Set RawDeviceOK"); + return STATUS_SUCCESS; + }); + case IRP_MN_QUERY_DEVICE_TEXT: + return PostProcess(Irp, + [this](PIRP Irp, NTSTATUS Status) -> NTSTATUS + { + UNREFERENCED_PARAMETER(Status); + return PatchDeviceText(Irp); + }); + default: + return CUsbDkNullFilterStrategy::PNPPreProcess(Irp); + } +} diff --git a/UsbDk/Registry.cpp b/UsbDk/Registry.cpp index 0e344c3..6a03e05 100644 --- a/UsbDk/Registry.cpp +++ b/UsbDk/Registry.cpp @@ -46,6 +46,34 @@ NTSTATUS CRegKey::Open(const CRegKey &ParentKey, const UNICODE_STRING &RegPath) return status; } +NTSTATUS CRegKey::QueryKeyInfo(KEY_INFORMATION_CLASS InfoClass, + CWdmMemoryBuffer &Buffer) +{ + ULONG BytesNeeded = 0; + NTSTATUS status = STATUS_BUFFER_TOO_SMALL; + + while ((status == STATUS_BUFFER_OVERFLOW) || (status == STATUS_BUFFER_TOO_SMALL)) + { + status = Buffer.Recreate(BytesNeeded, PagedPool); + if (NT_SUCCESS(status)) + { + status = ZwQueryKey(m_Key, + InfoClass, + Buffer.Ptr(), + static_cast(Buffer.Size()), + &BytesNeeded); + } + } + + if (!NT_SUCCESS(status) && (status != STATUS_NO_MORE_ENTRIES)) + { + TraceEvents(TRACE_LEVEL_ERROR, TRACE_REGISTRY, + "%!FUNC! Failed to query key info, info class %d (status: %!STATUS!)", InfoClass, status); + } + + return status; +} + NTSTATUS CRegKey::QuerySubkeyInfo(ULONG Index, KEY_INFORMATION_CLASS InfoClass, CWdmMemoryBuffer &Buffer) @@ -105,3 +133,15 @@ NTSTATUS CRegKey::QueryValueInfo(const UNICODE_STRING &ValueName, return status; } + +NTSTATUS CRegKey::SetValueInfo(const UNICODE_STRING &ValueName, + PKEY_VALUE_PARTIAL_INFORMATION Info) +{ + auto status = ZwSetValueKey(m_Key, + const_cast(&ValueName), + NULL, + Info->Type, + Info->Data, + Info->DataLength); + return status; +} diff --git a/UsbDk/Registry.h b/UsbDk/Registry.h index cd57804..d1a934f 100644 --- a/UsbDk/Registry.h +++ b/UsbDk/Registry.h @@ -30,6 +30,8 @@ class CRegKey NTSTATUS Open(const UNICODE_STRING &RegPath) { return Open(CRegKey(), RegPath); } + void Acquire(HANDLE key) { m_Key = key; } + template NTSTATUS ForEachSubKey(TFunctor Functor) { @@ -64,14 +66,21 @@ class CRegKey } } -protected: - NTSTATUS QuerySubkeyInfo(ULONG Index, - KEY_INFORMATION_CLASS InfoClass, - CWdmMemoryBuffer &Buffer); + NTSTATUS QueryKeyInfo(KEY_INFORMATION_CLASS InfoClass, + CWdmMemoryBuffer &Buffer); NTSTATUS QueryValueInfo(const UNICODE_STRING &ValueName, KEY_VALUE_INFORMATION_CLASS InfoClass, CWdmMemoryBuffer &Buffer); + NTSTATUS SetValueInfo(const UNICODE_STRING &ValueName, + PKEY_VALUE_PARTIAL_INFORMATION Info); + +protected: + + NTSTATUS QuerySubkeyInfo(ULONG Index, + KEY_INFORMATION_CLASS InfoClass, + CWdmMemoryBuffer &Buffer); + HANDLE m_Key = nullptr; }; diff --git a/UsbDk/Trace.h b/UsbDk/Trace.h index cd6ecb4..54939c6 100644 --- a/UsbDk/Trace.h +++ b/UsbDk/Trace.h @@ -73,6 +73,7 @@ // FUNC TraceEvents(LEVEL, FLAGS, MSG, ...); // CUSTOM_TYPE(devprop, ItemEnum(DEVICE_REGISTRY_PROPERTY)); // CUSTOM_TYPE(devid, ItemEnum(BUS_QUERY_ID_TYPE)); +// CUSTOM_TYPE(devtx, ItemEnum(DEVICE_TEXT_TYPE)); // CUSTOM_TYPE(pipetype, ItemEnum(_WDF_USB_PIPE_TYPE)); // CUSTOM_TYPE(usbdktransfertype, ItemEnum(USB_DK_TRANSFER_TYPE)); // CUSTOM_TYPE(usbdktransferdirection, ItemEnum(UsbDkTransferDirection)); diff --git a/UsbDk/UsbDk.vcxproj b/UsbDk/UsbDk.vcxproj index 94c01a6..960d84d 100644 --- a/UsbDk/UsbDk.vcxproj +++ b/UsbDk/UsbDk.vcxproj @@ -411,6 +411,7 @@ Use Use + Use Use diff --git a/UsbDk/UsbDk.vcxproj.filters b/UsbDk/UsbDk.vcxproj.filters index 9dc7a39..ebea8a1 100644 --- a/UsbDk/UsbDk.vcxproj.filters +++ b/UsbDk/UsbDk.vcxproj.filters @@ -165,6 +165,9 @@ Source Files + + Source Files + diff --git a/UsbDk/UsbDkUtil.cpp b/UsbDk/UsbDkUtil.cpp index afde765..2211577 100644 --- a/UsbDk/UsbDkUtil.cpp +++ b/UsbDk/UsbDkUtil.cpp @@ -143,6 +143,57 @@ size_t CStringBase::ToWSTR(PWCHAR Buffer, size_t SizeBytes) const return BytesToCopy + sizeof(Buffer[0]); } +/* Truncate at the end of the pattern string. */ +/* Return true on success */ +bool CStringBase::TruncateAfter(PCWSTR patn) +{ + size_t off, len, plen; + + const int bpwc = sizeof(WCHAR); + len = m_String.Length/bpwc; + plen = wcslen(patn); + + if (plen == 0 || len < plen) + return false; + + for (off = 0; len >= plen; off++, len--) { + + /* Look for first char match */ + if (patn[0] != m_String.Buffer[off]) + continue; + + /* See if the strings match */ + if (memcmp(&m_String.Buffer[off], patn, bpwc * plen) == 0) { + m_String.Length = static_cast(bpwc * (off + plen)); /* Truncate */ + return true; + } + } + return false; +} + +/* Do a string match with possible character '?' wilcards in pattern */ +/* Return true on match */ +bool CStringBase::WCMatch(PCWSTR patn) +{ + size_t off, len, plen; + + len = m_String.Length/sizeof(WCHAR); + plen = wcslen(patn); + + if (len > 0 && m_String.Buffer[len-1] == L'\0') + len--; + + if (len != plen) + return false; + + for (off = 0; off < len; off++) { + if (patn[off] != m_String.Buffer[off] + && patn[off] != L'?') + return false; + } + return true; +} + PVOID DuplicateStaticBuffer(const void *Buffer, SIZE_T Length, POOL_TYPE PoolType) { ASSERT(Buffer != nullptr); @@ -193,3 +244,64 @@ LONGLONG CStopWatch::Time100Ns() const KeQueryTickCount(&Now); return (Now.QuadPart - m_StartTime.QuadPart) * m_TimeIncrement; } + +/* Convert four Unicode Hex characters at an offset in a string, into an 16 bit value */ +static NTSTATUS FourHexToInteger( + PCUNICODE_STRING String, /* String to read from */ + USHORT Off, /* Character offset into string of 4 characters */ + PULONG Value /* Output */ +) { + const int bpwc = sizeof(WCHAR); + UNICODE_STRING SubString; + + SubString.Length = bpwc * 4; + SubString.MaximumLength = String->Length - bpwc * Off; + SubString.Buffer = String->Buffer + Off; + + if (SubString.MaximumLength < 4) + return STATUS_INVALID_PARAMETER; + + return RtlUnicodeStringToInteger(&SubString, 16, Value); +} + +/* Convert 2x4 digit hex unicode characters into a 32 bit value */ +NTSTATUS EightHexToInteger( + PCUNICODE_STRING String, /* String to read from */ + USHORT MsOff, /* Char offset into string of MS characters */ + USHORT LsOff, /* Char offset into string of LS characters */ + PULONG Value /* Output */ +) { + NTSTATUS status; + ULONG tval1, tval2; + + if ((status = FourHexToInteger(String, MsOff, &tval1)) != STATUS_SUCCESS) + return status; + if ((status = FourHexToInteger(String, LsOff, &tval2)) != STATUS_SUCCESS) + return status; + + *Value = (tval1 << 16) + tval2; + + return status; +} + +/* Convert 2x4 digit hex unicode characters into a 32 bit value, Sz version. */ +NTSTATUS EightHexToInteger( + PCWCHAR String, /* String to read from */ + USHORT MsOff, /* Char offset into string of MS characters */ + USHORT LsOff, /* Char offset into string of LS characters */ + PULONG Value /* Output */ +) { + const int bpwc = sizeof(WCHAR); + UNICODE_STRING UCString; + size_t slen = wcslen(String); + + if (slen > (bpwc * ((1 << (sizeof(USHORT) * 8))-1))) + return STATUS_INVALID_PARAMETER; + + UCString.Length = static_cast(bpwc * slen); + UCString.MaximumLength = UCString.Length; + UCString.Buffer = const_cast(String); /* UCString becomes const, so safe. */ + + return EightHexToInteger(&UCString, MsOff, LsOff, Value); +} + diff --git a/UsbDk/UsbDkUtil.h b/UsbDk/UsbDkUtil.h index ff599ea..da841e3 100644 --- a/UsbDk/UsbDkUtil.h +++ b/UsbDk/UsbDkUtil.h @@ -215,6 +215,7 @@ class CCountingObject void CounterIncrement() { m_Counter++; } void CounterDecrement() { m_Counter--; } ULONG GetCount() { return m_Counter; } + void ResetCount() { m_Counter = 0; } private: ULONG m_Counter = 0; }; @@ -224,6 +225,7 @@ class CNonCountingObject public: void CounterIncrement() { } void CounterDecrement() { } + void ResetCount() {} protected: ULONG GetCount() { return 0; } }; @@ -449,6 +451,25 @@ class CWdmSet : private TAccessStrategy, public TCountingStrategy CWdmSet(const CWdmSet&) = delete; CWdmSet& operator= (const CWdmSet&) = delete; + + /* Move from this Set to the other Set */ + void MoveList(CWdmSet &OtherList) + { + CLockedContext LockedContextOther(OtherList); + TInternalList &OtherInternalList = OtherList.m_Objects; + OtherInternalList.Clear(); + OtherList.ResetCount(); + + CLockedContext LockedContextThis(*this); + m_Objects.ForEachDetached([this, &OtherList, &OtherInternalList](TEntryType *ExistingEntry) + { + OtherInternalList.PushBack(ExistingEntry); + OtherList.CounterIncrement(); + CounterDecrement(); + return true; + }); + } + private: void SwapLists(TInternalList &OtherList) { @@ -541,6 +562,12 @@ class CStringBase size_t ToWSTR(PWCHAR Buffer, size_t SizeBytes) const; + /* Return true on success */ + bool TruncateAfter(PCWSTR patn); + + /* Character wilcard match. Return true on match */ + bool WCMatch(PCWSTR patn); + protected: CStringBase(const CStringBase&) = delete; CStringBase& operator= (const CStringBase&) = delete; @@ -549,6 +576,8 @@ class CStringBase UNICODE_STRING m_String = {}; }; + +/* Holds a UNICODE_STRING referencing a string stored elsewhere. */ class CStringHolder : public CStringBase { public: @@ -571,6 +600,7 @@ class CStringHolder : public CStringBase CStringHolder& operator= (const CStringHolder&) = delete; }; +/* Stores a UNICODE_STRING it allocates. */ class CString : public CStringBase { public: @@ -593,6 +623,21 @@ class CString : public CStringBase NTSTATUS Append(ULONG Num, ULONG Base = 10); void Destroy(); + NTSTATUS Append(NTSTRSAFE_PCWSTR String) + { + CStringHolder StringH; + auto status = StringH.Attach(String); + if (!NT_SUCCESS(status)) + return status; + return Append(StringH); + } + + void Swap(CString &OtherString) { + UNICODE_STRING TempString = m_String; + m_String = OtherString.m_String; + OtherString.m_String = TempString; + } + CString() { } @@ -663,3 +708,20 @@ class CStopWatch NTSTATUS UsbDkCreateCurrentProcessHandle(HANDLE &Handle); + +/* Convert 2x4 digit hex unicode characters into a 32 bit value */ +NTSTATUS EightHexToInteger( + PCUNICODE_STRING String, /* String to read from */ + USHORT MsOff, /* Char offset into string of MS characters */ + USHORT LsOff, /* Char offset into string of LS characters */ + PULONG Value /* Output */ +); + +/* Convert 2x4 digit hex unicode characters into a 32 bit value, Sz version. */ +NTSTATUS EightHexToInteger( + PCWCHAR String, /* String to read from */ + USHORT MsOff, /* Char offset into string of MS characters */ + USHORT LsOff, /* Char offset into string of LS characters */ + PULONG Value /* Output */ +); +