diff --git a/lib/importlib_resources/__init__.py b/lib/importlib_resources/__init__.py index 15f6b26b..34e3a995 100644 --- a/lib/importlib_resources/__init__.py +++ b/lib/importlib_resources/__init__.py @@ -17,7 +17,7 @@ from ._legacy import ( Resource, ) -from importlib_resources.abc import ResourceReader +from .abc import ResourceReader __all__ = [ diff --git a/lib/importlib_resources/_common.py b/lib/importlib_resources/_common.py index a12e2c75..9f19784d 100644 --- a/lib/importlib_resources/_common.py +++ b/lib/importlib_resources/_common.py @@ -5,25 +5,58 @@ import functools import contextlib import types import importlib +import inspect +import warnings +import itertools -from typing import Union, Optional +from typing import Union, Optional, cast from .abc import ResourceReader, Traversable from ._compat import wrap_spec Package = Union[types.ModuleType, str] +Anchor = Package -def files(package): - # type: (Package) -> Traversable +def package_to_anchor(func): """ - Get a Traversable resource from a package + Replace 'package' parameter as 'anchor' and warn about the change. + + Other errors should fall through. + + >>> files('a', 'b') + Traceback (most recent call last): + TypeError: files() takes from 0 to 1 positional arguments but 2 were given """ - return from_package(get_package(package)) + undefined = object() + + @functools.wraps(func) + def wrapper(anchor=undefined, package=undefined): + if package is not undefined: + if anchor is not undefined: + return func(anchor, package) + warnings.warn( + "First parameter to files is renamed to 'anchor'", + DeprecationWarning, + stacklevel=2, + ) + return func(package) + elif anchor is undefined: + return func() + return func(anchor) + + return wrapper -def get_resource_reader(package): - # type: (types.ModuleType) -> Optional[ResourceReader] +@package_to_anchor +def files(anchor: Optional[Anchor] = None) -> Traversable: + """ + Get a Traversable resource for an anchor. + """ + return from_package(resolve(anchor)) + + +def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]: """ Return the package's loader if it's a ResourceReader. """ @@ -39,24 +72,39 @@ def get_resource_reader(package): return reader(spec.name) # type: ignore -def resolve(cand): - # type: (Package) -> types.ModuleType - return cand if isinstance(cand, types.ModuleType) else importlib.import_module(cand) +@functools.singledispatch +def resolve(cand: Optional[Anchor]) -> types.ModuleType: + return cast(types.ModuleType, cand) -def get_package(package): - # type: (Package) -> types.ModuleType - """Take a package name or module object and return the module. +@resolve.register +def _(cand: str) -> types.ModuleType: + return importlib.import_module(cand) - Raise an exception if the resolved module is not a package. + +@resolve.register +def _(cand: None) -> types.ModuleType: + return resolve(_infer_caller().f_globals['__name__']) + + +def _infer_caller(): """ - resolved = resolve(package) - if wrap_spec(resolved).submodule_search_locations is None: - raise TypeError(f'{package!r} is not a package') - return resolved + Walk the stack and find the frame of the first caller not in this module. + """ + + def is_this_file(frame_info): + return frame_info.filename == __file__ + + def is_wrapper(frame_info): + return frame_info.function == 'wrapper' + + not_this_file = itertools.filterfalse(is_this_file, inspect.stack()) + # also exclude 'wrapper' due to singledispatch in the call stack + callers = itertools.filterfalse(is_wrapper, not_this_file) + return next(callers).frame -def from_package(package): +def from_package(package: types.ModuleType): """ Return a Traversable object for the given package. @@ -67,7 +115,14 @@ def from_package(package): @contextlib.contextmanager -def _tempfile(reader, suffix=''): +def _tempfile( + reader, + suffix='', + # gh-93353: Keep a reference to call os.remove() in late Python + # finalization. + *, + _os_remove=os.remove, +): # Not using tempfile.NamedTemporaryFile as it leads to deeper 'try' # blocks due to the need to close the temporary file to work on Windows # properly. @@ -81,18 +136,35 @@ def _tempfile(reader, suffix=''): yield pathlib.Path(raw_path) finally: try: - os.remove(raw_path) + _os_remove(raw_path) except FileNotFoundError: pass +def _temp_file(path): + return _tempfile(path.read_bytes, suffix=path.name) + + +def _is_present_dir(path: Traversable) -> bool: + """ + Some Traversables implement ``is_dir()`` to raise an + exception (i.e. ``FileNotFoundError``) when the + directory doesn't exist. This function wraps that call + to always return a boolean and only return True + if there's a dir and it exists. + """ + with contextlib.suppress(FileNotFoundError): + return path.is_dir() + return False + + @functools.singledispatch def as_file(path): """ Given a Traversable object, return that object as a path on the local file system in a context manager. """ - return _tempfile(path.read_bytes, suffix=path.name) + return _temp_dir(path) if _is_present_dir(path) else _temp_file(path) @as_file.register(pathlib.Path) @@ -102,3 +174,34 @@ def _(path): Degenerate behavior for pathlib.Path objects. """ yield path + + +@contextlib.contextmanager +def _temp_path(dir: tempfile.TemporaryDirectory): + """ + Wrap tempfile.TemporyDirectory to return a pathlib object. + """ + with dir as result: + yield pathlib.Path(result) + + +@contextlib.contextmanager +def _temp_dir(path): + """ + Given a traversable dir, recursively replicate the whole tree + to the file system in a context manager. + """ + assert path.is_dir() + with _temp_path(tempfile.TemporaryDirectory()) as temp_dir: + yield _write_contents(temp_dir, path) + + +def _write_contents(target, source): + child = target.joinpath(source.name) + if source.is_dir(): + child.mkdir() + for item in source.iterdir(): + _write_contents(child, item) + else: + child.open('wb').write(source.read_bytes()) + return child diff --git a/lib/importlib_resources/_legacy.py b/lib/importlib_resources/_legacy.py index 1d5d3f1f..b1ea8105 100644 --- a/lib/importlib_resources/_legacy.py +++ b/lib/importlib_resources/_legacy.py @@ -27,8 +27,7 @@ def deprecated(func): return wrapper -def normalize_path(path): - # type: (Any) -> str +def normalize_path(path: Any) -> str: """Normalize a path by ensuring it is a string. If the resulting string contains path separators, an exception is raised. diff --git a/lib/importlib_resources/abc.py b/lib/importlib_resources/abc.py index 5e38ba61..23b6aeaf 100644 --- a/lib/importlib_resources/abc.py +++ b/lib/importlib_resources/abc.py @@ -1,5 +1,7 @@ import abc import io +import itertools +import pathlib from typing import Any, BinaryIO, Iterable, Iterator, NoReturn, Text, Optional from ._compat import runtime_checkable, Protocol, StrPath @@ -50,6 +52,10 @@ class ResourceReader(metaclass=abc.ABCMeta): raise FileNotFoundError +class TraversalError(Exception): + pass + + @runtime_checkable class Traversable(Protocol): """ @@ -92,7 +98,6 @@ class Traversable(Protocol): Return True if self is a file """ - @abc.abstractmethod def joinpath(self, *descendants: StrPath) -> "Traversable": """ Return Traversable resolved with any descendants applied. @@ -101,6 +106,22 @@ class Traversable(Protocol): and each may contain multiple levels separated by ``posixpath.sep`` (``/``). """ + if not descendants: + return self + names = itertools.chain.from_iterable( + path.parts for path in map(pathlib.PurePosixPath, descendants) + ) + target = next(names) + matches = ( + traversable for traversable in self.iterdir() if traversable.name == target + ) + try: + match = next(matches) + except StopIteration: + raise TraversalError( + "Target not found during traversal.", target, list(names) + ) + return match.joinpath(*names) def __truediv__(self, child: StrPath) -> "Traversable": """ @@ -118,7 +139,8 @@ class Traversable(Protocol): accepted by io.TextIOWrapper. """ - @abc.abstractproperty + @property + @abc.abstractmethod def name(self) -> str: """ The base name of this object without any parent references. diff --git a/lib/importlib_resources/readers.py b/lib/importlib_resources/readers.py index f1190ca4..ab34db74 100644 --- a/lib/importlib_resources/readers.py +++ b/lib/importlib_resources/readers.py @@ -82,15 +82,13 @@ class MultiplexedPath(abc.Traversable): def is_file(self): return False - def joinpath(self, child): - # first try to find child in current paths - for file in self.iterdir(): - if file.name == child: - return file - # if it does not exist, construct it with the first path - return self._paths[0] / child - - __truediv__ = joinpath + def joinpath(self, *descendants): + try: + return super().joinpath(*descendants) + except abc.TraversalError: + # One of the paths did not resolve (a directory does not exist). + # Just return something that will not exist. + return self._paths[0].joinpath(*descendants) def open(self, *args, **kwargs): raise FileNotFoundError(f'{self} is not a file') diff --git a/lib/importlib_resources/simple.py b/lib/importlib_resources/simple.py index d0fbf237..7770c922 100644 --- a/lib/importlib_resources/simple.py +++ b/lib/importlib_resources/simple.py @@ -16,31 +16,28 @@ class SimpleReader(abc.ABC): provider. """ - @abc.abstractproperty - def package(self): - # type: () -> str + @property + @abc.abstractmethod + def package(self) -> str: """ The name of the package for which this reader loads resources. """ @abc.abstractmethod - def children(self): - # type: () -> List['SimpleReader'] + def children(self) -> List['SimpleReader']: """ Obtain an iterable of SimpleReader for available child containers (e.g. directories). """ @abc.abstractmethod - def resources(self): - # type: () -> List[str] + def resources(self) -> List[str]: """ Obtain available named resources for this virtual package. """ @abc.abstractmethod - def open_binary(self, resource): - # type: (str) -> BinaryIO + def open_binary(self, resource: str) -> BinaryIO: """ Obtain a File-like for a named resource. """ @@ -50,13 +47,35 @@ class SimpleReader(abc.ABC): return self.package.split('.')[-1] +class ResourceContainer(Traversable): + """ + Traversable container for a package's resources via its reader. + """ + + def __init__(self, reader: SimpleReader): + self.reader = reader + + def is_dir(self): + return True + + def is_file(self): + return False + + def iterdir(self): + files = (ResourceHandle(self, name) for name in self.reader.resources) + dirs = map(ResourceContainer, self.reader.children()) + return itertools.chain(files, dirs) + + def open(self, *args, **kwargs): + raise IsADirectoryError() + + class ResourceHandle(Traversable): """ Handle to a named resource in a ResourceReader. """ - def __init__(self, parent, name): - # type: (ResourceContainer, str) -> None + def __init__(self, parent: ResourceContainer, name: str): self.parent = parent self.name = name # type: ignore @@ -76,44 +95,6 @@ class ResourceHandle(Traversable): raise RuntimeError("Cannot traverse into a resource") -class ResourceContainer(Traversable): - """ - Traversable container for a package's resources via its reader. - """ - - def __init__(self, reader): - # type: (SimpleReader) -> None - self.reader = reader - - def is_dir(self): - return True - - def is_file(self): - return False - - def iterdir(self): - files = (ResourceHandle(self, name) for name in self.reader.resources) - dirs = map(ResourceContainer, self.reader.children()) - return itertools.chain(files, dirs) - - def open(self, *args, **kwargs): - raise IsADirectoryError() - - @staticmethod - def _flatten(compound_names): - for name in compound_names: - yield from name.split('/') - - def joinpath(self, *descendants): - if not descendants: - return self - names = self._flatten(descendants) - target = next(names) - return next( - traversable for traversable in self.iterdir() if traversable.name == target - ).joinpath(*names) - - class TraversableReader(TraversableResources, SimpleReader): """ A TraversableResources based on SimpleReader. Resource providers diff --git a/lib/importlib_resources/tests/_compat.py b/lib/importlib_resources/tests/_compat.py index 4c99cffd..e7bf06dd 100644 --- a/lib/importlib_resources/tests/_compat.py +++ b/lib/importlib_resources/tests/_compat.py @@ -6,7 +6,20 @@ try: except ImportError: # Python 3.9 and earlier class import_helper: # type: ignore - from test.support import modules_setup, modules_cleanup + from test.support import ( + modules_setup, + modules_cleanup, + DirsOnSysPath, + CleanImport, + ) + + +try: + from test.support import os_helper # type: ignore +except ImportError: + # Python 3.9 compat + class os_helper: # type:ignore + from test.support import temp_dir try: diff --git a/lib/importlib_resources/tests/_path.py b/lib/importlib_resources/tests/_path.py new file mode 100644 index 00000000..c630e4d3 --- /dev/null +++ b/lib/importlib_resources/tests/_path.py @@ -0,0 +1,50 @@ +import pathlib +import functools + + +#### +# from jaraco.path 3.4 + + +def build(spec, prefix=pathlib.Path()): + """ + Build a set of files/directories, as described by the spec. + + Each key represents a pathname, and the value represents + the content. Content may be a nested directory. + + >>> spec = { + ... 'README.txt': "A README file", + ... "foo": { + ... "__init__.py": "", + ... "bar": { + ... "__init__.py": "", + ... }, + ... "baz.py": "# Some code", + ... } + ... } + >>> tmpdir = getfixture('tmpdir') + >>> build(spec, tmpdir) + """ + for name, contents in spec.items(): + create(contents, pathlib.Path(prefix) / name) + + +@functools.singledispatch +def create(content, path): + path.mkdir(exist_ok=True) + build(content, prefix=path) # type: ignore + + +@create.register +def _(content: bytes, path): + path.write_bytes(content) + + +@create.register +def _(content: str, path): + path.write_text(content) + + +# end from jaraco.path +#### diff --git a/lib/importlib_resources/tests/test_files.py b/lib/importlib_resources/tests/test_files.py index 2676b49e..d258fb5f 100644 --- a/lib/importlib_resources/tests/test_files.py +++ b/lib/importlib_resources/tests/test_files.py @@ -1,10 +1,23 @@ import typing +import textwrap import unittest +import warnings +import importlib +import contextlib import importlib_resources as resources -from importlib_resources.abc import Traversable +from ..abc import Traversable from . import data01 from . import util +from . import _path +from ._compat import os_helper, import_helper + + +@contextlib.contextmanager +def suppress_known_deprecation(): + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter('default', category=DeprecationWarning) + yield ctx class FilesTests: @@ -25,6 +38,14 @@ class FilesTests: def test_traversable(self): assert isinstance(resources.files(self.data), Traversable) + def test_old_parameter(self): + """ + Files used to take a 'package' parameter. Make sure anyone + passing by name is still supported. + """ + with suppress_known_deprecation(): + resources.files(package=self.data) + class OpenDiskTests(FilesTests, unittest.TestCase): def setUp(self): @@ -42,5 +63,50 @@ class OpenNamespaceTests(FilesTests, unittest.TestCase): self.data = namespacedata01 +class SiteDir: + def setUp(self): + self.fixtures = contextlib.ExitStack() + self.addCleanup(self.fixtures.close) + self.site_dir = self.fixtures.enter_context(os_helper.temp_dir()) + self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir)) + self.fixtures.enter_context(import_helper.CleanImport()) + + +class ModulesFilesTests(SiteDir, unittest.TestCase): + def test_module_resources(self): + """ + A module can have resources found adjacent to the module. + """ + spec = { + 'mod.py': '', + 'res.txt': 'resources are the best', + } + _path.build(spec, self.site_dir) + import mod + + actual = resources.files(mod).joinpath('res.txt').read_text() + assert actual == spec['res.txt'] + + +class ImplicitContextFilesTests(SiteDir, unittest.TestCase): + def test_implicit_files(self): + """ + Without any parameter, files() will infer the location as the caller. + """ + spec = { + 'somepkg': { + '__init__.py': textwrap.dedent( + """ + import importlib_resources as res + val = res.files().joinpath('res.txt').read_text() + """ + ), + 'res.txt': 'resources are the best', + }, + } + _path.build(spec, self.site_dir) + assert importlib.import_module('somepkg').val == 'resources are the best' + + if __name__ == '__main__': unittest.main() diff --git a/lib/importlib_resources/tests/test_reader.py b/lib/importlib_resources/tests/test_reader.py index 16841a50..1c8ebeeb 100644 --- a/lib/importlib_resources/tests/test_reader.py +++ b/lib/importlib_resources/tests/test_reader.py @@ -75,6 +75,11 @@ class MultiplexedPathTest(unittest.TestCase): str(path.joinpath('imaginary'))[len(prefix) + 1 :], os.path.join('namespacedata01', 'imaginary'), ) + self.assertEqual(path.joinpath(), path) + + def test_join_path_compound(self): + path = MultiplexedPath(self.folder) + assert not path.joinpath('imaginary/foo.py').exists() def test_repr(self): self.assertEqual( diff --git a/lib/importlib_resources/tests/test_resource.py b/lib/importlib_resources/tests/test_resource.py index 5affd8b0..82390271 100644 --- a/lib/importlib_resources/tests/test_resource.py +++ b/lib/importlib_resources/tests/test_resource.py @@ -111,6 +111,14 @@ class ResourceFromZipsTest01(util.ZipSetupBase, unittest.TestCase): {'__init__.py', 'binary.file'}, ) + def test_as_file_directory(self): + with resources.as_file(resources.files('ziptestdata')) as data: + assert data.name == 'ziptestdata' + assert data.is_dir() + assert data.joinpath('subdirectory').is_dir() + assert len(list(data.iterdir())) + assert not data.parent.exists() + class ResourceFromZipsTest02(util.ZipSetupBase, unittest.TestCase): ZIP_MODULE = zipdata02 # type: ignore diff --git a/lib/importlib_resources/tests/util.py b/lib/importlib_resources/tests/util.py index c6d83e4b..b596c0ce 100644 --- a/lib/importlib_resources/tests/util.py +++ b/lib/importlib_resources/tests/util.py @@ -3,7 +3,7 @@ import importlib import io import sys import types -from pathlib import Path, PurePath +import pathlib from . import data01 from . import zipdata01 @@ -94,7 +94,7 @@ class CommonTests(metaclass=abc.ABCMeta): def test_pathlib_path(self): # Passing in a pathlib.PurePath object for the path should succeed. - path = PurePath('utf-8.file') + path = pathlib.PurePath('utf-8.file') self.execute(data01, path) def test_importing_module_as_side_effect(self): @@ -102,17 +102,6 @@ class CommonTests(metaclass=abc.ABCMeta): del sys.modules[data01.__name__] self.execute(data01.__name__, 'utf-8.file') - def test_non_package_by_name(self): - # The anchor package cannot be a module. - with self.assertRaises(TypeError): - self.execute(__name__, 'utf-8.file') - - def test_non_package_by_package(self): - # The anchor package cannot be a module. - with self.assertRaises(TypeError): - module = sys.modules['importlib_resources.tests.util'] - self.execute(module, 'utf-8.file') - def test_missing_path(self): # Attempting to open or read or request the path for a # non-existent path should succeed if open_resource @@ -144,7 +133,7 @@ class ZipSetupBase: @classmethod def setUpClass(cls): - data_path = Path(cls.ZIP_MODULE.__file__) + data_path = pathlib.Path(cls.ZIP_MODULE.__file__) data_dir = data_path.parent cls._zip_path = str(data_dir / 'ziptestdata.zip') sys.path.append(cls._zip_path)