From 6a26426c70e0cf56af829181abdc4d12fc2648dc Mon Sep 17 00:00:00 2001 From: Mauro Leggieri Date: Wed, 6 Nov 2019 12:18:30 -0300 Subject: [PATCH] Fixed sorted list management and cookie parsing. (#1) * Store error only on close * Fixed sorted list management and cookie parsing. --- Include/ArrayList.h | 165 +++++++++++++++++------------------ Include/Comm/IpcCommon.h | 29 ++++-- Include/Http/HttpCookie.h | 9 ++ Include/Http/HttpServer.h | 14 +-- Include/LinkedList.h | 48 +++++----- Include/PropertyBag.h | 11 ++- Include/RedBlackTree.h | 120 ++++++++++++------------- Source/Comm/HostResolver.cpp | 33 ++++--- Source/Comm/IpcCommon.cpp | 22 ++++- Source/Comm/NamedPipes.cpp | 6 +- Source/Comm/Sockets.cpp | 6 +- Source/Http/HttpClient.cpp | 67 ++++++++------ Source/Http/HttpCookie.cpp | 47 ++++++++++ Source/Http/HttpServer.cpp | 13 ++- Source/Http/Url.cpp | 16 ++-- Source/PropertyBag.cpp | 10 --- Source/TimedEvent.cpp | 37 +++++--- Source/WaitableObjects.cpp | 2 +- Test/TestHttpClient.cpp | 7 +- 19 files changed, 381 insertions(+), 281 deletions(-) diff --git a/Include/ArrayList.h b/Include/ArrayList.h index 4ed1dceee..7b10e52b4 100644 --- a/Include/ArrayList.h +++ b/Include/ArrayList.h @@ -72,44 +72,39 @@ class TArrayList : public virtual CBaseMemObj virtual BOOL SortedInsert(_In_ TType elem) { - SIZE_T nIndex, nMin, nMax; + SIZE_T nIndex; - nMin = 1; //shifted by one to avoid problems with negative indexes - nMax = nCount; //if count == 0, loop will not enter - while (nMin <= nMax) + if (SetSize(nCount + 1) == FALSE) + return FALSE; + nIndex = nCount; + while (nIndex > 0 && elem < lpItems[nIndex - 1]) { - nIndex = nMin + (nMax - nMin) / 2; - if (elem == lpItems[nIndex - 1]) - { - nMin = nIndex; - break; - } - if (elem < lpItems[nIndex - 1]) - nMax = nIndex - 1; - else - nMin = nIndex + 1; + lpItems[nIndex] = lpItems[nIndex - 1]; + nIndex--; } - return InsertElementAt(elem, nMin - 1); + lpItems[nIndex] = elem; + nCount++; + return TRUE; }; template BOOL SortedInsert(_In_ TType elem, _In_ _Comparator lpCompareFunc, _In_opt_ LPVOID lpContext = NULL, _In_opt_ BOOL bDontInsertDuplicates = FALSE, _Out_opt_ LPBOOL lpbAlreadyOnList = NULL) { - SIZE_T nIndex, nMin, nMax; - int res; + SIZE_T nIndex; if (lpbAlreadyOnList != NULL) *lpbAlreadyOnList = FALSE; if (lpCompareFunc == NULL) return FALSE; - nMin = 1; //shifted by one to avoid problems with negative indexes - nMax = nCount; //if count == 0, loop will not enter - while (nMin <= nMax) + + if (SetSize(nCount + 1) == FALSE) + return FALSE; + nIndex = nCount; + while (nIndex > 0) { - nIndex = nMin + (nMax - nMin) / 2; - res = lpCompareFunc(lpContext, &elem, &lpItems[nIndex - 1]); - if (res == 0) + int comp = lpCompareFunc(lpContext, &elem, &lpItems[nIndex - 1]); + if (comp == 0) { if (bDontInsertDuplicates != FALSE) { @@ -117,44 +112,46 @@ class TArrayList : public virtual CBaseMemObj *lpbAlreadyOnList = TRUE; return TRUE; } - nMin = nIndex; - break; } - if (res < 0) - nMax = nIndex - 1; - else - nMin = nIndex + 1; + if (comp >= 0) + break; + lpItems[nIndex] = lpItems[nIndex - 1]; + nIndex--; } - return InsertElementAt(elem, nMin - 1); + lpItems[nIndex] = elem; + nCount++; + return TRUE; }; template - TType* BinarySearchPtr(_In_ _KeyType lpKey, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) + TType* BinarySearchPtr(_In_ _KeyType _key, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) { - SIZE_T nIndex = BinarySearch(lpKey, lpSearchFunc, lpContext); + SIZE_T nIndex = BinarySearch(_key, lpSearchFunc, lpContext); return (nIndex != (SIZE_T)-1) ? lpItems + nIndex : NULL; }; template - SIZE_T BinarySearch(_In_ _KeyType lpKey, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) + SIZE_T BinarySearch(_In_ _KeyType _key, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) { - SIZE_T nMid, nMin, nMax; - int res; + SIZE_T nBase, n; - if (lpKey == NULL || lpSearchFunc == NULL) + if (lpSearchFunc == NULL) return (SIZE_T)-1; - nMin = 1; //shifted by one to avoid problems with negative indexes - nMax = nCount; //if count == 0, loop will not enter - while (nMin <= nMax) - { - nMid = nMin + (nMax - nMin) / 2; - res = lpSearchFunc(lpContext, lpKey, &lpItems[nMid - 1]); - if (res == 0) - return nMid - 1; - if (res < 0) - nMax = nMid - 1; - else - nMin = nMid + 1; + nBase = 0; + n = nCount; + while (n > 0) + { + SIZE_T nMid = nBase + (n >> 1); + + int comp = lpSearchFunc(lpContext, _key, &lpItems[nMid]); + if (comp == 0) + return nMid; + if (comp > 0) + { + nBase = nMid + 1; + n--; + } + n >>= 1; } return (SIZE_T)-1; }; @@ -484,22 +481,20 @@ class TArrayList4Structs : public virtual CBaseMemObj BOOL SortedInsert(_In_ TType *lpElem, _In_ _Comparator lpCompareFunc, _In_opt_ LPVOID lpContext = NULL, _In_opt_ BOOL bDontInsertDuplicates = FALSE, _Out_opt_ LPBOOL lpbAlreadyOnList = NULL) { - SIZE_T nIndex, nMin, nMax; - int res; + SIZE_T nIndex; if (lpbAlreadyOnList != NULL) *lpbAlreadyOnList = FALSE; if (lpElem == NULL && lpCompareFunc == NULL) return FALSE; - if (nCount == 0) - return InsertElementAt(lpElem); - nMin = 1; //shifted by one to avoid problems with negative indexes - nMax = nCount; - while (nMin <= nMax) - { - nIndex = (nMin + nMax) / 2; - res = lpCompareFunc(lpContext, lpElem, &lpItems[nIndex - 1]); - if (res == 0) + + if (SetSize(nCount + 1) == FALSE) + return FALSE; + nIndex = nCount; + while (nIndex > 0) + { + int comp = lpCompareFunc(lpContext, lpElem, &lpItems[nIndex - 1]); + if (comp == 0) { if (bDontInsertDuplicates != FALSE) { @@ -507,44 +502,46 @@ class TArrayList4Structs : public virtual CBaseMemObj *lpbAlreadyOnList = TRUE; return TRUE; } - nMin = nIndex; - break; } - if (res < 0) - nMax = nIndex - 1; - else - nMin = nIndex + 1; + if (comp >= 0) + break; + MxMemCopy(&lpItems[nIndex], &lpItems[nIndex - 1], sizeof(TType)); + nIndex--; } - return InsertElementAt(lpElem, nMin - 1); + MxMemCopy(&lpItems[nIndex], lpElem, sizeof(TType)); + nCount++; + return TRUE; }; template - TType* BinarySearchPtr(_In_ _KeyType lpKey, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) + TType* BinarySearchPtr(_In_ _KeyType _key, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) { - SIZE_T nIndex = BinarySearch(lpKey, lpSearchFunc, lpContext); + SIZE_T nIndex = BinarySearch(_key, lpSearchFunc, lpContext); return (nIndex != (SIZE_T)-1) ? lpItems + nIndex : NULL; }; template - SIZE_T BinarySearch(_In_ _KeyType lpKey, _In_ _Comparator lpCompareFunc, _In_opt_ LPVOID lpContext = NULL) + SIZE_T BinarySearch(_In_ _KeyType _key, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) { - SIZE_T nMid, nMin, nMax; - int res; + SIZE_T nBase, n; - if (lpKey == NULL || lpCompareFunc == NULL) + if (_key == NULL || lpSearchFunc == NULL) return (SIZE_T)-1; - nMin = 1; //shifted by one to avoid problems with negative indexes - nMax = nCount; //if count == 0, loop will not enter - while (nMin <= nMax) - { - nMid = nMin + (nMax - nMin) / 2; - res = lpCompareFunc(lpContext, lpKey, &lpItems[nMid - 1]); - if (res == 0) - return nMid - 1; - if (res < 0) - nMax = nMid - 1; - else - nMin = nMid + 1; + nBase = 0; + n = nCount; + while (n > 0) + { + SIZE_T nMid = nBase + (n >> 1); + + int comp = lpSearchFunc(lpContext, _key, &lpItems[nMid]); + if (comp == 0) + return nMid; + if (comp > 0) + { + nBase = nMid + 1; + n--; + } + n >>= 1; } return (SIZE_T)-1; }; diff --git a/Include/Comm/IpcCommon.h b/Include/Comm/IpcCommon.h index b5c5fc0d5..5130e8f26 100644 --- a/Include/Comm/IpcCommon.h +++ b/Include/Comm/IpcCommon.h @@ -597,8 +597,7 @@ class MX_NOVTABLE CIpc : public virtual CBaseMemObj, public CLoggable protected: friend class CLayer; - class MX_NOVTABLE CConnectionBase : public virtual TRefCounted, - public TRedBlackTreeNode + class MX_NOVTABLE CConnectionBase : public virtual TRefCounted, public TRedBlackTreeNode { MX_DISABLE_COPY_CONSTRUCTOR(CConnectionBase); protected: @@ -606,11 +605,6 @@ class MX_NOVTABLE CIpc : public virtual CBaseMemObj, public CLoggable public: virtual ~CConnectionBase(); - SIZE_T GetNodeKey() const - { - return (SIZE_T)this; - }; - virtual VOID ShutdownLink(_In_ BOOL bAbortive); HRESULT SendMsg(_In_ LPCVOID lpData, _In_ SIZE_T nDataSize, _In_opt_ CLayer *lpToLayer); @@ -661,6 +655,25 @@ class MX_NOVTABLE CIpc : public virtual CBaseMemObj, public CLoggable VOID UpdateStats(_In_ BOOL bRead, _In_ DWORD dwBytesTransferred); VOID GetStats(_In_ BOOL bRead, _Out_ PULONGLONG lpullBytesTransferred, _Out_opt_ float *lpnThroughputKbps); + protected: + static int InsertCompareFunc(_In_ LPVOID lpContext, _In_ CConnectionBase *lpConn1, _In_ CConnectionBase *lpConn2) + { + if ((SIZE_T)lpConn1 < (SIZE_T)lpConn2) + return -1; + if ((SIZE_T)lpConn1 > (SIZE_T)lpConn2) + return 1; + return 0; + }; + + static int SearchCompareFunc(_In_ LPVOID lpContext, _In_ SIZE_T key, _In_ CConnectionBase *lpConn) + { + if (key < (SIZE_T)lpConn) + return -1; + if (key > (SIZE_T)lpConn) + return 1; + return 0; + }; + protected: friend class CIpc; @@ -743,7 +756,7 @@ class MX_NOVTABLE CIpc : public virtual CBaseMemObj, public CLoggable OnEngineErrorCallback cEngineErrorCallback; struct { LONG volatile nRwMutex; - TRedBlackTree cTree; + TRedBlackTree cTree; } sConnections; CPacketList cFreePacketsList32768; CPacketList cFreePacketsList4096; diff --git a/Include/Http/HttpCookie.h b/Include/Http/HttpCookie.h index 55a89c144..bb55876cd 100644 --- a/Include/Http/HttpCookie.h +++ b/Include/Http/HttpCookie.h @@ -39,6 +39,11 @@ namespace MX { class CHttpCookie : public virtual CBaseMemObj { +public: + typedef enum { + SameSiteNone, SameSiteLax, SameSiteStrict + } eSameSite; + public: CHttpCookie(); CHttpCookie(_In_ const CHttpCookie& cSrc) throw(...); @@ -85,6 +90,9 @@ class CHttpCookie : public virtual CBaseMemObj VOID SetHttpOnlyFlag(_In_ BOOL bIsHttpOnly); BOOL GetHttpOnlyFlag() const; + HRESULT SetSameSite(_In_ eSameSite nSameSite); + eSameSite GetSameSite() const; + HRESULT ToString(_Inout_ CStringA& cStrDestA, _In_ BOOL bAddAttributes=TRUE); HRESULT DoesDomainMatch(_In_z_ LPCSTR szDomainToMatchA); @@ -101,6 +109,7 @@ class CHttpCookie : public virtual CBaseMemObj CStringA cStrDomainA; CStringA cStrPathA; CDateTime cExpiresDt; + eSameSite nSameSite; }; //----------------------------------------------------------- diff --git a/Include/Http/HttpServer.h b/Include/Http/HttpServer.h index f8df7e649..f3ebe3797 100644 --- a/Include/Http/HttpServer.h +++ b/Include/Http/HttpServer.h @@ -384,24 +384,24 @@ class CHttpServer : public virtual CThread, public CLoggable, private CCriticalS HRESULT OnDownloadStarted(_Out_ LPHANDLE lphFile, _In_z_ LPCWSTR szFileNameW, _In_ LPVOID lpUserParam); private: - class CRequestLimiter : public TRedBlackTreeNode + class CRequestLimiter : public virtual CBaseMemObj, public TRedBlackTreeNode { public: - CRequestLimiter(_In_ PSOCKADDR_INET lpAddr) : TRedBlackTreeNode() + CRequestLimiter(_In_ PSOCKADDR_INET lpAddr) : CBaseMemObj(), TRedBlackTreeNode() { MxMemCopy(&sAddr, lpAddr, sizeof(SOCKADDR_INET)); _InterlockedExchange(&nCount, 1); return; }; - virtual PSOCKADDR_INET GetNodeKey() const + static int InsertCompareFunc(_In_ LPVOID lpContext, _In_ CRequestLimiter *lpLim1, _In_ CRequestLimiter *lpLim2) { - return &(const_cast(this)->sAddr); + return ::MxMemCompare(&(lpLim1->sAddr), &(lpLim2->sAddr), sizeof(SOCKADDR_INET)); }; - virtual int CompareKeys(_In_ PSOCKADDR_INET key) const + static int SearchCompareFunc(_In_ LPVOID lpContext, _In_ PSOCKADDR_INET key, _In_ CRequestLimiter *lpLim) { - return ::MxMemCompare(key, GetNodeKey(), sizeof(SOCKADDR_INET)); + return ::MxMemCompare(key, &(lpLim->sAddr), sizeof(SOCKADDR_INET)); }; public: @@ -446,7 +446,7 @@ class CHttpServer : public virtual CThread, public CLoggable, private CCriticalS struct { LONG volatile nRwMutex; - TRedBlackTree cTree; + TRedBlackTree cTree; } sRequestLimiter; }; diff --git a/Include/LinkedList.h b/Include/LinkedList.h index d28cfb89c..479b0eb30 100644 --- a/Include/LinkedList.h +++ b/Include/LinkedList.h @@ -26,15 +26,15 @@ namespace MX { -template +template class TLnkLst; -template +template class TLnkLstNode { public: - typedef TLnkLst _LnkLstList; - typedef TLnkLstNode _LnkLstNode; + typedef TLnkLst _LnkLstList; + typedef TLnkLstNode _LnkLstNode; TLnkLstNode() { @@ -53,12 +53,12 @@ class TLnkLstNode return lpPrev; }; - _inline classOrStruct* GetNextEntry() + _inline T* GetNextEntry() { return (lpNext != NULL) ? (lpNext->GetEntry()) : NULL; }; - _inline classOrStruct* GetPrevEntry() + _inline T* GetPrevEntry() { return (lpPrev != NULL) ? (lpPrev->GetEntry()) : NULL; }; @@ -68,9 +68,9 @@ class TLnkLstNode return lpList; }; - _inline classOrStruct* GetEntry() + _inline T* GetEntry() { - return static_cast(this); + return static_cast(this); }; _inline VOID RemoveNode() @@ -81,7 +81,7 @@ class TLnkLstNode }; private: - template + template friend class TLnkLst; _LnkLstList *lpList; @@ -90,14 +90,14 @@ class TLnkLstNode //----------------------------------------------------------- -template -class TLnkLst +template +class TLnkLst : public virtual CBaseMemObj { public: - typedef TLnkLst _LnkLstList; - typedef TLnkLstNode _LnkLstNode; + typedef TLnkLst _LnkLstList; + typedef TLnkLstNode _LnkLstNode; - TLnkLst() + TLnkLst() : CBaseMemObj() { lpHead = lpTail = NULL; #ifdef _DEBUG @@ -220,17 +220,17 @@ class TLnkLst return; }; - _inline classOrStruct* GetHead() + _inline T* GetHead() { return (lpHead != NULL) ? (lpHead->GetEntry()) : NULL; }; - _inline classOrStruct* GetTail() + _inline T* GetTail() { return (lpTail != NULL) ? (lpTail->GetEntry()) : NULL; }; - _inline classOrStruct* PopHead() + _inline T* PopHead() { _LnkLstNode *lpNode; @@ -240,7 +240,7 @@ class TLnkLst return lpNode->GetEntry(); }; - _inline classOrStruct* PopTail() + _inline T* PopTail() { _LnkLstNode *lpNode; @@ -269,18 +269,18 @@ class TLnkLst class Iterator { public: - classOrStruct* Begin(_In_ _LnkLstList &_list) + T* Begin(_In_ _LnkLstList &_list) { lpNextCursor = _list.lpHead; return Next(); }; - classOrStruct* Begin(_In_ const _LnkLstList &_list) + T* Begin(_In_ const _LnkLstList &_list) { return Begin(const_cast<_LnkLstList&>(_list)); }; - classOrStruct* Next() + T* Next() { lpCursor = lpNextCursor; if (lpCursor == NULL) @@ -298,18 +298,18 @@ class TLnkLst class IteratorRev { public: - classOrStruct* Begin(_In_ _LnkLstList &_list) + T* Begin(_In_ _LnkLstList &_list) { lpNextCursor = _list.lpTail; return Next(); }; - classOrStruct* Begin(_In_ const _LnkLstList &_list) + T* Begin(_In_ const _LnkLstList &_list) { return Begin(const_cast<_LnkLstList&>(_list)); }; - classOrStruct* Next() + T* Next() { lpCursor = lpNextCursor; if (lpCursor == NULL) diff --git a/Include/PropertyBag.h b/Include/PropertyBag.h index 1e26df081..9c595d290 100644 --- a/Include/PropertyBag.h +++ b/Include/PropertyBag.h @@ -90,8 +90,15 @@ class CPropertyBag : public virtual CBaseMemObj BOOL Insert(_In_ PROPERTY *lpNewProp); SIZE_T Find(_In_z_ LPCSTR szNameA); - static int InsertCompareFunc(_In_ LPVOID lpContext, _In_ PROPERTY **lpItem1, _In_ PROPERTY **lpItem2); - static int SearchCompareFunc(_In_ LPVOID lpContext, _In_ LPCVOID lpKey, _In_ PROPERTY **lpItem); + static int InsertCompareFunc(_In_ LPVOID lpContext, _In_ PROPERTY **lpItem1, _In_ PROPERTY **lpItem2) + { + return StrCompareA((*lpItem1)->szNameA, (*lpItem2)->szNameA, TRUE); + }; + + static int SearchCompareFunc(_In_ LPVOID lpContext, _In_ LPCVOID lpKey, _In_ PROPERTY **lpItem) + { + return StrCompareA((LPCSTR)lpKey, (*lpItem)->szNameA, TRUE); + }; private: TArrayListWithFree cPropertiesList; diff --git a/Include/RedBlackTree.h b/Include/RedBlackTree.h index d9ec15855..1add7ae0d 100644 --- a/Include/RedBlackTree.h +++ b/Include/RedBlackTree.h @@ -26,23 +26,23 @@ namespace MX { -template +template class TRedBlackTree; -template -class MX_NOVTABLE TRedBlackTreeNode : public virtual CBaseMemObj +template +class TRedBlackTreeNode { public: - typedef TRedBlackTree _RbTree; - typedef TRedBlackTreeNode _RbTreeNode; + typedef TRedBlackTree _RbTree; + typedef TRedBlackTreeNode _RbTreeNode; - template + template friend class TRedBlackTree; - friend class TRedBlackTree; + friend class TRedBlackTree; public: - TRedBlackTreeNode() : CBaseMemObj() + TRedBlackTreeNode() { bRed = FALSE; lpTree = NULL; @@ -50,20 +50,6 @@ class MX_NOVTABLE TRedBlackTreeNode : public virtual CBaseMemObj return; }; - virtual KeyType GetNodeKey() const = 0; - - //Returns -1 if key is less than "this" node's key, 1 if greater or 0 if equal - virtual int CompareKeys(_In_ KeyType key) const - { - KeyType this_key = GetNodeKey(); - - if (key < this_key) - return -1; - if (key > this_key) - return 1; - return 0; - }; - _RbTreeNode* GetNextNode() { _RbTreeNode *lpSucc, *lpNode2; @@ -125,19 +111,19 @@ class MX_NOVTABLE TRedBlackTreeNode : public virtual CBaseMemObj return (this) ? lpParent : NULL; }; - _inline classOrStruct* GetNextEntry() + _inline T* GetNextEntry() { _RbTreeNode *lpNode = GetNextNode(); return (lpNode != NULL) ? (lpNode->GetEntry()) : NULL; }; - _inline classOrStruct* GetPrevEntry() + _inline T* GetPrevEntry() { _RbTreeNode *lpNode = GetPrevNode(); return (lpNode != NULL) ? (lpNode->GetEntry()) : NULL; }; - _inline classOrStruct* GetParentEntry() + _inline T* GetParentEntry() { return (lpParent != NULL) ? (lpParent->GetEntry()) : NULL; }; @@ -147,9 +133,9 @@ class MX_NOVTABLE TRedBlackTreeNode : public virtual CBaseMemObj return lpTree; }; - _inline classOrStruct* GetEntry() + _inline T* GetEntry() { - return static_cast(this); + return static_cast(this); }; _inline VOID RemoveNode() @@ -187,17 +173,17 @@ class MX_NOVTABLE TRedBlackTreeNode : public virtual CBaseMemObj //----------------------------------------------------------- -template -class TRedBlackTree +template +class TRedBlackTree : public virtual CBaseMemObj { public: - typedef TRedBlackTree _RbTree; - typedef TRedBlackTreeNode _RbTreeNode; + typedef TRedBlackTree _RbTree; + typedef TRedBlackTreeNode _RbTreeNode; MX_DISABLE_COPY_CONSTRUCTOR(_RbTree); public: - TRedBlackTree() + TRedBlackTree() : CBaseMemObj() { lpRoot = NULL; nItemsCount = 0; @@ -214,14 +200,14 @@ class TRedBlackTree return nItemsCount; }; - _inline classOrStruct* Find(_In_ KeyType _key) + template + _inline T* Find(_In_ _KeyType _key, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) { _RbTreeNode *lpNode = lpRoot; - int comp; while (lpNode != NULL) { - comp = lpNode->CompareKeys(_key); + int comp = lpSearchFunc(lpContext, _key, lpNode->GetEntry()); if (comp == 0) return lpNode->GetEntry(); lpNode = (comp < 0) ? lpNode->lpLeft : lpNode->lpRight; @@ -229,7 +215,7 @@ class TRedBlackTree return NULL; }; - _inline classOrStruct* GetFirst() + _inline T* GetFirst() { _RbTreeNode *lpNode; @@ -241,7 +227,7 @@ class TRedBlackTree return lpNode->GetEntry(); }; - _inline classOrStruct* GetLast() + _inline T* GetLast() { _RbTreeNode *lpNode; @@ -254,14 +240,15 @@ class TRedBlackTree }; //Try to get entry with key greater or equal to the specified one. Else get nearest less. - _inline classOrStruct* GetCeiling(_In_ KeyType _key) + template + _inline T* GetCeiling(_In_ _KeyType _key, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) { _RbTreeNode *lpNode = lpRoot, *lpParent; int comp; while (lpNode != NULL) { - comp = lpNode->CompareKeys(_key); + int comp = lpSearchFunc(lpContext, _key, lpNode->GetEntry()); if (comp == 0) return lpNode->GetEntry(); if (comp < 0) @@ -292,14 +279,14 @@ class TRedBlackTree }; //Try to get entry with key less or equal to the specified one. Else get nearest greater. - _inline classOrStruct* GetFloor(_In_ KeyType _key) + template + _inline T* GetFloor(_In_ _KeyType _key, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) { _RbTreeNode *lpNode = lpRoot, *lpParent; - int comp; while (lpNode != NULL) { - comp = lpNode->CompareKeys(_key); + int comp = lpSearchFunc(lpContext, _key, lpNode->GetEntry()); if (comp == 0) return lpNode->GetEntry(); if (comp > 0) @@ -329,14 +316,14 @@ class TRedBlackTree }; //Try to get entry with key greater or equal to the specified one. Else null - _inline classOrStruct* GetHigher(_In_ KeyType _key) + template + _inline T* GetHigher(_In_ _KeyType _key, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) { _RbTreeNode *lpNode = lpRoot, *lpParent; - int comp; while (lpNode != NULL) { - comp = lpNode->CompareKeys(_key); + int comp = lpSearchFunc(lpContext, _key, lpNode->GetEntry()); if (comp < 0) { if (lpNode->lpLeft == NULL) @@ -365,14 +352,14 @@ class TRedBlackTree }; //Try to get entry with key less or equal to the specified one. Else null - _inline classOrStruct* GetLower(_In_ KeyType _key) + template + _inline T* GetLower(_In_ _KeyType _key, _In_ _Comparator lpSearchFunc, _In_opt_ LPVOID lpContext = NULL) { _RbTreeNode *lpNode = lpRoot, *lpParent; - int comp; while (lpNode != NULL) { - comp = lpNode->CompareKeys(_key); + int comp = lpSearchFunc(lpContext, _key, lpNode->GetEntry()); if (comp > 0) { if (lpNode->lpRight == NULL) @@ -400,18 +387,19 @@ class TRedBlackTree return NULL; }; - BOOL Insert(_In_ _RbTreeNode *lpNewNode, _In_opt_ BOOL bAllowDuplicates = FALSE, - _In_opt_ _RbTreeNode **lplpMatchingNode = NULL) + template + BOOL Insert(_In_ _RbTreeNode *lpNewNode, _In_ _Comparator lpCompareFunc, _In_opt_ BOOL bAllowDuplicates = FALSE, + _In_opt_ T **lplpMatchingEntry = NULL, _In_opt_ LPVOID lpContext = NULL) { _RbTreeNode *lpNode, *lpParent, *lpUncle; - int res; + int comp; MX_ASSERT(lpNewNode != NULL); MX_ASSERT(lpNewNode->lpParent == NULL); MX_ASSERT(lpNewNode->lpLeft == NULL && lpNewNode->lpRight == NULL); - if (lplpMatchingNode != NULL) - *lplpMatchingNode = NULL; + if (lplpMatchingEntry != NULL) + *lplpMatchingEntry = NULL; if (lpRoot == NULL) { lpNewNode->lpParent = NULL; @@ -431,19 +419,19 @@ class TRedBlackTree do { lpParent = lpNode; - res = lpNode->CompareKeys(lpNewNode->GetNodeKey()); - if (res == 0 && bAllowDuplicates == FALSE) + comp = lpCompareFunc(lpContext, lpNewNode->GetEntry(), lpNode->GetEntry()); + if (comp == 0 && bAllowDuplicates == FALSE) { - if (lplpMatchingNode != NULL) - *lplpMatchingNode = lpNode; + if (lplpMatchingEntry != NULL) + *lplpMatchingEntry = lpNode->GetEntry(); return FALSE; } - lpNode = (res < 0) ? lpNode->lpLeft : lpNode->lpRight; + lpNode = (comp < 0) ? lpNode->lpLeft : lpNode->lpRight; } while (lpNode != NULL); lpNewNode->lpLeft = lpNewNode->lpRight = NULL; lpNewNode->lpParent = lpParent; - if (res < 0) + if (comp < 0) lpParent->lpLeft = lpNewNode; else lpParent->lpRight = lpNewNode; @@ -861,18 +849,18 @@ class TRedBlackTree class Iterator { public: - classOrStruct* Begin(_In_ _RbTree &_tree) + T* Begin(_In_ _RbTree &_tree) { lpNextCursor = _tree.GetFirstNode(); return Next(); }; - classOrStruct* Begin(_In_ const _RbTree &_tree) + T* Begin(_In_ const _RbTree &_tree) { return Begin(const_cast<_RbTree&>(_tree)); }; - classOrStruct* Next() + T* Next() { _RbTreeNode *lpCursor = lpNextCursor; if (lpCursor == NULL) @@ -890,18 +878,18 @@ class TRedBlackTree class IteratorRev { public: - classOrStruct* Begin(_In_ _RbTree &_tree) + T* Begin(_In_ _RbTree &_tree) { lpNextCursor = _tree.GetLast(); return Next(); }; - classOrStruct* Begin(_In_ const _RbTree &_tree) + T* Begin(_In_ const _RbTree &_tree) { return Begin(const_cast<_RbTree&>(_tree)); }; - classOrStruct* Next() + T* Next() { _RbTreeNode *lpCursor = lpNextCursor; if (lpCursor == NULL) @@ -915,7 +903,7 @@ class TRedBlackTree }; private: - template + template friend class TRedBlackTreeNode; _RbTreeNode *lpRoot; diff --git a/Source/Comm/HostResolver.cpp b/Source/Comm/HostResolver.cpp index b396b7cc4..10c656a46 100644 --- a/Source/Comm/HostResolver.cpp +++ b/Source/Comm/HostResolver.cpp @@ -82,14 +82,14 @@ class CHostResolver : public TRefCounted { MX_DISABLE_COPY_CONSTRUCTOR(CHostResolver); public: - class CAsyncItem : public TRedBlackTreeNode + class CAsyncItem : public virtual CBaseMemObj, public TRedBlackTreeNode { MX_DISABLE_COPY_CONSTRUCTOR(CAsyncItem); public: CAsyncItem(_In_ CHostResolver *_lpResolver, _In_z_ LPCWSTR szHostNameW, _In_ int nDesiredFamily, _In_ DWORD dwTimeoutMs, _In_ PSOCKADDR_INET lpSockAddr, _In_ HostResolver::OnResultCallback cCallback, _In_ LPVOID lpUserData, - _In_ lpfnFreeAddrInfoExW _fnFreeAddrInfoExW) : TRedBlackTreeNode() + _In_ lpfnFreeAddrInfoExW _fnFreeAddrInfoExW) : CBaseMemObj(), TRedBlackTreeNode() { lpResolver = _lpResolver; fnFreeAddrInfoExW = _fnFreeAddrInfoExW; @@ -106,11 +106,6 @@ class CHostResolver : public TRefCounted return; }; - virtual ULONG GetNodeKey() const - { - return (ULONG)nId; - }; - VOID Setup(_In_z_ LPCWSTR szHostNameW, _In_ int _nDesiredFamily, _In_ DWORD _dwTimeoutMs, _In_ PSOCKADDR_INET _lpSockAddr, _In_ HostResolver::OnResultCallback _cCallback, _In_ LPVOID _lpUserData) { @@ -153,6 +148,24 @@ class CHostResolver : public TRefCounted return; }; + static int InsertCompareFunc(_In_ LPVOID lpContext, _In_ CAsyncItem *lpAsyncItem1, _In_ CAsyncItem *lpAsyncItem2) + { + if ((ULONG)(lpAsyncItem1->nId) < (ULONG)(lpAsyncItem2->nId)) + return -1; + if ((ULONG)(lpAsyncItem1->nId) > (ULONG)(lpAsyncItem2->nId)) + return 1; + return 0; + }; + + static int SearchCompareFunc(_In_ LPVOID lpContext, _In_ ULONG key, _In_ CAsyncItem *lpAsyncItem) + { + if (key < (ULONG)(lpAsyncItem->nId)) + return -1; + if (key > (ULONG)(lpAsyncItem->nId)) + return 1; + return 0; + }; + public: CHostResolver *lpResolver; lpfnFreeAddrInfoExW fnFreeAddrInfoExW; @@ -210,7 +223,7 @@ class CHostResolver : public TRefCounted lpfnGetAddrInfoExOverlappedResult fnGetAddrInfoExOverlappedResult; struct { LONG volatile nMutex; - TRedBlackTree cTree; + TRedBlackTree cTree; } sAsyncTasks; struct { LONG volatile nMutex; @@ -1171,7 +1184,7 @@ HRESULT CHostResolver::AddResolverCommon(_Out_ LONG volatile *lpnResolverId, _In _InterlockedExchange(lpnResolverId, lpNewAsyncItem->nId); - sAsyncTasks.cTree.Insert(lpNewAsyncItem, TRUE); + sAsyncTasks.cTree.Insert(lpNewAsyncItem, &CAsyncItem::InsertCompareFunc, TRUE); tv.tv_sec = (long)(lpNewAsyncItem->dwTimeoutMs / 1000); tv.tv_usec = (long)(lpNewAsyncItem->dwTimeoutMs % 1000); @@ -1206,7 +1219,7 @@ VOID CHostResolver::RemoveResolver(_Out_ LONG volatile *lpnResolverId) { CFastLock cQueueLock(&(sAsyncTasks.nMutex)); - lpAsyncItem = sAsyncTasks.cTree.Find(nResolver); + lpAsyncItem = sAsyncTasks.cTree.Find(nResolver, &CAsyncItem::SearchCompareFunc); if (lpAsyncItem != NULL) { //only running async tasks are in the tree and can be canceled diff --git a/Source/Comm/IpcCommon.cpp b/Source/Comm/IpcCommon.cpp index 4d6cf2e73..be5a992c5 100644 --- a/Source/Comm/IpcCommon.cpp +++ b/Source/Comm/IpcCommon.cpp @@ -644,7 +644,7 @@ VOID CIpc::InternalFinalize() { { CAutoSlimRWLShared cConnListLock(&(sConnections.nRwMutex)); - TRedBlackTree::Iterator it; + TRedBlackTree::Iterator it; for (lpConn = it.Begin(sConnections.cTree); lpConn != NULL; lpConn = it.Next()) { @@ -860,7 +860,7 @@ CIpc::CConnectionBase* CIpc::CheckAndGetConnection(_In_opt_ HANDLE h) CAutoSlimRWLShared cConnListLock(&(sConnections.nRwMutex)); CConnectionBase *lpConn; - lpConn = sConnections.cTree.Find((SIZE_T)h); + lpConn = sConnections.cTree.Find((SIZE_T)h, &CConnectionBase::SearchCompareFunc); if (lpConn != NULL) { if (lpConn->SafeAddRef() > 0) @@ -1410,7 +1410,7 @@ BOOL CIpc::OnPreprocessPacket(_In_ DWORD dwBytes, _In_ CPacketBase *lpPacket, _I //----------------------------------------------------------- CIpc::CConnectionBase::CConnectionBase(_In_ CIpc *_lpIpc, _In_ CIpc::eConnectionClass _nClass) : - TRefCounted(), TRedBlackTreeNode() + TRefCounted(), TRedBlackTreeNode() { lpIpc = _lpIpc; nClass = _nClass; @@ -1680,7 +1680,10 @@ VOID CIpc::CConnectionBase::Close(_In_ HRESULT hRes) nOrigVal = _InterlockedCompareExchange(&nFlags, nNewVal, nInitVal); } while (nOrigVal != nInitVal); - _InterlockedCompareExchange(&hrErrorCode, hRes, S_OK); + if ((nInitVal & FLAG_Closed) == 0) + { + _InterlockedCompareExchange(&hrErrorCode, hRes, S_OK); + } if ((hRes & 0x0F000000) != 0) { SIZE_T nStack[10]; @@ -1833,6 +1836,8 @@ HRESULT CIpc::CConnectionBase::DoZeroRead(_In_ SIZE_T nPacketsCount, _Inout_ CPa while (nPacketsCount > 0) { + if (IsClosed() != FALSE) + return MX_E_Cancelled; lpPacket = lpIpc->GetPacket(this, CPacketBase::TypeZeroRead, 4096, FALSE); if (lpPacket == NULL) return E_OUTOFMEMORY; @@ -1888,6 +1893,12 @@ HRESULT CIpc::CConnectionBase::DoRead(_In_ SIZE_T nPacketsCount, _In_opt_ CPacke while (nPacketsCount > 0) { + if (IsClosed() != FALSE) + { + if (lpReusePacket != NULL) + FreePacket(lpReusePacket); + return MX_E_Cancelled; + } if (lpReusePacket != NULL) { lpReusePacket->Reset(CPacketBase::TypeRead, this); @@ -2006,6 +2017,9 @@ HRESULT CIpc::CConnectionBase::SendPackets(_Inout_ CPacketBase **lplpFirstPacket while ((*lpnChainLength) > 0) { + if (IsClosed() != FALSE) + return MX_E_Cancelled; + //extract packets for this round lpChainStart = lpPrevPacket = (*lplpFirstPacket); lpCurrPacket = lpChainStart->GetChainedPacket(); diff --git a/Source/Comm/NamedPipes.cpp b/Source/Comm/NamedPipes.cpp index 0389b3c79..d62396c8b 100644 --- a/Source/Comm/NamedPipes.cpp +++ b/Source/Comm/NamedPipes.cpp @@ -261,7 +261,7 @@ HRESULT CNamedPipes::ConnectToServer(_In_z_ LPCWSTR szServerNameW, _In_ OnCreate { CAutoSlimRWLExclusive cConnListLock(&(sConnections.nRwMutex)); - sConnections.cTree.Insert(lpNewConn); + sConnections.cTree.Insert(lpNewConn, &CConnectionBase::InsertCompareFunc); } hRes = FireOnCreate(lpNewConn); if (SUCCEEDED(hRes)) @@ -378,7 +378,7 @@ HRESULT CNamedPipes::CreateRemoteClientConnection(_In_ HANDLE hProc, _Out_ HANDL { CAutoSlimRWLExclusive cConnListLock(&(sConnections.nRwMutex)); - sConnections.cTree.Insert(lpNewConn); + sConnections.cTree.Insert(lpNewConn, &CConnectionBase::InsertCompareFunc); } hRes = FireOnCreate(lpNewConn); if (SUCCEEDED(hRes)) @@ -444,7 +444,7 @@ HRESULT CNamedPipes::CreateServerConnection(_In_ CServerInfo *lpServerInfo, _In_ { CAutoSlimRWLExclusive cConnListLock(&(sConnections.nRwMutex)); - sConnections.cTree.Insert(lpNewConn); + sConnections.cTree.Insert(lpNewConn, &CConnectionBase::InsertCompareFunc); } hRes = FireOnCreate(lpNewConn); if (SUCCEEDED(hRes)) diff --git a/Source/Comm/Sockets.cpp b/Source/Comm/Sockets.cpp index 92f66ae49..18884f1ca 100644 --- a/Source/Comm/Sockets.cpp +++ b/Source/Comm/Sockets.cpp @@ -292,7 +292,7 @@ HRESULT CSockets::CreateListener(_In_ eFamily nFamily, _In_ int nPort, _In_ OnCr { CAutoSlimRWLExclusive cConnListLock(&(sConnections.nRwMutex)); - sConnections.cTree.Insert(cConn.Get()); + sConnections.cTree.Insert(cConn.Get(), &CConnectionBase::InsertCompareFunc); } hRes = FireOnCreate(cConn.Get()); if (SUCCEEDED(hRes)) @@ -357,7 +357,7 @@ HRESULT CSockets::ConnectToServer(_In_ eFamily nFamily, _In_z_ LPCSTR szAddressA { CAutoSlimRWLExclusive cConnListLock(&(sConnections.nRwMutex)); - sConnections.cTree.Insert(cConn.Get()); + sConnections.cTree.Insert(cConn.Get(), &CConnectionBase::InsertCompareFunc); } hRes = FireOnCreate(cConn.Get()); if (SUCCEEDED(hRes)) @@ -573,7 +573,7 @@ HRESULT CSockets::CreateServerConnection(_In_ CConnection *lpListenConn) { CAutoSlimRWLExclusive cConnListLock(&(sConnections.nRwMutex)); - sConnections.cTree.Insert(cIncomingConn.Get()); + sConnections.cTree.Insert(cIncomingConn.Get(), &CConnectionBase::InsertCompareFunc); } cIncomingConn->AddRef(); diff --git a/Source/Http/HttpClient.cpp b/Source/Http/HttpClient.cpp index 201f2f0f9..fba7b2b14 100644 --- a/Source/Http/HttpClient.cpp +++ b/Source/Http/HttpClient.cpp @@ -1025,6 +1025,7 @@ HRESULT CHttpClient::InternalOpen(_In_ CUrl &cUrl) return E_INVALIDARG; if (cUrl.GetHost()[0] == 0) return E_INVALIDARG; + if (cUrl.GetPort() >= 0) nUrlPort = cUrl.GetPort(); else if (cUrl.GetSchemeCode() == CUrl::SchemeHttps) @@ -1074,7 +1075,7 @@ HRESULT CHttpClient::InternalOpen(_In_ CUrl &cUrl) if (sResponse.cHttpCmn.GetParserState() != CHttpCommon::StateDone || StrCompareW(sRequest.cUrl.GetScheme(), cUrl.GetScheme()) != 0 || StrCompareW(sRequest.cUrl.GetHost(), szConnectHostW) != 0 || - sRequest.cUrl.GetPort() != nConnectPort) + nUrlPort != nConnectPort) { cSocketMgr.Close(hConn, S_OK); hConn = NULL; @@ -1099,7 +1100,6 @@ HRESULT CHttpClient::InternalOpen(_In_ CUrl &cUrl) { GenerateRequestBoundary(); - sRequest.cUrl.SetPort(nUrlPort); hRes = cSocketMgr.ConnectToServer(CSockets::FamilyIPv4, szConnectHostW, nConnectPort, MX_BIND_MEMBER_CALLBACK(&CHttpClient::OnSocketCreate, this), NULL, &hConn); } @@ -1487,12 +1487,9 @@ HRESULT CHttpClient::OnSocketDataReceived(_In_ CIpc *lpIpc, _In_ HANDLE h, _In_ CHttpHeaderRespLocation *lpHeader = sResponse.cHttpCmn.GetHeader(); if (lpHeader != NULL) { + hRes = cUrlTemp.ParseFromString(lpHeader->GetLocation()); if (SUCCEEDED(hRes)) - { - hRes = cUrlTemp.ParseFromString(lpHeader->GetLocation()); - if (SUCCEEDED(hRes)) - hRes = sRedirectOrRetryAuth.cUrl.Merge(cUrlTemp); - } + hRes = sRedirectOrRetryAuth.cUrl.Merge(cUrlTemp); } else { @@ -1800,28 +1797,28 @@ VOID CHttpClient::OnRedirectOrRetryAuth(_In_ LONG nTimerId, _In_ LPVOID lpUserDa VOID CHttpClient::OnResponseHeadersTimeout(_In_ LONG nTimerId, _In_ LPVOID lpUserData) { CAutoRundownProtection cAutoRundownProt(&nRundownLock); + HRESULT hRes; if (cAutoRundownProt.IsAcquired() != FALSE) + return; + + hRes = S_OK; + { - HRESULT hRes; + CCriticalSection::CAutoLock cLock(cMutex); - hRes = S_OK; + if (_InterlockedCompareExchange(&(sResponse.nTimerId), 0, nTimerId) == nTimerId) { - CCriticalSection::CAutoLock cLock(cMutex); - - if (_InterlockedCompareExchange(&(sResponse.nTimerId), 0, nTimerId) == nTimerId) + if (nState == StateReceivingResponseHeaders || nState == StateReceivingResponseBody || + nState == StateWaitingProxyTunnelConnectionResponse) { - if (nState == StateReceivingResponseHeaders || nState == StateReceivingResponseBody || - nState == StateWaitingProxyTunnelConnectionResponse) - { - SetErrorOnRequestAndClose(hRes = MX_E_Timeout); - } + SetErrorOnRequestAndClose(hRes = MX_E_Timeout); } } - //raise error event if any - if (FAILED(hRes) && cErrorCallback) - cErrorCallback(this, hRes); } + //raise error event if any + if (FAILED(hRes) && cErrorCallback) + cErrorCallback(this, hRes); return; } @@ -1899,7 +1896,9 @@ VOID CHttpClient::OnAfterSendRequestHeaders(_In_ CIpc *lpIpc, _In_ HANDLE h, _In nState = StateReceivingResponseHeaders; //set response headers timeout + cMutex.Unlock(); hRes = SetupResponseHeadersTimeout(); + cMutex.Lock(); } else { @@ -1925,7 +1924,7 @@ VOID CHttpClient::OnAfterSendRequestBody(_In_ CIpc *lpIpc, _In_ HANDLE h, _In_ L _In_ CIpc::CUserData *lpUserData) { CAutoRundownProtection cAutoRundownProt(&nRundownLock); - HRESULT hRes = S_OK; + BOOL bSetTimeouts; if (cAutoRundownProt.IsAcquired() == FALSE) return; @@ -1933,17 +1932,27 @@ VOID CHttpClient::OnAfterSendRequestBody(_In_ CIpc *lpIpc, _In_ HANDLE h, _In_ L { CCriticalSection::CAutoLock cLock(cMutex); - if (h == hConn && nState == StateReceivingResponseHeaders) + bSetTimeouts = (h == hConn && nState == StateReceivingResponseHeaders) ? TRUE : FALSE; + } + //set response headers timeout + if (bSetTimeouts != FALSE) + { + HRESULT hRes; + + hRes = SetupResponseHeadersTimeout(); + if (FAILED(hRes)) { - //set response headers timeout - hRes = SetupResponseHeadersTimeout(); + { + CCriticalSection::CAutoLock cLock(cMutex); + + SetErrorOnRequestAndClose(hRes); + } + + //raise error event if any + if (cErrorCallback) + cErrorCallback(this, hRes); } - if (FAILED(hRes)) - SetErrorOnRequestAndClose(hRes); } - //raise error event if any - if (FAILED(hRes) && cErrorCallback) - cErrorCallback(this, hRes); return; } diff --git a/Source/Http/HttpCookie.cpp b/Source/Http/HttpCookie.cpp index f2aea5ef7..dc21eb6d7 100644 --- a/Source/Http/HttpCookie.cpp +++ b/Source/Http/HttpCookie.cpp @@ -47,12 +47,14 @@ namespace MX { CHttpCookie::CHttpCookie() : CBaseMemObj() { nFlags = 0; + nSameSite = SameSiteNone; return; } CHttpCookie::CHttpCookie(_In_ const CHttpCookie& cSrc) throw(...) : CBaseMemObj() { nFlags = 0; + nSameSite = SameSiteNone; operator=(cSrc); return; } @@ -78,6 +80,7 @@ CHttpCookie& CHttpCookie::operator=(_In_ const CHttpCookie& cSrc) throw(...) throw (LONG)E_OUTOFMEMORY; nFlags = cSrc.nFlags; + nSameSite = cSrc.nSameSite; cStrNameA.Attach(cStrTempNameA.Detach()); cStrValueA.Attach(cStrTempValueA.Detach()); cStrDomainA.Attach(cStrTempDomainA.Detach()); @@ -95,6 +98,7 @@ VOID CHttpCookie::Clear() cStrPathA.Empty(); cExpiresDt.Clear(); nFlags = 0; + nSameSite = SameSiteNone; return; } @@ -326,6 +330,19 @@ BOOL CHttpCookie::GetHttpOnlyFlag() const return ((nFlags & COOKIE_FLAG_HTTPONLY) != 0) ? TRUE : FALSE; } +HRESULT CHttpCookie::SetSameSite(_In_ eSameSite _nSameSite) +{ + if (_nSameSite != SameSiteNone && _nSameSite != SameSiteLax && _nSameSite != SameSiteStrict) + return E_INVALIDARG; + nSameSite = _nSameSite; + return S_OK; +} + +CHttpCookie::eSameSite CHttpCookie::GetSameSite() const +{ + return nSameSite; +} + HRESULT CHttpCookie::ToString(_Inout_ CStringA& cStrDestA, _In_ BOOL bAddAttributes) { CDateTime cTempDt, *lpDt; @@ -392,6 +409,17 @@ HRESULT CHttpCookie::ToString(_Inout_ CStringA& cStrDestA, _In_ BOOL bAddAttribu if (cStrDestA.Concat("; HttpOnly") == FALSE) return E_OUTOFMEMORY; } + switch (nSameSite) + { + case SameSiteLax: + if (cStrDestA.Concat("; SameSite=Lax") == FALSE) + return E_OUTOFMEMORY; + break; + case SameSiteStrict: + if (cStrDestA.Concat("; SameSite=Strict") == FALSE) + return E_OUTOFMEMORY; + break; + } } return S_OK; } @@ -576,6 +604,25 @@ HRESULT CHttpCookie::ParseFromResponseHeader(_In_z_ LPCSTR szSrcA, _In_opt_ SIZE if (FAILED(hRes)) return hRes; } + else if (StrCompareA((LPSTR)cStrTempNameA, "SameSite", TRUE) == 0) + { + if (cStrTempValueA.IsEmpty() != FALSE || StrCompareA((LPCSTR)cStrTempValueA, "none", TRUE) == 0) + { + nSameSite = SameSiteNone; + } + else if (StrCompareA((LPCSTR)cStrTempValueA, "lax", TRUE) == 0) + { + nSameSite = SameSiteLax; + } + else if (StrCompareA((LPCSTR)cStrTempValueA, "strict", TRUE) == 0) + { + nSameSite = SameSiteStrict; + } + else + { + return MX_E_InvalidData; + } + } else { return MX_E_InvalidData; diff --git a/Source/Http/HttpServer.cpp b/Source/Http/HttpServer.cpp index 2d17687ef..23c841b48 100644 --- a/Source/Http/HttpServer.cpp +++ b/Source/Http/HttpServer.cpp @@ -753,14 +753,14 @@ VOID CHttpServer::OnSocketDestroy(_In_ CIpc *lpIpc, _In_ HANDLE h, _In_ CIpc::CU CAutoSlimRWLShared cLimiterLock(&(sRequestLimiter.nRwMutex)); CRequestLimiter *lpLimiter; - lpLimiter = sRequestLimiter.cTree.Find(&(lpRequest->sPeerAddr)); + lpLimiter = sRequestLimiter.cTree.Find(&(lpRequest->sPeerAddr), &CRequestLimiter::SearchCompareFunc); if (lpLimiter != NULL) { if ((DWORD)_InterlockedDecrement(&(lpLimiter->nCount)) == 0) { cLimiterLock.UpgradeToExclusive(); - lpLimiter = sRequestLimiter.cTree.Find(&(lpRequest->sPeerAddr)); + lpLimiter = sRequestLimiter.cTree.Find(&(lpRequest->sPeerAddr), &CRequestLimiter::SearchCompareFunc); if (lpLimiter != NULL) { if (__InterlockedRead(&(lpLimiter->nCount)) == 0) @@ -803,7 +803,7 @@ HRESULT CHttpServer::OnSocketConnect(_In_ CIpc *lpIpc, _In_ HANDLE h, _In_ CIpc: CAutoSlimRWLShared cLimiterLock(&(sRequestLimiter.nRwMutex)); CRequestLimiter *lpLimiter; - lpLimiter = sRequestLimiter.cTree.Find(&(lpNewRequest->sPeerAddr)); + lpLimiter = sRequestLimiter.cTree.Find(&(lpNewRequest->sPeerAddr), &CRequestLimiter::SearchCompareFunc); if (lpLimiter != NULL) { if ((DWORD)_InterlockedIncrement(&(lpLimiter->nCount)) > dwMaxConnectionsPerIp) @@ -811,17 +811,16 @@ HRESULT CHttpServer::OnSocketConnect(_In_ CIpc *lpIpc, _In_ HANDLE h, _In_ CIpc: } else { - MX::TRedBlackTreeNode *_lpMatchingLimiter; + MX::CHttpServer::CRequestLimiter *lpMatchingLimiter; lpLimiter = MX_DEBUG_NEW CRequestLimiter(&(lpNewRequest->sPeerAddr)); if (lpLimiter == NULL) return E_OUTOFMEMORY; cLimiterLock.UpgradeToExclusive(); - if (sRequestLimiter.cTree.Insert(lpLimiter, FALSE, &_lpMatchingLimiter) == FALSE) + if (sRequestLimiter.cTree.Insert(lpLimiter, &CRequestLimiter::InsertCompareFunc, FALSE, + &lpMatchingLimiter) == FALSE) { - CRequestLimiter *lpMatchingLimiter = static_cast(_lpMatchingLimiter); - delete lpLimiter; if ((DWORD)_InterlockedIncrement(&(lpMatchingLimiter->nCount)) > dwMaxConnectionsPerIp) diff --git a/Source/Http/Url.cpp b/Source/Http/Url.cpp index 64d52c0e1..b5926114f 100644 --- a/Source/Http/Url.cpp +++ b/Source/Http/Url.cpp @@ -1288,16 +1288,12 @@ HRESULT CUrl::ParseFromString(_In_z_ LPCSTR szUrlA, _In_opt_ SIZE_T nSrcLen) } //find the = sign i = FindChar(szUrlA, nPos, "=", NULL); - if (i == (SIZE_T)-1) - i = nPos; - if (i > 0) - { - //from 0 to 'i' we have the name - //from 'i+1' to 'nPos' we have the value - hRes = AddQueryString(szUrlA, szUrlA+(i+1), nPos-(i+1), i); - if (FAILED(hRes)) - goto done; - } + if (i != (SIZE_T)-1) + hRes = AddQueryString(szUrlA, szUrlA + (i + 1), nPos - (i + 1), i); + else + hRes = AddQueryString(szUrlA, "", 0, nPos); + if (FAILED(hRes)) + goto done; szUrlA += nPos; nSrcLen -= nPos; } diff --git a/Source/PropertyBag.cpp b/Source/PropertyBag.cpp index 345964ddf..ef63dd02e 100644 --- a/Source/PropertyBag.cpp +++ b/Source/PropertyBag.cpp @@ -429,14 +429,4 @@ SIZE_T CPropertyBag::Find(_In_z_ LPCSTR szNameA) return cPropertiesList.BinarySearch(szNameA, &CPropertyBag::SearchCompareFunc, NULL); } -int CPropertyBag::InsertCompareFunc(_In_ LPVOID lpContext, _In_ PROPERTY **lpItem1, _In_ PROPERTY **lpItem2) -{ - return StrCompareA((*lpItem1)->szNameA, (*lpItem2)->szNameA, TRUE); -} - -int CPropertyBag::SearchCompareFunc(_In_ LPVOID lpContext, _In_ LPCVOID lpKey, _In_ PROPERTY **lpItem) -{ - return StrCompareA((LPCSTR)lpKey, (*lpItem)->szNameA, TRUE); -} - } //namespace MX diff --git a/Source/TimedEvent.cpp b/Source/TimedEvent.cpp index b1a6af9b9..976cb7a23 100644 --- a/Source/TimedEvent.cpp +++ b/Source/TimedEvent.cpp @@ -46,12 +46,12 @@ class CTimerHandler : public TRefCounted { MX_DISABLE_COPY_CONSTRUCTOR(CTimerHandler); public: - class CTimer : public TRedBlackTreeNode + class CTimer : public virtual CBaseMemObj, public TRedBlackTreeNode { MX_DISABLE_COPY_CONSTRUCTOR(CTimer); public: CTimer(_In_ DWORD dwTimeoutMs, _In_ MX::TimedEvent::OnTimeoutCallback cCallback, _In_ LPVOID lpUserData, - _In_ BOOL bOneShot) : TRedBlackTreeNode() + _In_ BOOL bOneShot) : CBaseMemObj(), TRedBlackTreeNode() { Setup(dwTimeoutMs, cCallback, lpUserData, bOneShot); return; @@ -70,11 +70,6 @@ class CTimerHandler : public TRefCounted return; }; - __inline ULONGLONG GetNodeKey() const - { - return nDueTime; - }; - __inline VOID CalculateDueTime(_In_opt_ PULARGE_INTEGER lpuliCurrTime) { ULARGE_INTEGER uliCurrTime; @@ -117,6 +112,15 @@ class CTimerHandler : public TRefCounted return TRUE; }; + static int InsertCompareFunc(_In_ LPVOID lpContext, _In_ CTimer *lpTimer1, _In_ CTimer *lpTimer2) + { + if (lpTimer1->nDueTime < lpTimer2->nDueTime) + return -1; + if (lpTimer1->nDueTime > lpTimer2->nDueTime) + return 1; + return 0; + }; + public: LONG nId; DWORD dwTimeoutMs; @@ -138,7 +142,7 @@ class CTimerHandler : public TRefCounted HRESULT AddTimer(_Out_ LONG volatile *lpnTimerId, _In_ DWORD dwTimeoutMs, _In_ MX::TimedEvent::OnTimeoutCallback cCallback, _In_ LPVOID lpUserData, _In_ BOOL bOneShot); - VOID RemoveTimer(_Out_ LONG volatile *lpnTimerId); + VOID RemoveTimer(_Inout_ LONG volatile *lpnTimerId); private: VOID ThreadProc(); @@ -160,7 +164,7 @@ class CTimerHandler : public TRefCounted LONG volatile nMutex; CWindowsEvent cChangedEvent; TArrayList cSortedByIdList; - TRedBlackTree cTree; + TRedBlackTree cTree; } sQueue; struct { LONG volatile nMutex; @@ -377,7 +381,7 @@ HRESULT CTimerHandler::AddTimer(_Out_ LONG volatile *lpnTimerId, _In_ DWORD dwTi } _InterlockedExchange(lpnTimerId, cNewTimer->nId); - sQueue.cTree.Insert(cNewTimer.Detach(), TRUE); + sQueue.cTree.Insert(cNewTimer.Detach(), &CTimer::InsertCompareFunc, TRUE); } sQueue.cChangedEvent.Set(); @@ -385,7 +389,7 @@ HRESULT CTimerHandler::AddTimer(_Out_ LONG volatile *lpnTimerId, _In_ DWORD dwTi return S_OK; } -VOID CTimerHandler::RemoveTimer(_Out_ LONG volatile *lpnTimerId) +VOID CTimerHandler::RemoveTimer(_Inout_ LONG volatile *lpnTimerId) { if (lpnTimerId != NULL) { @@ -499,12 +503,21 @@ DWORD CTimerHandler::ProcessQueue() lpTimer->RemoveNode(); if ((nFlags & _FLAG_OneShot) != 0) { + SIZE_T nIndex; + + nIndex = sQueue.cSortedByIdList.BinarySearch(&(lpTimer->nId), &CTimerHandler::SearchByIdCompareFunc, NULL); + MX_ASSERT(nIndex != (SIZE_T)-1); + if (nIndex != (SIZE_T)-1) + { + sQueue.cSortedByIdList.RemoveElementAt(nIndex); + } + FreeTimer(lpTimer); } else { lpTimer->CalculateDueTime(&uliCurrTime); - sQueue.cTree.Insert(lpTimer, TRUE); + sQueue.cTree.Insert(lpTimer, &CTimer::InsertCompareFunc, TRUE); _InterlockedAnd(&(lpTimer->nFlags), ~_FLAG_Running); } diff --git a/Source/WaitableObjects.cpp b/Source/WaitableObjects.cpp index da595106f..c21294368 100644 --- a/Source/WaitableObjects.cpp +++ b/Source/WaitableObjects.cpp @@ -451,7 +451,7 @@ BOOL RundownProt_Acquire(_In_ LONG volatile *lpnValue) initVal = newVal; if ((initVal & 0x80000000L) != 0) return FALSE; - newVal = _InterlockedCompareExchange(lpnValue, initVal+1, initVal); + newVal = _InterlockedCompareExchange(lpnValue, initVal + 1, initVal); } while (newVal != initVal); return TRUE; diff --git a/Test/TestHttpClient.cpp b/Test/TestHttpClient.cpp index 1651e11be..55db824ad 100644 --- a/Test/TestHttpClient.cpp +++ b/Test/TestHttpClient.cpp @@ -20,7 +20,7 @@ #include "TestHttpClient.h" #include -#define SIMPLE_TEST +//#define SIMPLE_TEST //----------------------------------------------------------- @@ -230,6 +230,9 @@ static HRESULT SimpleTest1(_In_ MX::CSockets *lpSckMgr, _In_ MX::CSslCertificate //cHttpClient.SetHeadersReceivedCallback(MX_BIND_CALLBACK(&OnResponseHeadersReceived)); cHttpClient.SetHeadersReceivedCallback(MX_BIND_CALLBACK(&OnResponseHeadersReceived_BigDownload)); + hRes = cHttpClient.Open("http://www.sitepoint.com/forums/showthread.php?" + "390414-Reading-from-socket-connection-SLOW"); + /* //hRes = cHttpClient.SetAuthCredentials(L"guest", L"guest"); hRes = S_OK; if (SUCCEEDED(hRes)) @@ -244,6 +247,8 @@ static HRESULT SimpleTest1(_In_ MX::CSockets *lpSckMgr, _In_ MX::CSslCertificate //hRes = cHttpClient.Open("https://jigsaw.w3.org/HTTP/Basic/"); //hRes = cHttpClient.Open("https://jigsaw.w3.org/HTTP/Digest/"); } + */ + if (SUCCEEDED(hRes)) { while (cHttpClient.IsDocumentComplete() == FALSE && cHttpClient.IsClosed() == FALSE)