diff --git a/authenticate_message.go b/authenticate_message.go index ab183db..ef446d3 100644 --- a/authenticate_message.go +++ b/authenticate_message.go @@ -4,27 +4,27 @@ import ( "bytes" "crypto/rand" "encoding/binary" - "encoding/hex" "errors" - "strings" + "fmt" "time" ) -type authenicateMessage struct { +const micFieldOffset = 72 +const micFieldLength = 16 + +type authenticateMessage struct { LmChallengeResponse []byte NtChallengeResponse []byte TargetName string UserName string - // only set if negotiateFlag_NTLMSSP_NEGOTIATE_KEY_EXCH - EncryptedRandomSessionKey []byte - - NegotiateFlags negotiateFlags - - MIC []byte + NegotiateFlags NegotiateFlags + Version } +type MIC [16]byte + type authenticateMessageFields struct { messageHeader LmChallengeResponse varField @@ -33,30 +33,27 @@ type authenticateMessageFields struct { UserName varField Workstation varField _ [8]byte - NegotiateFlags negotiateFlags + NegotiateFlags NegotiateFlags + Version + MIC } -func (m authenicateMessage) MarshalBinary() ([]byte, error) { - if !m.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE) { - return nil, errors.New("Only unicode is supported") - } - +func (m authenticateMessage) MarshalBinary() ([]byte, error) { target, user := toUnicode(m.TargetName), toUnicode(m.UserName) workstation := toUnicode("") ptr := binary.Size(&authenticateMessageFields{}) f := authenticateMessageFields{ messageHeader: newMessageHeader(3), - NegotiateFlags: m.NegotiateFlags, LmChallengeResponse: newVarField(&ptr, len(m.LmChallengeResponse)), NtChallengeResponse: newVarField(&ptr, len(m.NtChallengeResponse)), TargetName: newVarField(&ptr, len(target)), UserName: newVarField(&ptr, len(user)), Workstation: newVarField(&ptr, len(workstation)), + NegotiateFlags: m.NegotiateFlags, + Version: m.Version, } - f.NegotiateFlags.Unset(negotiateFlagNTLMSSPNEGOTIATEVERSION) - b := bytes.Buffer{} if err := binary.Write(&b, binary.LittleEndian, &f); err != nil { return nil, err @@ -77,12 +74,14 @@ func (m authenicateMessage) MarshalBinary() ([]byte, error) { return nil, err } - return b.Bytes(), nil + authenticateMessageData := b.Bytes() + + return authenticateMessageData, nil } //ProcessChallenge crafts an AUTHENTICATE message in response to the CHALLENGE message //that was received from the server -func ProcessChallenge(challengeMessageData []byte, user, password string, domainNeeded bool) ([]byte, error) { +func ProcessChallenge(negotiateMessageData, challengeMessageData []byte, user, password, domain, spn string, channelBinding []byte) ([]byte, error) { if user == "" && password == "" { return nil, errors.New("Anonymous authentication not supported") } @@ -98,90 +97,104 @@ func ProcessChallenge(challengeMessageData []byte, user, password string, domain if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) { return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)") } - - if !domainNeeded { - cm.TargetName = "" + + if !cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEUNICODE) { + return nil, errors.New("Only unicode is supported") } - am := authenicateMessage{ + flags := (defaultFlags & cm.NegotiateFlags) | negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY + + am := authenticateMessage{ UserName: user, - TargetName: cm.TargetName, - NegotiateFlags: cm.NegotiateFlags, + TargetName: domain, + NegotiateFlags: flags, } - timestamp := cm.TargetInfo[avIDMsvAvTimestamp] - if timestamp == nil { // no time sent, take current time - ft := uint64(time.Now().UnixNano()) / 100 - ft += 116444736000000000 // add time between unix & windows offset - timestamp = make([]byte, 8) - binary.LittleEndian.PutUint64(timestamp, ft) + targetInfo := cm.TargetInfo + + cbt, err := computeChannelBindingHash(channelBinding) + if err != nil { + return nil, fmt.Errorf("failed to compute channel binding token: %w", err) } - clientChallenge := make([]byte, 8) - rand.Reader.Read(clientChallenge) + targetInfo, serverTimestamp := updateTargetInfoAvPairs(targetInfo, cbt, spn) - ntlmV2Hash := getNtlmV2Hash(password, user, cm.TargetName) + timestamp := getTimestamp(serverTimestamp) - am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash, - cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw) + ntlmV2Hash := getNtlmV2Hash(password, am.UserName, am.TargetName) + + clientChallenge := getClientChallenge() + + targetInfoData, err := targetInfo.marshal() + if err != nil { + return nil, fmt.Errorf("failed to marshal TargetInfo AvPair struct: %w", err) + } + + NtChallengeResponse, sessionKey := computeNtlmV2Response(ntlmV2Hash, + cm.ServerChallenge[:], clientChallenge, timestamp, targetInfoData) + am.NtChallengeResponse = NtChallengeResponse if cm.TargetInfoRaw == nil { am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash, cm.ServerChallenge[:], clientChallenge) - } - return am.MarshalBinary() -} - -func ProcessChallengeWithHash(challengeMessageData []byte, user, hash string) ([]byte, error) { - if user == "" && hash == "" { - return nil, errors.New("Anonymous authentication not supported") + } else { + am.LmChallengeResponse = make([]byte, 24) } - var cm challengeMessage - if err := cm.UnmarshalBinary(challengeMessageData); err != nil { + authenticateMessageData, err := am.MarshalBinary() + if err != nil { return nil, err } - if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATELMKEY) { - return nil, errors.New("Only NTLM v2 is supported, but server requested v1 (NTLMSSP_NEGOTIATE_LM_KEY)") - } - if cm.NegotiateFlags.Has(negotiateFlagNTLMSSPNEGOTIATEKEYEXCH) { - return nil, errors.New("Key exchange requested but not supported (NTLMSSP_NEGOTIATE_KEY_EXCH)") - } + mic := computeMIC(sessionKey, negotiateMessageData, challengeMessageData, authenticateMessageData) + copy(authenticateMessageData[micFieldOffset:micFieldOffset+micFieldLength], mic) - am := authenicateMessage{ - UserName: user, - TargetName: cm.TargetName, - NegotiateFlags: cm.NegotiateFlags, - } + return authenticateMessageData, nil +} - timestamp := cm.TargetInfo[avIDMsvAvTimestamp] - if timestamp == nil { // no time sent, take current time +func getTimestamp(serverTimestamp []byte) []byte { + if serverTimestamp != nil { // no time sent, take current time + return serverTimestamp + } else { + // Prepares current timestamp in format specified in + // [MS-NLMP](https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/83f5e789-660d-4781-8491-5f8c6641f75e) + // A FILETIME structure ([MS-DTYP] section 2.3.3) in little-endian byte order that contains the server local + // time. This structure is always sent in the CHALLENGE_MESSAGE. ft := uint64(time.Now().UnixNano()) / 100 ft += 116444736000000000 // add time between unix & windows offset - timestamp = make([]byte, 8) + timestamp := make([]byte, 8) binary.LittleEndian.PutUint64(timestamp, ft) + return timestamp } +} +func getClientChallenge() []byte { clientChallenge := make([]byte, 8) rand.Reader.Read(clientChallenge) + return clientChallenge +} - hashParts := strings.Split(hash, ":") - if len(hashParts) > 1 { - hash = hashParts[1] - } - hashBytes, err := hex.DecodeString(hash) - if err != nil { - return nil, err - } - ntlmV2Hash := hmacMd5(hashBytes, toUnicode(strings.ToUpper(user)+cm.TargetName)) +func updateTargetInfoAvPairs(targetInfo AvPairs, channelBindingHash []byte, spn string) (AvPairs, []byte) { - am.NtChallengeResponse = computeNtlmV2Response(ntlmV2Hash, - cm.ServerChallenge[:], clientChallenge, timestamp, cm.TargetInfoRaw) + serverTimestamp := targetInfo[avIDMsvAvTimestamp] - if cm.TargetInfoRaw == nil { - am.LmChallengeResponse = computeLmV2Response(ntlmV2Hash, - cm.ServerChallenge[:], clientChallenge) + // update AvFlags - MIC present + { + flags := targetInfo[avIDMsvAvFlags] + if flags == nil { + flags = make([]byte, 4) + targetInfo[avIDMsvAvFlags] = flags + } + avFlags := AvFlags(binary.LittleEndian.Uint32(flags)) + avFlags.Set(AvFlagMICPresent) + binary.LittleEndian.PutUint32(flags, uint32(avFlags)) } - return am.MarshalBinary() + + // EPA support + { + targetInfo[avIDMsvChannelBindings] = channelBindingHash + targetInfo[avIDMsvAvTargetName] = toUnicode(spn) + } + + return targetInfo, serverTimestamp } diff --git a/avids.go b/avids.go index 196b5f1..4a83305 100644 --- a/avids.go +++ b/avids.go @@ -1,7 +1,19 @@ package ntlmssp +import ( + "bytes" + "encoding/binary" + "fmt" +) + type avID uint16 +type AvPairs map[avID][]byte + +func NewAvPairs() AvPairs { + return make(AvPairs) +} + const ( avIDMsvAvEOL avID = iota avIDMsvAvNbComputerName @@ -15,3 +27,75 @@ const ( avIDMsvAvTargetName avIDMsvChannelBindings ) + +func (pairs AvPairs) unmarshal(data []byte) error { + + r := bytes.NewReader(data) + for { + var id avID + var l uint16 + err := binary.Read(r, binary.LittleEndian, &id) + if err != nil { + return err + } + if id == avIDMsvAvEOL { + break + } + + err = binary.Read(r, binary.LittleEndian, &l) + if err != nil { + return err + } + value := make([]byte, l) + n, err := r.Read(value) + if err != nil { + return err + } + if n != int(l) { + return fmt.Errorf("Expected to read %d bytes, got only %d", l, n) + } + (pairs)[id] = value + } + return nil +} + +func (pairs AvPairs) marshal() ([]byte, error) { + buffer := bytes.NewBuffer(make([]byte, 0, 2)) + + for id := avIDMsvAvNbComputerName; id <= avIDMsvChannelBindings; id++ { + value := (pairs)[id] + if value != nil { + if err := binary.Write(buffer, binary.LittleEndian, id); err != nil { + return nil, err + } + if err := binary.Write(buffer, binary.LittleEndian, uint16(len(value))); err != nil { + return nil, err + } + _, err := buffer.Write(value) + if err != nil { + return nil, err + } + } + } + if err := binary.Write(buffer, binary.LittleEndian, avIDMsvAvEOL); err != nil { + return nil, err + } + _, err := buffer.Write([]byte{0, 0}) + if err != nil { + return nil, err + } + + return buffer.Bytes(), nil +} + +type AvFlags uint32 + +func (f *AvFlags) Set(flag AvFlags) { + *f = *f | flag +} + +const ( + AvFlagAuthenticationConstrained AvFlags = 0x00000001 // Indicates to the client that the account authentication is constrained. + AvFlagMICPresent AvFlags = 0x00000002 // Indicates that the client is providing message integrity in the MIC field (section 2.2.1.3) in the AUTHENTICATE_MESSAGE.<14> + AvFlagUntrustedSPN AvFlags = 0x00000004 // Indicates that the client is providing a target SPN generated from an untrusted source.<15> +) diff --git a/avids_test.go b/avids_test.go new file mode 100755 index 0000000..f1a86f6 --- /dev/null +++ b/avids_test.go @@ -0,0 +1,88 @@ +package ntlmssp + +import ( + "bytes" + "reflect" + "testing" +) + +func TestMarshalAVPairs(t *testing.T) { + tests := []struct { + name string + input AvPairs + expected []byte + }{ + {"empty", AvPairs{}, []byte{0x00, 0x00, 0x00, 0x00}}, // avIDMsvAvEOL, len(0) + {"with 2 pairs", + AvPairs{ + avIDMsvAvTargetName: []byte{0, 0}, + avIDMsvAvNbDomainName: []byte{1, 1, 1, 1}, + }, + []byte{ + 0x02, 0x00, 0x04, 0x00, 0x01, 0x01, 0x01, 0x01, // avIDMsvAvNbDomainName, len(4), 1, 1, 1, 1 + 0x09, 0x00, 0x02, 0x00, 0x00, 0x00, // avIDMsvAvTargetName, len(2), 0, 0 + 0x00, 0x00, 0x00, 0x00}, // avIDMsvAvEOL, len(0) + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + data, err := tc.input.marshal() + if err != nil { + t.Errorf("Expected no errors, but got %v", err) + } + if data == nil { + t.Fatalf("Expected written data to not be null") + } + if len(tc.expected) != len(data) { + t.Fatalf("Expected %d bytes, but got %d", len(tc.expected), len(data)) + } + + if !bytes.Equal(tc.expected, data) { + t.Errorf("Expected %v, but got %v", tc.expected, data) + } + }) + } +} + +func TestUnmarshalAVPairsWithTwoElements(t *testing.T) { + tests := []struct { + name string + input []byte + expected AvPairs + }{ + {"empty", + []byte{0x00, 0x00}, + NewAvPairs()}, // avIDMsvAvEOL1 + {"with 2 pairs", + []byte{ + 0x02, 0x00, 0x04, 0x00, 0x01, 0x01, 0x01, 0x01, // avIDMsvAvNbDomainName, len(4), 1, 1, 1, 1 + 0x09, 0x00, 0x02, 0x00, 0x00, 0x00, // avIDMsvAvTargetName, len(2), 0, 0 + 0x00, 0x00}, // avIDMsvAvEOL + func() AvPairs { + pairs := NewAvPairs() + pairs[avIDMsvAvTargetName] = []byte{0, 0} + pairs[avIDMsvAvNbDomainName] = []byte{1, 1, 1, 1} + return pairs + }(), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + + marshalled := tc.input + + result := NewAvPairs() + err := result.unmarshal(marshalled) + + if err != nil { + t.Fatalf("Expected read data to not be null") + } + if len(tc.expected) != len(result) { + t.Fatalf("Expected %d entries, but got %d", len(tc.expected), len(result)) + } + if !reflect.DeepEqual(tc.expected, result) { + t.Fatalf("Expected %v, but got %v", tc.expected, result) + } + }) + } +} diff --git a/challenge_message.go b/challenge_message.go index 053b55e..55046d9 100644 --- a/challenge_message.go +++ b/challenge_message.go @@ -9,7 +9,7 @@ import ( type challengeMessageFields struct { messageHeader TargetName varField - NegotiateFlags negotiateFlags + NegotiateFlags NegotiateFlags ServerChallenge [8]byte _ [8]byte TargetInfo varField @@ -22,7 +22,7 @@ func (m challengeMessageFields) IsValid() bool { type challengeMessage struct { challengeMessageFields TargetName string - TargetInfo map[avID][]byte + TargetInfo AvPairs TargetInfoRaw []byte } @@ -49,33 +49,15 @@ func (m *challengeMessage) UnmarshalBinary(data []byte) error { if err != nil { return err } - m.TargetInfo = make(map[avID][]byte) - r := bytes.NewReader(d) - for { - var id avID - var l uint16 - err = binary.Read(r, binary.LittleEndian, &id) - if err != nil { - return err - } - if id == avIDMsvAvEOL { - break - } - err = binary.Read(r, binary.LittleEndian, &l) - if err != nil { - return err - } - value := make([]byte, l) - n, err := r.Read(value) - if err != nil { - return err - } - if n != int(l) { - return fmt.Errorf("Expected to read %d bytes, got only %d", l, n) - } - m.TargetInfo[id] = value + targetInfo := NewAvPairs() + err = targetInfo.unmarshal(d) + if err != nil { + return err } + + m.TargetInfo = targetInfo + } return nil diff --git a/negotiate_flags.go b/negotiate_flags.go index 5905c02..0ef475a 100644 --- a/negotiate_flags.go +++ b/negotiate_flags.go @@ -1,9 +1,9 @@ package ntlmssp -type negotiateFlags uint32 +type NegotiateFlags uint32 const ( - /*A*/ negotiateFlagNTLMSSPNEGOTIATEUNICODE negotiateFlags = 1 << 0 + /*A*/ negotiateFlagNTLMSSPNEGOTIATEUNICODE NegotiateFlags = 1 << 0 /*B*/ negotiateFlagNTLMNEGOTIATEOEM = 1 << 1 /*C*/ negotiateFlagNTLMSSPREQUESTTARGET = 1 << 2 @@ -43,10 +43,16 @@ const ( /*W*/ negotiateFlagNTLMSSPNEGOTIATE56 = 1 << 31 ) -func (field negotiateFlags) Has(flags negotiateFlags) bool { +func (field NegotiateFlags) Has(flags NegotiateFlags) bool { return field&flags == flags } -func (field *negotiateFlags) Unset(flags negotiateFlags) { +func (field *NegotiateFlags) Unset(flags NegotiateFlags) { *field = *field ^ (*field & flags) } + +var defaultFlags = negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY | + negotiateFlagNTLMSSPNEGOTIATEALWAYSSIGN | + negotiateFlagNTLMSSPNEGOTIATENTLM | + negotiateFlagNTLMSSPREQUESTTARGET | + negotiateFlagNTLMSSPNEGOTIATEUNICODE diff --git a/negotiate_message.go b/negotiate_message.go index e466a98..041265a 100644 --- a/negotiate_message.go +++ b/negotiate_message.go @@ -11,7 +11,7 @@ const expMsgBodyLen = 40 type negotiateMessageFields struct { messageHeader - NegotiateFlags negotiateFlags + NegotiateFlags NegotiateFlags Domain varField Workstation varField @@ -19,12 +19,6 @@ type negotiateMessageFields struct { Version } -var defaultFlags = negotiateFlagNTLMSSPNEGOTIATETARGETINFO | - negotiateFlagNTLMSSPNEGOTIATE56 | - negotiateFlagNTLMSSPNEGOTIATE128 | - negotiateFlagNTLMSSPNEGOTIATEUNICODE | - negotiateFlagNTLMSSPNEGOTIATEEXTENDEDSESSIONSECURITY - //NewNegotiateMessage creates a new NEGOTIATE message with the //flags that this package supports. func NewNegotiateMessage(domainName, workstationName string) ([]byte, error) { @@ -39,12 +33,17 @@ func NewNegotiateMessage(domainName, workstationName string) ([]byte, error) { flags |= negotiateFlagNTLMSSPNEGOTIATEOEMWORKSTATIONSUPPLIED } + version := EmptyVersion() + if flags.Has(negotiateFlagNTLMSSPNEGOTIATEVERSION) { + version = DefaultVersion() + } + msg := negotiateMessageFields{ messageHeader: newMessageHeader(1), NegotiateFlags: flags, Domain: newVarField(&payloadOffset, len(domainName)), Workstation: newVarField(&payloadOffset, len(workstationName)), - Version: DefaultVersion(), + Version: version, } b := bytes.Buffer{} diff --git a/negotiator.go b/negotiator.go index cce4955..4d55c41 100644 --- a/negotiator.go +++ b/negotiator.go @@ -2,7 +2,11 @@ package ntlmssp import ( "bytes" + "crypto" + "crypto/tls" + "crypto/x509" "encoding/base64" + "errors" "io" "io/ioutil" "net/http" @@ -10,22 +14,15 @@ import ( ) // GetDomain : parse domain name from based on slashes in the input -// Need to check for upn as well -func GetDomain(user string) (string, string, bool) { +func GetDomain(user string) (string, string) { domain := "" - domainNeeded := false if strings.Contains(user, "\\") { ucomponents := strings.SplitN(user, "\\", 2) domain = ucomponents[0] user = ucomponents[1] - domainNeeded = true - } else if strings.Contains(user, "@") { - domainNeeded = false - } else { - domainNeeded = true } - return user, domain, domainNeeded + return user, domain } //Negotiator is a http.Roundtripper decorator that automatically @@ -97,8 +94,7 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) } // get domain from username - domain := "" - u, domain, domainNeeded := GetDomain(u) + u, domain := GetDomain(u) // send negotiate negotiateMessage, err := NewNegotiateMessage(domain, "") @@ -131,8 +127,18 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) io.Copy(ioutil.Discard, res.Body) res.Body.Close() + spn := getSpn(req.Host) + + var channelBinding []byte = nil + if res.TLS != nil { + channelBinding, err = makeChannelBinding(*res.TLS) + if err != nil { + return nil, errors.New("couldn't make TLS channel binding") + } + } + // send authenticate - authenticateMessage, err := ProcessChallenge(challengeMessage, u, p, domainNeeded) + authenticateMessage, err := ProcessChallenge(negotiateMessage, challengeMessage, u, p, domain, spn, channelBinding) if err != nil { return nil, err } @@ -149,3 +155,42 @@ func (l Negotiator) RoundTrip(req *http.Request) (res *http.Response, err error) return res, err } + +func makeChannelBinding(state tls.ConnectionState) ([]byte, error) { + + certificate := state.PeerCertificates[0] + prefix := []byte("tls-server-end-point:") + + if certificate == nil { + return nil, errors.New("TLS connection is missing server certificate") + } + + // choose the channel binding hash type + // Use the same hash type used for the certificate signature, except for MD5 and SHA-1 which + // use SHA256 + hashType := crypto.SHA256 + switch certificate.SignatureAlgorithm { + case x509.SHA384WithRSA, x509.ECDSAWithSHA384, x509.SHA384WithRSAPSS: + hashType = crypto.SHA384 + case x509.SHA512WithRSA, x509.ECDSAWithSHA512, x509.SHA512WithRSAPSS: + hashType = crypto.SHA512 + } + + hasher := hashType.New() + _, _ = hasher.Write(certificate.Raw) + data := hasher.Sum(nil) + + buf := bytes.NewBuffer(make([]byte, 0, len(prefix)+len(data))) + buf.Write(prefix) + buf.Write(data) + + return buf.Bytes(), nil +} + +func getSpn(host string) string { + spn := "" + if host != "" { + spn = "HTTP/" + strings.ToLower(host) + } + return spn +} diff --git a/nlmp.go b/nlmp.go index 1e65abe..4f6a164 100644 --- a/nlmp.go +++ b/nlmp.go @@ -8,8 +8,10 @@ package ntlmssp import ( + "bytes" "crypto/hmac" "crypto/md5" + "encoding/binary" "golang.org/x/crypto/md4" "strings" ) @@ -24,18 +26,19 @@ func getNtlmHash(password string) []byte { return hash.Sum(nil) } -func computeNtlmV2Response(ntlmV2Hash, serverChallenge, clientChallenge, - timestamp, targetInfo []byte) []byte { +func computeNtlmV2Response(ntlmV2Hash, serverChallenge, clientChallenge, timestamp, targetInfo []byte) ([]byte, []byte) { - temp := []byte{1, 1, 0, 0, 0, 0, 0, 0} - temp = append(temp, timestamp...) - temp = append(temp, clientChallenge...) - temp = append(temp, 0, 0, 0, 0) - temp = append(temp, targetInfo...) - temp = append(temp, 0, 0, 0, 0) + buf := bytes.NewBuffer([]byte{1, 1, 0, 0, 0, 0, 0, 0}) + buf.Write(timestamp) + buf.Write(clientChallenge) + buf.Write([]byte{0, 0, 0, 0}) + buf.Write(targetInfo) + buf.Write([]byte{0, 0, 0, 0}) - NTProofStr := hmacMd5(ntlmV2Hash, serverChallenge, temp) - return append(NTProofStr, temp...) + NTProofStr := hmacMd5(ntlmV2Hash, serverChallenge, buf.Bytes()) + sessionKey := hmacMd5(ntlmV2Hash, NTProofStr) + + return append(NTProofStr, buf.Bytes()...), sessionKey } func computeLmV2Response(ntlmV2Hash, serverChallenge, clientChallenge []byte) []byte { @@ -49,3 +52,53 @@ func hmacMd5(key []byte, data ...[]byte) []byte { } return mac.Sum(nil) } + +func md5sum(target []byte, data ...[]byte) []byte { + h := md5.New() + for _, d := range data { + h.Write(d) + } + return h.Sum(target) + +} + +func computeMIC(sessionKey []byte, messages ...[]byte) []byte { + return hmacMd5(sessionKey, messages...) +} + +type gssChannelBindingStructHeader struct { + _ [16]byte + tokenLength uint32 +} + +func computeChannelBindingHash(channelBinding []byte) ([]byte, error) { + + if channelBinding != nil { + + // Based on [MS-NLMP documentation](https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-nlmp/83f5e789-660d-4781-8491-5f8c6641f75e): + // Channel binding hash value contains an MD5 hash ([RFC4121] section 4.1.1.2) of a gss_channel_bindings_struct + // ([RFC2744](https://www.ietf.org/rfc/rfc2744.txt) section 3.11). An all-zero value of the hash is used to indicate + // absence of channel bindings. + cbtStruct := gssChannelBindingStructHeader{ + tokenLength: uint32(len(channelBinding)), + } + + size := binary.Size(&gssChannelBindingStructHeader{}) + + buf := bytes.NewBuffer(make([]byte, 0, size+len(channelBinding))) + if err := binary.Write(buf, binary.LittleEndian, &cbtStruct); err != nil { + return nil, err + } + _, err := buf.Write(channelBinding) + if err != nil { + return nil, err + } + + channelBindingHash := make([]byte, 0, 16) + channelBindingHash = md5sum(channelBindingHash, buf.Bytes()) + + return channelBindingHash, nil + } else { + return make([]byte, 16), nil + } +} diff --git a/nlmp_test.go b/nlmp_test.go index 0969733..bb520a0 100644 --- a/nlmp_test.go +++ b/nlmp_test.go @@ -21,50 +21,52 @@ func TestUsernameDomainWorkstation(t *testing.T) { // taking a username and workstation as input, check that the username, domain, workstation // and negotiate message bytes all match their expected values tables := []struct { - u string - w string - xu string - xd string - xb []byte + name string + u string + w string + xu string + xd string + xb []byte }{ - {username, "", username, "", []byte{ - 0x4e, 0x54, 0x4c, 0x4d, 0x53, 0x53, 0x50, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, - 0x88, 0xa0, 0x00, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x28, 0x00, 0x00, 0x00, 0x06, 0x01, 0xb1, 0x1d, 0x00, 0x00, 0x00, 0x0f}}, - {domain + "\\" + username, "", username, domain, []byte{ - 0x4e, 0x54, 0x4c, 0x4d, 0x53, 0x53, 0x50, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x10, - 0x88, 0xa0, 0x08, 0x00, 0x08, 0x00, 0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x30, 0x00, 0x00, 0x00, 0x06, 0x01, 0xb1, 0x1d, 0x00, 0x00, 0x00, 0x0f, 0x4d, 0x59, - 0x44, 0x4f, 0x4d, 0x41, 0x49, 0x4e}}, - {domain + "\\" + username, workstation, username, domain, []byte{ - 0x4e, 0x54, 0x4c, 0x4d, 0x53, 0x53, 0x50, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x30, - 0x88, 0xa0, 0x08, 0x00, 0x08, 0x00, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, - 0x30, 0x00, 0x00, 0x00, 0x06, 0x01, 0xb1, 0x1d, 0x00, 0x00, 0x00, 0x0f, 0x4d, 0x59, - 0x44, 0x4f, 0x4d, 0x41, 0x49, 0x4e, 0x4d, 0x59, 0x50, 0x43}}, - {username, workstation, username, "", []byte{ - 0x4e, 0x54, 0x4c, 0x4d, 0x53, 0x53, 0x50, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x20, - 0x88, 0xa0, 0x00, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, - 0x28, 0x00, 0x00, 0x00, 0x06, 0x01, 0xb1, 0x1d, 0x00, 0x00, 0x00, 0x0f, 0x4d, 0x59, - 0x50, 0x43}}, + {"username without domain and empty workstation", username, "", username, "", []byte{ + 0x4e, 0x54, 0x4c, 0x4d, 0x53, 0x53, 0x50, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x82, 0x08, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {"username with domain and empty workstation", domain + "\\" + username, "", username, domain, []byte{ + 0x4e, 0x54, 0x4c, 0x4d, 0x53, 0x53, 0x50, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x92, 0x08, 0x00, + 0x08, 0x00, 0x08, 0x00, 0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4d, 0x59, 0x44, 0x4f, 0x4d, 0x41, 0x49, 0x4e}}, + {"username with domain and non-empty workstation", domain + "\\" + username, workstation, username, domain, []byte{ + 0x4e, 0x54, 0x4c, 0x4d, 0x53, 0x53, 0x50, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0xb2, 0x08, 0x00, + 0x08, 0x00, 0x08, 0x00, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x30, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4d, 0x59, 0x44, 0x4f, 0x4d, 0x41, 0x49, 0x4e, + 0x4d, 0x59, 0x50, 0x43}}, + {"username without domain and non-empty workstation", username, workstation, username, "", []byte{ + 0x4e, 0x54, 0x4c, 0x4d, 0x53, 0x53, 0x50, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0xa2, 0x08, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x28, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x4d, 0x59, 0x50, 0x43}}, } for _, table := range tables { - tuser, tdomain := GetDomain(table.u) - if tuser != table.xu { - t.Fatalf("username not correct, expected %v got %v", tuser, table.xu) - } - if tdomain != table.xd { - t.Fatalf("domain not correct, expected %v got %v", tdomain, table.xd) - } - - tb, err := NewNegotiateMessage(tdomain, table.w) - if err != nil { - t.Fatalf("error creating new negotiate message with domain '%v' and workstation '%v'", tdomain, table.w) - } - - if !bytes.Equal(tb, table.xb) { - t.Fatalf("negotiate message bytes not correct, expected %v got %v", tb, table.xb) - } + t.Run(table.name, func(t *testing.T) { + tuser, tdomain := GetDomain(table.u) + if tuser != table.xu { + t.Fatalf("username not correct, expected %v got %v", tuser, table.xu) + } + if tdomain != table.xd { + t.Fatalf("domain not correct, expected %v got %v", tdomain, table.xd) + } + + tb, err := NewNegotiateMessage(tdomain, table.w) + if err != nil { + t.Fatalf("error creating new negotiate message with domain '%v' and workstation '%v'", tdomain, table.w) + } + + if !bytes.Equal(tb, table.xb) { + t.Fatalf("negotiate message bytes not correct, expected\n%v got\n%v", + hex.EncodeToString(table.xb), hex.EncodeToString(tb)) + } + }) } } @@ -75,7 +77,7 @@ func TestCalculateNTLMv2Response(t *testing.T) { Time := []byte{0x00, 0x90, 0xd3, 0x36, 0xb7, 0x34, 0xc3, 0x01} targetInfo := []byte{0x02, 0x00, 0x0c, 0x00, 0x44, 0x00, 0x4f, 0x00, 0x4d, 0x00, 0x41, 0x00, 0x49, 0x00, 0x4e, 0x00, 0x01, 0x00, 0x0c, 0x00, 0x53, 0x00, 0x45, 0x00, 0x52, 0x00, 0x56, 0x00, 0x45, 0x00, 0x52, 0x00, 0x04, 0x00, 0x14, 0x00, 0x64, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x61, 0x00, 0x69, 0x00, 0x6e, 0x00, 0x2e, 0x00, 0x63, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x03, 0x00, 0x22, 0x00, 0x73, 0x00, 0x65, 0x00, 0x72, 0x00, 0x76, 0x00, 0x65, 0x00, 0x72, 0x00, 0x2e, 0x00, 0x64, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x61, 0x00, 0x69, 0x00, 0x6e, 0x00, 0x2e, 0x00, 0x63, 0x00, 0x6f, 0x00, 0x6d, 0x00, 0x00, 0x00, 0x00, 0x00} - v := computeNtlmV2Response(NTLMv2Hash, challenge, ClientChallenge, Time, targetInfo) + v, _ := computeNtlmV2Response(NTLMv2Hash, challenge, ClientChallenge, Time, targetInfo) if expected := []byte{ 0xcb, 0xab, 0xbc, 0xa7, 0x13, 0xeb, 0x79, 0x5d, 0x04, 0xc9, 0x7a, 0xbc, 0x01, 0xee, 0x49, 0x83, @@ -105,7 +107,7 @@ func TestCalculateNTLMv2ResponseWithHash(t *testing.T) { NTLMv2Hash := hmacMd5(hashBytes, toUnicode(strings.ToUpper(username)+target)) - v := computeNtlmV2Response(NTLMv2Hash, challenge, ClientChallenge, Time, targetInfo) + v, _ := computeNtlmV2Response(NTLMv2Hash, challenge, ClientChallenge, Time, targetInfo) if expected := []byte{ 0xcb, 0xab, 0xbc, 0xa7, 0x13, 0xeb, 0x79, 0x5d, 0x04, 0xc9, 0x7a, 0xbc, 0x01, 0xee, 0x49, 0x83, @@ -156,3 +158,24 @@ func TestNTLMv2Hash(t *testing.T) { t.Fatalf("expected %v, got %v", expected, v) } } + +func TestComputeBindingChannelHash(t *testing.T) { + channelBindingDouble := getChannelBindingDouble() + + hash, _ := computeChannelBindingHash(channelBindingDouble) + if len(hash) != 16 { + t.Errorf("expected hash of len(%d), got len(%d)", 16, len(hash)) + } + allZeroes := make([]byte, 16) + if bytes.Equal(hash, allZeroes) { + t.Error("expected non-zero hash") + } +} + +func getChannelBindingDouble() []byte { + buf := bytes.NewBuffer(make([]byte, 0, 30)) + buf.WriteString("tls-server-end-point:") + buf.Write([]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x8}) + channelBindingDouble := buf.Bytes() + return channelBindingDouble +} diff --git a/version.go b/version.go index 6d84892..2591aa4 100644 --- a/version.go +++ b/version.go @@ -18,3 +18,7 @@ func DefaultVersion() Version { NTLMRevisionCurrent: 15, } } + +func EmptyVersion() Version { + return Version{} +}