Skip to content

Commit

Permalink
Use pathlib, shutil and os instead of py.path
Browse files Browse the repository at this point in the history
  • Loading branch information
gjover committed Feb 25, 2022
1 parent 0839a7b commit 9c42adf
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions pytest_datafiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
Module containing a 'datafiles' fixture for pytest Tests.
"""
import os
import shutil
from functools import partial

from py import path # pylint: disable=E0611
from pathlib import Path
import pytest


Expand All @@ -16,12 +17,12 @@ def _copy(src, target):
if not src.exists():
raise ValueError("'%s' does not exist!" % src)

if src.isdir():
src.copy(target / src.basename)
elif src.islink():
(target / src.basename).mksymlinkto(src.realpath())
if src.is_dir():
shutil.copytree(src, target / src.name)
elif src.is_simlink():
os.symlink(os.readlink(src), target / src.name)
else: # file
src.copy(target)
shutil.copy(src, target)


def _copy_all(entry_list, target_dir, on_duplicate):
Expand All @@ -31,7 +32,7 @@ def _copy_all(entry_list, target_dir, on_duplicate):
an entry already exists: raise an exception, overwrite it or ignore it).
"""
for entry in entry_list:
target_entry = target_dir / entry.basename
target_entry = target_dir / entry.name
if not target_entry.exists() or on_duplicate == 'overwrite':
_copy(entry, target_dir)
elif on_duplicate == 'exception':
Expand All @@ -53,22 +54,22 @@ def _get_all_entries(entry_list, keep_top_dir):
"""
all_files = []

entry_list = [path.local(entry) for entry in entry_list]
entry_list = [Path(entry) for entry in entry_list]

if keep_top_dir:
all_files = entry_list
else:
for entry in entry_list:
if entry.isdir():
all_files.extend(entry.listdir())
if entry.is_dir():
all_files.extend(list(entry.glob('*')))
else:
all_files.append(entry)
return all_files


class DataFilesPlugin:
def __init__(self, root=""):
self.root = root
self.root = Path(root)

@pytest.fixture
def datafiles(self, request, tmpdir):
Expand All @@ -82,7 +83,7 @@ def datafiles(self, request, tmpdir):
"on_duplicate": "exception", # ignore, overwrite
}
for mark in request.node.iter_markers("datafiles"):
entries = map(partial(os.path.join, self.root), mark.args)
entries = [self.root / entry for entry in mark.args]
entry_list.extend(entries)
options.update(mark.kwargs)

Expand All @@ -97,7 +98,7 @@ def datafiles(self, request, tmpdir):
)

all_entries = _get_all_entries(entry_list, keep_top_dir)
_copy_all(all_entries, tmpdir, on_duplicate)
_copy_all(all_entries, Path(tmpdir), on_duplicate)
return tmpdir


Expand Down

0 comments on commit 9c42adf

Please sign in to comment.