From 81ab75de72571e5aa98fc06a88cff7d12b1e0e89 Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Thu, 1 Jun 2023 12:06:21 +0900 Subject: [PATCH 01/10] feat: add search with channels inspired by https://github.com/go-ldap/ldap/pull/319 --- examples_test.go | 29 ++++++++++ ldap_test.go | 62 +++++++++++++++++++++ search.go | 128 ++++++++++++++++++++++++++++++++++++++++++++ v3/examples_test.go | 29 ++++++++++ v3/ldap_test.go | 62 +++++++++++++++++++++ v3/search.go | 128 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 438 insertions(+) diff --git a/examples_test.go b/examples_test.go index d788e4f5..59fd6071 100644 --- a/examples_test.go +++ b/examples_test.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -50,6 +51,34 @@ func ExampleConn_Search() { } } +// This example demonstrates how to search with channel +func ExampleConn_SearchWithChannel() { + l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + "dc=example,dc=com", // The base dn to search + ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, + "(&(objectClass=organizationalPerson))", // The filter to apply + []string{"dn", "cn"}, // A list attributes to retrieve + nil, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := l.SearchWithChannel(ctx, searchRequest) + for res := range ch { + if res.Error != nil { + log.Fatalf("Error searching: %s", res.Error) + } + fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN) + } +} + // This example demonstrates how to start a TLS connection func ExampleConn_StartTLS() { l, err := DialURL("ldap://ldap.example.com:389") diff --git a/ldap_test.go b/ldap_test.go index 61417fd5..85efc8af 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "testing" @@ -344,3 +345,64 @@ func TestEscapeDN(t *testing.T) { }) } } + +func TestSearchWithChannel(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + nil) + + srs := make([]*Entry, 0) + ctx := context.Background() + for sr := range l.SearchWithChannel(ctx, searchRequest) { + if sr.Error != nil { + t.Fatal(err) + } + srs = append(srs, sr.Entry) + } + + t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) +} + +func TestSearchWithChannelAndCancel(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + nil) + + cancelNum := 10 + srs := make([]*Entry, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for sr := range l.SearchWithChannel(ctx, searchRequest) { + if sr.Error != nil { + t.Fatal(err) + } + srs = append(srs, sr.Entry) + if len(srs) == cancelNum { + cancel() + } + } + if len(srs) > cancelNum+2 { + // The cancel process is asynchronous, + // so a few entries after it canceled might be received + t.Errorf("Got entries %d, expected less than %d", len(srs), cancelNum+2) + } + t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) +} diff --git a/search.go b/search.go index ef3119b9..7176336e 100644 --- a/search.go +++ b/search.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "errors" "fmt" "reflect" @@ -375,6 +376,28 @@ func (s *SearchResult) appendTo(r *SearchResult) { r.Controls = append(r.Controls, s.Controls...) } +// SearchSingleResult holds the server's single response to a search request +type SearchSingleResult struct { + // Entry is the returned entry + Entry *Entry + // Referral is the returned referral + Referral string + // Controls are the returned controls + Controls []Control + // Error is set when the search request was failed + Error error +} + +// Print outputs a human-readable description +func (s *SearchSingleResult) Print() { + s.Entry.Print() +} + +// PrettyPrint outputs a human-readable description with indenting +func (s *SearchSingleResult) PrettyPrint(indent int) { + s.Entry.PrettyPrint(indent) +} + // SearchRequest represents a search request to send to the server type SearchRequest struct { BaseDN string @@ -559,6 +582,111 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } +// SearchWithChannel performs a search request and returns all search results +// via the returned channel as soon as they are received. This means you get +// all results until an error happens (or the search successfully finished), +// e.g. for size / time limited requests all are recieved via the channel +// until the limit is reached. +func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) chan *SearchSingleResult { + ch := make(chan *SearchSingleResult) + go func() { + defer close(ch) + if l.IsClosing() { + return + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + // encode search request + err := searchRequest.appendTo(packet) + if err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + l.Debug.PrintPacket(packet) + + msgCtx, err := l.sendMessage(packet) + if err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + defer l.finishMessage(msgCtx) + + foundSearchSingleResultDone := false + for !foundSearchSingleResultDone { + select { + case <-ctx.Done(): + l.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) + return + default: + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + ch <- &SearchSingleResult{Error: err} + return + } + packet, err = packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case ApplicationSearchResultEntry: + entry := new(Entry) + entry.DN = packet.Children[1].Children[0].Value.(string) + for _, child := range packet.Children[1].Children[1].Children { + attr := new(EntryAttribute) + attr.Name = child.Children[0].Value.(string) + for _, value := range child.Children[1].Children { + attr.Values = append(attr.Values, value.Value.(string)) + attr.ByteValues = append(attr.ByteValues, value.ByteValue) + } + entry.Attributes = append(entry.Attributes, attr) + } + ch <- &SearchSingleResult{Entry: entry} + + case ApplicationSearchResultDone: + if err := GetLDAPError(packet); err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + if len(packet.Children) == 3 { + result := &SearchSingleResult{} + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + werr := fmt.Errorf("failed to decode child control: %w", err) + ch <- &SearchSingleResult{Error: werr} + return + } + result.Controls = append(result.Controls, decodedChild) + } + ch <- result + } + foundSearchSingleResultDone = true + + case ApplicationSearchResultReference: + ref := packet.Children[1].Children[0].Value.(string) + ch <- &SearchSingleResult{Referral: ref} + } + } + } + l.Debug.Printf("%d: returning", msgCtx.id) + }() + return ch +} + // unpackAttributes will extract all given LDAP attributes and it's values // from the ber.Packet func unpackAttributes(children []*ber.Packet) []*EntryAttribute { diff --git a/v3/examples_test.go b/v3/examples_test.go index d788e4f5..59fd6071 100644 --- a/v3/examples_test.go +++ b/v3/examples_test.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -50,6 +51,34 @@ func ExampleConn_Search() { } } +// This example demonstrates how to search with channel +func ExampleConn_SearchWithChannel() { + l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + "dc=example,dc=com", // The base dn to search + ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, + "(&(objectClass=organizationalPerson))", // The filter to apply + []string{"dn", "cn"}, // A list attributes to retrieve + nil, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ch := l.SearchWithChannel(ctx, searchRequest) + for res := range ch { + if res.Error != nil { + log.Fatalf("Error searching: %s", res.Error) + } + fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN) + } +} + // This example demonstrates how to start a TLS connection func ExampleConn_StartTLS() { l, err := DialURL("ldap://ldap.example.com:389") diff --git a/v3/ldap_test.go b/v3/ldap_test.go index 61417fd5..85efc8af 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "testing" @@ -344,3 +345,64 @@ func TestEscapeDN(t *testing.T) { }) } } + +func TestSearchWithChannel(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + nil) + + srs := make([]*Entry, 0) + ctx := context.Background() + for sr := range l.SearchWithChannel(ctx, searchRequest) { + if sr.Error != nil { + t.Fatal(err) + } + srs = append(srs, sr.Entry) + } + + t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) +} + +func TestSearchWithChannelAndCancel(t *testing.T) { + l, err := DialURL(ldapServer) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + baseDN, + ScopeWholeSubtree, DerefAlways, 0, 0, false, + filter[2], + attributes, + nil) + + cancelNum := 10 + srs := make([]*Entry, 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for sr := range l.SearchWithChannel(ctx, searchRequest) { + if sr.Error != nil { + t.Fatal(err) + } + srs = append(srs, sr.Entry) + if len(srs) == cancelNum { + cancel() + } + } + if len(srs) > cancelNum+2 { + // The cancel process is asynchronous, + // so a few entries after it canceled might be received + t.Errorf("Got entries %d, expected less than %d", len(srs), cancelNum+2) + } + t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) +} diff --git a/v3/search.go b/v3/search.go index 9c0ccd07..72928c38 100644 --- a/v3/search.go +++ b/v3/search.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "errors" "fmt" "reflect" @@ -377,6 +378,28 @@ func (s *SearchResult) appendTo(r *SearchResult) { r.Controls = append(r.Controls, s.Controls...) } +// SearchSingleResult holds the server's single response to a search request +type SearchSingleResult struct { + // Entry is the returned entry + Entry *Entry + // Referral is the returned referral + Referral string + // Controls are the returned controls + Controls []Control + // Error is set when the search request was failed + Error error +} + +// Print outputs a human-readable description +func (s *SearchSingleResult) Print() { + s.Entry.Print() +} + +// PrettyPrint outputs a human-readable description with indenting +func (s *SearchSingleResult) PrettyPrint(indent int) { + s.Entry.PrettyPrint(indent) +} + // SearchRequest represents a search request to send to the server type SearchRequest struct { BaseDN string @@ -561,6 +584,111 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } +// SearchWithChannel performs a search request and returns all search results +// via the returned channel as soon as they are received. This means you get +// all results until an error happens (or the search successfully finished), +// e.g. for size / time limited requests all are recieved via the channel +// until the limit is reached. +func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) chan *SearchSingleResult { + ch := make(chan *SearchSingleResult) + go func() { + defer close(ch) + if l.IsClosing() { + return + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) + // encode search request + err := searchRequest.appendTo(packet) + if err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + l.Debug.PrintPacket(packet) + + msgCtx, err := l.sendMessage(packet) + if err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + defer l.finishMessage(msgCtx) + + foundSearchSingleResultDone := false + for !foundSearchSingleResultDone { + select { + case <-ctx.Done(): + l.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) + return + default: + l.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + ch <- &SearchSingleResult{Error: err} + return + } + packet, err = packetResponse.ReadPacket() + l.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case ApplicationSearchResultEntry: + entry := new(Entry) + entry.DN = packet.Children[1].Children[0].Value.(string) + for _, child := range packet.Children[1].Children[1].Children { + attr := new(EntryAttribute) + attr.Name = child.Children[0].Value.(string) + for _, value := range child.Children[1].Children { + attr.Values = append(attr.Values, value.Value.(string)) + attr.ByteValues = append(attr.ByteValues, value.ByteValue) + } + entry.Attributes = append(entry.Attributes, attr) + } + ch <- &SearchSingleResult{Entry: entry} + + case ApplicationSearchResultDone: + if err := GetLDAPError(packet); err != nil { + ch <- &SearchSingleResult{Error: err} + return + } + if len(packet.Children) == 3 { + result := &SearchSingleResult{} + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + werr := fmt.Errorf("failed to decode child control: %w", err) + ch <- &SearchSingleResult{Error: werr} + return + } + result.Controls = append(result.Controls, decodedChild) + } + ch <- result + } + foundSearchSingleResultDone = true + + case ApplicationSearchResultReference: + ref := packet.Children[1].Children[0].Value.(string) + ch <- &SearchSingleResult{Referral: ref} + } + } + } + l.Debug.Printf("%d: returning", msgCtx.id) + }() + return ch +} + // unpackAttributes will extract all given LDAP attributes and it's values // from the ber.Packet func unpackAttributes(children []*ber.Packet) []*EntryAttribute { From 1131fd567a432ca43e2b48f84e288a9dcfb37e48 Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Sun, 4 Jun 2023 11:33:49 +0900 Subject: [PATCH 02/10] refactor: fix to check proper test results #319 --- ldap_test.go | 10 +++++----- v3/ldap_test.go | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ldap_test.go b/ldap_test.go index 85efc8af..daf3581b 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -390,7 +390,9 @@ func TestSearchWithChannelAndCancel(t *testing.T) { srs := make([]*Entry, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - for sr := range l.SearchWithChannel(ctx, searchRequest) { + ch := l.SearchWithChannel(ctx, searchRequest) + for i := 0; i < 10; i++ { + sr := <-ch if sr.Error != nil { t.Fatal(err) } @@ -399,10 +401,8 @@ func TestSearchWithChannelAndCancel(t *testing.T) { cancel() } } - if len(srs) > cancelNum+2 { - // The cancel process is asynchronous, - // so a few entries after it canceled might be received - t.Errorf("Got entries %d, expected less than %d", len(srs), cancelNum+2) + if len(srs) != cancelNum { + t.Errorf("Got entries %d, expected %d", len(srs), cancelNum) } t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } diff --git a/v3/ldap_test.go b/v3/ldap_test.go index 85efc8af..daf3581b 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -390,7 +390,9 @@ func TestSearchWithChannelAndCancel(t *testing.T) { srs := make([]*Entry, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - for sr := range l.SearchWithChannel(ctx, searchRequest) { + ch := l.SearchWithChannel(ctx, searchRequest) + for i := 0; i < 10; i++ { + sr := <-ch if sr.Error != nil { t.Fatal(err) } @@ -399,10 +401,8 @@ func TestSearchWithChannelAndCancel(t *testing.T) { cancel() } } - if len(srs) > cancelNum+2 { - // The cancel process is asynchronous, - // so a few entries after it canceled might be received - t.Errorf("Got entries %d, expected less than %d", len(srs), cancelNum+2) + if len(srs) != cancelNum { + t.Errorf("Got entries %d, expected %d", len(srs), cancelNum) } t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } From 60bc0800855041c6579e7768f4bf73701e44b16f Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Sun, 4 Jun 2023 11:34:42 +0900 Subject: [PATCH 03/10] refactor: fix to use unpackAttributes() for Attributes #319 --- search.go | 16 +++++----------- v3/search.go | 16 +++++----------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/search.go b/search.go index 7176336e..11bed706 100644 --- a/search.go +++ b/search.go @@ -643,18 +643,12 @@ func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchReque switch packet.Children[1].Tag { case ApplicationSearchResultEntry: - entry := new(Entry) - entry.DN = packet.Children[1].Children[0].Value.(string) - for _, child := range packet.Children[1].Children[1].Children { - attr := new(EntryAttribute) - attr.Name = child.Children[0].Value.(string) - for _, value := range child.Children[1].Children { - attr.Values = append(attr.Values, value.Value.(string)) - attr.ByteValues = append(attr.ByteValues, value.ByteValue) - } - entry.Attributes = append(entry.Attributes, attr) + ch <- &SearchSingleResult{ + Entry: &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + }, } - ch <- &SearchSingleResult{Entry: entry} case ApplicationSearchResultDone: if err := GetLDAPError(packet); err != nil { diff --git a/v3/search.go b/v3/search.go index 72928c38..334ad6e0 100644 --- a/v3/search.go +++ b/v3/search.go @@ -645,18 +645,12 @@ func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchReque switch packet.Children[1].Tag { case ApplicationSearchResultEntry: - entry := new(Entry) - entry.DN = packet.Children[1].Children[0].Value.(string) - for _, child := range packet.Children[1].Children[1].Children { - attr := new(EntryAttribute) - attr.Name = child.Children[0].Value.(string) - for _, value := range child.Children[1].Children { - attr.Values = append(attr.Values, value.Value.(string)) - attr.ByteValues = append(attr.ByteValues, value.ByteValue) - } - entry.Attributes = append(entry.Attributes, attr) + ch <- &SearchSingleResult{ + Entry: &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + }, } - ch <- &SearchSingleResult{Entry: entry} case ApplicationSearchResultDone: if err := GetLDAPError(packet); err != nil { From 03e1d76e2ce1f4cc7779844e6246ac1e82d1a1fe Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Sun, 4 Jun 2023 11:51:21 +0900 Subject: [PATCH 04/10] refactor: returns receive-only channel to prevent closing it from the caller #319 --- search.go | 2 +- v3/search.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/search.go b/search.go index 11bed706..cc2c1efd 100644 --- a/search.go +++ b/search.go @@ -587,7 +587,7 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { // all results until an error happens (or the search successfully finished), // e.g. for size / time limited requests all are recieved via the channel // until the limit is reached. -func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) chan *SearchSingleResult { +func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) <-chan *SearchSingleResult { ch := make(chan *SearchSingleResult) go func() { defer close(ch) diff --git a/v3/search.go b/v3/search.go index 334ad6e0..ae8b837b 100644 --- a/v3/search.go +++ b/v3/search.go @@ -589,7 +589,7 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { // all results until an error happens (or the search successfully finished), // e.g. for size / time limited requests all are recieved via the channel // until the limit is reached. -func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) chan *SearchSingleResult { +func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) <-chan *SearchSingleResult { ch := make(chan *SearchSingleResult) go func() { defer close(ch) From 51a1a4f89073e12e0121b1ac9c01266f822e4855 Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Sun, 4 Jun 2023 12:10:44 +0900 Subject: [PATCH 05/10] refactor: pass channelSize to be able to controll buffered channel by the caller #319 --- examples_test.go | 2 +- ldap_test.go | 4 ++-- search.go | 9 +++++++-- v3/examples_test.go | 2 +- v3/ldap_test.go | 4 ++-- v3/search.go | 9 +++++++-- 6 files changed, 20 insertions(+), 10 deletions(-) diff --git a/examples_test.go b/examples_test.go index 59fd6071..1de6c9be 100644 --- a/examples_test.go +++ b/examples_test.go @@ -70,7 +70,7 @@ func ExampleConn_SearchWithChannel() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest) + ch := l.SearchWithChannel(ctx, searchRequest, 64) for res := range ch { if res.Error != nil { log.Fatalf("Error searching: %s", res.Error) diff --git a/ldap_test.go b/ldap_test.go index daf3581b..34199db2 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -362,7 +362,7 @@ func TestSearchWithChannel(t *testing.T) { srs := make([]*Entry, 0) ctx := context.Background() - for sr := range l.SearchWithChannel(ctx, searchRequest) { + for sr := range l.SearchWithChannel(ctx, searchRequest, 64) { if sr.Error != nil { t.Fatal(err) } @@ -390,7 +390,7 @@ func TestSearchWithChannelAndCancel(t *testing.T) { srs := make([]*Entry, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest) + ch := l.SearchWithChannel(ctx, searchRequest, 0) for i := 0; i < 10; i++ { sr := <-ch if sr.Error != nil { diff --git a/search.go b/search.go index cc2c1efd..37693f14 100644 --- a/search.go +++ b/search.go @@ -587,8 +587,13 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { // all results until an error happens (or the search successfully finished), // e.g. for size / time limited requests all are recieved via the channel // until the limit is reached. -func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) <-chan *SearchSingleResult { - ch := make(chan *SearchSingleResult) +func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult { + var ch chan *SearchSingleResult + if channelSize > 0 { + ch = make(chan *SearchSingleResult, channelSize) + } else { + ch = make(chan *SearchSingleResult) + } go func() { defer close(ch) if l.IsClosing() { diff --git a/v3/examples_test.go b/v3/examples_test.go index 59fd6071..1de6c9be 100644 --- a/v3/examples_test.go +++ b/v3/examples_test.go @@ -70,7 +70,7 @@ func ExampleConn_SearchWithChannel() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest) + ch := l.SearchWithChannel(ctx, searchRequest, 64) for res := range ch { if res.Error != nil { log.Fatalf("Error searching: %s", res.Error) diff --git a/v3/ldap_test.go b/v3/ldap_test.go index daf3581b..34199db2 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -362,7 +362,7 @@ func TestSearchWithChannel(t *testing.T) { srs := make([]*Entry, 0) ctx := context.Background() - for sr := range l.SearchWithChannel(ctx, searchRequest) { + for sr := range l.SearchWithChannel(ctx, searchRequest, 64) { if sr.Error != nil { t.Fatal(err) } @@ -390,7 +390,7 @@ func TestSearchWithChannelAndCancel(t *testing.T) { srs := make([]*Entry, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest) + ch := l.SearchWithChannel(ctx, searchRequest, 0) for i := 0; i < 10; i++ { sr := <-ch if sr.Error != nil { diff --git a/v3/search.go b/v3/search.go index ae8b837b..87582dfa 100644 --- a/v3/search.go +++ b/v3/search.go @@ -589,8 +589,13 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { // all results until an error happens (or the search successfully finished), // e.g. for size / time limited requests all are recieved via the channel // until the limit is reached. -func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest) <-chan *SearchSingleResult { - ch := make(chan *SearchSingleResult) +func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult { + var ch chan *SearchSingleResult + if channelSize > 0 { + ch = make(chan *SearchSingleResult, channelSize) + } else { + ch = make(chan *SearchSingleResult) + } go func() { defer close(ch) if l.IsClosing() { From 698ccfe697da4052a6b1751d331d8344590fdff0 Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Sun, 4 Jun 2023 13:01:31 +0900 Subject: [PATCH 06/10] fix: recover an asynchronouse closing timing issue #319 --- search.go | 8 +++++++- v3/search.go | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/search.go b/search.go index 37693f14..d2961947 100644 --- a/search.go +++ b/search.go @@ -595,7 +595,13 @@ func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchReque ch = make(chan *SearchSingleResult) } go func() { - defer close(ch) + defer func() { + close(ch) + if err := recover(); err != nil { + l.err = fmt.Errorf("ldap: recovered panic in SearchWithChannel: %v", err) + } + }() + if l.IsClosing() { return } diff --git a/v3/search.go b/v3/search.go index 87582dfa..f3edcd70 100644 --- a/v3/search.go +++ b/v3/search.go @@ -597,7 +597,13 @@ func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchReque ch = make(chan *SearchSingleResult) } go func() { - defer close(ch) + defer func() { + close(ch) + if err := recover(); err != nil { + l.err = fmt.Errorf("ldap: recovered panic in SearchWithChannel: %v", err) + } + }() + if l.IsClosing() { return } From 72797107f94d3406274fc3d357ec57db3598b42b Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Sun, 4 Jun 2023 14:09:27 +0900 Subject: [PATCH 07/10] fix: consume all entries from the channel to prevent blocking by the connection #319 --- ldap_test.go | 3 +++ v3/ldap_test.go | 3 +++ 2 files changed, 6 insertions(+) diff --git a/ldap_test.go b/ldap_test.go index 34199db2..bbeecc9d 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -401,6 +401,9 @@ func TestSearchWithChannelAndCancel(t *testing.T) { cancel() } } + for range ch { + t.Log("Consume all entries from the channel to prevent blocking by the connection") + } if len(srs) != cancelNum { t.Errorf("Got entries %d, expected %d", len(srs), cancelNum) } diff --git a/v3/ldap_test.go b/v3/ldap_test.go index 34199db2..bbeecc9d 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -401,6 +401,9 @@ func TestSearchWithChannelAndCancel(t *testing.T) { cancel() } } + for range ch { + t.Log("Consume all entries from the channel to prevent blocking by the connection") + } if len(srs) != cancelNum { t.Errorf("Got entries %d, expected %d", len(srs), cancelNum) } From a9daeebe787c8b355e1522e764fc436e78cc98ab Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Mon, 5 Jun 2023 11:48:15 +0900 Subject: [PATCH 08/10] feat: add initial search async function with channel #341 --- v3/client.go | 3 + v3/examples_test.go | 29 ++++++++ v3/response.go | 172 ++++++++++++++++++++++++++++++++++++++++++++ v3/search.go | 11 +++ 4 files changed, 215 insertions(+) create mode 100644 v3/response.go diff --git a/v3/client.go b/v3/client.go index b438d254..cef2d91b 100644 --- a/v3/client.go +++ b/v3/client.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "time" ) @@ -32,6 +33,8 @@ type Client interface { PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) Search(*SearchRequest) (*SearchResult, error) + SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response + SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error) } diff --git a/v3/examples_test.go b/v3/examples_test.go index 1de6c9be..46898abb 100644 --- a/v3/examples_test.go +++ b/v3/examples_test.go @@ -51,6 +51,35 @@ func ExampleConn_Search() { } } +// This example demonstrates how to search with channel +func ExampleConn_SearchAsync() { + l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + searchRequest := NewSearchRequest( + "dc=example,dc=com", // The base dn to search + ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, + "(&(objectClass=organizationalPerson))", // The filter to apply + []string{"dn", "cn"}, // A list attributes to retrieve + nil, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + entry := r.Entry() + fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN) + } + if err := r.Err(); err != nil { + log.Fatal(err) + } +} + // This example demonstrates how to search with channel func ExampleConn_SearchWithChannel() { l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) diff --git a/v3/response.go b/v3/response.go new file mode 100644 index 00000000..3bfef84e --- /dev/null +++ b/v3/response.go @@ -0,0 +1,172 @@ +package ldap + +import ( + "context" + "errors" + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// Response defines an interface to get data from an LDAP server +type Response interface { + Entry() *Entry + Referral() string + Controls() []Control + Err() error + Next() bool +} + +type searchResponse struct { + conn *Conn + ch chan *SearchSingleResult + + entry *Entry + referral string + controls []Control + err error +} + +// Entry returns an entry from the given search request +func (r *searchResponse) Entry() *Entry { + return r.entry +} + +// Referral returns a referral from the given search request +func (r *searchResponse) Referral() string { + return r.referral +} + +// Controls returns controls from the given search request +func (r *searchResponse) Controls() []Control { + return r.controls +} + +// Err returns an error when the given search request was failed +func (r *searchResponse) Err() error { + return r.err +} + +// Next returns whether next data exist or not +func (r *searchResponse) Next() bool { + res := <-r.ch + if res == nil { + return false + } + r.err = res.Error + if r.err != nil { + return false + } + r.err = r.conn.GetLastError() + if r.err != nil { + return false + } + r.entry = res.Entry + r.referral = res.Referral + r.controls = res.Controls + return true +} + +func (r *searchResponse) searchAsync( + ctx context.Context, searchRequest *SearchRequest, bufferSize int) { + if bufferSize > 0 { + r.ch = make(chan *SearchSingleResult, bufferSize) + } else { + r.ch = make(chan *SearchSingleResult) + } + go func() { + defer func() { + close(r.ch) + if err := recover(); err != nil { + r.conn.err = fmt.Errorf("ldap: recovered panic in searchAsync: %v", err) + } + }() + + if r.conn.IsClosing() { + return + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID")) + // encode search request + err := searchRequest.appendTo(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + r.conn.Debug.PrintPacket(packet) + + msgCtx, err := r.conn.sendMessage(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + defer r.conn.finishMessage(msgCtx) + + foundSearchSingleResultDone := false + for !foundSearchSingleResultDone { + select { + case <-ctx.Done(): + r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) + return + default: + r.conn.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + r.ch <- &SearchSingleResult{Error: err} + return + } + packet, err = packetResponse.ReadPacket() + r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + + if r.conn.Debug { + if err := addLDAPDescriptions(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case ApplicationSearchResultEntry: + r.ch <- &SearchSingleResult{ + Entry: &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + }, + } + + case ApplicationSearchResultDone: + if err := GetLDAPError(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + if len(packet.Children) == 3 { + result := &SearchSingleResult{} + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + werr := fmt.Errorf("failed to decode child control: %w", err) + r.ch <- &SearchSingleResult{Error: werr} + return + } + result.Controls = append(result.Controls, decodedChild) + } + r.ch <- result + } + foundSearchSingleResultDone = true + + case ApplicationSearchResultReference: + ref := packet.Children[1].Children[0].Value.(string) + r.ch <- &SearchSingleResult{Referral: ref} + } + } + } + r.conn.Debug.Printf("%d: returning", msgCtx.id) + }() +} diff --git a/v3/search.go b/v3/search.go index f3edcd70..2d8e13ad 100644 --- a/v3/search.go +++ b/v3/search.go @@ -584,6 +584,17 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } +// SearchAsync performs a search request and returns all search results asynchronously. +// This means you get all results until an error happens (or the search successfully finished), +// e.g. for size / time limited requests all are recieved until the limit is reached. +// To stop the search, call cancel function returned context. +func (l *Conn) SearchAsync( + ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response { + r := &searchResponse{conn: l} + r.searchAsync(ctx, searchRequest, bufferSize) + return r +} + // SearchWithChannel performs a search request and returns all search results // via the returned channel as soon as they are received. This means you get // all results until an error happens (or the search successfully finished), From 2f623f08afb7ee79c06388291ba173de2540647f Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Thu, 22 Jun 2023 10:02:05 +0900 Subject: [PATCH 09/10] feat: provide search async function and drop search with channels #319 #341 --- client.go | 2 + examples_test.go | 17 +++-- ldap_test.go | 41 +++++----- response.go | 182 ++++++++++++++++++++++++++++++++++++++++++++ search.go | 117 +++------------------------- v3/client.go | 1 - v3/examples_test.go | 30 +------- v3/ldap_test.go | 41 +++++----- v3/response.go | 28 ++++--- v3/search.go | 116 +--------------------------- 10 files changed, 267 insertions(+), 308 deletions(-) create mode 100644 response.go diff --git a/client.go b/client.go index b438d254..5799f39b 100644 --- a/client.go +++ b/client.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "time" ) @@ -32,6 +33,7 @@ type Client interface { PasswordModify(*PasswordModifyRequest) (*PasswordModifyResult, error) Search(*SearchRequest) (*SearchResult, error) + SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error) } diff --git a/examples_test.go b/examples_test.go index 1de6c9be..61f16197 100644 --- a/examples_test.go +++ b/examples_test.go @@ -51,8 +51,8 @@ func ExampleConn_Search() { } } -// This example demonstrates how to search with channel -func ExampleConn_SearchWithChannel() { +// This example demonstrates how to search asynchronously +func ExampleConn_SearchAsync() { l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) if err != nil { log.Fatal(err) @@ -70,12 +70,13 @@ func ExampleConn_SearchWithChannel() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest, 64) - for res := range ch { - if res.Error != nil { - log.Fatalf("Error searching: %s", res.Error) - } - fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN) + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + entry := r.Entry() + fmt.Printf("%s has DN %s\n", entry.GetAttributeValue("cn"), entry.DN) + } + if err := r.Err(); err != nil { + log.Fatal(err) } } diff --git a/ldap_test.go b/ldap_test.go index bbeecc9d..5b96e039 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -3,6 +3,7 @@ package ldap import ( "context" "crypto/tls" + "log" "testing" ber "github.com/go-asn1-ber/asn1-ber" @@ -346,7 +347,7 @@ func TestEscapeDN(t *testing.T) { } } -func TestSearchWithChannel(t *testing.T) { +func TestSearchAsync(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { t.Fatal(err) @@ -362,17 +363,18 @@ func TestSearchWithChannel(t *testing.T) { srs := make([]*Entry, 0) ctx := context.Background() - for sr := range l.SearchWithChannel(ctx, searchRequest, 64) { - if sr.Error != nil { - t.Fatal(err) - } - srs = append(srs, sr.Entry) + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + srs = append(srs, r.Entry()) + } + if err := r.Err(); err != nil { + log.Fatal(err) } - t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) + t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } -func TestSearchWithChannelAndCancel(t *testing.T) { +func TestSearchAsyncAndCancel(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { t.Fatal(err) @@ -390,22 +392,21 @@ func TestSearchWithChannelAndCancel(t *testing.T) { srs := make([]*Entry, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest, 0) - for i := 0; i < 10; i++ { - sr := <-ch - if sr.Error != nil { - t.Fatal(err) - } - srs = append(srs, sr.Entry) + r := l.SearchAsync(ctx, searchRequest, 0) + for r.Next() { + srs = append(srs, r.Entry()) if len(srs) == cancelNum { cancel() } } - for range ch { - t.Log("Consume all entries from the channel to prevent blocking by the connection") + if err := r.Err(); err != nil { + log.Fatal(err) } - if len(srs) != cancelNum { - t.Errorf("Got entries %d, expected %d", len(srs), cancelNum) + + if len(srs) > cancelNum+3 { + // the cancellation process is asynchronous, + // so it might get some entries after calling cancel() + t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3) } - t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) + t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } diff --git a/response.go b/response.go new file mode 100644 index 00000000..81d97d9b --- /dev/null +++ b/response.go @@ -0,0 +1,182 @@ +package ldap + +import ( + "context" + "errors" + "fmt" + + ber "github.com/go-asn1-ber/asn1-ber" +) + +// Response defines an interface to get data from an LDAP server +type Response interface { + Entry() *Entry + Referral() string + Controls() []Control + Err() error + Next() bool +} + +type searchResponse struct { + conn *Conn + ch chan *SearchSingleResult + + entry *Entry + referral string + controls []Control + err error +} + +// Entry returns an entry from the given search request +func (r *searchResponse) Entry() *Entry { + return r.entry +} + +// Referral returns a referral from the given search request +func (r *searchResponse) Referral() string { + return r.referral +} + +// Controls returns controls from the given search request +func (r *searchResponse) Controls() []Control { + return r.controls +} + +// Err returns an error when the given search request was failed +func (r *searchResponse) Err() error { + return r.err +} + +// Next returns whether next data exist or not +func (r *searchResponse) Next() bool { + res, ok := <-r.ch + if !ok { + return false + } + if res == nil { + return false + } + r.err = res.Error + if r.err != nil { + return false + } + r.err = r.conn.GetLastError() + if r.err != nil { + return false + } + r.entry = res.Entry + r.referral = res.Referral + r.controls = res.Controls + return true +} + +func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) { + go func() { + defer func() { + close(r.ch) + if err := recover(); err != nil { + r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err) + } + }() + + if r.conn.IsClosing() { + return + } + + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, r.conn.nextMessageID(), "MessageID")) + // encode search request + err := searchRequest.appendTo(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + r.conn.Debug.PrintPacket(packet) + + msgCtx, err := r.conn.sendMessage(packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + defer r.conn.finishMessage(msgCtx) + + foundSearchSingleResultDone := false + for !foundSearchSingleResultDone { + select { + case <-ctx.Done(): + r.conn.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) + return + default: + r.conn.Debug.Printf("%d: waiting for response", msgCtx.id) + packetResponse, ok := <-msgCtx.responses + if !ok { + err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) + r.ch <- &SearchSingleResult{Error: err} + return + } + packet, err = packetResponse.ReadPacket() + r.conn.Debug.Printf("%d: got response %p", msgCtx.id, packet) + if err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + + if r.conn.Debug { + if err := addLDAPDescriptions(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case ApplicationSearchResultEntry: + r.ch <- &SearchSingleResult{ + Entry: &Entry{ + DN: packet.Children[1].Children[0].Value.(string), + Attributes: unpackAttributes(packet.Children[1].Children[1].Children), + }, + } + + case ApplicationSearchResultDone: + if err := GetLDAPError(packet); err != nil { + r.ch <- &SearchSingleResult{Error: err} + return + } + if len(packet.Children) == 3 { + result := &SearchSingleResult{} + for _, child := range packet.Children[2].Children { + decodedChild, err := DecodeControl(child) + if err != nil { + werr := fmt.Errorf("failed to decode child control: %w", err) + r.ch <- &SearchSingleResult{Error: werr} + return + } + result.Controls = append(result.Controls, decodedChild) + } + r.ch <- result + } + foundSearchSingleResultDone = true + + case ApplicationSearchResultReference: + ref := packet.Children[1].Children[0].Value.(string) + r.ch <- &SearchSingleResult{Referral: ref} + } + } + } + r.conn.Debug.Printf("%d: returning", msgCtx.id) + }() +} + +func newSearchResponse(conn *Conn, bufferSize int) *searchResponse { + var ch chan *SearchSingleResult + if bufferSize > 0 { + ch = make(chan *SearchSingleResult, bufferSize) + } else { + ch = make(chan *SearchSingleResult) + } + return &searchResponse{ + conn: conn, + ch: ch, + } +} diff --git a/search.go b/search.go index d2961947..3d8d9e70 100644 --- a/search.go +++ b/search.go @@ -582,114 +582,15 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { } } -// SearchWithChannel performs a search request and returns all search results -// via the returned channel as soon as they are received. This means you get -// all results until an error happens (or the search successfully finished), -// e.g. for size / time limited requests all are recieved via the channel -// until the limit is reached. -func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult { - var ch chan *SearchSingleResult - if channelSize > 0 { - ch = make(chan *SearchSingleResult, channelSize) - } else { - ch = make(chan *SearchSingleResult) - } - go func() { - defer func() { - close(ch) - if err := recover(); err != nil { - l.err = fmt.Errorf("ldap: recovered panic in SearchWithChannel: %v", err) - } - }() - - if l.IsClosing() { - return - } - - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - // encode search request - err := searchRequest.appendTo(packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - defer l.finishMessage(msgCtx) - - foundSearchSingleResultDone := false - for !foundSearchSingleResultDone { - select { - case <-ctx.Done(): - l.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) - return - default: - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - ch <- &SearchSingleResult{Error: err} - return - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - ber.PrintPacket(packet) - } - - switch packet.Children[1].Tag { - case ApplicationSearchResultEntry: - ch <- &SearchSingleResult{ - Entry: &Entry{ - DN: packet.Children[1].Children[0].Value.(string), - Attributes: unpackAttributes(packet.Children[1].Children[1].Children), - }, - } - - case ApplicationSearchResultDone: - if err := GetLDAPError(packet); err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - if len(packet.Children) == 3 { - result := &SearchSingleResult{} - for _, child := range packet.Children[2].Children { - decodedChild, err := DecodeControl(child) - if err != nil { - werr := fmt.Errorf("failed to decode child control: %w", err) - ch <- &SearchSingleResult{Error: werr} - return - } - result.Controls = append(result.Controls, decodedChild) - } - ch <- result - } - foundSearchSingleResultDone = true - - case ApplicationSearchResultReference: - ref := packet.Children[1].Children[0].Value.(string) - ch <- &SearchSingleResult{Referral: ref} - } - } - } - l.Debug.Printf("%d: returning", msgCtx.id) - }() - return ch +// SearchAsync performs a search request and returns all search results asynchronously. +// This means you get all results until an error happens (or the search successfully finished), +// e.g. for size / time limited requests all are recieved until the limit is reached. +// To stop the search, call cancel function returned context. +func (l *Conn) SearchAsync( + ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response { + r := newSearchResponse(l, bufferSize) + r.start(ctx, searchRequest) + return r } // unpackAttributes will extract all given LDAP attributes and it's values diff --git a/v3/client.go b/v3/client.go index cef2d91b..5799f39b 100644 --- a/v3/client.go +++ b/v3/client.go @@ -34,7 +34,6 @@ type Client interface { Search(*SearchRequest) (*SearchResult, error) SearchAsync(ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response - SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, error) DirSync(searchRequest *SearchRequest, flags, maxAttrCount int64, cookie []byte) (*SearchResult, error) } diff --git a/v3/examples_test.go b/v3/examples_test.go index 46898abb..61f16197 100644 --- a/v3/examples_test.go +++ b/v3/examples_test.go @@ -51,7 +51,7 @@ func ExampleConn_Search() { } } -// This example demonstrates how to search with channel +// This example demonstrates how to search asynchronously func ExampleConn_SearchAsync() { l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) if err != nil { @@ -80,34 +80,6 @@ func ExampleConn_SearchAsync() { } } -// This example demonstrates how to search with channel -func ExampleConn_SearchWithChannel() { - l, err := DialURL(fmt.Sprintf("%s:%d", "ldap.example.com", 389)) - if err != nil { - log.Fatal(err) - } - defer l.Close() - - searchRequest := NewSearchRequest( - "dc=example,dc=com", // The base dn to search - ScopeWholeSubtree, NeverDerefAliases, 0, 0, false, - "(&(objectClass=organizationalPerson))", // The filter to apply - []string{"dn", "cn"}, // A list attributes to retrieve - nil, - ) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ch := l.SearchWithChannel(ctx, searchRequest, 64) - for res := range ch { - if res.Error != nil { - log.Fatalf("Error searching: %s", res.Error) - } - fmt.Printf("%s has DN %s\n", res.Entry.GetAttributeValue("cn"), res.Entry.DN) - } -} - // This example demonstrates how to start a TLS connection func ExampleConn_StartTLS() { l, err := DialURL("ldap://ldap.example.com:389") diff --git a/v3/ldap_test.go b/v3/ldap_test.go index bbeecc9d..5b96e039 100644 --- a/v3/ldap_test.go +++ b/v3/ldap_test.go @@ -3,6 +3,7 @@ package ldap import ( "context" "crypto/tls" + "log" "testing" ber "github.com/go-asn1-ber/asn1-ber" @@ -346,7 +347,7 @@ func TestEscapeDN(t *testing.T) { } } -func TestSearchWithChannel(t *testing.T) { +func TestSearchAsync(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { t.Fatal(err) @@ -362,17 +363,18 @@ func TestSearchWithChannel(t *testing.T) { srs := make([]*Entry, 0) ctx := context.Background() - for sr := range l.SearchWithChannel(ctx, searchRequest, 64) { - if sr.Error != nil { - t.Fatal(err) - } - srs = append(srs, sr.Entry) + r := l.SearchAsync(ctx, searchRequest, 64) + for r.Next() { + srs = append(srs, r.Entry()) + } + if err := r.Err(); err != nil { + log.Fatal(err) } - t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) + t.Logf("TestSearcAsync: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } -func TestSearchWithChannelAndCancel(t *testing.T) { +func TestSearchAsyncAndCancel(t *testing.T) { l, err := DialURL(ldapServer) if err != nil { t.Fatal(err) @@ -390,22 +392,21 @@ func TestSearchWithChannelAndCancel(t *testing.T) { srs := make([]*Entry, 0) ctx, cancel := context.WithCancel(context.Background()) defer cancel() - ch := l.SearchWithChannel(ctx, searchRequest, 0) - for i := 0; i < 10; i++ { - sr := <-ch - if sr.Error != nil { - t.Fatal(err) - } - srs = append(srs, sr.Entry) + r := l.SearchAsync(ctx, searchRequest, 0) + for r.Next() { + srs = append(srs, r.Entry()) if len(srs) == cancelNum { cancel() } } - for range ch { - t.Log("Consume all entries from the channel to prevent blocking by the connection") + if err := r.Err(); err != nil { + log.Fatal(err) } - if len(srs) != cancelNum { - t.Errorf("Got entries %d, expected %d", len(srs), cancelNum) + + if len(srs) > cancelNum+3 { + // the cancellation process is asynchronous, + // so it might get some entries after calling cancel() + t.Errorf("Got entries %d, expected < %d", len(srs), cancelNum+3) } - t.Logf("TestSearchWithChannel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) + t.Logf("TestSearchAsyncAndCancel: %s -> num of entries = %d", searchRequest.Filter, len(srs)) } diff --git a/v3/response.go b/v3/response.go index 3bfef84e..81d97d9b 100644 --- a/v3/response.go +++ b/v3/response.go @@ -49,7 +49,10 @@ func (r *searchResponse) Err() error { // Next returns whether next data exist or not func (r *searchResponse) Next() bool { - res := <-r.ch + res, ok := <-r.ch + if !ok { + return false + } if res == nil { return false } @@ -67,18 +70,12 @@ func (r *searchResponse) Next() bool { return true } -func (r *searchResponse) searchAsync( - ctx context.Context, searchRequest *SearchRequest, bufferSize int) { - if bufferSize > 0 { - r.ch = make(chan *SearchSingleResult, bufferSize) - } else { - r.ch = make(chan *SearchSingleResult) - } +func (r *searchResponse) start(ctx context.Context, searchRequest *SearchRequest) { go func() { defer func() { close(r.ch) if err := recover(); err != nil { - r.conn.err = fmt.Errorf("ldap: recovered panic in searchAsync: %v", err) + r.conn.err = fmt.Errorf("ldap: recovered panic in searchResponse: %v", err) } }() @@ -170,3 +167,16 @@ func (r *searchResponse) searchAsync( r.conn.Debug.Printf("%d: returning", msgCtx.id) }() } + +func newSearchResponse(conn *Conn, bufferSize int) *searchResponse { + var ch chan *SearchSingleResult + if bufferSize > 0 { + ch = make(chan *SearchSingleResult, bufferSize) + } else { + ch = make(chan *SearchSingleResult) + } + return &searchResponse{ + conn: conn, + ch: ch, + } +} diff --git a/v3/search.go b/v3/search.go index 2d8e13ad..afac768c 100644 --- a/v3/search.go +++ b/v3/search.go @@ -378,7 +378,7 @@ func (s *SearchResult) appendTo(r *SearchResult) { r.Controls = append(r.Controls, s.Controls...) } -// SearchSingleResult holds the server's single response to a search request +// SearchSingleResult holds the server's single entry response to a search request type SearchSingleResult struct { // Entry is the returned entry Entry *Entry @@ -590,121 +590,11 @@ func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, error) { // To stop the search, call cancel function returned context. func (l *Conn) SearchAsync( ctx context.Context, searchRequest *SearchRequest, bufferSize int) Response { - r := &searchResponse{conn: l} - r.searchAsync(ctx, searchRequest, bufferSize) + r := newSearchResponse(l, bufferSize) + r.start(ctx, searchRequest) return r } -// SearchWithChannel performs a search request and returns all search results -// via the returned channel as soon as they are received. This means you get -// all results until an error happens (or the search successfully finished), -// e.g. for size / time limited requests all are recieved via the channel -// until the limit is reached. -func (l *Conn) SearchWithChannel(ctx context.Context, searchRequest *SearchRequest, channelSize int) <-chan *SearchSingleResult { - var ch chan *SearchSingleResult - if channelSize > 0 { - ch = make(chan *SearchSingleResult, channelSize) - } else { - ch = make(chan *SearchSingleResult) - } - go func() { - defer func() { - close(ch) - if err := recover(); err != nil { - l.err = fmt.Errorf("ldap: recovered panic in SearchWithChannel: %v", err) - } - }() - - if l.IsClosing() { - return - } - - packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") - packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID")) - // encode search request - err := searchRequest.appendTo(packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - l.Debug.PrintPacket(packet) - - msgCtx, err := l.sendMessage(packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - defer l.finishMessage(msgCtx) - - foundSearchSingleResultDone := false - for !foundSearchSingleResultDone { - select { - case <-ctx.Done(): - l.Debug.Printf("%d: %s", msgCtx.id, ctx.Err().Error()) - return - default: - l.Debug.Printf("%d: waiting for response", msgCtx.id) - packetResponse, ok := <-msgCtx.responses - if !ok { - err := NewError(ErrorNetwork, errors.New("ldap: response channel closed")) - ch <- &SearchSingleResult{Error: err} - return - } - packet, err = packetResponse.ReadPacket() - l.Debug.Printf("%d: got response %p", msgCtx.id, packet) - if err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - - if l.Debug { - if err := addLDAPDescriptions(packet); err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - ber.PrintPacket(packet) - } - - switch packet.Children[1].Tag { - case ApplicationSearchResultEntry: - ch <- &SearchSingleResult{ - Entry: &Entry{ - DN: packet.Children[1].Children[0].Value.(string), - Attributes: unpackAttributes(packet.Children[1].Children[1].Children), - }, - } - - case ApplicationSearchResultDone: - if err := GetLDAPError(packet); err != nil { - ch <- &SearchSingleResult{Error: err} - return - } - if len(packet.Children) == 3 { - result := &SearchSingleResult{} - for _, child := range packet.Children[2].Children { - decodedChild, err := DecodeControl(child) - if err != nil { - werr := fmt.Errorf("failed to decode child control: %w", err) - ch <- &SearchSingleResult{Error: werr} - return - } - result.Controls = append(result.Controls, decodedChild) - } - ch <- result - } - foundSearchSingleResultDone = true - - case ApplicationSearchResultReference: - ref := packet.Children[1].Children[0].Value.(string) - ch <- &SearchSingleResult{Referral: ref} - } - } - } - l.Debug.Printf("%d: returning", msgCtx.id) - }() - return ch -} - // unpackAttributes will extract all given LDAP attributes and it's values // from the ber.Packet func unpackAttributes(children []*ber.Packet) []*EntryAttribute { From 9de41ced83456b5370af3765bf416db47052328f Mon Sep 17 00:00:00 2001 From: Tetsuya Morimoto Date: Thu, 22 Jun 2023 10:20:36 +0900 Subject: [PATCH 10/10] refactor: lock when to call GetLastError since it might be in communication --- conn.go | 2 ++ v3/conn.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/conn.go b/conn.go index d39213b4..474d494b 100644 --- a/conn.go +++ b/conn.go @@ -327,6 +327,8 @@ func (l *Conn) nextMessageID() int64 { // GetLastError returns the last recorded error from goroutines like processMessages and reader. // Only the last recorded error will be returned. func (l *Conn) GetLastError() error { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() return l.err } diff --git a/v3/conn.go b/v3/conn.go index 3ed80883..a42a9697 100644 --- a/v3/conn.go +++ b/v3/conn.go @@ -327,6 +327,8 @@ func (l *Conn) nextMessageID() int64 { // GetLastError returns the last recorded error from goroutines like processMessages and reader. // // Only the last recorded error will be returned. func (l *Conn) GetLastError() error { + l.messageMutex.Lock() + defer l.messageMutex.Unlock() return l.err }