Skip to content

Commit

Permalink
Merge pull request #3571 from ipfs/feat/better-enum-async
Browse files Browse the repository at this point in the history
rewrite enumerate children async to be less fragile
  • Loading branch information
whyrusleeping authored Jan 17, 2017
2 parents ea36c38 + 397a35a commit 75cce80
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 79 deletions.
145 changes: 67 additions & 78 deletions merkledag/merkledag.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (n *dagService) Remove(nd node.Node) error {

// FetchGraph fetches all nodes that are children of the given node
func FetchGraph(ctx context.Context, c *cid.Cid, serv DAGService) error {
return EnumerateChildrenAsync(ctx, serv, c, cid.NewSet().Visit)
return EnumerateChildren(ctx, serv, c, cid.NewSet().Visit, false)
}

// FindLinks searches this nodes links for the given key,
Expand Down Expand Up @@ -389,103 +389,92 @@ func EnumerateChildren(ctx context.Context, ds LinkService, root *cid.Cid, visit
return nil
}

func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visit func(*cid.Cid) bool) error {
toprocess := make(chan []*cid.Cid, 8)
nodes := make(chan *NodeOption, 8)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
defer close(toprocess)
// FetchGraphConcurrency is total number of concurrent fetches that
// 'fetchNodes' will start at a time
var FetchGraphConcurrency = 8

go fetchNodes(ctx, ds, toprocess, nodes)
func EnumerateChildrenAsync(ctx context.Context, ds DAGService, c *cid.Cid, visit func(*cid.Cid) bool) error {
if !visit(c) {
return nil
}

root, err := ds.Get(ctx, c)
if err != nil {
return err
}

nodes <- &NodeOption{Node: root}
live := 1

for {
select {
case opt, ok := <-nodes:
if !ok {
return nil
}

if opt.Err != nil {
return opt.Err
}

nd := opt.Node

// a node has been fetched
live--

var cids []*cid.Cid
for _, lnk := range nd.Links() {
c := lnk.Cid
if visit(c) {
live++
cids = append(cids, c)
feed := make(chan node.Node)
out := make(chan *NodeOption)
done := make(chan struct{})

var setlk sync.Mutex

for i := 0; i < FetchGraphConcurrency; i++ {
go func() {
for n := range feed {
links := n.Links()
cids := make([]*cid.Cid, 0, len(links))
for _, l := range links {
setlk.Lock()
unseen := visit(l.Cid)
setlk.Unlock()
if unseen {
cids = append(cids, l.Cid)
}
}
}

if live == 0 {
return nil
}

if len(cids) > 0 {
for nopt := range ds.GetMany(ctx, cids) {
select {
case out <- nopt:
case <-ctx.Done():
return
}
}
select {
case toprocess <- cids:
case done <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}
}
case <-ctx.Done():
return ctx.Err()
}
}()
}
}
defer close(feed)

// FetchGraphConcurrency is total number of concurrenct fetches that
// 'fetchNodes' will start at a time
var FetchGraphConcurrency = 8

func fetchNodes(ctx context.Context, ds DAGService, in <-chan []*cid.Cid, out chan<- *NodeOption) {
var wg sync.WaitGroup
defer func() {
// wait for all 'get' calls to complete so we don't accidentally send
// on a closed channel
wg.Wait()
close(out)
}()
send := feed
var todobuffer []node.Node
var inProgress int

rateLimit := make(chan struct{}, FetchGraphConcurrency)
next := root
for {
select {
case send <- next:
inProgress++
if len(todobuffer) > 0 {
next = todobuffer[0]
todobuffer = todobuffer[1:]
} else {
next = nil
send = nil
}
case <-done:
inProgress--
if inProgress == 0 && next == nil {
return nil
}
case nc := <-out:
if nc.Err != nil {
return nc.Err
}

get := func(ks []*cid.Cid) {
defer wg.Done()
defer func() {
<-rateLimit
}()
nodes := ds.GetMany(ctx, ks)
for opt := range nodes {
select {
case out <- opt:
case <-ctx.Done():
return
if next == nil {
next = nc.Node
send = feed
} else {
todobuffer = append(todobuffer, nc.Node)
}
}
}

for ks := range in {
select {
case rateLimit <- struct{}{}:
case <-ctx.Done():
return
return ctx.Err()
}
wg.Add(1)
go get(ks)
}

}
43 changes: 43 additions & 0 deletions merkledag/merkledag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,3 +504,46 @@ func TestCidRawDoesnNeedData(t *testing.T) {
t.Fatal("raw node shouldn't have any links")
}
}

func TestEnumerateAsyncFailsNotFound(t *testing.T) {
a := NodeWithData([]byte("foo1"))
b := NodeWithData([]byte("foo2"))
c := NodeWithData([]byte("foo3"))
d := NodeWithData([]byte("foo4"))

ds := dstest.Mock()
for _, n := range []node.Node{a, b, c} {
_, err := ds.Add(n)
if err != nil {
t.Fatal(err)
}
}

parent := new(ProtoNode)
if err := parent.AddNodeLinkClean("a", a); err != nil {
t.Fatal(err)
}

if err := parent.AddNodeLinkClean("b", b); err != nil {
t.Fatal(err)
}

if err := parent.AddNodeLinkClean("c", c); err != nil {
t.Fatal(err)
}

if err := parent.AddNodeLinkClean("d", d); err != nil {
t.Fatal(err)
}

pcid, err := ds.Add(parent)
if err != nil {
t.Fatal(err)
}

cset := cid.NewSet()
err = EnumerateChildrenAsync(context.Background(), ds, pcid, cset.Visit)
if err == nil {
t.Fatal("this should have failed")
}
}
2 changes: 1 addition & 1 deletion test/sharness/t0081-repo-pinning.sh
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ test_expect_success "some are no longer there" '
test_expect_success "recursive pin fails without objects" '
ipfs pin rm -r=false "$HASH_DIR1" &&
test_must_fail ipfs pin add -r "$HASH_DIR1" 2>err_expected8 &&
grep "pin: failed to fetch all nodes" err_expected8 ||
grep "pin: merkledag: not found" err_expected8 ||
test_fsh cat err_expected8
'

Expand Down

0 comments on commit 75cce80

Please sign in to comment.