Skip to content

Commit

Permalink
Merge pull request #29 from arduino/recursion_loops
Browse files Browse the repository at this point in the history
Detection of recursion loops in ReadDirRecursive* methods
  • Loading branch information
cmaglie authored Jan 12, 2024
2 parents b3ef88d + b49819c commit dcc3db3
Show file tree
Hide file tree
Showing 15 changed files with 122 additions and 50 deletions.
5 changes: 2 additions & 3 deletions paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ package paths
import (
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
Expand Down Expand Up @@ -418,14 +417,14 @@ func (p *Path) Chtimes(atime, mtime time.Time) error {

// ReadFile reads the file named by filename and returns the contents
func (p *Path) ReadFile() ([]byte, error) {
return ioutil.ReadFile(p.path)
return os.ReadFile(p.path)
}

// WriteFile writes data to a file named by filename. If the file
// does not exist, WriteFile creates it otherwise WriteFile truncates
// it before writing.
func (p *Path) WriteFile(data []byte) error {
return ioutil.WriteFile(p.path, data, os.FileMode(0644))
return os.WriteFile(p.path, data, os.FileMode(0644))
}

// WriteToTempFile writes data to a newly generated temporary file.
Expand Down
89 changes: 42 additions & 47 deletions readdir.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
package paths

import (
"io/ioutil"
"errors"
"os"
"strings"
)

Expand All @@ -41,7 +42,7 @@ type ReadDirFilter func(file *Path) bool
// ReadDir returns a PathList containing the content of the directory
// pointed by the current Path. The resulting list is filtered by the given filters chained.
func (p *Path) ReadDir(filters ...ReadDirFilter) (PathList, error) {
infos, err := ioutil.ReadDir(p.path)
infos, err := os.ReadDir(p.path)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -69,27 +70,7 @@ func (p *Path) ReadDir(filters ...ReadDirFilter) (PathList, error) {
// ReadDirRecursive returns a PathList containing the content of the directory
// and its subdirectories pointed by the current Path
func (p *Path) ReadDirRecursive() (PathList, error) {
infos, err := ioutil.ReadDir(p.path)
if err != nil {
return nil, err
}
paths := PathList{}
for _, info := range infos {
path := p.Join(info.Name())
paths.Add(path)

if isDir, err := path.IsDirCheck(); err != nil {
return nil, err
} else if isDir {
subPaths, err := path.ReadDirRecursive()
if err != nil {
return nil, err
}
paths.AddAll(subPaths)
}

}
return paths, nil
return p.ReadDirRecursiveFiltered(nil)
}

// ReadDirRecursiveFiltered returns a PathList containing the content of the directory
Expand All @@ -101,41 +82,55 @@ func (p *Path) ReadDirRecursive() (PathList, error) {
// - `filters` are the filters that are checked to determine if the entry should be
// added to the resulting PathList
func (p *Path) ReadDirRecursiveFiltered(recursionFilter ReadDirFilter, filters ...ReadDirFilter) (PathList, error) {
infos, err := ioutil.ReadDir(p.path)
if err != nil {
return nil, err
}
var search func(*Path) (PathList, error)

accept := func(p *Path) bool {
for _, filter := range filters {
if !filter(p) {
return false
}
explored := map[string]bool{}
search = func(currPath *Path) (PathList, error) {
canonical := currPath.Canonical().path
if explored[canonical] {
return nil, errors.New("directories symlink loop detected")
}
return true
}
explored[canonical] = true
defer delete(explored, canonical)

paths := PathList{}
for _, info := range infos {
path := p.Join(info.Name())
infos, err := os.ReadDir(currPath.path)
if err != nil {
return nil, err
}

if accept(path) {
paths.Add(path)
accept := func(p *Path) bool {
for _, filter := range filters {
if !filter(p) {
return false
}
}
return true
}

if recursionFilter == nil || recursionFilter(path) {
if isDir, err := path.IsDirCheck(); err != nil {
return nil, err
} else if isDir {
subPaths, err := path.ReadDirRecursiveFiltered(recursionFilter, filters...)
if err != nil {
paths := PathList{}
for _, info := range infos {
path := currPath.Join(info.Name())

if accept(path) {
paths.Add(path)
}

if recursionFilter == nil || recursionFilter(path) {
if isDir, err := path.IsDirCheck(); err != nil {
return nil, err
} else if isDir {
subPaths, err := search(path)
if err != nil {
return nil, err
}
paths.AddAll(subPaths)
}
paths.AddAll(subPaths)
}
}
return paths, nil
}
return paths, nil

return search(p)
}

// FilterDirectories is a ReadDirFilter that accepts only directories
Expand Down
69 changes: 69 additions & 0 deletions readdir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"fmt"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -245,3 +246,71 @@ func TestReadDirRecursiveFiltered(t *testing.T) {
pathEqualsTo(t, "testdata/fileset/test.txt", l[7])
pathEqualsTo(t, "testdata/fileset/test.txt.gz", l[8])
}

func TestReadDirRecursiveLoopDetection(t *testing.T) {
loopsPath := New("testdata", "loops")
unbuondedReaddir := func(testdir string) (PathList, error) {
// This is required to unbound the recursion, otherwise it will stop
// when the paths becomes too long due to the symlink loop: this is not
// what we want, we are looking for an early detection of the loop.
skipBrokenLinks := func(p *Path) bool {
_, err := p.Stat()
return err == nil
}

var files PathList
var err error
done := make(chan bool)
go func() {
files, err = loopsPath.Join(testdir).ReadDirRecursiveFiltered(
skipBrokenLinks,
)
done <- true
}()
require.Eventually(
t,
func() bool {
select {
case <-done:
return true
default:
return false
}
},
5*time.Second,
10*time.Millisecond,
"Infinite symlink loop while loading sketch",
)
return files, err
}

for _, dir := range []string{"loop_1", "loop_2", "loop_3", "loop_4"} {
l, err := unbuondedReaddir(dir)
require.EqualError(t, err, "directories symlink loop detected", "loop not detected in %s", dir)
require.Nil(t, l)
}

{
l, err := unbuondedReaddir("regular_1")
require.NoError(t, err)
require.Len(t, l, 4)
l.Sort()
pathEqualsTo(t, "testdata/loops/regular_1/dir1", l[0])
pathEqualsTo(t, "testdata/loops/regular_1/dir1/file1", l[1])
pathEqualsTo(t, "testdata/loops/regular_1/dir2", l[2])
pathEqualsTo(t, "testdata/loops/regular_1/dir2/file1", l[3])
}

{
l, err := unbuondedReaddir("regular_2")
require.NoError(t, err)
require.Len(t, l, 6)
l.Sort()
pathEqualsTo(t, "testdata/loops/regular_2/dir1", l[0])
pathEqualsTo(t, "testdata/loops/regular_2/dir1/file1", l[1])
pathEqualsTo(t, "testdata/loops/regular_2/dir2", l[2])
pathEqualsTo(t, "testdata/loops/regular_2/dir2/dir1", l[3])
pathEqualsTo(t, "testdata/loops/regular_2/dir2/dir1/file1", l[4])
pathEqualsTo(t, "testdata/loops/regular_2/dir2/file2", l[5])
}
}
1 change: 1 addition & 0 deletions testdata/loops/loop_1/dir1/loop
1 change: 1 addition & 0 deletions testdata/loops/loop_2/dir1/loop2
1 change: 1 addition & 0 deletions testdata/loops/loop_2/dir2/loop1
1 change: 1 addition & 0 deletions testdata/loops/loop_3/dir1/loop2
1 change: 1 addition & 0 deletions testdata/loops/loop_3/dir2/dir3/loop2
1 change: 1 addition & 0 deletions testdata/loops/loop_4/dir1/dir2/loop2
1 change: 1 addition & 0 deletions testdata/loops/loop_4/dir1/dir3/dir4/loop1
Empty file.
1 change: 1 addition & 0 deletions testdata/loops/regular_1/dir2
Empty file.
1 change: 1 addition & 0 deletions testdata/loops/regular_2/dir2/dir1
Empty file.

0 comments on commit dcc3db3

Please sign in to comment.