Compare commits

...

42 commits

Author SHA1 Message Date
Clinton Hall
bfbf1fb4c1 support for qbittorrent v5.0 (#2001)
* support for qbittorrent v5.0

* Remove py3.8 tests

* Add py 3.13 tests

* Update mediafile.py for Py3.13

* Create filetype.py

* Update link for NZBGet
2024-11-08 07:29:55 +13:00
Clinton Hall
470f611240 Merge branch 'master' into nightly 2024-04-26 12:12:53 +12:00
Clinton Hall
97df874d36 class not added put into debug logging 2024-04-26 12:09:41 +12:00
Matt Park
e9fbbf540c added global ignore flag for bytecode cleanup
Resolves #1867
2024-04-26 12:09:41 +12:00
Clinton Hall
39f5c31486 fix warnings (#1990) 2024-04-26 12:09:41 +12:00
Clinton Hall
cbc2090b0b always return imdbid and dirname 2024-04-26 12:09:41 +12:00
Clinton Hall
cc109bcc0b Add Python 3.12 and fix Radarr handling (#1989)
* Added Python3.12 and future 3.13

* Fix Radarr result handling

* remove py2.7 and py3.7 support
2024-04-26 12:09:41 +12:00
Matt Park
4c512051f7 Update movies.py
Check for an updated dir_name in case IMDB id was appended.
2024-04-26 12:09:41 +12:00
Matt Park
0c564243c2 Update identification.py
Return updated dir_name if needed
2024-04-26 12:09:41 +12:00
Clinton Hall
9ea322111c class not added put into debug logging 2024-03-20 14:23:32 +13:00
Matt Park
e14bc6c733 added global ignore flag for bytecode cleanup
Resolves #1867
2024-03-05 10:59:37 +13:00
Clinton Hall
27df8a4d8e
fix warnings (#1990) 2024-03-01 18:25:19 +13:00
Clinton Hall
b7d6ad8c07
always return imdbid and dirname 2024-02-29 07:01:23 +13:00
Clinton Hall
f98d6fff65
Add Python 3.12 and fix Radarr handling (#1989)
* Added Python3.12 and future 3.13

* Fix Radarr result handling

* remove py2.7 and py3.7 support
2024-02-28 15:47:04 +13:00
Clinton Hall
b802aca7e1
Merge pull request #1982 from MattPark/last-resort-movie-id
Last resort movie identification
2023-12-16 09:32:35 +13:00
Matt Park
836df51d14
Update movies.py
Check for an updated dir_name in case IMDB id was appended.
2023-10-02 15:15:01 -04:00
Matt Park
c6292d5390
Update identification.py
Return updated dir_name if needed
2023-10-02 15:13:36 -04:00
Clinton Hall
558970c212
Merge pull request #1980 from clinton-hall/nightly
Merge Nightly
2023-08-10 21:23:42 +12:00
Clinton Hall
38c628d605
Merge pull request #1979 from clinton-hall/clinton-hall-patch-1
Remove Py2.7 tests
2023-08-10 21:14:47 +12:00
Clinton Hall
029b58b2a6
Remove Py2.7 tests
This is no longer supported in azure pipelines.
2023-08-10 21:09:25 +12:00
Clinton Hall
2885461a12
Merge pull request #1978 from clinton-hall/remove_group
Initialize remove_groups #1973
2023-08-10 21:01:31 +12:00
Clinton Hall
ad73e597e4
Initialize remove_groups #1973
This parameter was not being loaded and therefore was ignored.
2023-08-09 22:50:25 +12:00
clinton-hall
6c2f7c75d4 update to v 12.1.12 2023-07-03 17:41:15 +12:00
clinton-hall
95e22d7af4 Merge branch 'master' into nightly 2023-07-03 17:21:31 +12:00
kandarz
e72c0b9228
Add 'dvb_subtitle' codec to list of ignored codecs when using 'mov_text' (#1974)
Add 'dvb_subtitle' codec to list of ignored codecs when using 'mov_text'. DVB subtitles are bitmap based.
2023-07-03 16:59:24 +12:00
Clinton Hall
c4cc554ea1 update to sonarr api v3 2023-05-22 22:51:28 +12:00
Labrys of Knossos
3078da31af Fix posix_ownership. 2023-05-22 22:51:28 +12:00
Labrys of Knossos
1fdfd128ba Add comments. 2023-05-22 22:51:28 +12:00
Labrys of Knossos
d3100f6178 Add database permissions logging upon failed access. 2023-05-22 22:51:28 +12:00
Clinton Hall
01bb239cdf
Merge pull request #1969 from clinton-hall/Sonarr-apiv3
update to sonarr api v3
2023-05-22 22:43:15 +12:00
Clinton Hall
d0b555c251
update to sonarr api v3 2023-04-18 20:59:28 +12:00
Labrys of Knossos
0c5f7be263
Merge pull request #1955 from clinton-hall/permitted
Fix permissions for posix and add comments
2023-01-01 06:03:08 -05:00
Labrys of Knossos
19d9e27c43 Fix posix_ownership. 2022-12-31 22:26:19 -05:00
Labrys of Knossos
1046c50778
Merge pull request #1954 from clinton-hall/permitted
Add database permissions logging upon failed access.
2022-12-31 18:34:20 -05:00
Labrys of Knossos
2c2d7f24b1 Add comments. 2022-12-31 18:21:33 -05:00
Labrys of Knossos
6e52bb2b33 Add database permissions logging upon failed access. 2022-12-31 17:56:38 -05:00
Clinton Hall
bd9c91ff5e
Merge pull request #1936 from clinton-hall/nightly
update to V12.1.11
2022-12-12 20:24:01 +13:00
Clinton Hall
b8482bed0e
Remove Py3.6 tests.
No longer available for pipeline tests.
2022-12-12 20:18:18 +13:00
Labrys of Knossos
2b6a7add72
Merge pull request #1919 from clinton-hall/hello-friend
Add new Python versions to tests.
2022-12-02 22:32:51 -05:00
Labrys of Knossos
55c1091efa Add new Python versions to classifiers. 2022-12-02 22:25:50 -05:00
Labrys of Knossos
8b409a5716 Add new Python versions to tests. 2022-12-02 22:25:37 -05:00
Clinton Hall
7436ba7716
Merge pull request #1896 from clinton-hall/nightly
Nightly
2022-08-18 16:31:35 +12:00
205 changed files with 20480 additions and 21128 deletions

View file

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 12.1.11 current_version = 12.1.13
commit = True commit = True
tag = False tag = False

2
.github/README.md vendored
View file

@ -2,7 +2,7 @@ nzbToMedia
========== ==========
Provides an [efficient](https://github.com/clinton-hall/nzbToMedia/wiki/Efficient-on-demand-post-processing) way to handle postprocessing for [CouchPotatoServer](https://couchpota.to/ "CouchPotatoServer") and [SickBeard](http://sickbeard.com/ "SickBeard") (and its [forks](https://github.com/clinton-hall/nzbToMedia/wiki/Failed-Download-Handling-%28FDH%29#sick-beard-and-its-forks)) Provides an [efficient](https://github.com/clinton-hall/nzbToMedia/wiki/Efficient-on-demand-post-processing) way to handle postprocessing for [CouchPotatoServer](https://couchpota.to/ "CouchPotatoServer") and [SickBeard](http://sickbeard.com/ "SickBeard") (and its [forks](https://github.com/clinton-hall/nzbToMedia/wiki/Failed-Download-Handling-%28FDH%29#sick-beard-and-its-forks))
when using one of the popular NZB download clients like [SABnzbd](http://sabnzbd.org/ "SABnzbd") and [NZBGet](http://nzbget.sourceforge.net/ "NZBGet") on low performance systems like a NAS. when using one of the popular NZB download clients like [SABnzbd](http://sabnzbd.org/ "SABnzbd") and [NZBGet](https://nzbget.com/ "NZBGet") on low performance systems like a NAS.
This script is based on sabToSickBeard (written by Nic Wolfe and supplied with SickBeard), with the support for NZBGet being added by [thorli](https://github.com/thorli "thorli") and further contributions by [schumi2004](https://github.com/schumi2004 "schumi2004") and [hugbug](https://sourceforge.net/apps/phpbb/nzbget/memberlist.php?mode=viewprofile&u=67 "hugbug"). This script is based on sabToSickBeard (written by Nic Wolfe and supplied with SickBeard), with the support for NZBGet being added by [thorli](https://github.com/thorli "thorli") and further contributions by [schumi2004](https://github.com/schumi2004 "schumi2004") and [hugbug](https://sourceforge.net/apps/phpbb/nzbget/memberlist.php?mode=viewprofile&u=67 "hugbug").
Torrent suport added by [jkaberg](https://github.com/jkaberg "jkaberg") and [berkona](https://github.com/berkona "berkona") Torrent suport added by [jkaberg](https://github.com/jkaberg "jkaberg") and [berkona](https://github.com/berkona "berkona")
Corrupt video checking, auto SickBeard fork determination and a whole lot of code improvement was done by [echel0n](https://github.com/echel0n "echel0n") Corrupt video checking, auto SickBeard fork determination and a whole lot of code improvement was done by [echel0n](https://github.com/echel0n "echel0n")

View file

@ -13,18 +13,16 @@ jobs:
vmImage: 'Ubuntu-latest' vmImage: 'Ubuntu-latest'
strategy: strategy:
matrix: matrix:
Python27:
python.version: '2.7'
Python36:
python.version: '3.6'
Python37:
python.version: '3.7'
Python38:
python.version: '3.8'
Python39: Python39:
python.version: '3.9' python.version: '3.9'
Python310: Python310:
python.version: '3.10' python.version: '3.10'
Python311:
python.version: '3.11'
Python312:
python.version: '3.12'
Python313:
python.version: '3.13'
maxParallel: 3 maxParallel: 3
steps: steps:
@ -70,5 +68,7 @@ jobs:
versionSpec: '3.x' versionSpec: '3.x'
architecture: 'x64' architecture: 'x64'
- script: python setup.py sdist - script: |
python -m pip install setuptools
python setup.py sdist
displayName: 'Build sdist' displayName: 'Build sdist'

View file

@ -116,6 +116,7 @@ def clean_bytecode():
result = git_clean( result = git_clean(
remove_directories=True, remove_directories=True,
force=True, force=True,
ignore_rules=True,
exclude=[ exclude=[
'*.*', # exclude everything '*.*', # exclude everything
'!*.py[co]', # except bytecode '!*.py[co]', # except bytecode

View file

@ -83,7 +83,7 @@ from core.utils import (
wake_up, wake_up,
) )
__version__ = '12.1.11' __version__ = '12.1.13'
# Client Agents # Client Agents
NZB_CLIENTS = ['sabnzbd', 'nzbget', 'manual'] NZB_CLIENTS = ['sabnzbd', 'nzbget', 'manual']
@ -1047,6 +1047,7 @@ def initialize(section=None):
configure_utility_locations() configure_utility_locations()
configure_sections(section) configure_sections(section)
configure_torrent_class() configure_torrent_class()
configure_groups()
__INITIALIZED__ = True __INITIALIZED__ = True

View file

@ -311,7 +311,7 @@ class InitSickBeard(object):
# Create the fork object and pass self (SickBeardInit) to it for all the data, like Config. # Create the fork object and pass self (SickBeardInit) to it for all the data, like Config.
self.fork_obj = mapped_forks[self.fork](self) self.fork_obj = mapped_forks[self.fork](self)
else: else:
logger.info('{section}:{category} Could not create a fork object for {fork}. Probaly class not added yet.'.format( logger.debug('{section}:{category} Could not create a fork object for {fork}. Probaly class not added yet.'.format(
section=self.section, category=self.input_category, fork=self.fork) section=self.section, category=self.input_category, fork=self.fork)
) )

View file

@ -66,7 +66,7 @@ def process(section, dir_name, input_name=None, status=0, client_agent='manual',
else: else:
extract = int(cfg.get('extract', 0)) extract = int(cfg.get('extract', 0))
imdbid = find_imdbid(dir_name, input_name, omdbapikey) imdbid, dir_name = find_imdbid(dir_name, input_name, omdbapikey)
if section == 'CouchPotato': if section == 'CouchPotato':
base_url = '{0}{1}:{2}{3}/api/{4}/'.format(protocol, host, port, web_root, apikey) base_url = '{0}{1}:{2}{3}/api/{4}/'.format(protocol, host, port, web_root, apikey)
if section == 'Radarr': if section == 'Radarr':
@ -260,7 +260,10 @@ def process(section, dir_name, input_name=None, status=0, client_agent='manual',
) )
elif section == 'Radarr': elif section == 'Radarr':
try: try:
scan_id = int(result['id']) if isinstance(result, list):
scan_id = int(result[0]['id'])
else:
scan_id = int(result['id'])
logger.debug('Scan started with id: {0}'.format(scan_id), section) logger.debug('Scan started with id: {0}'.format(scan_id), section)
except Exception as e: except Exception as e:
logger.warning('No scan id was returned due to: {0}'.format(e), section) logger.warning('No scan id was returned due to: {0}'.format(e), section)

View file

@ -317,9 +317,9 @@ def process(section, dir_name, input_name=None, failed=False, client_agent='manu
else: else:
url = '{0}{1}:{2}{3}/api/v{4}/{5}/'.format(protocol, host, port, web_root, api_version, apikey) url = '{0}{1}:{2}{3}/api/v{4}/{5}/'.format(protocol, host, port, web_root, api_version, apikey)
elif section == 'NzbDrone': elif section == 'NzbDrone':
url = '{0}{1}:{2}{3}/api/command'.format(protocol, host, port, web_root) url = '{0}{1}:{2}{3}/api/v3/command'.format(protocol, host, port, web_root)
url2 = '{0}{1}:{2}{3}/api/config/downloadClient'.format(protocol, host, port, web_root) url2 = '{0}{1}:{2}{3}/api/v3/config/downloadClient'.format(protocol, host, port, web_root)
headers = {'X-Api-Key': apikey} headers = {'X-Api-Key': apikey, "Content-Type": "application/json"}
# params = {'sortKey': 'series.title', 'page': 1, 'pageSize': 1, 'sortDir': 'asc'} # params = {'sortKey': 'series.title', 'page': 1, 'pageSize': 1, 'sortDir': 'asc'}
if remote_path: if remote_path:
logger.debug('remote_path: {0}'.format(remote_dir(dir_name)), section) logger.debug('remote_path: {0}'.format(remote_dir(dir_name)), section)

View file

@ -7,14 +7,17 @@ from __future__ import (
unicode_literals, unicode_literals,
) )
import os.path
import re import re
import sqlite3 import sqlite3
import sys
import time import time
from six import text_type, PY2 from six import text_type, PY2
import core import core
from core import logger from core import logger
from core import permissions
if PY2: if PY2:
class Row(sqlite3.Row, object): class Row(sqlite3.Row, object):
@ -60,10 +63,29 @@ def db_filename(filename='nzbtomedia.db', suffix=None):
class DBConnection(object): class DBConnection(object):
def __init__(self, filename='nzbtomedia.db', suffix=None, row_type=None): def __init__(self, filename='nzbtomedia.db', suffix=None, row_type=None):
self.filename = filename self.filename = filename
self.connection = sqlite3.connect(db_filename(filename), 20) path = db_filename(filename)
self.connection.row_factory = Row try:
self.connection = sqlite3.connect(path, 20)
except sqlite3.OperationalError as error:
if os.path.exists(path):
logger.error('Please check permissions on database: {0}'.format(path))
else:
logger.error('Database file does not exist')
logger.error('Please check permissions on directory: {0}'.format(path))
path = os.path.dirname(path)
mode = permissions.mode(path)
owner, group = permissions.ownership(path)
logger.error(
"=== PERMISSIONS ===========================\n"
" Path : {0}\n"
" Mode : {1}\n"
" Owner: {2}\n"
" Group: {3}\n"
"===========================================".format(path, mode, owner, group),
)
else:
self.connection.row_factory = Row
def check_db_version(self): def check_db_version(self):
result = None result = None
@ -256,7 +278,11 @@ class DBSanityCheck(object):
def upgrade_database(connection, schema): def upgrade_database(connection, schema):
logger.log(u'Checking database structure...', logger.MESSAGE) logger.log(u'Checking database structure...', logger.MESSAGE)
_process_upgrade(connection, schema) try:
_process_upgrade(connection, schema)
except Exception as error:
logger.error(error)
sys.exit(1)
def pretty_name(class_name): def pretty_name(class_name):

88
core/permissions.py Normal file
View file

@ -0,0 +1,88 @@
import os
import sys
import logging
log = logging.getLogger(__name__)
log.addHandler(logging.NullHandler())
WINDOWS = sys.platform == 'win32'
POSIX = not WINDOWS
try:
import pwd
import grp
except ImportError:
if POSIX:
raise
try:
from win32security import GetNamedSecurityInfo
from win32security import LookupAccountSid
from win32security import GROUP_SECURITY_INFORMATION
from win32security import OWNER_SECURITY_INFORMATION
from win32security import SE_FILE_OBJECT
except ImportError:
if WINDOWS:
raise
def mode(path):
"""Get permissions."""
stat_result = os.stat(path) # Get information from path
permissions_mask = 0o777 # Set mask for permissions info
# Get only the permissions part of st_mode as an integer
int_mode = stat_result.st_mode & permissions_mask
oct_mode = oct(int_mode) # Convert to octal representation
return oct_mode[2:] # Return mode but strip octal prefix
def nt_ownership(path):
"""Get the owner and group for a file or directory."""
def fully_qualified_name(sid):
"""Return a fully qualified account name."""
# Look up the account information for the given SID
# https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-lookupaccountsida
name, domain, _acct_type = LookupAccountSid(None, sid)
# Return account information formatted as DOMAIN\ACCOUNT_NAME
return '{}\\{}'.format(domain, name)
# Get the Windows security descriptor for the path
# https://learn.microsoft.com/en-us/windows/win32/api/aclapi/nf-aclapi-getnamedsecurityinfoa
security_descriptor = GetNamedSecurityInfo(
path, # Name of the item to query
SE_FILE_OBJECT, # Type of item to query (file or directory)
# Add OWNER and GROUP security information to result
OWNER_SECURITY_INFORMATION | GROUP_SECURITY_INFORMATION,
)
# Get the Security Identifier for the owner and group from the security descriptor
# https://learn.microsoft.com/en-us/windows/win32/api/securitybaseapi/nf-securitybaseapi-getsecuritydescriptorowner
# https://learn.microsoft.com/en-us/windows/win32/api/securitybaseapi/nf-securitybaseapi-getsecuritydescriptorgroup
owner_sid = security_descriptor.GetSecurityDescriptorOwner()
group_sid = security_descriptor.GetSecurityDescriptorGroup()
# Get the fully qualified account name (e.g. DOMAIN\ACCOUNT_NAME)
owner = fully_qualified_name(owner_sid)
group = fully_qualified_name(group_sid)
return owner, group
def posix_ownership(path):
"""Get the owner and group for a file or directory."""
# Get path information
stat_result = os.stat(path)
# Get account name from path stat result
owner = pwd.getpwuid(stat_result.st_uid).pw_name
group = grp.getgrgid(stat_result.st_gid).gr_name
return owner, group
# Select the ownership function appropriate for the platform
if WINDOWS:
ownership = nt_ownership
else:
ownership = posix_ownership

View file

@ -479,7 +479,7 @@ def build_commands(file, new_dir, movie_name, bitbucket):
burnt = 1 burnt = 1
if not core.ALLOWSUBS: if not core.ALLOWSUBS:
break break
if sub['codec_name'] in ['dvd_subtitle', 'VobSub'] and core.SCODEC == 'mov_text': # We can't convert these. if sub['codec_name'] in ['dvd_subtitle', 'dvb_subtitle', 'VobSub'] and core.SCODEC == 'mov_text': # We can't convert these.
continue continue
map_cmd.extend(['-map', '0:{index}'.format(index=sub['index'])]) map_cmd.extend(['-map', '0:{index}'.format(index=sub['index'])])
s_mapped.extend([sub['index']]) s_mapped.extend([sub['index']])
@ -490,7 +490,7 @@ def build_commands(file, new_dir, movie_name, bitbucket):
break break
if sub['index'] in s_mapped: if sub['index'] in s_mapped:
continue continue
if sub['codec_name'] in ['dvd_subtitle', 'VobSub'] and core.SCODEC == 'mov_text': # We can't convert these. if sub['codec_name'] in ['dvd_subtitle', 'dvb_subtitle', 'VobSub'] and core.SCODEC == 'mov_text': # We can't convert these.
continue continue
map_cmd.extend(['-map', '0:{index}'.format(index=sub['index'])]) map_cmd.extend(['-map', '0:{index}'.format(index=sub['index'])])
s_mapped.extend([sub['index']]) s_mapped.extend([sub['index']])
@ -516,7 +516,7 @@ def build_commands(file, new_dir, movie_name, bitbucket):
continue continue
if core.SCODEC == 'mov_text': if core.SCODEC == 'mov_text':
subcode = [stream['codec_name'] for stream in sub_details['streams']] subcode = [stream['codec_name'] for stream in sub_details['streams']]
if set(subcode).intersection(['dvd_subtitle', 'VobSub']): # We can't convert these. if set(subcode).intersection(['dvd_subtitle', 'dvb_subtitle', 'VobSub']): # We can't convert these.
continue continue
command.extend(['-i', subfile]) command.extend(['-i', subfile])
lan = os.path.splitext(os.path.splitext(subfile)[0])[1][1:].split('-')[0] lan = os.path.splitext(os.path.splitext(subfile)[0])[1][1:].split('-')[0]

View file

@ -11,7 +11,7 @@ import shutil
import stat import stat
import time import time
import beets.mediafile import mediafile as mediafiletool
import guessit import guessit
from six import text_type from six import text_type
@ -28,7 +28,7 @@ def move_file(mediafile, path, link):
file_ext = os.path.splitext(mediafile)[1] file_ext = os.path.splitext(mediafile)[1]
try: try:
if file_ext in core.AUDIO_CONTAINER: if file_ext in core.AUDIO_CONTAINER:
f = beets.mediafile.MediaFile(mediafile) f = mediafiletool.MediaFile(mediafile)
# get artist and album info # get artist and album info
artist = f.artist artist = f.artist

View file

@ -27,14 +27,14 @@ def find_imdbid(dir_name, input_name, omdb_api_key):
if m: if m:
imdbid = m.group(1) imdbid = m.group(1)
logger.info('Found imdbID [{0}]'.format(imdbid)) logger.info('Found imdbID [{0}]'.format(imdbid))
return imdbid return imdbid, dir_name
if os.path.isdir(dir_name): if os.path.isdir(dir_name):
for file in os.listdir(text_type(dir_name)): for file in os.listdir(text_type(dir_name)):
m = re.search(r'\b(tt\d{7,8})\b', file) m = re.search(r'\b(tt\d{7,8})\b', file)
if m: if m:
imdbid = m.group(1) imdbid = m.group(1)
logger.info('Found imdbID [{0}] via file name'.format(imdbid)) logger.info('Found imdbID [{0}] via file name'.format(imdbid))
return imdbid return imdbid, dir_name
if 'NZBPR__DNZB_MOREINFO' in os.environ: if 'NZBPR__DNZB_MOREINFO' in os.environ:
dnzb_more_info = os.environ.get('NZBPR__DNZB_MOREINFO', '') dnzb_more_info = os.environ.get('NZBPR__DNZB_MOREINFO', '')
if dnzb_more_info != '': if dnzb_more_info != '':
@ -43,7 +43,7 @@ def find_imdbid(dir_name, input_name, omdb_api_key):
if m: if m:
imdbid = m.group(1) imdbid = m.group(1)
logger.info('Found imdbID [{0}] from DNZB-MoreInfo'.format(imdbid)) logger.info('Found imdbID [{0}] from DNZB-MoreInfo'.format(imdbid))
return imdbid return imdbid, dir_name
logger.info('Searching IMDB for imdbID ...') logger.info('Searching IMDB for imdbID ...')
try: try:
guess = guessit.guessit(input_name) guess = guessit.guessit(input_name)
@ -64,7 +64,7 @@ def find_imdbid(dir_name, input_name, omdb_api_key):
if not omdb_api_key: if not omdb_api_key:
logger.info('Unable to determine imdbID: No api key provided for omdbapi.com.') logger.info('Unable to determine imdbID: No api key provided for omdbapi.com.')
return return imdbid, dir_name
logger.debug('Opening URL: {0}'.format(url)) logger.debug('Opening URL: {0}'.format(url))
@ -73,7 +73,7 @@ def find_imdbid(dir_name, input_name, omdb_api_key):
verify=False, timeout=(60, 300)) verify=False, timeout=(60, 300))
except requests.ConnectionError: except requests.ConnectionError:
logger.error('Unable to open URL {0}'.format(url)) logger.error('Unable to open URL {0}'.format(url))
return return imdbid, dir_name
try: try:
results = r.json() results = r.json()
@ -87,10 +87,12 @@ def find_imdbid(dir_name, input_name, omdb_api_key):
if imdbid: if imdbid:
logger.info('Found imdbID [{0}]'.format(imdbid)) logger.info('Found imdbID [{0}]'.format(imdbid))
return imdbid new_dir_name = '{}.cp({})'.format(dir_name, imdbid)
os.rename(dir_name, new_dir_name)
return imdbid, new_dir_name
logger.warning('Unable to find a imdbID for {0}'.format(input_name)) logger.warning('Unable to find a imdbID for {0}'.format(input_name))
return imdbid return imdbid, dir_name
def category_search(input_directory, input_name, input_category, root, categories): def category_search(input_directory, input_name, input_category, root, categories):

4
eol.py
View file

@ -28,6 +28,8 @@ def date(string, fmt='%Y-%m-%d'):
# https://devguide.python.org/ # https://devguide.python.org/
# https://devguide.python.org/devcycle/#devcycle # https://devguide.python.org/devcycle/#devcycle
PYTHON_EOL = { PYTHON_EOL = {
(3, 13): date('2029-10-1'),
(3, 12): date('2028-10-1'),
(3, 11): date('2027-10-1'), (3, 11): date('2027-10-1'),
(3, 10): date('2026-10-01'), (3, 10): date('2026-10-01'),
(3, 9): date('2025-10-05'), (3, 9): date('2025-10-05'),
@ -99,7 +101,7 @@ def check(version=None, grace_period=0):
:return: None :return: None
""" """
try: try:
warn_for_status(version, grace_period) raise_for_status(version, grace_period)
except LifetimeError as error: except LifetimeError as error:
print('Please use a newer version of Python.') print('Please use a newer version of Python.')
print_statuses() print_statuses()

1
libs/common/__init__.py Normal file
View file

@ -0,0 +1 @@

View file

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .filetype import * # noqa
from .helpers import * # noqa
from .match import * # noqa
# Current package semver version
__version__ = version = '1.2.0'

View file

@ -0,0 +1,41 @@
import glob
from itertools import chain
from os.path import isfile
import filetype
def guess(path):
kind = filetype.guess(path)
if kind is None:
print('{}: File type determination failure.'.format(path))
else:
print('{}: {} ({})'.format(path, kind.extension, kind.mime))
def main():
import argparse
parser = argparse.ArgumentParser(
prog='filetype', description='Determine type of FILEs.'
)
parser.add_argument(
'file', nargs='+',
help='files, wildcard is supported'
)
parser.add_argument(
'-v', '--version', action='version',
version=f'%(prog)s {filetype.version}',
help='output version information and exit'
)
args = parser.parse_args()
items = chain.from_iterable(map(glob.iglob, args.file))
files = filter(isfile, items)
for file in files:
guess(file)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .match import match
from .types import TYPES, Type
# Expose supported matchers types
types = TYPES
def guess(obj):
"""
Infers the type of the given input.
Function is overloaded to accept multiple types in input
and perform the needed type inference based on it.
Args:
obj: path to file, bytes or bytearray.
Returns:
The matched type instance. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
return match(obj) if obj else None
def guess_mime(obj):
"""
Infers the file type of the given input
and returns its MIME type.
Args:
obj: path to file, bytes or bytearray.
Returns:
The matched MIME type as string. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
kind = guess(obj)
return kind.mime if kind else kind
def guess_extension(obj):
"""
Infers the file type of the given input
and returns its RFC file extension.
Args:
obj: path to file, bytes or bytearray.
Returns:
The matched file extension as string. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
kind = guess(obj)
return kind.extension if kind else kind
def get_type(mime=None, ext=None):
"""
Returns the file type instance searching by
MIME type or file extension.
Args:
ext: file extension string. E.g: jpg, png, mp4, mp3
mime: MIME string. E.g: image/jpeg, video/mpeg
Returns:
The matched file type instance. Otherwise None.
"""
for kind in types:
if kind.extension == ext or kind.mime == mime:
return kind
return None
def add_type(instance):
"""
Adds a new type matcher instance to the supported types.
Args:
instance: Type inherited instance.
Returns:
None
"""
if not isinstance(instance, Type):
raise TypeError('instance must inherit from filetype.types.Type')
types.insert(0, instance)

View file

@ -0,0 +1,140 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .types import TYPES
from .match import (
image_match, font_match, document_match,
video_match, audio_match, archive_match
)
def is_extension_supported(ext):
"""
Checks if the given extension string is
one of the supported by the file matchers.
Args:
ext (str): file extension string. E.g: jpg, png, mp4, mp3
Returns:
True if the file extension is supported.
Otherwise False.
"""
for kind in TYPES:
if kind.extension == ext:
return True
return False
def is_mime_supported(mime):
"""
Checks if the given MIME type string is
one of the supported by the file matchers.
Args:
mime (str): MIME string. E.g: image/jpeg, video/mpeg
Returns:
True if the MIME type is supported.
Otherwise False.
"""
for kind in TYPES:
if kind.mime == mime:
return True
return False
def is_image(obj):
"""
Checks if a given input is a supported type image.
Args:
obj: path to file, bytes or bytearray.
Returns:
True if obj is a valid image. Otherwise False.
Raises:
TypeError: if obj is not a supported type.
"""
return image_match(obj) is not None
def is_archive(obj):
"""
Checks if a given input is a supported type archive.
Args:
obj: path to file, bytes or bytearray.
Returns:
True if obj is a valid archive. Otherwise False.
Raises:
TypeError: if obj is not a supported type.
"""
return archive_match(obj) is not None
def is_audio(obj):
"""
Checks if a given input is a supported type audio.
Args:
obj: path to file, bytes or bytearray.
Returns:
True if obj is a valid audio. Otherwise False.
Raises:
TypeError: if obj is not a supported type.
"""
return audio_match(obj) is not None
def is_video(obj):
"""
Checks if a given input is a supported type video.
Args:
obj: path to file, bytes or bytearray.
Returns:
True if obj is a valid video. Otherwise False.
Raises:
TypeError: if obj is not a supported type.
"""
return video_match(obj) is not None
def is_font(obj):
"""
Checks if a given input is a supported type font.
Args:
obj: path to file, bytes or bytearray.
Returns:
True if obj is a valid font. Otherwise False.
Raises:
TypeError: if obj is not a supported type.
"""
return font_match(obj) is not None
def is_document(obj):
"""
Checks if a given input is a supported type document.
Args:
obj: path to file, bytes or bytearray.
Returns:
True if obj is a valid document. Otherwise False.
Raises:
TypeError: if obj is not a supported type.
"""
return document_match(obj) is not None

View file

@ -0,0 +1,155 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .types import ARCHIVE as archive_matchers
from .types import AUDIO as audio_matchers
from .types import APPLICATION as application_matchers
from .types import DOCUMENT as document_matchers
from .types import FONT as font_matchers
from .types import IMAGE as image_matchers
from .types import VIDEO as video_matchers
from .types import TYPES
from .utils import get_bytes
def match(obj, matchers=TYPES):
"""
Matches the given input against the available
file type matchers.
Args:
obj: path to file, bytes or bytearray.
Returns:
Type instance if type matches. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
buf = get_bytes(obj)
for matcher in matchers:
if matcher.match(buf):
return matcher
return None
def image_match(obj):
"""
Matches the given input against the available
image type matchers.
Args:
obj: path to file, bytes or bytearray.
Returns:
Type instance if matches. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
return match(obj, image_matchers)
def font_match(obj):
"""
Matches the given input against the available
font type matchers.
Args:
obj: path to file, bytes or bytearray.
Returns:
Type instance if matches. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
return match(obj, font_matchers)
def video_match(obj):
"""
Matches the given input against the available
video type matchers.
Args:
obj: path to file, bytes or bytearray.
Returns:
Type instance if matches. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
return match(obj, video_matchers)
def audio_match(obj):
"""
Matches the given input against the available
autio type matchers.
Args:
obj: path to file, bytes or bytearray.
Returns:
Type instance if matches. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
return match(obj, audio_matchers)
def archive_match(obj):
"""
Matches the given input against the available
archive type matchers.
Args:
obj: path to file, bytes or bytearray.
Returns:
Type instance if matches. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
return match(obj, archive_matchers)
def application_match(obj):
"""
Matches the given input against the available
application type matchers.
Args:
obj: path to file, bytes or bytearray.
Returns:
Type instance if matches. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
return match(obj, application_matchers)
def document_match(obj):
"""
Matches the given input against the available
document type matchers.
Args:
obj: path to file, bytes or bytearray.
Returns:
Type instance if matches. Otherwise None.
Raises:
TypeError: if obj is not a supported type.
"""
return match(obj, document_matchers)

View file

@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from . import archive
from . import audio
from . import application
from . import document
from . import font
from . import image
from . import video
from .base import Type # noqa
# Supported image types
IMAGE = (
image.Dwg(),
image.Xcf(),
image.Jpeg(),
image.Jpx(),
image.Jxl(),
image.Apng(),
image.Png(),
image.Gif(),
image.Webp(),
image.Tiff(),
image.Cr2(),
image.Bmp(),
image.Jxr(),
image.Psd(),
image.Ico(),
image.Heic(),
image.Dcm(),
image.Avif(),
image.Qoi(),
image.Dds(),
)
# Supported video types
VIDEO = (
video.M3gp(),
video.Mp4(),
video.M4v(),
video.Mkv(),
video.Mov(),
video.Avi(),
video.Wmv(),
video.Mpeg(),
video.Webm(),
video.Flv(),
)
# Supported audio types
AUDIO = (
audio.Aac(),
audio.Midi(),
audio.Mp3(),
audio.M4a(),
audio.Ogg(),
audio.Flac(),
audio.Wav(),
audio.Amr(),
audio.Aiff(),
)
# Supported font types
FONT = (font.Woff(), font.Woff2(), font.Ttf(), font.Otf())
# Supported archive container types
ARCHIVE = (
archive.Br(),
archive.Rpm(),
archive.Dcm(),
archive.Epub(),
archive.Zip(),
archive.Tar(),
archive.Rar(),
archive.Gz(),
archive.Bz2(),
archive.SevenZ(),
archive.Pdf(),
archive.Exe(),
archive.Swf(),
archive.Rtf(),
archive.Nes(),
archive.Crx(),
archive.Cab(),
archive.Eot(),
archive.Ps(),
archive.Xz(),
archive.Sqlite(),
archive.Deb(),
archive.Ar(),
archive.Z(),
archive.Lzop(),
archive.Lz(),
archive.Elf(),
archive.Lz4(),
archive.Zstd(),
)
# Supported archive container types
APPLICATION = (
application.Wasm(),
)
# Supported document types
DOCUMENT = (
document.Doc(),
document.Docx(),
document.Odt(),
document.Xls(),
document.Xlsx(),
document.Ods(),
document.Ppt(),
document.Pptx(),
document.Odp(),
)
# Expose supported type matchers
TYPES = list(IMAGE + AUDIO + VIDEO + FONT + DOCUMENT + ARCHIVE + APPLICATION)

View file

@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .base import Type
class Wasm(Type):
"""Implements the Wasm image type matcher."""
MIME = 'application/wasm'
EXTENSION = 'wasm'
def __init__(self):
super(Wasm, self).__init__(
mime=Wasm.MIME,
extension=Wasm.EXTENSION
)
def match(self, buf):
return buf[:8] == bytearray([0x00, 0x61, 0x73, 0x6d,
0x01, 0x00, 0x00, 0x00])

View file

@ -0,0 +1,688 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import struct
from .base import Type
class Epub(Type):
"""
Implements the EPUB archive type matcher.
"""
MIME = 'application/epub+zip'
EXTENSION = 'epub'
def __init__(self):
super(Epub, self).__init__(
mime=Epub.MIME,
extension=Epub.EXTENSION
)
def match(self, buf):
return (len(buf) > 57 and
buf[0] == 0x50 and buf[1] == 0x4B and
buf[2] == 0x3 and buf[3] == 0x4 and
buf[30] == 0x6D and buf[31] == 0x69 and
buf[32] == 0x6D and buf[33] == 0x65 and
buf[34] == 0x74 and buf[35] == 0x79 and
buf[36] == 0x70 and buf[37] == 0x65 and
buf[38] == 0x61 and buf[39] == 0x70 and
buf[40] == 0x70 and buf[41] == 0x6C and
buf[42] == 0x69 and buf[43] == 0x63 and
buf[44] == 0x61 and buf[45] == 0x74 and
buf[46] == 0x69 and buf[47] == 0x6F and
buf[48] == 0x6E and buf[49] == 0x2F and
buf[50] == 0x65 and buf[51] == 0x70 and
buf[52] == 0x75 and buf[53] == 0x62 and
buf[54] == 0x2B and buf[55] == 0x7A and
buf[56] == 0x69 and buf[57] == 0x70)
class Zip(Type):
"""
Implements the Zip archive type matcher.
"""
MIME = 'application/zip'
EXTENSION = 'zip'
def __init__(self):
super(Zip, self).__init__(
mime=Zip.MIME,
extension=Zip.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x50 and buf[1] == 0x4B and
(buf[2] == 0x3 or buf[2] == 0x5 or
buf[2] == 0x7) and
(buf[3] == 0x4 or buf[3] == 0x6 or
buf[3] == 0x8))
class Tar(Type):
"""
Implements the Tar archive type matcher.
"""
MIME = 'application/x-tar'
EXTENSION = 'tar'
def __init__(self):
super(Tar, self).__init__(
mime=Tar.MIME,
extension=Tar.EXTENSION
)
def match(self, buf):
return (len(buf) > 261 and
buf[257] == 0x75 and
buf[258] == 0x73 and
buf[259] == 0x74 and
buf[260] == 0x61 and
buf[261] == 0x72)
class Rar(Type):
"""
Implements the RAR archive type matcher.
"""
MIME = 'application/x-rar-compressed'
EXTENSION = 'rar'
def __init__(self):
super(Rar, self).__init__(
mime=Rar.MIME,
extension=Rar.EXTENSION
)
def match(self, buf):
return (len(buf) > 6 and
buf[0] == 0x52 and
buf[1] == 0x61 and
buf[2] == 0x72 and
buf[3] == 0x21 and
buf[4] == 0x1A and
buf[5] == 0x7 and
(buf[6] == 0x0 or
buf[6] == 0x1))
class Gz(Type):
"""
Implements the GZ archive type matcher.
"""
MIME = 'application/gzip'
EXTENSION = 'gz'
def __init__(self):
super(Gz, self).__init__(
mime=Gz.MIME,
extension=Gz.EXTENSION
)
def match(self, buf):
return (len(buf) > 2 and
buf[0] == 0x1F and
buf[1] == 0x8B and
buf[2] == 0x8)
class Bz2(Type):
"""
Implements the BZ2 archive type matcher.
"""
MIME = 'application/x-bzip2'
EXTENSION = 'bz2'
def __init__(self):
super(Bz2, self).__init__(
mime=Bz2.MIME,
extension=Bz2.EXTENSION
)
def match(self, buf):
return (len(buf) > 2 and
buf[0] == 0x42 and
buf[1] == 0x5A and
buf[2] == 0x68)
class SevenZ(Type):
"""
Implements the SevenZ (7z) archive type matcher.
"""
MIME = 'application/x-7z-compressed'
EXTENSION = '7z'
def __init__(self):
super(SevenZ, self).__init__(
mime=SevenZ.MIME,
extension=SevenZ.EXTENSION
)
def match(self, buf):
return (len(buf) > 5 and
buf[0] == 0x37 and
buf[1] == 0x7A and
buf[2] == 0xBC and
buf[3] == 0xAF and
buf[4] == 0x27 and
buf[5] == 0x1C)
class Pdf(Type):
"""
Implements the PDF archive type matcher.
"""
MIME = 'application/pdf'
EXTENSION = 'pdf'
def __init__(self):
super(Pdf, self).__init__(
mime=Pdf.MIME,
extension=Pdf.EXTENSION
)
def match(self, buf):
# Detect BOM and skip first 3 bytes
if (len(buf) > 3 and
buf[0] == 0xEF and
buf[1] == 0xBB and
buf[2] == 0xBF): # noqa E129
buf = buf[3:]
return (len(buf) > 3 and
buf[0] == 0x25 and
buf[1] == 0x50 and
buf[2] == 0x44 and
buf[3] == 0x46)
class Exe(Type):
"""
Implements the EXE archive type matcher.
"""
MIME = 'application/x-msdownload'
EXTENSION = 'exe'
def __init__(self):
super(Exe, self).__init__(
mime=Exe.MIME,
extension=Exe.EXTENSION
)
def match(self, buf):
return (len(buf) > 1 and
buf[0] == 0x4D and
buf[1] == 0x5A)
class Swf(Type):
"""
Implements the SWF archive type matcher.
"""
MIME = 'application/x-shockwave-flash'
EXTENSION = 'swf'
def __init__(self):
super(Swf, self).__init__(
mime=Swf.MIME,
extension=Swf.EXTENSION
)
def match(self, buf):
return (len(buf) > 2 and
(buf[0] == 0x46 or
buf[0] == 0x43 or
buf[0] == 0x5A) and
buf[1] == 0x57 and
buf[2] == 0x53)
class Rtf(Type):
"""
Implements the RTF archive type matcher.
"""
MIME = 'application/rtf'
EXTENSION = 'rtf'
def __init__(self):
super(Rtf, self).__init__(
mime=Rtf.MIME,
extension=Rtf.EXTENSION
)
def match(self, buf):
return (len(buf) > 4 and
buf[0] == 0x7B and
buf[1] == 0x5C and
buf[2] == 0x72 and
buf[3] == 0x74 and
buf[4] == 0x66)
class Nes(Type):
"""
Implements the NES archive type matcher.
"""
MIME = 'application/x-nintendo-nes-rom'
EXTENSION = 'nes'
def __init__(self):
super(Nes, self).__init__(
mime=Nes.MIME,
extension=Nes.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x4E and
buf[1] == 0x45 and
buf[2] == 0x53 and
buf[3] == 0x1A)
class Crx(Type):
"""
Implements the CRX archive type matcher.
"""
MIME = 'application/x-google-chrome-extension'
EXTENSION = 'crx'
def __init__(self):
super(Crx, self).__init__(
mime=Crx.MIME,
extension=Crx.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x43 and
buf[1] == 0x72 and
buf[2] == 0x32 and
buf[3] == 0x34)
class Cab(Type):
"""
Implements the CAB archive type matcher.
"""
MIME = 'application/vnd.ms-cab-compressed'
EXTENSION = 'cab'
def __init__(self):
super(Cab, self).__init__(
mime=Cab.MIME,
extension=Cab.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
((buf[0] == 0x4D and
buf[1] == 0x53 and
buf[2] == 0x43 and
buf[3] == 0x46) or
(buf[0] == 0x49 and
buf[1] == 0x53 and
buf[2] == 0x63 and
buf[3] == 0x28)))
class Eot(Type):
"""
Implements the EOT archive type matcher.
"""
MIME = 'application/octet-stream'
EXTENSION = 'eot'
def __init__(self):
super(Eot, self).__init__(
mime=Eot.MIME,
extension=Eot.EXTENSION
)
def match(self, buf):
return (len(buf) > 35 and
buf[34] == 0x4C and
buf[35] == 0x50 and
((buf[8] == 0x02 and
buf[9] == 0x00 and
buf[10] == 0x01) or
(buf[8] == 0x01 and
buf[9] == 0x00 and
buf[10] == 0x00) or
(buf[8] == 0x02 and
buf[9] == 0x00 and
buf[10] == 0x02)))
class Ps(Type):
"""
Implements the PS archive type matcher.
"""
MIME = 'application/postscript'
EXTENSION = 'ps'
def __init__(self):
super(Ps, self).__init__(
mime=Ps.MIME,
extension=Ps.EXTENSION
)
def match(self, buf):
return (len(buf) > 1 and
buf[0] == 0x25 and
buf[1] == 0x21)
class Xz(Type):
"""
Implements the XS archive type matcher.
"""
MIME = 'application/x-xz'
EXTENSION = 'xz'
def __init__(self):
super(Xz, self).__init__(
mime=Xz.MIME,
extension=Xz.EXTENSION
)
def match(self, buf):
return (len(buf) > 5 and
buf[0] == 0xFD and
buf[1] == 0x37 and
buf[2] == 0x7A and
buf[3] == 0x58 and
buf[4] == 0x5A and
buf[5] == 0x00)
class Sqlite(Type):
"""
Implements the Sqlite DB archive type matcher.
"""
MIME = 'application/x-sqlite3'
EXTENSION = 'sqlite'
def __init__(self):
super(Sqlite, self).__init__(
mime=Sqlite.MIME,
extension=Sqlite.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x53 and
buf[1] == 0x51 and
buf[2] == 0x4C and
buf[3] == 0x69)
class Deb(Type):
"""
Implements the DEB archive type matcher.
"""
MIME = 'application/x-deb'
EXTENSION = 'deb'
def __init__(self):
super(Deb, self).__init__(
mime=Deb.MIME,
extension=Deb.EXTENSION
)
def match(self, buf):
return (len(buf) > 20 and
buf[0] == 0x21 and
buf[1] == 0x3C and
buf[2] == 0x61 and
buf[3] == 0x72 and
buf[4] == 0x63 and
buf[5] == 0x68 and
buf[6] == 0x3E and
buf[7] == 0x0A and
buf[8] == 0x64 and
buf[9] == 0x65 and
buf[10] == 0x62 and
buf[11] == 0x69 and
buf[12] == 0x61 and
buf[13] == 0x6E and
buf[14] == 0x2D and
buf[15] == 0x62 and
buf[16] == 0x69 and
buf[17] == 0x6E and
buf[18] == 0x61 and
buf[19] == 0x72 and
buf[20] == 0x79)
class Ar(Type):
"""
Implements the AR archive type matcher.
"""
MIME = 'application/x-unix-archive'
EXTENSION = 'ar'
def __init__(self):
super(Ar, self).__init__(
mime=Ar.MIME,
extension=Ar.EXTENSION
)
def match(self, buf):
return (len(buf) > 6 and
buf[0] == 0x21 and
buf[1] == 0x3C and
buf[2] == 0x61 and
buf[3] == 0x72 and
buf[4] == 0x63 and
buf[5] == 0x68 and
buf[6] == 0x3E)
class Z(Type):
"""
Implements the Z archive type matcher.
"""
MIME = 'application/x-compress'
EXTENSION = 'Z'
def __init__(self):
super(Z, self).__init__(
mime=Z.MIME,
extension=Z.EXTENSION
)
def match(self, buf):
return (len(buf) > 1 and
((buf[0] == 0x1F and
buf[1] == 0xA0) or
(buf[0] == 0x1F and
buf[1] == 0x9D)))
class Lzop(Type):
"""
Implements the Lzop archive type matcher.
"""
MIME = 'application/x-lzop'
EXTENSION = 'lzo'
def __init__(self):
super(Lzop, self).__init__(
mime=Lzop.MIME,
extension=Lzop.EXTENSION
)
def match(self, buf):
return (len(buf) > 7 and
buf[0] == 0x89 and
buf[1] == 0x4C and
buf[2] == 0x5A and
buf[3] == 0x4F and
buf[4] == 0x00 and
buf[5] == 0x0D and
buf[6] == 0x0A and
buf[7] == 0x1A)
class Lz(Type):
"""
Implements the Lz archive type matcher.
"""
MIME = 'application/x-lzip'
EXTENSION = 'lz'
def __init__(self):
super(Lz, self).__init__(
mime=Lz.MIME,
extension=Lz.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x4C and
buf[1] == 0x5A and
buf[2] == 0x49 and
buf[3] == 0x50)
class Elf(Type):
"""
Implements the Elf archive type matcher
"""
MIME = 'application/x-executable'
EXTENSION = 'elf'
def __init__(self):
super(Elf, self).__init__(
mime=Elf.MIME,
extension=Elf.EXTENSION
)
def match(self, buf):
return (len(buf) > 52 and
buf[0] == 0x7F and
buf[1] == 0x45 and
buf[2] == 0x4C and
buf[3] == 0x46)
class Lz4(Type):
"""
Implements the Lz4 archive type matcher.
"""
MIME = 'application/x-lz4'
EXTENSION = 'lz4'
def __init__(self):
super(Lz4, self).__init__(
mime=Lz4.MIME,
extension=Lz4.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x04 and
buf[1] == 0x22 and
buf[2] == 0x4D and
buf[3] == 0x18)
class Br(Type):
"""Implements the Br image type matcher."""
MIME = 'application/x-brotli'
EXTENSION = 'br'
def __init__(self):
super(Br, self).__init__(
mime=Br.MIME,
extension=Br.EXTENSION
)
def match(self, buf):
return buf[:4] == bytearray([0xce, 0xb2, 0xcf, 0x81])
class Dcm(Type):
"""Implements the Dcm image type matcher."""
MIME = 'application/dicom'
EXTENSION = 'dcm'
def __init__(self):
super(Dcm, self).__init__(
mime=Dcm.MIME,
extension=Dcm.EXTENSION
)
def match(self, buf):
return buf[128:131] == bytearray([0x44, 0x49, 0x43, 0x4d])
class Rpm(Type):
"""Implements the Rpm image type matcher."""
MIME = 'application/x-rpm'
EXTENSION = 'rpm'
def __init__(self):
super(Rpm, self).__init__(
mime=Rpm.MIME,
extension=Rpm.EXTENSION
)
def match(self, buf):
return buf[:4] == bytearray([0xed, 0xab, 0xee, 0xdb])
class Zstd(Type):
"""
Implements the Zstd archive type matcher.
https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
"""
MIME = 'application/zstd'
EXTENSION = 'zst'
MAGIC_SKIPPABLE_START = 0x184D2A50
MAGIC_SKIPPABLE_MASK = 0xFFFFFFF0
def __init__(self):
super(Zstd, self).__init__(
mime=Zstd.MIME,
extension=Zstd.EXTENSION
)
@staticmethod
def _to_little_endian_int(buf):
# return int.from_bytes(buf, byteorder='little')
return struct.unpack('<L', buf)[0]
def match(self, buf):
# Zstandard compressed data is made of one or more frames.
# There are two frame formats defined by Zstandard:
# Zstandard frames and Skippable frames.
# See more details from
# https://tools.ietf.org/id/draft-kucherawy-dispatch-zstd-00.html#rfc.section.2
is_zstd = (
len(buf) > 3 and
buf[0] in (0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28) and
buf[1] == 0xb5 and
buf[2] == 0x2f and
buf[3] == 0xfd)
if is_zstd:
return True
# skippable frames
if len(buf) < 8:
return False
magic = self._to_little_endian_int(buf[:4]) & Zstd.MAGIC_SKIPPABLE_MASK
if magic == Zstd.MAGIC_SKIPPABLE_START:
user_data_len = self._to_little_endian_int(buf[4:8])
if len(buf) < 8 + user_data_len:
return False
next_frame = buf[8 + user_data_len:]
return self.match(next_frame)
return False

View file

@ -0,0 +1,221 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .base import Type
class Midi(Type):
"""
Implements the Midi audio type matcher.
"""
MIME = 'audio/midi'
EXTENSION = 'midi'
def __init__(self):
super(Midi, self).__init__(
mime=Midi.MIME,
extension=Midi.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x4D and
buf[1] == 0x54 and
buf[2] == 0x68 and
buf[3] == 0x64)
class Mp3(Type):
"""
Implements the MP3 audio type matcher.
"""
MIME = 'audio/mpeg'
EXTENSION = 'mp3'
def __init__(self):
super(Mp3, self).__init__(
mime=Mp3.MIME,
extension=Mp3.EXTENSION
)
def match(self, buf):
if len(buf) > 2:
if (
buf[0] == 0x49 and
buf[1] == 0x44 and
buf[2] == 0x33
):
return True
if buf[0] == 0xFF:
if (
buf[1] == 0xE2 or # MPEG 2.5 with error protection
buf[1] == 0xE3 or # MPEG 2.5 w/o error protection
buf[1] == 0xF2 or # MPEG 2 with error protection
buf[1] == 0xF3 or # MPEG 2 w/o error protection
buf[1] == 0xFA or # MPEG 1 with error protection
buf[1] == 0xFB # MPEG 1 w/o error protection
):
return True
return False
class M4a(Type):
"""
Implements the M4A audio type matcher.
"""
MIME = 'audio/mp4'
EXTENSION = 'm4a'
def __init__(self):
super(M4a, self).__init__(
mime=M4a.MIME,
extension=M4a.EXTENSION
)
def match(self, buf):
return (len(buf) > 10 and
((buf[4] == 0x66 and
buf[5] == 0x74 and
buf[6] == 0x79 and
buf[7] == 0x70 and
buf[8] == 0x4D and
buf[9] == 0x34 and
buf[10] == 0x41) or
(buf[0] == 0x4D and
buf[1] == 0x34 and
buf[2] == 0x41 and
buf[3] == 0x20)))
class Ogg(Type):
"""
Implements the OGG audio type matcher.
"""
MIME = 'audio/ogg'
EXTENSION = 'ogg'
def __init__(self):
super(Ogg, self).__init__(
mime=Ogg.MIME,
extension=Ogg.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x4F and
buf[1] == 0x67 and
buf[2] == 0x67 and
buf[3] == 0x53)
class Flac(Type):
"""
Implements the FLAC audio type matcher.
"""
MIME = 'audio/x-flac'
EXTENSION = 'flac'
def __init__(self):
super(Flac, self).__init__(
mime=Flac.MIME,
extension=Flac.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x66 and
buf[1] == 0x4C and
buf[2] == 0x61 and
buf[3] == 0x43)
class Wav(Type):
"""
Implements the WAV audio type matcher.
"""
MIME = 'audio/x-wav'
EXTENSION = 'wav'
def __init__(self):
super(Wav, self).__init__(
mime=Wav.MIME,
extension=Wav.EXTENSION
)
def match(self, buf):
return (len(buf) > 11 and
buf[0] == 0x52 and
buf[1] == 0x49 and
buf[2] == 0x46 and
buf[3] == 0x46 and
buf[8] == 0x57 and
buf[9] == 0x41 and
buf[10] == 0x56 and
buf[11] == 0x45)
class Amr(Type):
"""
Implements the AMR audio type matcher.
"""
MIME = 'audio/amr'
EXTENSION = 'amr'
def __init__(self):
super(Amr, self).__init__(
mime=Amr.MIME,
extension=Amr.EXTENSION
)
def match(self, buf):
return (len(buf) > 11 and
buf[0] == 0x23 and
buf[1] == 0x21 and
buf[2] == 0x41 and
buf[3] == 0x4D and
buf[4] == 0x52 and
buf[5] == 0x0A)
class Aac(Type):
"""Implements the Aac audio type matcher."""
MIME = 'audio/aac'
EXTENSION = 'aac'
def __init__(self):
super(Aac, self).__init__(
mime=Aac.MIME,
extension=Aac.EXTENSION
)
def match(self, buf):
return (buf[:2] == bytearray([0xff, 0xf1]) or
buf[:2] == bytearray([0xff, 0xf9]))
class Aiff(Type):
"""
Implements the AIFF audio type matcher.
"""
MIME = 'audio/x-aiff'
EXTENSION = 'aiff'
def __init__(self):
super(Aiff, self).__init__(
mime=Aiff.MIME,
extension=Aiff.EXTENSION
)
def match(self, buf):
return (len(buf) > 11 and
buf[0] == 0x46 and
buf[1] == 0x4F and
buf[2] == 0x52 and
buf[3] == 0x4D and
buf[8] == 0x41 and
buf[9] == 0x49 and
buf[10] == 0x46 and
buf[11] == 0x46)

View file

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
class Type(object):
"""
Represents the file type object inherited by
specific file type matchers.
Provides convenient accessor and helper methods.
"""
def __init__(self, mime, extension):
self.__mime = mime
self.__extension = extension
@property
def mime(self):
return self.__mime
@property
def extension(self):
return self.__extension
def is_extension(self, extension):
return self.__extension is extension
def is_mime(self, mime):
return self.__mime is mime
def match(self, buf):
raise NotImplementedError

View file

@ -0,0 +1,265 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .base import Type
class ZippedDocumentBase(Type):
def match(self, buf):
# start by checking for ZIP local file header signature
idx = self.search_signature(buf, 0, 6000)
if idx != 0:
return
return self.match_document(buf)
def match_document(self, buf):
raise NotImplementedError
def compare_bytes(self, buf, subslice, start_offset):
sl = len(subslice)
if start_offset + sl > len(buf):
return False
return buf[start_offset:start_offset + sl] == subslice
def search_signature(self, buf, start, rangeNum):
signature = b"PK\x03\x04"
length = len(buf)
end = start + rangeNum
end = length if end > length else end
if start >= end:
return -1
try:
return buf.index(signature, start, end)
except ValueError:
return -1
class OpenDocument(ZippedDocumentBase):
def match_document(self, buf):
# Check if first file in archive is the identifying file
if not self.compare_bytes(buf, b"mimetype", 0x1E):
return
# Check content of mimetype file if it matches current mime
return self.compare_bytes(buf, bytes(self.mime, "ASCII"), 0x26)
class OfficeOpenXml(ZippedDocumentBase):
def match_document(self, buf):
# Check if first file in archive is the identifying file
ft = self.match_filename(buf, 0x1E)
if ft:
return ft
# Otherwise check that the fist file is one of these
if (
not self.compare_bytes(buf, b"[Content_Types].xml", 0x1E)
and not self.compare_bytes(buf, b"_rels/.rels", 0x1E)
and not self.compare_bytes(buf, b"docProps", 0x1E)
):
return
# Loop through next 3 files and check if they match
# NOTE: OpenOffice/Libreoffice orders ZIP entry differently, so check the 4th file
# https://github.com/h2non/filetype/blob/d730d98ad5c990883148485b6fd5adbdd378364a/matchers/document.go#L134
idx = 0
for i in range(4):
# Search for next file header
idx = self.search_signature(buf, idx + 4, 6000)
if idx == -1:
return
# Filename is at file header + 30
ft = self.match_filename(buf, idx + 30)
if ft:
return ft
def match_filename(self, buf, offset):
if self.compare_bytes(buf, b"word/", offset):
return (
self.mime
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
if self.compare_bytes(buf, b"ppt/", offset):
return (
self.mime
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
)
if self.compare_bytes(buf, b"xl/", offset):
return (
self.mime
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
class Doc(Type):
"""
Implements the Microsoft Word (Office 97-2003) document type matcher.
"""
MIME = "application/msword"
EXTENSION = "doc"
def __init__(self):
super(Doc, self).__init__(mime=Doc.MIME, extension=Doc.EXTENSION)
def match(self, buf):
if len(buf) > 515 and buf[0:8] == b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1":
if buf[512:516] == b"\xEC\xA5\xC1\x00":
return True
if (
len(buf) > 2142
and (
b"\x00\x0A\x00\x00\x00MSWordDoc\x00\x10\x00\x00\x00Word.Document.8\x00\xF49\xB2q"
in buf[2075:2142]
or b"W\0o\0r\0d\0D\0o\0c\0u\0m\0e\0n\0t\0"
in buf[0x580:0x598]
)
):
return True
if (
len(buf) > 663 and buf[512:531] == b"R\x00o\x00o\x00t\x00 \x00E\x00n\x00t\x00r\x00y"
and buf[640:663] == b"W\x00o\x00r\x00d\x00D\x00o\x00c\x00u\x00m\x00e\x00n\x00t"
):
return True
return False
class Docx(OfficeOpenXml):
"""
Implements the Microsoft Word OOXML (Office 2007+) document type matcher.
"""
MIME = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
EXTENSION = "docx"
def __init__(self):
super(Docx, self).__init__(mime=Docx.MIME, extension=Docx.EXTENSION)
class Odt(OpenDocument):
"""
Implements the OpenDocument Text document type matcher.
"""
MIME = "application/vnd.oasis.opendocument.text"
EXTENSION = "odt"
def __init__(self):
super(Odt, self).__init__(mime=Odt.MIME, extension=Odt.EXTENSION)
class Xls(Type):
"""
Implements the Microsoft Excel (Office 97-2003) document type matcher.
"""
MIME = "application/vnd.ms-excel"
EXTENSION = "xls"
def __init__(self):
super(Xls, self).__init__(mime=Xls.MIME, extension=Xls.EXTENSION)
def match(self, buf):
if len(buf) > 520 and buf[0:8] == b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1":
if buf[512:516] == b"\xFD\xFF\xFF\xFF" and (
buf[518] == 0x00 or buf[518] == 0x02
):
return True
if buf[512:520] == b"\x09\x08\x10\x00\x00\x06\x05\x00":
return True
if (
len(buf) > 2095
and b"\xE2\x00\x00\x00\x5C\x00\x70\x00\x04\x00\x00Calc"
in buf[1568:2095]
):
return True
return False
class Xlsx(OfficeOpenXml):
"""
Implements the Microsoft Excel OOXML (Office 2007+) document type matcher.
"""
MIME = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
EXTENSION = "xlsx"
def __init__(self):
super(Xlsx, self).__init__(mime=Xlsx.MIME, extension=Xlsx.EXTENSION)
class Ods(OpenDocument):
"""
Implements the OpenDocument Spreadsheet document type matcher.
"""
MIME = "application/vnd.oasis.opendocument.spreadsheet"
EXTENSION = "ods"
def __init__(self):
super(Ods, self).__init__(mime=Ods.MIME, extension=Ods.EXTENSION)
class Ppt(Type):
"""
Implements the Microsoft PowerPoint (Office 97-2003) document type matcher.
"""
MIME = "application/vnd.ms-powerpoint"
EXTENSION = "ppt"
def __init__(self):
super(Ppt, self).__init__(mime=Ppt.MIME, extension=Ppt.EXTENSION)
def match(self, buf):
if len(buf) > 524 and buf[0:8] == b"\xD0\xCF\x11\xE0\xA1\xB1\x1A\xE1":
if buf[512:516] == b"\xA0\x46\x1D\xF0":
return True
if buf[512:516] == b"\x00\x6E\x1E\xF0":
return True
if buf[512:516] == b"\x0F\x00\xE8\x03":
return True
if buf[512:516] == b"\xFD\xFF\xFF\xFF" and buf[522:524] == b"\x00\x00":
return True
if (
len(buf) > 2096
and buf[2072:2096]
== b"\x00\xB9\x29\xE8\x11\x00\x00\x00MS PowerPoint 97"
):
return True
return False
class Pptx(OfficeOpenXml):
"""
Implements the Microsoft PowerPoint OOXML (Office 2007+) document type matcher.
"""
MIME = "application/vnd.openxmlformats-officedocument.presentationml.presentation"
EXTENSION = "pptx"
def __init__(self):
super(Pptx, self).__init__(mime=Pptx.MIME, extension=Pptx.EXTENSION)
class Odp(OpenDocument):
"""
Implements the OpenDocument Presentation document type matcher.
"""
MIME = "application/vnd.oasis.opendocument.presentation"
EXTENSION = "odp"
def __init__(self):
super(Odp, self).__init__(mime=Odp.MIME, extension=Odp.EXTENSION)

View file

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .base import Type
class Woff(Type):
"""
Implements the WOFF font type matcher.
"""
MIME = 'application/font-woff'
EXTENSION = 'woff'
def __init__(self):
super(Woff, self).__init__(
mime=Woff.MIME,
extension=Woff.EXTENSION
)
def match(self, buf):
return (len(buf) > 7 and
buf[0] == 0x77 and
buf[1] == 0x4F and
buf[2] == 0x46 and
buf[3] == 0x46 and
((buf[4] == 0x00 and
buf[5] == 0x01 and
buf[6] == 0x00 and
buf[7] == 0x00) or
(buf[4] == 0x4F and
buf[5] == 0x54 and
buf[6] == 0x54 and
buf[7] == 0x4F) or
(buf[4] == 0x74 and
buf[5] == 0x72 and
buf[6] == 0x75 and
buf[7] == 0x65)))
class Woff2(Type):
"""
Implements the WOFF2 font type matcher.
"""
MIME = 'application/font-woff'
EXTENSION = 'woff2'
def __init__(self):
super(Woff2, self).__init__(
mime=Woff2.MIME,
extension=Woff2.EXTENSION
)
def match(self, buf):
return (len(buf) > 7 and
buf[0] == 0x77 and
buf[1] == 0x4F and
buf[2] == 0x46 and
buf[3] == 0x32 and
((buf[4] == 0x00 and
buf[5] == 0x01 and
buf[6] == 0x00 and
buf[7] == 0x00) or
(buf[4] == 0x4F and
buf[5] == 0x54 and
buf[6] == 0x54 and
buf[7] == 0x4F) or
(buf[4] == 0x74 and
buf[5] == 0x72 and
buf[6] == 0x75 and
buf[7] == 0x65)))
class Ttf(Type):
"""
Implements the TTF font type matcher.
"""
MIME = 'application/font-sfnt'
EXTENSION = 'ttf'
def __init__(self):
super(Ttf, self).__init__(
mime=Ttf.MIME,
extension=Ttf.EXTENSION
)
def match(self, buf):
return (len(buf) > 4 and
buf[0] == 0x00 and
buf[1] == 0x01 and
buf[2] == 0x00 and
buf[3] == 0x00 and
buf[4] == 0x00)
class Otf(Type):
"""
Implements the OTF font type matcher.
"""
MIME = 'application/font-sfnt'
EXTENSION = 'otf'
def __init__(self):
super(Otf, self).__init__(
mime=Otf.MIME,
extension=Otf.EXTENSION
)
def match(self, buf):
return (len(buf) > 4 and
buf[0] == 0x4F and
buf[1] == 0x54 and
buf[2] == 0x54 and
buf[3] == 0x4F and
buf[4] == 0x00)

View file

@ -0,0 +1,453 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .base import Type
from .isobmff import IsoBmff
class Jpeg(Type):
"""
Implements the JPEG image type matcher.
"""
MIME = 'image/jpeg'
EXTENSION = 'jpg'
def __init__(self):
super(Jpeg, self).__init__(
mime=Jpeg.MIME,
extension=Jpeg.EXTENSION
)
def match(self, buf):
return (len(buf) > 2 and
buf[0] == 0xFF and
buf[1] == 0xD8 and
buf[2] == 0xFF)
class Jpx(Type):
"""
Implements the JPEG2000 image type matcher.
"""
MIME = "image/jpx"
EXTENSION = "jpx"
def __init__(self):
super(Jpx, self).__init__(mime=Jpx.MIME, extension=Jpx.EXTENSION)
def match(self, buf):
return (
len(buf) > 50
and buf[0] == 0x00
and buf[1] == 0x00
and buf[2] == 0x00
and buf[3] == 0x0C
and buf[16:24] == b"ftypjp2 "
)
class Jxl(Type):
"""
Implements the JPEG XL image type matcher.
"""
MIME = "image/jxl"
EXTENSION = "jxl"
def __init__(self):
super(Jxl, self).__init__(mime=Jxl.MIME, extension=Jxl.EXTENSION)
def match(self, buf):
return (
(len(buf) > 1 and
buf[0] == 0xFF and
buf[1] == 0x0A) or
(len(buf) > 11 and
buf[0] == 0x00 and
buf[1] == 0x00 and
buf[2] == 0x00 and
buf[3] == 0x00 and
buf[4] == 0x0C and
buf[5] == 0x4A and
buf[6] == 0x58 and
buf[7] == 0x4C and
buf[8] == 0x20 and
buf[9] == 0x0D and
buf[10] == 0x87 and
buf[11] == 0x0A)
)
class Apng(Type):
"""
Implements the APNG image type matcher.
"""
MIME = 'image/apng'
EXTENSION = 'apng'
def __init__(self):
super(Apng, self).__init__(
mime=Apng.MIME,
extension=Apng.EXTENSION
)
def match(self, buf):
if (len(buf) > 8 and
buf[:8] == bytearray([0x89, 0x50, 0x4e, 0x47,
0x0d, 0x0a, 0x1a, 0x0a])):
# cursor in buf, skip already readed 8 bytes
i = 8
while len(buf) > i:
data_length = int.from_bytes(buf[i:i+4], byteorder="big")
i += 4
chunk_type = buf[i:i+4].decode("ascii", errors='ignore')
i += 4
# acTL chunk in APNG must appear before IDAT
# IEND is end of PNG
if (chunk_type == "IDAT" or chunk_type == "IEND"):
return False
if (chunk_type == "acTL"):
return True
# move to the next chunk by skipping data and crc (4 bytes)
i += data_length + 4
return False
class Png(Type):
"""
Implements the PNG image type matcher.
"""
MIME = 'image/png'
EXTENSION = 'png'
def __init__(self):
super(Png, self).__init__(
mime=Png.MIME,
extension=Png.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x89 and
buf[1] == 0x50 and
buf[2] == 0x4E and
buf[3] == 0x47)
class Gif(Type):
"""
Implements the GIF image type matcher.
"""
MIME = 'image/gif'
EXTENSION = 'gif'
def __init__(self):
super(Gif, self).__init__(
mime=Gif.MIME,
extension=Gif.EXTENSION,
)
def match(self, buf):
return (len(buf) > 2 and
buf[0] == 0x47 and
buf[1] == 0x49 and
buf[2] == 0x46)
class Webp(Type):
"""
Implements the WEBP image type matcher.
"""
MIME = 'image/webp'
EXTENSION = 'webp'
def __init__(self):
super(Webp, self).__init__(
mime=Webp.MIME,
extension=Webp.EXTENSION,
)
def match(self, buf):
return (len(buf) > 13 and
buf[0] == 0x52 and
buf[1] == 0x49 and
buf[2] == 0x46 and
buf[3] == 0x46 and
buf[8] == 0x57 and
buf[9] == 0x45 and
buf[10] == 0x42 and
buf[11] == 0x50 and
buf[12] == 0x56 and
buf[13] == 0x50)
class Cr2(Type):
"""
Implements the CR2 image type matcher.
"""
MIME = 'image/x-canon-cr2'
EXTENSION = 'cr2'
def __init__(self):
super(Cr2, self).__init__(
mime=Cr2.MIME,
extension=Cr2.EXTENSION,
)
def match(self, buf):
return (len(buf) > 9 and
((buf[0] == 0x49 and buf[1] == 0x49 and
buf[2] == 0x2A and buf[3] == 0x0) or
(buf[0] == 0x4D and buf[1] == 0x4D and
buf[2] == 0x0 and buf[3] == 0x2A)) and
buf[8] == 0x43 and buf[9] == 0x52)
class Tiff(Type):
"""
Implements the TIFF image type matcher.
"""
MIME = 'image/tiff'
EXTENSION = 'tif'
def __init__(self):
super(Tiff, self).__init__(
mime=Tiff.MIME,
extension=Tiff.EXTENSION,
)
def match(self, buf):
return (len(buf) > 9 and
((buf[0] == 0x49 and buf[1] == 0x49 and
buf[2] == 0x2A and buf[3] == 0x0) or
(buf[0] == 0x4D and buf[1] == 0x4D and
buf[2] == 0x0 and buf[3] == 0x2A))
and not (buf[8] == 0x43 and buf[9] == 0x52))
class Bmp(Type):
"""
Implements the BMP image type matcher.
"""
MIME = 'image/bmp'
EXTENSION = 'bmp'
def __init__(self):
super(Bmp, self).__init__(
mime=Bmp.MIME,
extension=Bmp.EXTENSION,
)
def match(self, buf):
return (len(buf) > 1 and
buf[0] == 0x42 and
buf[1] == 0x4D)
class Jxr(Type):
"""
Implements the JXR image type matcher.
"""
MIME = 'image/vnd.ms-photo'
EXTENSION = 'jxr'
def __init__(self):
super(Jxr, self).__init__(
mime=Jxr.MIME,
extension=Jxr.EXTENSION,
)
def match(self, buf):
return (len(buf) > 2 and
buf[0] == 0x49 and
buf[1] == 0x49 and
buf[2] == 0xBC)
class Psd(Type):
"""
Implements the PSD image type matcher.
"""
MIME = 'image/vnd.adobe.photoshop'
EXTENSION = 'psd'
def __init__(self):
super(Psd, self).__init__(
mime=Psd.MIME,
extension=Psd.EXTENSION,
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x38 and
buf[1] == 0x42 and
buf[2] == 0x50 and
buf[3] == 0x53)
class Ico(Type):
"""
Implements the ICO image type matcher.
"""
MIME = 'image/x-icon'
EXTENSION = 'ico'
def __init__(self):
super(Ico, self).__init__(
mime=Ico.MIME,
extension=Ico.EXTENSION,
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x00 and
buf[1] == 0x00 and
buf[2] == 0x01 and
buf[3] == 0x00)
class Heic(IsoBmff):
"""
Implements the HEIC image type matcher.
"""
MIME = 'image/heic'
EXTENSION = 'heic'
def __init__(self):
super(Heic, self).__init__(
mime=Heic.MIME,
extension=Heic.EXTENSION
)
def match(self, buf):
if not self._is_isobmff(buf):
return False
major_brand, minor_version, compatible_brands = self._get_ftyp(buf)
if major_brand == 'heic':
return True
if major_brand in ['mif1', 'msf1'] and 'heic' in compatible_brands:
return True
return False
class Dcm(Type):
MIME = 'application/dicom'
EXTENSION = 'dcm'
OFFSET = 128
def __init__(self):
super(Dcm, self).__init__(
mime=Dcm.MIME,
extension=Dcm.EXTENSION
)
def match(self, buf):
return (len(buf) > Dcm.OFFSET + 4 and
buf[Dcm.OFFSET + 0] == 0x44 and
buf[Dcm.OFFSET + 1] == 0x49 and
buf[Dcm.OFFSET + 2] == 0x43 and
buf[Dcm.OFFSET + 3] == 0x4D)
class Dwg(Type):
"""Implements the Dwg image type matcher."""
MIME = 'image/vnd.dwg'
EXTENSION = 'dwg'
def __init__(self):
super(Dwg, self).__init__(
mime=Dwg.MIME,
extension=Dwg.EXTENSION
)
def match(self, buf):
return buf[:4] == bytearray([0x41, 0x43, 0x31, 0x30])
class Xcf(Type):
"""Implements the Xcf image type matcher."""
MIME = 'image/x-xcf'
EXTENSION = 'xcf'
def __init__(self):
super(Xcf, self).__init__(
mime=Xcf.MIME,
extension=Xcf.EXTENSION
)
def match(self, buf):
return buf[:10] == bytearray([0x67, 0x69, 0x6d, 0x70, 0x20,
0x78, 0x63, 0x66, 0x20, 0x76])
class Avif(IsoBmff):
"""
Implements the AVIF image type matcher.
"""
MIME = 'image/avif'
EXTENSION = 'avif'
def __init__(self):
super(Avif, self).__init__(
mime=Avif.MIME,
extension=Avif.EXTENSION
)
def match(self, buf):
if not self._is_isobmff(buf):
return False
major_brand, minor_version, compatible_brands = self._get_ftyp(buf)
if major_brand in ['avif', 'avis']:
return True
if major_brand in ['mif1', 'msf1'] and 'avif' in compatible_brands:
return True
return False
class Qoi(Type):
"""
Implements the QOI image type matcher.
"""
MIME = 'image/qoi'
EXTENSION = 'qoi'
def __init__(self):
super(Qoi, self).__init__(
mime=Qoi.MIME,
extension=Qoi.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x71 and
buf[1] == 0x6F and
buf[2] == 0x69 and
buf[3] == 0x66)
class Dds(Type):
"""
Implements the DDS image type matcher.
"""
MIME = 'image/dds'
EXTENSION = 'dds'
def __init__(self):
super(Dds, self).__init__(
mime=Dds.MIME,
extension=Dds.EXTENSION
)
def match(self, buf):
return buf.startswith(b'\x44\x44\x53\x20')

View file

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
import codecs
from .base import Type
class IsoBmff(Type):
"""
Implements the ISO-BMFF base type.
"""
def __init__(self, mime, extension):
super(IsoBmff, self).__init__(
mime=mime,
extension=extension
)
def _is_isobmff(self, buf):
if len(buf) < 16 or buf[4:8] != b'ftyp':
return False
if len(buf) < int(codecs.encode(buf[0:4], 'hex'), 16):
return False
return True
def _get_ftyp(self, buf):
ftyp_len = int(codecs.encode(buf[0:4], 'hex'), 16)
major_brand = buf[8:12].decode(errors='ignore')
minor_version = int(codecs.encode(buf[12:16], 'hex'), 16)
compatible_brands = []
for i in range(16, ftyp_len, 4):
compatible_brands.append(buf[i:i+4].decode(errors='ignore'))
return major_brand, minor_version, compatible_brands

View file

@ -0,0 +1,230 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from .base import Type
from .isobmff import IsoBmff
class Mp4(IsoBmff):
"""
Implements the MP4 video type matcher.
"""
MIME = 'video/mp4'
EXTENSION = 'mp4'
def __init__(self):
super(Mp4, self).__init__(
mime=Mp4.MIME,
extension=Mp4.EXTENSION
)
def match(self, buf):
if not self._is_isobmff(buf):
return False
major_brand, minor_version, compatible_brands = self._get_ftyp(buf)
for brand in compatible_brands:
if brand in ['mp41', 'mp42', 'isom']:
return True
return major_brand in ['mp41', 'mp42', 'isom']
class M4v(Type):
"""
Implements the M4V video type matcher.
"""
MIME = 'video/x-m4v'
EXTENSION = 'm4v'
def __init__(self):
super(M4v, self).__init__(
mime=M4v.MIME,
extension=M4v.EXTENSION
)
def match(self, buf):
return (len(buf) > 10 and
buf[0] == 0x0 and buf[1] == 0x0 and
buf[2] == 0x0 and buf[3] == 0x1C and
buf[4] == 0x66 and buf[5] == 0x74 and
buf[6] == 0x79 and buf[7] == 0x70 and
buf[8] == 0x4D and buf[9] == 0x34 and
buf[10] == 0x56)
class Mkv(Type):
"""
Implements the MKV video type matcher.
"""
MIME = 'video/x-matroska'
EXTENSION = 'mkv'
def __init__(self):
super(Mkv, self).__init__(
mime=Mkv.MIME,
extension=Mkv.EXTENSION
)
def match(self, buf):
contains_ebml_element = buf.startswith(b'\x1A\x45\xDF\xA3')
contains_doctype_element = buf.find(b'\x42\x82\x88matroska') > -1
return contains_ebml_element and contains_doctype_element
class Webm(Type):
"""
Implements the WebM video type matcher.
"""
MIME = 'video/webm'
EXTENSION = 'webm'
def __init__(self):
super(Webm, self).__init__(
mime=Webm.MIME,
extension=Webm.EXTENSION
)
def match(self, buf):
contains_ebml_element = buf.startswith(b'\x1A\x45\xDF\xA3')
contains_doctype_element = buf.find(b'\x42\x82\x84webm') > -1
return contains_ebml_element and contains_doctype_element
class Mov(IsoBmff):
"""
Implements the MOV video type matcher.
"""
MIME = 'video/quicktime'
EXTENSION = 'mov'
def __init__(self):
super(Mov, self).__init__(
mime=Mov.MIME,
extension=Mov.EXTENSION
)
def match(self, buf):
if not self._is_isobmff(buf):
return False
major_brand, minor_version, compatible_brands = self._get_ftyp(buf)
return major_brand == 'qt '
class Avi(Type):
"""
Implements the AVI video type matcher.
"""
MIME = 'video/x-msvideo'
EXTENSION = 'avi'
def __init__(self):
super(Avi, self).__init__(
mime=Avi.MIME,
extension=Avi.EXTENSION
)
def match(self, buf):
return (len(buf) > 11 and
buf[0] == 0x52 and
buf[1] == 0x49 and
buf[2] == 0x46 and
buf[3] == 0x46 and
buf[8] == 0x41 and
buf[9] == 0x56 and
buf[10] == 0x49 and
buf[11] == 0x20)
class Wmv(Type):
"""
Implements the WMV video type matcher.
"""
MIME = 'video/x-ms-wmv'
EXTENSION = 'wmv'
def __init__(self):
super(Wmv, self).__init__(
mime=Wmv.MIME,
extension=Wmv.EXTENSION
)
def match(self, buf):
return (len(buf) > 9 and
buf[0] == 0x30 and
buf[1] == 0x26 and
buf[2] == 0xB2 and
buf[3] == 0x75 and
buf[4] == 0x8E and
buf[5] == 0x66 and
buf[6] == 0xCF and
buf[7] == 0x11 and
buf[8] == 0xA6 and
buf[9] == 0xD9)
class Flv(Type):
"""
Implements the FLV video type matcher.
"""
MIME = 'video/x-flv'
EXTENSION = 'flv'
def __init__(self):
super(Flv, self).__init__(
mime=Flv.MIME,
extension=Flv.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x46 and
buf[1] == 0x4C and
buf[2] == 0x56 and
buf[3] == 0x01)
class Mpeg(Type):
"""
Implements the MPEG video type matcher.
"""
MIME = 'video/mpeg'
EXTENSION = 'mpg'
def __init__(self):
super(Mpeg, self).__init__(
mime=Mpeg.MIME,
extension=Mpeg.EXTENSION
)
def match(self, buf):
return (len(buf) > 3 and
buf[0] == 0x0 and
buf[1] == 0x0 and
buf[2] == 0x1 and
buf[3] >= 0xb0 and
buf[3] <= 0xbf)
class M3gp(Type):
"""Implements the 3gp video type matcher."""
MIME = 'video/3gpp'
EXTENSION = '3gp'
def __init__(self):
super(M3gp, self).__init__(
mime=M3gp.MIME,
extension=M3gp.EXTENSION
)
def match(self, buf):
return (len(buf) > 10 and
buf[4] == 0x66 and
buf[5] == 0x74 and
buf[6] == 0x79 and
buf[7] == 0x70 and
buf[8] == 0x33 and
buf[9] == 0x67 and
buf[10] == 0x70)

View file

@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# Python 2.7 workaround
try:
import pathlib
except ImportError:
pass
_NUM_SIGNATURE_BYTES = 8192
def get_signature_bytes(path):
"""
Reads file from disk and returns the first 8192 bytes
of data representing the magic number header signature.
Args:
path: path string to file.
Returns:
First 8192 bytes of the file content as bytearray type.
"""
with open(path, 'rb') as fp:
return bytearray(fp.read(_NUM_SIGNATURE_BYTES))
def signature(array):
"""
Returns the first 8192 bytes of the given bytearray
as part of the file header signature.
Args:
array: bytearray to extract the header signature.
Returns:
First 8192 bytes of the file content as bytearray type.
"""
length = len(array)
index = _NUM_SIGNATURE_BYTES if length > _NUM_SIGNATURE_BYTES else length
return array[:index]
def get_bytes(obj):
"""
Infers the input type and reads the first 8192 bytes,
returning a sliced bytearray.
Args:
obj: path to readable, file-like object(with read() method), bytes,
bytearray or memoryview
Returns:
First 8192 bytes of the file content as bytearray type.
Raises:
TypeError: if obj is not a supported type.
"""
if isinstance(obj, bytearray):
return signature(obj)
if isinstance(obj, str):
return get_signature_bytes(obj)
if isinstance(obj, bytes):
return signature(obj)
if isinstance(obj, memoryview):
return bytearray(signature(obj).tolist())
if isinstance(obj, pathlib.PurePath):
return get_signature_bytes(obj)
if hasattr(obj, 'read'):
if hasattr(obj, 'tell') and hasattr(obj, 'seek'):
start_pos = obj.tell()
obj.seek(0)
magic_bytes = obj.read(_NUM_SIGNATURE_BYTES)
obj.seek(start_pos)
return get_bytes(magic_bytes)
return get_bytes(obj.read(_NUM_SIGNATURE_BYTES))
raise TypeError('Unsupported type as file input: %s' % type(obj))

View file

@ -0,0 +1,17 @@
"""Read resources contained within a package."""
from ._common import (
as_file,
files,
Package,
)
from .abc import ResourceReader
__all__ = [
'Package',
'ResourceReader',
'as_file',
'files',
]

View file

@ -0,0 +1,168 @@
from contextlib import suppress
from io import TextIOWrapper
from . import abc
class SpecLoaderAdapter:
"""
Adapt a package spec to adapt the underlying loader.
"""
def __init__(self, spec, adapter=lambda spec: spec.loader):
self.spec = spec
self.loader = adapter(spec)
def __getattr__(self, name):
return getattr(self.spec, name)
class TraversableResourcesLoader:
"""
Adapt a loader to provide TraversableResources.
"""
def __init__(self, spec):
self.spec = spec
def get_resource_reader(self, name):
return CompatibilityFiles(self.spec)._native()
def _io_wrapper(file, mode='r', *args, **kwargs):
if mode == 'r':
return TextIOWrapper(file, *args, **kwargs)
elif mode == 'rb':
return file
raise ValueError(f"Invalid mode value '{mode}', only 'r' and 'rb' are supported")
class CompatibilityFiles:
"""
Adapter for an existing or non-existent resource reader
to provide a compatibility .files().
"""
class SpecPath(abc.Traversable):
"""
Path tied to a module spec.
Can be read and exposes the resource reader children.
"""
def __init__(self, spec, reader):
self._spec = spec
self._reader = reader
def iterdir(self):
if not self._reader:
return iter(())
return iter(
CompatibilityFiles.ChildPath(self._reader, path)
for path in self._reader.contents()
)
def is_file(self):
return False
is_dir = is_file
def joinpath(self, other):
if not self._reader:
return CompatibilityFiles.OrphanPath(other)
return CompatibilityFiles.ChildPath(self._reader, other)
@property
def name(self):
return self._spec.name
def open(self, mode='r', *args, **kwargs):
return _io_wrapper(self._reader.open_resource(None), mode, *args, **kwargs)
class ChildPath(abc.Traversable):
"""
Path tied to a resource reader child.
Can be read but doesn't expose any meaningful children.
"""
def __init__(self, reader, name):
self._reader = reader
self._name = name
def iterdir(self):
return iter(())
def is_file(self):
return self._reader.is_resource(self.name)
def is_dir(self):
return not self.is_file()
def joinpath(self, other):
return CompatibilityFiles.OrphanPath(self.name, other)
@property
def name(self):
return self._name
def open(self, mode='r', *args, **kwargs):
return _io_wrapper(
self._reader.open_resource(self.name), mode, *args, **kwargs
)
class OrphanPath(abc.Traversable):
"""
Orphan path, not tied to a module spec or resource reader.
Can't be read and doesn't expose any meaningful children.
"""
def __init__(self, *path_parts):
if len(path_parts) < 1:
raise ValueError('Need at least one path part to construct a path')
self._path = path_parts
def iterdir(self):
return iter(())
def is_file(self):
return False
is_dir = is_file
def joinpath(self, other):
return CompatibilityFiles.OrphanPath(*self._path, other)
@property
def name(self):
return self._path[-1]
def open(self, mode='r', *args, **kwargs):
raise FileNotFoundError("Can't open orphan path")
def __init__(self, spec):
self.spec = spec
@property
def _reader(self):
with suppress(AttributeError):
return self.spec.loader.get_resource_reader(self.spec.name)
def _native(self):
"""
Return the native reader if it supports files().
"""
reader = self._reader
return reader if hasattr(reader, 'files') else self
def __getattr__(self, attr):
return getattr(self._reader, attr)
def files(self):
return CompatibilityFiles.SpecPath(self.spec, self._reader)
def wrap_spec(package):
"""
Construct a package spec with traversable compatibility
on the spec/loader/reader.
"""
return SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader)

View file

@ -0,0 +1,209 @@
import os
import pathlib
import tempfile
import functools
import contextlib
import types
import importlib
import inspect
import warnings
import itertools
from typing import Union, Optional, cast
from .abc import ResourceReader, Traversable
from .future.adapters import wrap_spec
Package = Union[types.ModuleType, str]
Anchor = Package
def package_to_anchor(func):
"""
Replace 'package' parameter as 'anchor' and warn about the change.
Other errors should fall through.
>>> files('a', 'b')
Traceback (most recent call last):
TypeError: files() takes from 0 to 1 positional arguments but 2 were given
Remove this compatibility in Python 3.14.
"""
undefined = object()
@functools.wraps(func)
def wrapper(anchor=undefined, package=undefined):
if package is not undefined:
if anchor is not undefined:
return func(anchor, package)
warnings.warn(
"First parameter to files is renamed to 'anchor'",
DeprecationWarning,
stacklevel=2,
)
return func(package)
elif anchor is undefined:
return func()
return func(anchor)
return wrapper
@package_to_anchor
def files(anchor: Optional[Anchor] = None) -> Traversable:
"""
Get a Traversable resource for an anchor.
"""
return from_package(resolve(anchor))
def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:
"""
Return the package's loader if it's a ResourceReader.
"""
# We can't use
# a issubclass() check here because apparently abc.'s __subclasscheck__()
# hook wants to create a weak reference to the object, but
# zipimport.zipimporter does not support weak references, resulting in a
# TypeError. That seems terrible.
spec = package.__spec__
reader = getattr(spec.loader, 'get_resource_reader', None) # type: ignore
if reader is None:
return None
return reader(spec.name) # type: ignore
@functools.singledispatch
def resolve(cand: Optional[Anchor]) -> types.ModuleType:
return cast(types.ModuleType, cand)
@resolve.register
def _(cand: str) -> types.ModuleType:
return importlib.import_module(cand)
@resolve.register
def _(cand: None) -> types.ModuleType:
return resolve(_infer_caller().f_globals['__name__'])
def _infer_caller():
"""
Walk the stack and find the frame of the first caller not in this module.
"""
def is_this_file(frame_info):
return frame_info.filename == __file__
def is_wrapper(frame_info):
return frame_info.function == 'wrapper'
not_this_file = itertools.filterfalse(is_this_file, inspect.stack())
# also exclude 'wrapper' due to singledispatch in the call stack
callers = itertools.filterfalse(is_wrapper, not_this_file)
return next(callers).frame
def from_package(package: types.ModuleType):
"""
Return a Traversable object for the given package.
"""
spec = wrap_spec(package)
reader = spec.loader.get_resource_reader(spec.name)
return reader.files()
@contextlib.contextmanager
def _tempfile(
reader,
suffix='',
# gh-93353: Keep a reference to call os.remove() in late Python
# finalization.
*,
_os_remove=os.remove,
):
# Not using tempfile.NamedTemporaryFile as it leads to deeper 'try'
# blocks due to the need to close the temporary file to work on Windows
# properly.
fd, raw_path = tempfile.mkstemp(suffix=suffix)
try:
try:
os.write(fd, reader())
finally:
os.close(fd)
del reader
yield pathlib.Path(raw_path)
finally:
try:
_os_remove(raw_path)
except FileNotFoundError:
pass
def _temp_file(path):
return _tempfile(path.read_bytes, suffix=path.name)
def _is_present_dir(path: Traversable) -> bool:
"""
Some Traversables implement ``is_dir()`` to raise an
exception (i.e. ``FileNotFoundError``) when the
directory doesn't exist. This function wraps that call
to always return a boolean and only return True
if there's a dir and it exists.
"""
with contextlib.suppress(FileNotFoundError):
return path.is_dir()
return False
@functools.singledispatch
def as_file(path):
"""
Given a Traversable object, return that object as a
path on the local file system in a context manager.
"""
return _temp_dir(path) if _is_present_dir(path) else _temp_file(path)
@as_file.register(pathlib.Path)
@contextlib.contextmanager
def _(path):
"""
Degenerate behavior for pathlib.Path objects.
"""
yield path
@contextlib.contextmanager
def _temp_path(dir: tempfile.TemporaryDirectory):
"""
Wrap tempfile.TemporyDirectory to return a pathlib object.
"""
with dir as result:
yield pathlib.Path(result)
@contextlib.contextmanager
def _temp_dir(path):
"""
Given a traversable dir, recursively replicate the whole tree
to the file system in a context manager.
"""
assert path.is_dir()
with _temp_path(tempfile.TemporaryDirectory()) as temp_dir:
yield _write_contents(temp_dir, path)
def _write_contents(target, source):
child = target.joinpath(source.name)
if source.is_dir():
child.mkdir()
for item in source.iterdir():
_write_contents(child, item)
else:
child.write_bytes(source.read_bytes())
return child

View file

@ -0,0 +1,38 @@
# from more_itertools 9.0
def only(iterable, default=None, too_long=None):
"""If *iterable* has only one item, return it.
If it has zero items, return *default*.
If it has more than one item, raise the exception given by *too_long*,
which is ``ValueError`` by default.
>>> only([], default='missing')
'missing'
>>> only([1])
1
>>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError: Expected exactly one item in iterable, but got 1, 2,
and perhaps more.'
>>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError
Note that :func:`only` attempts to advance *iterable* twice to ensure there
is only one item. See :func:`spy` or :func:`peekable` to check
iterable contents less destructively.
"""
it = iter(iterable)
first_value = next(it, default)
try:
second_value = next(it)
except StopIteration:
pass
else:
msg = (
'Expected exactly one item in iterable, but got {!r}, {!r}, '
'and perhaps more.'.format(first_value, second_value)
)
raise too_long or ValueError(msg)
return first_value

View file

@ -0,0 +1,171 @@
import abc
import io
import itertools
import pathlib
from typing import Any, BinaryIO, Iterable, Iterator, NoReturn, Text, Optional
from typing import runtime_checkable, Protocol
from .compat.py38 import StrPath
__all__ = ["ResourceReader", "Traversable", "TraversableResources"]
class ResourceReader(metaclass=abc.ABCMeta):
"""Abstract base class for loaders to provide resource reading support."""
@abc.abstractmethod
def open_resource(self, resource: Text) -> BinaryIO:
"""Return an opened, file-like object for binary reading.
The 'resource' argument is expected to represent only a file name.
If the resource cannot be found, FileNotFoundError is raised.
"""
# This deliberately raises FileNotFoundError instead of
# NotImplementedError so that if this method is accidentally called,
# it'll still do the right thing.
raise FileNotFoundError
@abc.abstractmethod
def resource_path(self, resource: Text) -> Text:
"""Return the file system path to the specified resource.
The 'resource' argument is expected to represent only a file name.
If the resource does not exist on the file system, raise
FileNotFoundError.
"""
# This deliberately raises FileNotFoundError instead of
# NotImplementedError so that if this method is accidentally called,
# it'll still do the right thing.
raise FileNotFoundError
@abc.abstractmethod
def is_resource(self, path: Text) -> bool:
"""Return True if the named 'path' is a resource.
Files are resources, directories are not.
"""
raise FileNotFoundError
@abc.abstractmethod
def contents(self) -> Iterable[str]:
"""Return an iterable of entries in `package`."""
raise FileNotFoundError
class TraversalError(Exception):
pass
@runtime_checkable
class Traversable(Protocol):
"""
An object with a subset of pathlib.Path methods suitable for
traversing directories and opening files.
Any exceptions that occur when accessing the backing resource
may propagate unaltered.
"""
@abc.abstractmethod
def iterdir(self) -> Iterator["Traversable"]:
"""
Yield Traversable objects in self
"""
def read_bytes(self) -> bytes:
"""
Read contents of self as bytes
"""
with self.open('rb') as strm:
return strm.read()
def read_text(self, encoding: Optional[str] = None) -> str:
"""
Read contents of self as text
"""
with self.open(encoding=encoding) as strm:
return strm.read()
@abc.abstractmethod
def is_dir(self) -> bool:
"""
Return True if self is a directory
"""
@abc.abstractmethod
def is_file(self) -> bool:
"""
Return True if self is a file
"""
def joinpath(self, *descendants: StrPath) -> "Traversable":
"""
Return Traversable resolved with any descendants applied.
Each descendant should be a path segment relative to self
and each may contain multiple levels separated by
``posixpath.sep`` (``/``).
"""
if not descendants:
return self
names = itertools.chain.from_iterable(
path.parts for path in map(pathlib.PurePosixPath, descendants)
)
target = next(names)
matches = (
traversable for traversable in self.iterdir() if traversable.name == target
)
try:
match = next(matches)
except StopIteration:
raise TraversalError(
"Target not found during traversal.", target, list(names)
)
return match.joinpath(*names)
def __truediv__(self, child: StrPath) -> "Traversable":
"""
Return Traversable child in self
"""
return self.joinpath(child)
@abc.abstractmethod
def open(self, mode='r', *args, **kwargs):
"""
mode may be 'r' or 'rb' to open as text or binary. Return a handle
suitable for reading (same as pathlib.Path.open).
When opening as text, accepts encoding parameters such as those
accepted by io.TextIOWrapper.
"""
@property
@abc.abstractmethod
def name(self) -> str:
"""
The base name of this object without any parent references.
"""
class TraversableResources(ResourceReader):
"""
The required interface for providing traversable
resources.
"""
@abc.abstractmethod
def files(self) -> "Traversable":
"""Return a Traversable object for the loaded package."""
def open_resource(self, resource: StrPath) -> io.BufferedReader:
return self.files().joinpath(resource).open('rb')
def resource_path(self, resource: Any) -> NoReturn:
raise FileNotFoundError(resource)
def is_resource(self, path: StrPath) -> bool:
return self.files().joinpath(path).is_file()
def contents(self) -> Iterator[str]:
return (item.name for item in self.files().iterdir())

View file

@ -0,0 +1,11 @@
import os
import sys
from typing import Union
if sys.version_info >= (3, 9):
StrPath = Union[str, os.PathLike[str]]
else:
# PathLike is only subscriptable at runtime in 3.9+
StrPath = Union[str, "os.PathLike[str]"]

View file

@ -0,0 +1,10 @@
import sys
__all__ = ['ZipPath']
if sys.version_info >= (3, 10):
from zipfile import Path as ZipPath # type: ignore
else:
from zipp import Path as ZipPath # type: ignore

View file

@ -0,0 +1,46 @@
import pathlib
from contextlib import suppress
from types import SimpleNamespace
from .. import readers, _adapters
class TraversableResourcesLoader(_adapters.TraversableResourcesLoader):
"""
Adapt loaders to provide TraversableResources and other
compatibility.
Ensures the readers from importlib_resources are preferred
over stdlib readers.
"""
def get_resource_reader(self, name):
return self._standard_reader() or super().get_resource_reader(name)
def _standard_reader(self):
return self._zip_reader() or self._namespace_reader() or self._file_reader()
def _zip_reader(self):
with suppress(AttributeError):
return readers.ZipReader(self.spec.loader, self.spec.name)
def _namespace_reader(self):
with suppress(AttributeError, ValueError):
return readers.NamespaceReader(self.spec.submodule_search_locations)
def _file_reader(self):
try:
path = pathlib.Path(self.spec.origin)
except TypeError:
return None
if path.exists():
return readers.FileReader(SimpleNamespace(path=path))
def wrap_spec(package):
"""
Override _adapters.wrap_spec to use TraversableResourcesLoader
from above. Ensures that future behavior is always available on older
Pythons.
"""
return _adapters.SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader)

View file

View file

@ -0,0 +1,194 @@
import collections
import contextlib
import itertools
import pathlib
import operator
import re
import warnings
from . import abc
from ._itertools import only
from .compat.py39 import ZipPath
def remove_duplicates(items):
return iter(collections.OrderedDict.fromkeys(items))
class FileReader(abc.TraversableResources):
def __init__(self, loader):
self.path = pathlib.Path(loader.path).parent
def resource_path(self, resource):
"""
Return the file system path to prevent
`resources.path()` from creating a temporary
copy.
"""
return str(self.path.joinpath(resource))
def files(self):
return self.path
class ZipReader(abc.TraversableResources):
def __init__(self, loader, module):
_, _, name = module.rpartition('.')
self.prefix = loader.prefix.replace('\\', '/') + name + '/'
self.archive = loader.archive
def open_resource(self, resource):
try:
return super().open_resource(resource)
except KeyError as exc:
raise FileNotFoundError(exc.args[0])
def is_resource(self, path):
"""
Workaround for `zipfile.Path.is_file` returning true
for non-existent paths.
"""
target = self.files().joinpath(path)
return target.is_file() and target.exists()
def files(self):
return ZipPath(self.archive, self.prefix)
class MultiplexedPath(abc.Traversable):
"""
Given a series of Traversable objects, implement a merged
version of the interface across all objects. Useful for
namespace packages which may be multihomed at a single
name.
"""
def __init__(self, *paths):
self._paths = list(map(_ensure_traversable, remove_duplicates(paths)))
if not self._paths:
message = 'MultiplexedPath must contain at least one path'
raise FileNotFoundError(message)
if not all(path.is_dir() for path in self._paths):
raise NotADirectoryError('MultiplexedPath only supports directories')
def iterdir(self):
children = (child for path in self._paths for child in path.iterdir())
by_name = operator.attrgetter('name')
groups = itertools.groupby(sorted(children, key=by_name), key=by_name)
return map(self._follow, (locs for name, locs in groups))
def read_bytes(self):
raise FileNotFoundError(f'{self} is not a file')
def read_text(self, *args, **kwargs):
raise FileNotFoundError(f'{self} is not a file')
def is_dir(self):
return True
def is_file(self):
return False
def joinpath(self, *descendants):
try:
return super().joinpath(*descendants)
except abc.TraversalError:
# One of the paths did not resolve (a directory does not exist).
# Just return something that will not exist.
return self._paths[0].joinpath(*descendants)
@classmethod
def _follow(cls, children):
"""
Construct a MultiplexedPath if needed.
If children contains a sole element, return it.
Otherwise, return a MultiplexedPath of the items.
Unless one of the items is not a Directory, then return the first.
"""
subdirs, one_dir, one_file = itertools.tee(children, 3)
try:
return only(one_dir)
except ValueError:
try:
return cls(*subdirs)
except NotADirectoryError:
return next(one_file)
def open(self, *args, **kwargs):
raise FileNotFoundError(f'{self} is not a file')
@property
def name(self):
return self._paths[0].name
def __repr__(self):
paths = ', '.join(f"'{path}'" for path in self._paths)
return f'MultiplexedPath({paths})'
class NamespaceReader(abc.TraversableResources):
def __init__(self, namespace_path):
if 'NamespacePath' not in str(namespace_path):
raise ValueError('Invalid path')
self.path = MultiplexedPath(*map(self._resolve, namespace_path))
@classmethod
def _resolve(cls, path_str) -> abc.Traversable:
r"""
Given an item from a namespace path, resolve it to a Traversable.
path_str might be a directory on the filesystem or a path to a
zipfile plus the path within the zipfile, e.g. ``/foo/bar`` or
``/foo/baz.zip/inner_dir`` or ``foo\baz.zip\inner_dir\sub``.
"""
(dir,) = (cand for cand in cls._candidate_paths(path_str) if cand.is_dir())
return dir
@classmethod
def _candidate_paths(cls, path_str):
yield pathlib.Path(path_str)
yield from cls._resolve_zip_path(path_str)
@staticmethod
def _resolve_zip_path(path_str):
for match in reversed(list(re.finditer(r'[\\/]', path_str))):
with contextlib.suppress(
FileNotFoundError,
IsADirectoryError,
NotADirectoryError,
PermissionError,
):
inner = path_str[match.end() :].replace('\\', '/') + '/'
yield ZipPath(path_str[: match.start()], inner.lstrip('/'))
def resource_path(self, resource):
"""
Return the file system path to prevent
`resources.path()` from creating a temporary
copy.
"""
return str(self.path.joinpath(resource))
def files(self):
return self.path
def _ensure_traversable(path):
"""
Convert deprecated string arguments to traversables (pathlib.Path).
Remove with Python 3.15.
"""
if not isinstance(path, str):
return path
warnings.warn(
"String arguments are deprecated. Pass a Traversable instead.",
DeprecationWarning,
stacklevel=3,
)
return pathlib.Path(path)

View file

@ -0,0 +1,106 @@
"""
Interface adapters for low-level readers.
"""
import abc
import io
import itertools
from typing import BinaryIO, List
from .abc import Traversable, TraversableResources
class SimpleReader(abc.ABC):
"""
The minimum, low-level interface required from a resource
provider.
"""
@property
@abc.abstractmethod
def package(self) -> str:
"""
The name of the package for which this reader loads resources.
"""
@abc.abstractmethod
def children(self) -> List['SimpleReader']:
"""
Obtain an iterable of SimpleReader for available
child containers (e.g. directories).
"""
@abc.abstractmethod
def resources(self) -> List[str]:
"""
Obtain available named resources for this virtual package.
"""
@abc.abstractmethod
def open_binary(self, resource: str) -> BinaryIO:
"""
Obtain a File-like for a named resource.
"""
@property
def name(self):
return self.package.split('.')[-1]
class ResourceContainer(Traversable):
"""
Traversable container for a package's resources via its reader.
"""
def __init__(self, reader: SimpleReader):
self.reader = reader
def is_dir(self):
return True
def is_file(self):
return False
def iterdir(self):
files = (ResourceHandle(self, name) for name in self.reader.resources)
dirs = map(ResourceContainer, self.reader.children())
return itertools.chain(files, dirs)
def open(self, *args, **kwargs):
raise IsADirectoryError()
class ResourceHandle(Traversable):
"""
Handle to a named resource in a ResourceReader.
"""
def __init__(self, parent: ResourceContainer, name: str):
self.parent = parent
self.name = name # type: ignore
def is_file(self):
return True
def is_dir(self):
return False
def open(self, mode='r', *args, **kwargs):
stream = self.parent.reader.open_binary(self.name)
if 'b' not in mode:
stream = io.TextIOWrapper(stream, *args, **kwargs)
return stream
def joinpath(self, name):
raise RuntimeError("Cannot traverse into a resource")
class TraversableReader(TraversableResources, SimpleReader):
"""
A TraversableResources based on SimpleReader. Resource providers
may derive from this class to provide the TraversableResources
interface by supplying the SimpleReader interface.
"""
def files(self):
return ResourceContainer(self)

View file

@ -0,0 +1,32 @@
import os
try:
from test.support import import_helper # type: ignore
except ImportError:
# Python 3.9 and earlier
class import_helper: # type: ignore
from test.support import (
modules_setup,
modules_cleanup,
DirsOnSysPath,
CleanImport,
)
try:
from test.support import os_helper # type: ignore
except ImportError:
# Python 3.9 compat
class os_helper: # type:ignore
from test.support import temp_dir
try:
# Python 3.10
from test.support.os_helper import unlink
except ImportError:
from test.support import unlink as _unlink
def unlink(target):
return _unlink(os.fspath(target))

View file

@ -0,0 +1,56 @@
import pathlib
import functools
from typing import Dict, Union
####
# from jaraco.path 3.4.1
FilesSpec = Dict[str, Union[str, bytes, 'FilesSpec']] # type: ignore
def build(spec: FilesSpec, prefix=pathlib.Path()):
"""
Build a set of files/directories, as described by the spec.
Each key represents a pathname, and the value represents
the content. Content may be a nested directory.
>>> spec = {
... 'README.txt': "A README file",
... "foo": {
... "__init__.py": "",
... "bar": {
... "__init__.py": "",
... },
... "baz.py": "# Some code",
... }
... }
>>> target = getfixture('tmp_path')
>>> build(spec, target)
>>> target.joinpath('foo/baz.py').read_text(encoding='utf-8')
'# Some code'
"""
for name, contents in spec.items():
create(contents, pathlib.Path(prefix) / name)
@functools.singledispatch
def create(content: Union[str, bytes, FilesSpec], path):
path.mkdir(exist_ok=True)
build(content, prefix=path) # type: ignore
@create.register
def _(content: bytes, path):
path.write_bytes(content)
@create.register
def _(content: str, path):
path.write_text(content, encoding='utf-8')
# end from jaraco.path
####

View file

@ -0,0 +1 @@
Hello, UTF-8 world!

View file

@ -0,0 +1 @@
one resource

View file

@ -0,0 +1 @@
two resource

View file

@ -0,0 +1 @@
Hello, UTF-8 world!

View file

@ -0,0 +1,104 @@
import io
import unittest
import importlib_resources as resources
from importlib_resources._adapters import (
CompatibilityFiles,
wrap_spec,
)
from . import util
class CompatibilityFilesTests(unittest.TestCase):
@property
def package(self):
bytes_data = io.BytesIO(b'Hello, world!')
return util.create_package(
file=bytes_data,
path='some_path',
contents=('a', 'b', 'c'),
)
@property
def files(self):
return resources.files(self.package)
def test_spec_path_iter(self):
self.assertEqual(
sorted(path.name for path in self.files.iterdir()),
['a', 'b', 'c'],
)
def test_child_path_iter(self):
self.assertEqual(list((self.files / 'a').iterdir()), [])
def test_orphan_path_iter(self):
self.assertEqual(list((self.files / 'a' / 'a').iterdir()), [])
self.assertEqual(list((self.files / 'a' / 'a' / 'a').iterdir()), [])
def test_spec_path_is(self):
self.assertFalse(self.files.is_file())
self.assertFalse(self.files.is_dir())
def test_child_path_is(self):
self.assertTrue((self.files / 'a').is_file())
self.assertFalse((self.files / 'a').is_dir())
def test_orphan_path_is(self):
self.assertFalse((self.files / 'a' / 'a').is_file())
self.assertFalse((self.files / 'a' / 'a').is_dir())
self.assertFalse((self.files / 'a' / 'a' / 'a').is_file())
self.assertFalse((self.files / 'a' / 'a' / 'a').is_dir())
def test_spec_path_name(self):
self.assertEqual(self.files.name, 'testingpackage')
def test_child_path_name(self):
self.assertEqual((self.files / 'a').name, 'a')
def test_orphan_path_name(self):
self.assertEqual((self.files / 'a' / 'b').name, 'b')
self.assertEqual((self.files / 'a' / 'b' / 'c').name, 'c')
def test_spec_path_open(self):
self.assertEqual(self.files.read_bytes(), b'Hello, world!')
self.assertEqual(self.files.read_text(encoding='utf-8'), 'Hello, world!')
def test_child_path_open(self):
self.assertEqual((self.files / 'a').read_bytes(), b'Hello, world!')
self.assertEqual(
(self.files / 'a').read_text(encoding='utf-8'), 'Hello, world!'
)
def test_orphan_path_open(self):
with self.assertRaises(FileNotFoundError):
(self.files / 'a' / 'b').read_bytes()
with self.assertRaises(FileNotFoundError):
(self.files / 'a' / 'b' / 'c').read_bytes()
def test_open_invalid_mode(self):
with self.assertRaises(ValueError):
self.files.open('0')
def test_orphan_path_invalid(self):
with self.assertRaises(ValueError):
CompatibilityFiles.OrphanPath()
def test_wrap_spec(self):
spec = wrap_spec(self.package)
self.assertIsInstance(spec.loader.get_resource_reader(None), CompatibilityFiles)
class CompatibilityFilesNoReaderTests(unittest.TestCase):
@property
def package(self):
return util.create_package_from_loader(None)
@property
def files(self):
return resources.files(self.package)
def test_spec_path_joinpath(self):
self.assertIsInstance(self.files / 'a', CompatibilityFiles.OrphanPath)

View file

@ -0,0 +1,43 @@
import unittest
import importlib_resources as resources
from . import data01
from . import util
class ContentsTests:
expected = {
'__init__.py',
'binary.file',
'subdirectory',
'utf-16.file',
'utf-8.file',
}
def test_contents(self):
contents = {path.name for path in resources.files(self.data).iterdir()}
assert self.expected <= contents
class ContentsDiskTests(ContentsTests, unittest.TestCase):
def setUp(self):
self.data = data01
class ContentsZipTests(ContentsTests, util.ZipSetup, unittest.TestCase):
pass
class ContentsNamespaceTests(ContentsTests, unittest.TestCase):
expected = {
# no __init__ because of namespace design
'binary.file',
'subdirectory',
'utf-16.file',
'utf-8.file',
}
def setUp(self):
from . import namespacedata01
self.data = namespacedata01

View file

@ -0,0 +1,47 @@
import unittest
import contextlib
import pathlib
import importlib_resources as resources
from .. import abc
from ..abc import TraversableResources, ResourceReader
from . import util
from ._compat import os_helper
class SimpleLoader:
"""
A simple loader that only implements a resource reader.
"""
def __init__(self, reader: ResourceReader):
self.reader = reader
def get_resource_reader(self, package):
return self.reader
class MagicResources(TraversableResources):
"""
Magically returns the resources at path.
"""
def __init__(self, path: pathlib.Path):
self.path = path
def files(self):
return self.path
class CustomTraversableResourcesTests(unittest.TestCase):
def setUp(self):
self.fixtures = contextlib.ExitStack()
self.addCleanup(self.fixtures.close)
def test_custom_loader(self):
temp_dir = pathlib.Path(self.fixtures.enter_context(os_helper.temp_dir()))
loader = SimpleLoader(MagicResources(temp_dir))
pkg = util.create_package_from_loader(loader)
files = resources.files(pkg)
assert isinstance(files, abc.Traversable)
assert list(files.iterdir()) == []

View file

@ -0,0 +1,111 @@
import textwrap
import unittest
import warnings
import importlib
import contextlib
import importlib_resources as resources
from ..abc import Traversable
from . import data01
from . import util
from . import _path
from ._compat import os_helper, import_helper
@contextlib.contextmanager
def suppress_known_deprecation():
with warnings.catch_warnings(record=True) as ctx:
warnings.simplefilter('default', category=DeprecationWarning)
yield ctx
class FilesTests:
def test_read_bytes(self):
files = resources.files(self.data)
actual = files.joinpath('utf-8.file').read_bytes()
assert actual == b'Hello, UTF-8 world!\n'
def test_read_text(self):
files = resources.files(self.data)
actual = files.joinpath('utf-8.file').read_text(encoding='utf-8')
assert actual == 'Hello, UTF-8 world!\n'
def test_traversable(self):
assert isinstance(resources.files(self.data), Traversable)
def test_old_parameter(self):
"""
Files used to take a 'package' parameter. Make sure anyone
passing by name is still supported.
"""
with suppress_known_deprecation():
resources.files(package=self.data)
class OpenDiskTests(FilesTests, unittest.TestCase):
def setUp(self):
self.data = data01
class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase):
pass
class OpenNamespaceTests(FilesTests, unittest.TestCase):
def setUp(self):
from . import namespacedata01
self.data = namespacedata01
class OpenNamespaceZipTests(FilesTests, util.ZipSetup, unittest.TestCase):
ZIP_MODULE = 'namespacedata01'
class SiteDir:
def setUp(self):
self.fixtures = contextlib.ExitStack()
self.addCleanup(self.fixtures.close)
self.site_dir = self.fixtures.enter_context(os_helper.temp_dir())
self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir))
self.fixtures.enter_context(import_helper.CleanImport())
class ModulesFilesTests(SiteDir, unittest.TestCase):
def test_module_resources(self):
"""
A module can have resources found adjacent to the module.
"""
spec = {
'mod.py': '',
'res.txt': 'resources are the best',
}
_path.build(spec, self.site_dir)
import mod
actual = resources.files(mod).joinpath('res.txt').read_text(encoding='utf-8')
assert actual == spec['res.txt']
class ImplicitContextFilesTests(SiteDir, unittest.TestCase):
def test_implicit_files(self):
"""
Without any parameter, files() will infer the location as the caller.
"""
spec = {
'somepkg': {
'__init__.py': textwrap.dedent(
"""
import importlib_resources as res
val = res.files().joinpath('res.txt').read_text(encoding='utf-8')
"""
),
'res.txt': 'resources are the best',
},
}
_path.build(spec, self.site_dir)
assert importlib.import_module('somepkg').val == 'resources are the best'
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,89 @@
import unittest
import importlib_resources as resources
from . import data01
from . import util
class CommonBinaryTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
target = resources.files(package).joinpath(path)
with target.open('rb'):
pass
class CommonTextTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
target = resources.files(package).joinpath(path)
with target.open(encoding='utf-8'):
pass
class OpenTests:
def test_open_binary(self):
target = resources.files(self.data) / 'binary.file'
with target.open('rb') as fp:
result = fp.read()
self.assertEqual(result, bytes(range(4)))
def test_open_text_default_encoding(self):
target = resources.files(self.data) / 'utf-8.file'
with target.open(encoding='utf-8') as fp:
result = fp.read()
self.assertEqual(result, 'Hello, UTF-8 world!\n')
def test_open_text_given_encoding(self):
target = resources.files(self.data) / 'utf-16.file'
with target.open(encoding='utf-16', errors='strict') as fp:
result = fp.read()
self.assertEqual(result, 'Hello, UTF-16 world!\n')
def test_open_text_with_errors(self):
"""
Raises UnicodeError without the 'errors' argument.
"""
target = resources.files(self.data) / 'utf-16.file'
with target.open(encoding='utf-8', errors='strict') as fp:
self.assertRaises(UnicodeError, fp.read)
with target.open(encoding='utf-8', errors='ignore') as fp:
result = fp.read()
self.assertEqual(
result,
'H\x00e\x00l\x00l\x00o\x00,\x00 '
'\x00U\x00T\x00F\x00-\x001\x006\x00 '
'\x00w\x00o\x00r\x00l\x00d\x00!\x00\n\x00',
)
def test_open_binary_FileNotFoundError(self):
target = resources.files(self.data) / 'does-not-exist'
with self.assertRaises(FileNotFoundError):
target.open('rb')
def test_open_text_FileNotFoundError(self):
target = resources.files(self.data) / 'does-not-exist'
with self.assertRaises(FileNotFoundError):
target.open(encoding='utf-8')
class OpenDiskTests(OpenTests, unittest.TestCase):
def setUp(self):
self.data = data01
class OpenDiskNamespaceTests(OpenTests, unittest.TestCase):
def setUp(self):
from . import namespacedata01
self.data = namespacedata01
class OpenZipTests(OpenTests, util.ZipSetup, unittest.TestCase):
pass
class OpenNamespaceZipTests(OpenTests, util.ZipSetup, unittest.TestCase):
ZIP_MODULE = 'namespacedata01'
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,65 @@
import io
import pathlib
import unittest
import importlib_resources as resources
from . import data01
from . import util
class CommonTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
with resources.as_file(resources.files(package).joinpath(path)):
pass
class PathTests:
def test_reading(self):
"""
Path should be readable and a pathlib.Path instance.
"""
target = resources.files(self.data) / 'utf-8.file'
with resources.as_file(target) as path:
self.assertIsInstance(path, pathlib.Path)
self.assertTrue(path.name.endswith("utf-8.file"), repr(path))
self.assertEqual('Hello, UTF-8 world!\n', path.read_text(encoding='utf-8'))
class PathDiskTests(PathTests, unittest.TestCase):
data = data01
def test_natural_path(self):
"""
Guarantee the internal implementation detail that
file-system-backed resources do not get the tempdir
treatment.
"""
target = resources.files(self.data) / 'utf-8.file'
with resources.as_file(target) as path:
assert 'data' in str(path)
class PathMemoryTests(PathTests, unittest.TestCase):
def setUp(self):
file = io.BytesIO(b'Hello, UTF-8 world!\n')
self.addCleanup(file.close)
self.data = util.create_package(
file=file, path=FileNotFoundError("package exists only in memory")
)
self.data.__spec__.origin = None
self.data.__spec__.has_location = False
class PathZipTests(PathTests, util.ZipSetup, unittest.TestCase):
def test_remove_in_context_manager(self):
"""
It is not an error if the file that was temporarily stashed on the
file system is removed inside the `with` stanza.
"""
target = resources.files(self.data) / 'utf-8.file'
with resources.as_file(target) as path:
path.unlink()
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,97 @@
import unittest
import importlib_resources as resources
from . import data01
from . import util
from importlib import import_module
class CommonBinaryTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
resources.files(package).joinpath(path).read_bytes()
class CommonTextTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
resources.files(package).joinpath(path).read_text(encoding='utf-8')
class ReadTests:
def test_read_bytes(self):
result = resources.files(self.data).joinpath('binary.file').read_bytes()
self.assertEqual(result, bytes(range(4)))
def test_read_text_default_encoding(self):
result = (
resources.files(self.data)
.joinpath('utf-8.file')
.read_text(encoding='utf-8')
)
self.assertEqual(result, 'Hello, UTF-8 world!\n')
def test_read_text_given_encoding(self):
result = (
resources.files(self.data)
.joinpath('utf-16.file')
.read_text(encoding='utf-16')
)
self.assertEqual(result, 'Hello, UTF-16 world!\n')
def test_read_text_with_errors(self):
"""
Raises UnicodeError without the 'errors' argument.
"""
target = resources.files(self.data) / 'utf-16.file'
self.assertRaises(UnicodeError, target.read_text, encoding='utf-8')
result = target.read_text(encoding='utf-8', errors='ignore')
self.assertEqual(
result,
'H\x00e\x00l\x00l\x00o\x00,\x00 '
'\x00U\x00T\x00F\x00-\x001\x006\x00 '
'\x00w\x00o\x00r\x00l\x00d\x00!\x00\n\x00',
)
class ReadDiskTests(ReadTests, unittest.TestCase):
data = data01
class ReadZipTests(ReadTests, util.ZipSetup, unittest.TestCase):
def test_read_submodule_resource(self):
submodule = import_module('data01.subdirectory')
result = resources.files(submodule).joinpath('binary.file').read_bytes()
self.assertEqual(result, bytes(range(4, 8)))
def test_read_submodule_resource_by_name(self):
result = (
resources.files('data01.subdirectory').joinpath('binary.file').read_bytes()
)
self.assertEqual(result, bytes(range(4, 8)))
class ReadNamespaceTests(ReadTests, unittest.TestCase):
def setUp(self):
from . import namespacedata01
self.data = namespacedata01
class ReadNamespaceZipTests(ReadTests, util.ZipSetup, unittest.TestCase):
ZIP_MODULE = 'namespacedata01'
def test_read_submodule_resource(self):
submodule = import_module('namespacedata01.subdirectory')
result = resources.files(submodule).joinpath('binary.file').read_bytes()
self.assertEqual(result, bytes(range(12, 16)))
def test_read_submodule_resource_by_name(self):
result = (
resources.files('namespacedata01.subdirectory')
.joinpath('binary.file')
.read_bytes()
)
self.assertEqual(result, bytes(range(12, 16)))
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,145 @@
import os.path
import sys
import pathlib
import unittest
from importlib import import_module
from importlib_resources.readers import MultiplexedPath, NamespaceReader
class MultiplexedPathTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.folder = pathlib.Path(__file__).parent / 'namespacedata01'
def test_init_no_paths(self):
with self.assertRaises(FileNotFoundError):
MultiplexedPath()
def test_init_file(self):
with self.assertRaises(NotADirectoryError):
MultiplexedPath(self.folder / 'binary.file')
def test_iterdir(self):
contents = {path.name for path in MultiplexedPath(self.folder).iterdir()}
try:
contents.remove('__pycache__')
except (KeyError, ValueError):
pass
self.assertEqual(
contents, {'subdirectory', 'binary.file', 'utf-16.file', 'utf-8.file'}
)
def test_iterdir_duplicate(self):
data01 = pathlib.Path(__file__).parent.joinpath('data01')
contents = {
path.name for path in MultiplexedPath(self.folder, data01).iterdir()
}
for remove in ('__pycache__', '__init__.pyc'):
try:
contents.remove(remove)
except (KeyError, ValueError):
pass
self.assertEqual(
contents,
{'__init__.py', 'binary.file', 'subdirectory', 'utf-16.file', 'utf-8.file'},
)
def test_is_dir(self):
self.assertEqual(MultiplexedPath(self.folder).is_dir(), True)
def test_is_file(self):
self.assertEqual(MultiplexedPath(self.folder).is_file(), False)
def test_open_file(self):
path = MultiplexedPath(self.folder)
with self.assertRaises(FileNotFoundError):
path.read_bytes()
with self.assertRaises(FileNotFoundError):
path.read_text()
with self.assertRaises(FileNotFoundError):
path.open()
def test_join_path(self):
data01 = pathlib.Path(__file__).parent.joinpath('data01')
prefix = str(data01.parent)
path = MultiplexedPath(self.folder, data01)
self.assertEqual(
str(path.joinpath('binary.file'))[len(prefix) + 1 :],
os.path.join('namespacedata01', 'binary.file'),
)
sub = path.joinpath('subdirectory')
assert isinstance(sub, MultiplexedPath)
assert 'namespacedata01' in str(sub)
assert 'data01' in str(sub)
self.assertEqual(
str(path.joinpath('imaginary'))[len(prefix) + 1 :],
os.path.join('namespacedata01', 'imaginary'),
)
self.assertEqual(path.joinpath(), path)
def test_join_path_compound(self):
path = MultiplexedPath(self.folder)
assert not path.joinpath('imaginary/foo.py').exists()
def test_join_path_common_subdir(self):
data01 = pathlib.Path(__file__).parent.joinpath('data01')
data02 = pathlib.Path(__file__).parent.joinpath('data02')
prefix = str(data01.parent)
path = MultiplexedPath(data01, data02)
self.assertIsInstance(path.joinpath('subdirectory'), MultiplexedPath)
self.assertEqual(
str(path.joinpath('subdirectory', 'subsubdir'))[len(prefix) + 1 :],
os.path.join('data02', 'subdirectory', 'subsubdir'),
)
def test_repr(self):
self.assertEqual(
repr(MultiplexedPath(self.folder)),
f"MultiplexedPath('{self.folder}')",
)
def test_name(self):
self.assertEqual(
MultiplexedPath(self.folder).name,
os.path.basename(self.folder),
)
class NamespaceReaderTest(unittest.TestCase):
site_dir = str(pathlib.Path(__file__).parent)
@classmethod
def setUpClass(cls):
sys.path.append(cls.site_dir)
@classmethod
def tearDownClass(cls):
sys.path.remove(cls.site_dir)
def test_init_error(self):
with self.assertRaises(ValueError):
NamespaceReader(['path1', 'path2'])
def test_resource_path(self):
namespacedata01 = import_module('namespacedata01')
reader = NamespaceReader(namespacedata01.__spec__.submodule_search_locations)
root = os.path.abspath(os.path.join(__file__, '..', 'namespacedata01'))
self.assertEqual(
reader.resource_path('binary.file'), os.path.join(root, 'binary.file')
)
self.assertEqual(
reader.resource_path('imaginary'), os.path.join(root, 'imaginary')
)
def test_files(self):
namespacedata01 = import_module('namespacedata01')
reader = NamespaceReader(namespacedata01.__spec__.submodule_search_locations)
root = os.path.abspath(os.path.join(__file__, '..', 'namespacedata01'))
self.assertIsInstance(reader.files(), MultiplexedPath)
self.assertEqual(repr(reader.files()), f"MultiplexedPath('{root}')")
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,241 @@
import sys
import unittest
import importlib_resources as resources
import pathlib
from . import data01
from . import util
from importlib import import_module
class ResourceTests:
# Subclasses are expected to set the `data` attribute.
def test_is_file_exists(self):
target = resources.files(self.data) / 'binary.file'
self.assertTrue(target.is_file())
def test_is_file_missing(self):
target = resources.files(self.data) / 'not-a-file'
self.assertFalse(target.is_file())
def test_is_dir(self):
target = resources.files(self.data) / 'subdirectory'
self.assertFalse(target.is_file())
self.assertTrue(target.is_dir())
class ResourceDiskTests(ResourceTests, unittest.TestCase):
def setUp(self):
self.data = data01
class ResourceZipTests(ResourceTests, util.ZipSetup, unittest.TestCase):
pass
def names(traversable):
return {item.name for item in traversable.iterdir()}
class ResourceLoaderTests(unittest.TestCase):
def test_resource_contents(self):
package = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C']
)
self.assertEqual(names(resources.files(package)), {'A', 'B', 'C'})
def test_is_file(self):
package = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F']
)
self.assertTrue(resources.files(package).joinpath('B').is_file())
def test_is_dir(self):
package = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F']
)
self.assertTrue(resources.files(package).joinpath('D').is_dir())
def test_resource_missing(self):
package = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F']
)
self.assertFalse(resources.files(package).joinpath('Z').is_file())
class ResourceCornerCaseTests(unittest.TestCase):
def test_package_has_no_reader_fallback(self):
"""
Test odd ball packages which:
# 1. Do not have a ResourceReader as a loader
# 2. Are not on the file system
# 3. Are not in a zip file
"""
module = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C']
)
# Give the module a dummy loader.
module.__loader__ = object()
# Give the module a dummy origin.
module.__file__ = '/path/which/shall/not/be/named'
module.__spec__.loader = module.__loader__
module.__spec__.origin = module.__file__
self.assertFalse(resources.files(module).joinpath('A').is_file())
class ResourceFromZipsTest01(util.ZipSetupBase, unittest.TestCase):
ZIP_MODULE = 'data01'
def test_is_submodule_resource(self):
submodule = import_module('data01.subdirectory')
self.assertTrue(resources.files(submodule).joinpath('binary.file').is_file())
def test_read_submodule_resource_by_name(self):
self.assertTrue(
resources.files('data01.subdirectory').joinpath('binary.file').is_file()
)
def test_submodule_contents(self):
submodule = import_module('data01.subdirectory')
self.assertEqual(
names(resources.files(submodule)), {'__init__.py', 'binary.file'}
)
def test_submodule_contents_by_name(self):
self.assertEqual(
names(resources.files('data01.subdirectory')),
{'__init__.py', 'binary.file'},
)
def test_as_file_directory(self):
with resources.as_file(resources.files('data01')) as data:
assert data.name == 'data01'
assert data.is_dir()
assert data.joinpath('subdirectory').is_dir()
assert len(list(data.iterdir()))
assert not data.parent.exists()
class ResourceFromZipsTest02(util.ZipSetupBase, unittest.TestCase):
ZIP_MODULE = 'data02'
def test_unrelated_contents(self):
"""
Test thata zip with two unrelated subpackages return
distinct resources. Ref python/importlib_resources#44.
"""
self.assertEqual(
names(resources.files('data02.one')),
{'__init__.py', 'resource1.txt'},
)
self.assertEqual(
names(resources.files('data02.two')),
{'__init__.py', 'resource2.txt'},
)
class DeletingZipsTest(util.ZipSetupBase, unittest.TestCase):
"""Having accessed resources in a zip file should not keep an open
reference to the zip.
"""
def test_iterdir_does_not_keep_open(self):
[item.name for item in resources.files('data01').iterdir()]
def test_is_file_does_not_keep_open(self):
resources.files('data01').joinpath('binary.file').is_file()
def test_is_file_failure_does_not_keep_open(self):
resources.files('data01').joinpath('not-present').is_file()
@unittest.skip("Desired but not supported.")
def test_as_file_does_not_keep_open(self): # pragma: no cover
resources.as_file(resources.files('data01') / 'binary.file')
def test_entered_path_does_not_keep_open(self):
"""
Mimic what certifi does on import to make its bundle
available for the process duration.
"""
resources.as_file(resources.files('data01') / 'binary.file').__enter__()
def test_read_binary_does_not_keep_open(self):
resources.files('data01').joinpath('binary.file').read_bytes()
def test_read_text_does_not_keep_open(self):
resources.files('data01').joinpath('utf-8.file').read_text(encoding='utf-8')
class ResourceFromNamespaceTests:
def test_is_submodule_resource(self):
self.assertTrue(
resources.files(import_module('namespacedata01'))
.joinpath('binary.file')
.is_file()
)
def test_read_submodule_resource_by_name(self):
self.assertTrue(
resources.files('namespacedata01').joinpath('binary.file').is_file()
)
def test_submodule_contents(self):
contents = names(resources.files(import_module('namespacedata01')))
try:
contents.remove('__pycache__')
except KeyError:
pass
self.assertEqual(
contents, {'subdirectory', 'binary.file', 'utf-8.file', 'utf-16.file'}
)
def test_submodule_contents_by_name(self):
contents = names(resources.files('namespacedata01'))
try:
contents.remove('__pycache__')
except KeyError:
pass
self.assertEqual(
contents, {'subdirectory', 'binary.file', 'utf-8.file', 'utf-16.file'}
)
def test_submodule_sub_contents(self):
contents = names(resources.files(import_module('namespacedata01.subdirectory')))
try:
contents.remove('__pycache__')
except KeyError:
pass
self.assertEqual(contents, {'binary.file'})
def test_submodule_sub_contents_by_name(self):
contents = names(resources.files('namespacedata01.subdirectory'))
try:
contents.remove('__pycache__')
except KeyError:
pass
self.assertEqual(contents, {'binary.file'})
class ResourceFromNamespaceDiskTests(ResourceFromNamespaceTests, unittest.TestCase):
site_dir = str(pathlib.Path(__file__).parent)
@classmethod
def setUpClass(cls):
sys.path.append(cls.site_dir)
@classmethod
def tearDownClass(cls):
sys.path.remove(cls.site_dir)
class ResourceFromNamespaceZipTests(
util.ZipSetupBase,
ResourceFromNamespaceTests,
unittest.TestCase,
):
ZIP_MODULE = 'namespacedata01'
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,165 @@
import abc
import importlib
import io
import sys
import types
import pathlib
import contextlib
from . import data01
from ..abc import ResourceReader
from ._compat import import_helper, os_helper
from . import zip as zip_
from importlib.machinery import ModuleSpec
class Reader(ResourceReader):
def __init__(self, **kwargs):
vars(self).update(kwargs)
def get_resource_reader(self, package):
return self
def open_resource(self, path):
self._path = path
if isinstance(self.file, Exception):
raise self.file
return self.file
def resource_path(self, path_):
self._path = path_
if isinstance(self.path, Exception):
raise self.path
return self.path
def is_resource(self, path_):
self._path = path_
if isinstance(self.path, Exception):
raise self.path
def part(entry):
return entry.split('/')
return any(
len(parts) == 1 and parts[0] == path_ for parts in map(part, self._contents)
)
def contents(self):
if isinstance(self.path, Exception):
raise self.path
yield from self._contents
def create_package_from_loader(loader, is_package=True):
name = 'testingpackage'
module = types.ModuleType(name)
spec = ModuleSpec(name, loader, origin='does-not-exist', is_package=is_package)
module.__spec__ = spec
module.__loader__ = loader
return module
def create_package(file=None, path=None, is_package=True, contents=()):
return create_package_from_loader(
Reader(file=file, path=path, _contents=contents),
is_package,
)
class CommonTests(metaclass=abc.ABCMeta):
"""
Tests shared by test_open, test_path, and test_read.
"""
@abc.abstractmethod
def execute(self, package, path):
"""
Call the pertinent legacy API function (e.g. open_text, path)
on package and path.
"""
def test_package_name(self):
"""
Passing in the package name should succeed.
"""
self.execute(data01.__name__, 'utf-8.file')
def test_package_object(self):
"""
Passing in the package itself should succeed.
"""
self.execute(data01, 'utf-8.file')
def test_string_path(self):
"""
Passing in a string for the path should succeed.
"""
path = 'utf-8.file'
self.execute(data01, path)
def test_pathlib_path(self):
"""
Passing in a pathlib.PurePath object for the path should succeed.
"""
path = pathlib.PurePath('utf-8.file')
self.execute(data01, path)
def test_importing_module_as_side_effect(self):
"""
The anchor package can already be imported.
"""
del sys.modules[data01.__name__]
self.execute(data01.__name__, 'utf-8.file')
def test_missing_path(self):
"""
Attempting to open or read or request the path for a
non-existent path should succeed if open_resource
can return a viable data stream.
"""
bytes_data = io.BytesIO(b'Hello, world!')
package = create_package(file=bytes_data, path=FileNotFoundError())
self.execute(package, 'utf-8.file')
self.assertEqual(package.__loader__._path, 'utf-8.file')
def test_extant_path(self):
# Attempting to open or read or request the path when the
# path does exist should still succeed. Does not assert
# anything about the result.
bytes_data = io.BytesIO(b'Hello, world!')
# any path that exists
path = __file__
package = create_package(file=bytes_data, path=path)
self.execute(package, 'utf-8.file')
self.assertEqual(package.__loader__._path, 'utf-8.file')
def test_useless_loader(self):
package = create_package(file=FileNotFoundError(), path=FileNotFoundError())
with self.assertRaises(FileNotFoundError):
self.execute(package, 'utf-8.file')
class ZipSetupBase:
ZIP_MODULE = 'data01'
def setUp(self):
self.fixtures = contextlib.ExitStack()
self.addCleanup(self.fixtures.close)
modules = import_helper.modules_setup()
self.addCleanup(import_helper.modules_cleanup, *modules)
temp_dir = self.fixtures.enter_context(os_helper.temp_dir())
modules = pathlib.Path(temp_dir) / 'zipped modules.zip'
src_path = pathlib.Path(__file__).parent.joinpath(self.ZIP_MODULE)
self.fixtures.enter_context(
import_helper.DirsOnSysPath(str(zip_.make_zip_file(src_path, modules)))
)
self.data = importlib.import_module(self.ZIP_MODULE)
class ZipSetup(ZipSetupBase):
pass

View file

@ -0,0 +1,32 @@
"""
Generate zip test data files.
"""
import contextlib
import os
import pathlib
import zipfile
import zipp
def make_zip_file(src, dst):
"""
Zip the files in src into a new zipfile at dst.
"""
with zipfile.ZipFile(dst, 'w') as zf:
for src_path, rel in walk(src):
dst_name = src.name / pathlib.PurePosixPath(rel.as_posix())
zf.write(src_path, dst_name)
zipp.CompleteDirs.inject(zf)
return dst
def walk(datapath):
for dirpath, dirnames, filenames in os.walk(datapath):
with contextlib.suppress(ValueError):
dirnames.remove('__pycache__')
for filename in filenames:
res = pathlib.Path(dirpath) / filename
rel = res.relative_to(datapath)
yield res, rel

View file

@ -0,0 +1,288 @@
import os
import subprocess
import contextlib
import functools
import tempfile
import shutil
import operator
import warnings
@contextlib.contextmanager
def pushd(dir):
"""
>>> tmp_path = getfixture('tmp_path')
>>> with pushd(tmp_path):
... assert os.getcwd() == os.fspath(tmp_path)
>>> assert os.getcwd() != os.fspath(tmp_path)
"""
orig = os.getcwd()
os.chdir(dir)
try:
yield dir
finally:
os.chdir(orig)
@contextlib.contextmanager
def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
"""
Get a tarball, extract it, change to that directory, yield, then
clean up.
`runner` is the function to invoke commands.
`pushd` is a context manager for changing the directory.
"""
if target_dir is None:
target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
if runner is None:
runner = functools.partial(subprocess.check_call, shell=True)
else:
warnings.warn("runner parameter is deprecated", DeprecationWarning)
# In the tar command, use --strip-components=1 to strip the first path and
# then
# use -C to cause the files to be extracted to {target_dir}. This ensures
# that we always know where the files were extracted.
runner('mkdir {target_dir}'.format(**vars()))
try:
getter = 'wget {url} -O -'
extract = 'tar x{compression} --strip-components=1 -C {target_dir}'
cmd = ' | '.join((getter, extract))
runner(cmd.format(compression=infer_compression(url), **vars()))
with pushd(target_dir):
yield target_dir
finally:
runner('rm -Rf {target_dir}'.format(**vars()))
def infer_compression(url):
"""
Given a URL or filename, infer the compression code for tar.
>>> infer_compression('http://foo/bar.tar.gz')
'z'
>>> infer_compression('http://foo/bar.tgz')
'z'
>>> infer_compression('file.bz')
'j'
>>> infer_compression('file.xz')
'J'
"""
# cheat and just assume it's the last two characters
compression_indicator = url[-2:]
mapping = dict(gz='z', bz='j', xz='J')
# Assume 'z' (gzip) if no match
return mapping.get(compression_indicator, 'z')
@contextlib.contextmanager
def temp_dir(remover=shutil.rmtree):
"""
Create a temporary directory context. Pass a custom remover
to override the removal behavior.
>>> import pathlib
>>> with temp_dir() as the_dir:
... assert os.path.isdir(the_dir)
... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents')
>>> assert not os.path.exists(the_dir)
"""
temp_dir = tempfile.mkdtemp()
try:
yield temp_dir
finally:
remover(temp_dir)
@contextlib.contextmanager
def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
"""
Check out the repo indicated by url.
If dest_ctx is supplied, it should be a context manager
to yield the target directory for the check out.
"""
exe = 'git' if 'git' in url else 'hg'
with dest_ctx() as repo_dir:
cmd = [exe, 'clone', url, repo_dir]
if branch:
cmd.extend(['--branch', branch])
devnull = open(os.path.devnull, 'w')
stdout = devnull if quiet else None
subprocess.check_call(cmd, stdout=stdout)
yield repo_dir
@contextlib.contextmanager
def null():
"""
A null context suitable to stand in for a meaningful context.
>>> with null() as value:
... assert value is None
"""
yield
class ExceptionTrap:
"""
A context manager that will catch certain exceptions and provide an
indication they occurred.
>>> with ExceptionTrap() as trap:
... raise Exception()
>>> bool(trap)
True
>>> with ExceptionTrap() as trap:
... pass
>>> bool(trap)
False
>>> with ExceptionTrap(ValueError) as trap:
... raise ValueError("1 + 1 is not 3")
>>> bool(trap)
True
>>> trap.value
ValueError('1 + 1 is not 3')
>>> trap.tb
<traceback object at ...>
>>> with ExceptionTrap(ValueError) as trap:
... raise Exception()
Traceback (most recent call last):
...
Exception
>>> bool(trap)
False
"""
exc_info = None, None, None
def __init__(self, exceptions=(Exception,)):
self.exceptions = exceptions
def __enter__(self):
return self
@property
def type(self):
return self.exc_info[0]
@property
def value(self):
return self.exc_info[1]
@property
def tb(self):
return self.exc_info[2]
def __exit__(self, *exc_info):
type = exc_info[0]
matches = type and issubclass(type, self.exceptions)
if matches:
self.exc_info = exc_info
return matches
def __bool__(self):
return bool(self.type)
def raises(self, func, *, _test=bool):
"""
Wrap func and replace the result with the truth
value of the trap (True if an exception occurred).
First, give the decorator an alias to support Python 3.8
Syntax.
>>> raises = ExceptionTrap(ValueError).raises
Now decorate a function that always fails.
>>> @raises
... def fail():
... raise ValueError('failed')
>>> fail()
True
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with ExceptionTrap(self.exceptions) as trap:
func(*args, **kwargs)
return _test(trap)
return wrapper
def passes(self, func):
"""
Wrap func and replace the result with the truth
value of the trap (True if no exception).
First, give the decorator an alias to support Python 3.8
Syntax.
>>> passes = ExceptionTrap(ValueError).passes
Now decorate a function that always fails.
>>> @passes
... def fail():
... raise ValueError('failed')
>>> fail()
False
"""
return self.raises(func, _test=operator.not_)
class suppress(contextlib.suppress, contextlib.ContextDecorator):
"""
A version of contextlib.suppress with decorator support.
>>> @suppress(KeyError)
... def key_error():
... {}['']
>>> key_error()
"""
class on_interrupt(contextlib.ContextDecorator):
"""
Replace a KeyboardInterrupt with SystemExit(1)
>>> def do_interrupt():
... raise KeyboardInterrupt()
>>> on_interrupt('error')(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 1
>>> on_interrupt('error', code=255)(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 255
>>> on_interrupt('suppress')(do_interrupt)()
>>> with __import__('pytest').raises(KeyboardInterrupt):
... on_interrupt('ignore')(do_interrupt)()
"""
def __init__(
self,
action='error',
# py3.7 compat
# /,
code=1,
):
self.action = action
self.code = code
def __enter__(self):
return self
def __exit__(self, exctype, excinst, exctb):
if exctype is not KeyboardInterrupt or self.action == 'ignore':
return
elif self.action == 'error':
raise SystemExit(self.code) from excinst
return self.action == 'suppress'

View file

@ -0,0 +1,633 @@
import collections.abc
import functools
import inspect
import itertools
import operator
import time
import types
import warnings
import more_itertools
def compose(*funcs):
"""
Compose any number of unary functions into a single unary function.
>>> import textwrap
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
>>> strip_and_dedent(compose.__doc__) == expected
True
Compose also allows the innermost function to take arbitrary arguments.
>>> round_three = lambda x: round(x, ndigits=3)
>>> f = compose(round_three, int.__truediv__)
>>> [f(3*x, x+1) for x in range(1,10)]
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
def compose_two(f1, f2):
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs)
def once(func):
"""
Decorate func so it's only ever called the first time.
This decorator can ensure that an expensive or non-idempotent function
will not be expensive on subsequent calls and is idempotent.
>>> add_three = once(lambda a: a+3)
>>> add_three(3)
6
>>> add_three(9)
6
>>> add_three('12')
6
To reset the stored value, simply clear the property ``saved_result``.
>>> del add_three.saved_result
>>> add_three(9)
12
>>> add_three(8)
12
Or invoke 'reset()' on it.
>>> add_three.reset()
>>> add_three(-3)
0
>>> add_three(0)
0
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not hasattr(wrapper, 'saved_result'):
wrapper.saved_result = func(*args, **kwargs)
return wrapper.saved_result
wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result')
return wrapper
def method_cache(method, cache_wrapper=functools.lru_cache()):
"""
Wrap lru_cache to support storing the cache data in the object instances.
Abstracts the common paradigm where the method explicitly saves an
underscore-prefixed protected property on first call and returns that
subsequently.
>>> class MyClass:
... calls = 0
...
... @method_cache
... def method(self, value):
... self.calls += 1
... return value
>>> a = MyClass()
>>> a.method(3)
3
>>> for x in range(75):
... res = a.method(x)
>>> a.calls
75
Note that the apparent behavior will be exactly like that of lru_cache
except that the cache is stored on each instance, so values in one
instance will not flush values from another, and when an instance is
deleted, so are the cached values for that instance.
>>> b = MyClass()
>>> for x in range(35):
... res = b.method(x)
>>> b.calls
35
>>> a.method(0)
0
>>> a.calls
75
Note that if method had been decorated with ``functools.lru_cache()``,
a.calls would have been 76 (due to the cached value of 0 having been
flushed by the 'b' instance).
Clear the cache with ``.cache_clear()``
>>> a.method.cache_clear()
Same for a method that hasn't yet been called.
>>> c = MyClass()
>>> c.method.cache_clear()
Another cache wrapper may be supplied:
>>> cache = functools.lru_cache(maxsize=2)
>>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache)
>>> a = MyClass()
>>> a.method2()
3
Caution - do not subsequently wrap the method with another decorator, such
as ``@property``, which changes the semantics of the function.
See also
http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
for another implementation and additional justification.
"""
def wrapper(self, *args, **kwargs):
# it's the first call, replace the method with a cached, bound method
bound_method = types.MethodType(method, self)
cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs)
# Support cache clear even before cache has been created.
wrapper.cache_clear = lambda: None
return _special_method_cache(method, cache_wrapper) or wrapper
def _special_method_cache(method, cache_wrapper):
"""
Because Python treats special methods differently, it's not
possible to use instance attributes to implement the cached
methods.
Instead, install the wrapper method under a different name
and return a simple proxy to that wrapper.
https://github.com/jaraco/jaraco.functools/issues/5
"""
name = method.__name__
special_names = '__getattr__', '__getitem__'
if name not in special_names:
return None
wrapper_name = '__cached' + name
def proxy(self, /, *args, **kwargs):
if wrapper_name not in vars(self):
bound = types.MethodType(method, self)
cache = cache_wrapper(bound)
setattr(self, wrapper_name, cache)
else:
cache = getattr(self, wrapper_name)
return cache(*args, **kwargs)
return proxy
def apply(transform):
"""
Decorate a function with a transform function that is
invoked on results returned from the decorated function.
>>> @apply(reversed)
... def get_numbers(start):
... "doc for get_numbers"
... return range(start, start+3)
>>> list(get_numbers(4))
[6, 5, 4]
>>> get_numbers.__doc__
'doc for get_numbers'
"""
def wrap(func):
return functools.wraps(func)(compose(transform, func))
return wrap
def result_invoke(action):
r"""
Decorate a function with an action function that is
invoked on the results returned from the decorated
function (for its side effect), then return the original
result.
>>> @result_invoke(print)
... def add_two(a, b):
... return a + b
>>> x = add_two(2, 3)
5
>>> x
5
"""
def wrap(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
action(result)
return result
return wrapper
return wrap
def invoke(f, /, *args, **kwargs):
"""
Call a function for its side effect after initialization.
The benefit of using the decorator instead of simply invoking a function
after defining it is that it makes explicit the author's intent for the
function to be called immediately. Whereas if one simply calls the
function immediately, it's less obvious if that was intentional or
incidental. It also avoids repeating the name - the two actions, defining
the function and calling it immediately are modeled separately, but linked
by the decorator construct.
The benefit of having a function construct (opposed to just invoking some
behavior inline) is to serve as a scope in which the behavior occurs. It
avoids polluting the global namespace with local variables, provides an
anchor on which to attach documentation (docstring), keeps the behavior
logically separated (instead of conceptually separated or not separated at
all), and provides potential to re-use the behavior for testing or other
purposes.
This function is named as a pithy way to communicate, "call this function
primarily for its side effect", or "while defining this function, also
take it aside and call it". It exists because there's no Python construct
for "define and call" (nor should there be, as decorators serve this need
just fine). The behavior happens immediately and synchronously.
>>> @invoke
... def func(): print("called")
called
>>> func()
called
Use functools.partial to pass parameters to the initial call
>>> @functools.partial(invoke, name='bingo')
... def func(name): print('called with', name)
called with bingo
"""
f(*args, **kwargs)
return f
class Throttler:
"""Rate-limit a function (or other callable)."""
def __init__(self, func, max_rate=float('Inf')):
if isinstance(func, Throttler):
func = func.func
self.func = func
self.max_rate = max_rate
self.reset()
def reset(self):
self.last_called = 0
def __call__(self, *args, **kwargs):
self._wait()
return self.func(*args, **kwargs)
def _wait(self):
"""Ensure at least 1/max_rate seconds from last call."""
elapsed = time.time() - self.last_called
must_wait = 1 / self.max_rate - elapsed
time.sleep(max(0, must_wait))
self.last_called = time.time()
def __get__(self, obj, owner=None):
return first_invoke(self._wait, functools.partial(self.func, obj))
def first_invoke(func1, func2):
"""
Return a function that when invoked will invoke func1 without
any parameters (for its side effect) and then invoke func2
with whatever parameters were passed, returning its result.
"""
def wrapper(*args, **kwargs):
func1()
return func2(*args, **kwargs)
return wrapper
method_caller = first_invoke(
lambda: warnings.warn(
'`jaraco.functools.method_caller` is deprecated, '
'use `operator.methodcaller` instead',
DeprecationWarning,
stacklevel=3,
),
operator.methodcaller,
)
def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
"""
Given a callable func, trap the indicated exceptions
for up to 'retries' times, invoking cleanup on the
exception. On the final attempt, allow any exceptions
to propagate.
"""
attempts = itertools.count() if retries == float('inf') else range(retries)
for _ in attempts:
try:
return func()
except trap:
cleanup()
return func()
def retry(*r_args, **r_kwargs):
"""
Decorator wrapper for retry_call. Accepts arguments to retry_call
except func and then returns a decorator for the decorated function.
Ex:
>>> @retry(retries=3)
... def my_func(a, b):
... "this is my funk"
... print(a, b)
>>> my_func.__doc__
'this is my funk'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*f_args, **f_kwargs):
bound = functools.partial(func, *f_args, **f_kwargs)
return retry_call(bound, *r_args, **r_kwargs)
return wrapper
return decorate
def print_yielded(func):
"""
Convert a generator into a function that prints all yielded elements.
>>> @print_yielded
... def x():
... yield 3; yield None
>>> x()
3
None
"""
print_all = functools.partial(map, print)
print_results = compose(more_itertools.consume, print_all, func)
return functools.wraps(func)(print_results)
def pass_none(func):
"""
Wrap func so it's not called if its first param is None.
>>> print_text = pass_none(print)
>>> print_text('text')
text
>>> print_text(None)
"""
@functools.wraps(func)
def wrapper(param, /, *args, **kwargs):
if param is not None:
return func(param, *args, **kwargs)
return None
return wrapper
def assign_params(func, namespace):
"""
Assign parameters from namespace where func solicits.
>>> def func(x, y=3):
... print(x, y)
>>> assigned = assign_params(func, dict(x=2, z=4))
>>> assigned()
2 3
The usual errors are raised if a function doesn't receive
its required parameters:
>>> assigned = assign_params(func, dict(y=3, z=4))
>>> assigned()
Traceback (most recent call last):
TypeError: func() ...argument...
It even works on methods:
>>> class Handler:
... def meth(self, arg):
... print(arg)
>>> assign_params(Handler().meth, dict(arg='crystal', foo='clear'))()
crystal
"""
sig = inspect.signature(func)
params = sig.parameters.keys()
call_ns = {k: namespace[k] for k in params if k in namespace}
return functools.partial(func, **call_ns)
def save_method_args(method):
"""
Wrap a method such that when it is called, the args and kwargs are
saved on the method.
>>> class MyClass:
... @save_method_args
... def method(self, a, b):
... print(a, b)
>>> my_ob = MyClass()
>>> my_ob.method(1, 2)
1 2
>>> my_ob._saved_method.args
(1, 2)
>>> my_ob._saved_method.kwargs
{}
>>> my_ob.method(a=3, b='foo')
3 foo
>>> my_ob._saved_method.args
()
>>> my_ob._saved_method.kwargs == dict(a=3, b='foo')
True
The arguments are stored on the instance, allowing for
different instance to save different args.
>>> your_ob = MyClass()
>>> your_ob.method({str('x'): 3}, b=[4])
{'x': 3} [4]
>>> your_ob._saved_method.args
({'x': 3},)
>>> my_ob._saved_method.args
()
"""
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
@functools.wraps(method)
def wrapper(self, /, *args, **kwargs):
attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr)
return method(self, *args, **kwargs)
return wrapper
def except_(*exceptions, replace=None, use=None):
"""
Replace the indicated exceptions, if raised, with the indicated
literal replacement or evaluated expression (if present).
>>> safe_int = except_(ValueError)(int)
>>> safe_int('five')
>>> safe_int('5')
5
Specify a literal replacement with ``replace``.
>>> safe_int_r = except_(ValueError, replace=0)(int)
>>> safe_int_r('five')
0
Provide an expression to ``use`` to pass through particular parameters.
>>> safe_int_pt = except_(ValueError, use='args[0]')(int)
>>> safe_int_pt('five')
'five'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except exceptions:
try:
return eval(use)
except TypeError:
return replace
return wrapper
return decorate
def identity(x):
"""
Return the argument.
>>> o = object()
>>> identity(o) is o
True
"""
return x
def bypass_when(check, *, _op=identity):
"""
Decorate a function to return its parameter when ``check``.
>>> bypassed = [] # False
>>> @bypass_when(bypassed)
... def double(x):
... return x * 2
>>> double(2)
4
>>> bypassed[:] = [object()] # True
>>> double(2)
2
"""
def decorate(func):
@functools.wraps(func)
def wrapper(param, /):
return param if _op(check) else func(param)
return wrapper
return decorate
def bypass_unless(check):
"""
Decorate a function to return its parameter unless ``check``.
>>> enabled = [object()] # True
>>> @bypass_unless(enabled)
... def double(x):
... return x * 2
>>> double(2)
4
>>> del enabled[:] # False
>>> double(2)
2
"""
return bypass_when(check, _op=operator.not_)
@functools.singledispatch
def _splat_inner(args, func):
"""Splat args to func."""
return func(*args)
@_splat_inner.register
def _(args: collections.abc.Mapping, func):
"""Splat kargs to func as kwargs."""
return func(**args)
def splat(func):
"""
Wrap func to expect its parameters to be passed positionally in a tuple.
Has a similar effect to that of ``itertools.starmap`` over
simple ``map``.
>>> pairs = [(-1, 1), (0, 2)]
>>> more_itertools.consume(itertools.starmap(print, pairs))
-1 1
0 2
>>> more_itertools.consume(map(splat(print), pairs))
-1 1
0 2
The approach generalizes to other iterators that don't have a "star"
equivalent, such as a "starfilter".
>>> list(filter(splat(operator.add), pairs))
[(0, 2)]
Splat also accepts a mapping argument.
>>> def is_nice(msg, code):
... return "smile" in msg or code == 0
>>> msgs = [
... dict(msg='smile!', code=20),
... dict(msg='error :(', code=1),
... dict(msg='unknown', code=0),
... ]
>>> for msg in filter(splat(is_nice), msgs):
... print(msg)
{'msg': 'smile!', 'code': 20}
{'msg': 'unknown', 'code': 0}
"""
return functools.wraps(func)(functools.partial(_splat_inner, func=func))

View file

@ -0,0 +1,128 @@
from collections.abc import Callable, Hashable, Iterator
from functools import partial
from operator import methodcaller
import sys
from typing import (
Any,
Generic,
Protocol,
TypeVar,
overload,
)
if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec
else:
from typing_extensions import Concatenate, ParamSpec
_P = ParamSpec('_P')
_R = TypeVar('_R')
_T = TypeVar('_T')
_R1 = TypeVar('_R1')
_R2 = TypeVar('_R2')
_V = TypeVar('_V')
_S = TypeVar('_S')
_R_co = TypeVar('_R_co', covariant=True)
class _OnceCallable(Protocol[_P, _R]):
saved_result: _R
reset: Callable[[], None]
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
class _ProxyMethodCacheWrapper(Protocol[_R_co]):
cache_clear: Callable[[], None]
def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ...
class _MethodCacheWrapper(Protocol[_R_co]):
def cache_clear(self) -> None: ...
def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ...
# `compose()` overloads below will cover most use cases.
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[_P, _R],
/,
) -> Callable[_P, _T]: ...
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[[_R1], _R],
__func3: Callable[_P, _R1],
/,
) -> Callable[_P, _T]: ...
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[[_R2], _R],
__func3: Callable[[_R1], _R2],
__func4: Callable[_P, _R1],
/,
) -> Callable[_P, _T]: ...
def once(func: Callable[_P, _R]) -> _OnceCallable[_P, _R]: ...
def method_cache(
method: Callable[..., _R],
cache_wrapper: Callable[[Callable[..., _R]], _MethodCacheWrapper[_R]] = ...,
) -> _MethodCacheWrapper[_R] | _ProxyMethodCacheWrapper[_R]: ...
def apply(
transform: Callable[[_R], _T]
) -> Callable[[Callable[_P, _R]], Callable[_P, _T]]: ...
def result_invoke(
action: Callable[[_R], Any]
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
def invoke(
f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
def call_aside(
f: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
class Throttler(Generic[_R]):
last_called: float
func: Callable[..., _R]
max_rate: float
def __init__(
self, func: Callable[..., _R] | Throttler[_R], max_rate: float = ...
) -> None: ...
def reset(self) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> _R: ...
def __get__(self, obj: Any, owner: type[Any] | None = ...) -> Callable[..., _R]: ...
def first_invoke(
func1: Callable[..., Any], func2: Callable[_P, _R]
) -> Callable[_P, _R]: ...
method_caller: Callable[..., methodcaller]
def retry_call(
func: Callable[..., _R],
cleanup: Callable[..., None] = ...,
retries: int | float = ...,
trap: type[BaseException] | tuple[type[BaseException], ...] = ...,
) -> _R: ...
def retry(
cleanup: Callable[..., None] = ...,
retries: int | float = ...,
trap: type[BaseException] | tuple[type[BaseException], ...] = ...,
) -> Callable[[Callable[..., _R]], Callable[..., _R]]: ...
def print_yielded(func: Callable[_P, Iterator[Any]]) -> Callable[_P, None]: ...
def pass_none(
func: Callable[Concatenate[_T, _P], _R]
) -> Callable[Concatenate[_T, _P], _R]: ...
def assign_params(
func: Callable[..., _R], namespace: dict[str, Any]
) -> partial[_R]: ...
def save_method_args(
method: Callable[Concatenate[_S, _P], _R]
) -> Callable[Concatenate[_S, _P], _R]: ...
def except_(
*exceptions: type[BaseException], replace: Any = ..., use: Any = ...
) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]: ...
def identity(x: _T) -> _T: ...
def bypass_when(
check: _V, *, _op: Callable[[_V], Any] = ...
) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ...
def bypass_unless(
check: Any,
) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ...

View file

View file

@ -0,0 +1,2 @@
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus magna felis sollicitudin mauris. Integer in mauris eu nibh euismod gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue, eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis, neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis, molestie eu, feugiat in, orci. In hac habitasse platea dictumst.

View file

@ -0,0 +1,624 @@
import re
import itertools
import textwrap
import functools
try:
from importlib.resources import files # type: ignore
except ImportError: # pragma: nocover
from importlib_resources import files # type: ignore
from jaraco.functools import compose, method_cache
from jaraco.context import ExceptionTrap
def substitution(old, new):
"""
Return a function that will perform a substitution on a string
"""
return lambda s: s.replace(old, new)
def multi_substitution(*substitutions):
"""
Take a sequence of pairs specifying substitutions, and create
a function that performs those substitutions.
>>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo')
'baz'
"""
substitutions = itertools.starmap(substitution, substitutions)
# compose function applies last function first, so reverse the
# substitutions to get the expected order.
substitutions = reversed(tuple(substitutions))
return compose(*substitutions)
class FoldedCase(str):
"""
A case insensitive string class; behaves just like str
except compares equal when the only variation is case.
>>> s = FoldedCase('hello world')
>>> s == 'Hello World'
True
>>> 'Hello World' == s
True
>>> s != 'Hello World'
False
>>> s.index('O')
4
>>> s.split('O')
['hell', ' w', 'rld']
>>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta']))
['alpha', 'Beta', 'GAMMA']
Sequence membership is straightforward.
>>> "Hello World" in [s]
True
>>> s in ["Hello World"]
True
Allows testing for set inclusion, but candidate and elements
must both be folded.
>>> FoldedCase("Hello World") in {s}
True
>>> s in {FoldedCase("Hello World")}
True
String inclusion works as long as the FoldedCase object
is on the right.
>>> "hello" in FoldedCase("Hello World")
True
But not if the FoldedCase object is on the left:
>>> FoldedCase('hello') in 'Hello World'
False
In that case, use ``in_``:
>>> FoldedCase('hello').in_('Hello World')
True
>>> FoldedCase('hello') > FoldedCase('Hello')
False
>>> FoldedCase('ß') == FoldedCase('ss')
True
"""
def __lt__(self, other):
return self.casefold() < other.casefold()
def __gt__(self, other):
return self.casefold() > other.casefold()
def __eq__(self, other):
return self.casefold() == other.casefold()
def __ne__(self, other):
return self.casefold() != other.casefold()
def __hash__(self):
return hash(self.casefold())
def __contains__(self, other):
return super().casefold().__contains__(other.casefold())
def in_(self, other):
"Does self appear in other?"
return self in FoldedCase(other)
# cache casefold since it's likely to be called frequently.
@method_cache
def casefold(self):
return super().casefold()
def index(self, sub):
return self.casefold().index(sub.casefold())
def split(self, splitter=' ', maxsplit=0):
pattern = re.compile(re.escape(splitter), re.I)
return pattern.split(self, maxsplit)
# Python 3.8 compatibility
_unicode_trap = ExceptionTrap(UnicodeDecodeError)
@_unicode_trap.passes
def is_decodable(value):
r"""
Return True if the supplied value is decodable (using the default
encoding).
>>> is_decodable(b'\xff')
False
>>> is_decodable(b'\x32')
True
"""
value.decode()
def is_binary(value):
r"""
Return True if the value appears to be binary (that is, it's a byte
string and isn't decodable).
>>> is_binary(b'\xff')
True
>>> is_binary('\xff')
False
"""
return isinstance(value, bytes) and not is_decodable(value)
def trim(s):
r"""
Trim something like a docstring to remove the whitespace that
is common due to indentation and formatting.
>>> trim("\n\tfoo = bar\n\t\tbar = baz\n")
'foo = bar\n\tbar = baz'
"""
return textwrap.dedent(s).strip()
def wrap(s):
"""
Wrap lines of text, retaining existing newlines as
paragraph markers.
>>> print(wrap(lorem_ipsum))
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do
eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad
minim veniam, quis nostrud exercitation ullamco laboris nisi ut
aliquip ex ea commodo consequat. Duis aute irure dolor in
reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla
pariatur. Excepteur sint occaecat cupidatat non proident, sunt in
culpa qui officia deserunt mollit anim id est laborum.
<BLANKLINE>
Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam
varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus
magna felis sollicitudin mauris. Integer in mauris eu nibh euismod
gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis
risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue,
eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas
fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla
a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis,
neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing
sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque
nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus
quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis,
molestie eu, feugiat in, orci. In hac habitasse platea dictumst.
"""
paragraphs = s.splitlines()
wrapped = ('\n'.join(textwrap.wrap(para)) for para in paragraphs)
return '\n\n'.join(wrapped)
def unwrap(s):
r"""
Given a multi-line string, return an unwrapped version.
>>> wrapped = wrap(lorem_ipsum)
>>> wrapped.count('\n')
20
>>> unwrapped = unwrap(wrapped)
>>> unwrapped.count('\n')
1
>>> print(unwrapped)
Lorem ipsum dolor sit amet, consectetur adipiscing ...
Curabitur pretium tincidunt lacus. Nulla gravida orci ...
"""
paragraphs = re.split(r'\n\n+', s)
cleaned = (para.replace('\n', ' ') for para in paragraphs)
return '\n'.join(cleaned)
lorem_ipsum: str = (
files(__name__).joinpath('Lorem ipsum.txt').read_text(encoding='utf-8')
)
class Splitter:
"""object that will split a string with the given arguments for each call
>>> s = Splitter(',')
>>> s('hello, world, this is your, master calling')
['hello', ' world', ' this is your', ' master calling']
"""
def __init__(self, *args):
self.args = args
def __call__(self, s):
return s.split(*self.args)
def indent(string, prefix=' ' * 4):
"""
>>> indent('foo')
' foo'
"""
return prefix + string
class WordSet(tuple):
"""
Given an identifier, return the words that identifier represents,
whether in camel case, underscore-separated, etc.
>>> WordSet.parse("camelCase")
('camel', 'Case')
>>> WordSet.parse("under_sep")
('under', 'sep')
Acronyms should be retained
>>> WordSet.parse("firstSNL")
('first', 'SNL')
>>> WordSet.parse("you_and_I")
('you', 'and', 'I')
>>> WordSet.parse("A simple test")
('A', 'simple', 'test')
Multiple caps should not interfere with the first cap of another word.
>>> WordSet.parse("myABCClass")
('my', 'ABC', 'Class')
The result is a WordSet, providing access to other forms.
>>> WordSet.parse("myABCClass").underscore_separated()
'my_ABC_Class'
>>> WordSet.parse('a-command').camel_case()
'ACommand'
>>> WordSet.parse('someIdentifier').lowered().space_separated()
'some identifier'
Slices of the result should return another WordSet.
>>> WordSet.parse('taken-out-of-context')[1:].underscore_separated()
'out_of_context'
>>> WordSet.from_class_name(WordSet()).lowered().space_separated()
'word set'
>>> example = WordSet.parse('figured it out')
>>> example.headless_camel_case()
'figuredItOut'
>>> example.dash_separated()
'figured-it-out'
"""
_pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))')
def capitalized(self):
return WordSet(word.capitalize() for word in self)
def lowered(self):
return WordSet(word.lower() for word in self)
def camel_case(self):
return ''.join(self.capitalized())
def headless_camel_case(self):
words = iter(self)
first = next(words).lower()
new_words = itertools.chain((first,), WordSet(words).camel_case())
return ''.join(new_words)
def underscore_separated(self):
return '_'.join(self)
def dash_separated(self):
return '-'.join(self)
def space_separated(self):
return ' '.join(self)
def trim_right(self, item):
"""
Remove the item from the end of the set.
>>> WordSet.parse('foo bar').trim_right('foo')
('foo', 'bar')
>>> WordSet.parse('foo bar').trim_right('bar')
('foo',)
>>> WordSet.parse('').trim_right('bar')
()
"""
return self[:-1] if self and self[-1] == item else self
def trim_left(self, item):
"""
Remove the item from the beginning of the set.
>>> WordSet.parse('foo bar').trim_left('foo')
('bar',)
>>> WordSet.parse('foo bar').trim_left('bar')
('foo', 'bar')
>>> WordSet.parse('').trim_left('bar')
()
"""
return self[1:] if self and self[0] == item else self
def trim(self, item):
"""
>>> WordSet.parse('foo bar').trim('foo')
('bar',)
"""
return self.trim_left(item).trim_right(item)
def __getitem__(self, item):
result = super().__getitem__(item)
if isinstance(item, slice):
result = WordSet(result)
return result
@classmethod
def parse(cls, identifier):
matches = cls._pattern.finditer(identifier)
return WordSet(match.group(0) for match in matches)
@classmethod
def from_class_name(cls, subject):
return cls.parse(subject.__class__.__name__)
# for backward compatibility
words = WordSet.parse
def simple_html_strip(s):
r"""
Remove HTML from the string `s`.
>>> str(simple_html_strip(''))
''
>>> print(simple_html_strip('A <bold>stormy</bold> day in paradise'))
A stormy day in paradise
>>> print(simple_html_strip('Somebody <!-- do not --> tell the truth.'))
Somebody tell the truth.
>>> print(simple_html_strip('What about<br/>\nmultiple lines?'))
What about
multiple lines?
"""
html_stripper = re.compile('(<!--.*?-->)|(<[^>]*>)|([^<]+)', re.DOTALL)
texts = (match.group(3) or '' for match in html_stripper.finditer(s))
return ''.join(texts)
class SeparatedValues(str):
"""
A string separated by a separator. Overrides __iter__ for getting
the values.
>>> list(SeparatedValues('a,b,c'))
['a', 'b', 'c']
Whitespace is stripped and empty values are discarded.
>>> list(SeparatedValues(' a, b , c, '))
['a', 'b', 'c']
"""
separator = ','
def __iter__(self):
parts = self.split(self.separator)
return filter(None, (part.strip() for part in parts))
class Stripper:
r"""
Given a series of lines, find the common prefix and strip it from them.
>>> lines = [
... 'abcdefg\n',
... 'abc\n',
... 'abcde\n',
... ]
>>> res = Stripper.strip_prefix(lines)
>>> res.prefix
'abc'
>>> list(res.lines)
['defg\n', '\n', 'de\n']
If no prefix is common, nothing should be stripped.
>>> lines = [
... 'abcd\n',
... '1234\n',
... ]
>>> res = Stripper.strip_prefix(lines)
>>> res.prefix = ''
>>> list(res.lines)
['abcd\n', '1234\n']
"""
def __init__(self, prefix, lines):
self.prefix = prefix
self.lines = map(self, lines)
@classmethod
def strip_prefix(cls, lines):
prefix_lines, lines = itertools.tee(lines)
prefix = functools.reduce(cls.common_prefix, prefix_lines)
return cls(prefix, lines)
def __call__(self, line):
if not self.prefix:
return line
null, prefix, rest = line.partition(self.prefix)
return rest
@staticmethod
def common_prefix(s1, s2):
"""
Return the common prefix of two lines.
"""
index = min(len(s1), len(s2))
while s1[:index] != s2[:index]:
index -= 1
return s1[:index]
def remove_prefix(text, prefix):
"""
Remove the prefix from the text if it exists.
>>> remove_prefix('underwhelming performance', 'underwhelming ')
'performance'
>>> remove_prefix('something special', 'sample')
'something special'
"""
null, prefix, rest = text.rpartition(prefix)
return rest
def remove_suffix(text, suffix):
"""
Remove the suffix from the text if it exists.
>>> remove_suffix('name.git', '.git')
'name'
>>> remove_suffix('something special', 'sample')
'something special'
"""
rest, suffix, null = text.partition(suffix)
return rest
def normalize_newlines(text):
r"""
Replace alternate newlines with the canonical newline.
>>> normalize_newlines('Lorem Ipsum\u2029')
'Lorem Ipsum\n'
>>> normalize_newlines('Lorem Ipsum\r\n')
'Lorem Ipsum\n'
>>> normalize_newlines('Lorem Ipsum\x85')
'Lorem Ipsum\n'
"""
newlines = ['\r\n', '\r', '\n', '\u0085', '\u2028', '\u2029']
pattern = '|'.join(newlines)
return re.sub(pattern, '\n', text)
def _nonblank(str):
return str and not str.startswith('#')
@functools.singledispatch
def yield_lines(iterable):
r"""
Yield valid lines of a string or iterable.
>>> list(yield_lines(''))
[]
>>> list(yield_lines(['foo', 'bar']))
['foo', 'bar']
>>> list(yield_lines('foo\nbar'))
['foo', 'bar']
>>> list(yield_lines('\nfoo\n#bar\nbaz #comment'))
['foo', 'baz #comment']
>>> list(yield_lines(['foo\nbar', 'baz', 'bing\n\n\n']))
['foo', 'bar', 'baz', 'bing']
"""
return itertools.chain.from_iterable(map(yield_lines, iterable))
@yield_lines.register(str)
def _(text):
return filter(_nonblank, map(str.strip, text.splitlines()))
def drop_comment(line):
"""
Drop comments.
>>> drop_comment('foo # bar')
'foo'
A hash without a space may be in a URL.
>>> drop_comment('http://example.com/foo#bar')
'http://example.com/foo#bar'
"""
return line.partition(' #')[0]
def join_continuation(lines):
r"""
Join lines continued by a trailing backslash.
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
['foobar', 'baz']
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
['foobar', 'baz']
>>> list(join_continuation(['foo \\', 'bar \\', 'baz']))
['foobarbaz']
Not sure why, but...
The character preceding the backslash is also elided.
>>> list(join_continuation(['goo\\', 'dly']))
['godly']
A terrible idea, but...
If no line is available to continue, suppress the lines.
>>> list(join_continuation(['foo', 'bar\\', 'baz\\']))
['foo']
"""
lines = iter(lines)
for item in lines:
while item.endswith('\\'):
try:
item = item[:-2].strip() + next(lines)
except StopIteration:
return
yield item
def read_newlines(filename, limit=1024):
r"""
>>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\n', newline='', encoding='utf-8')
>>> read_newlines(filename)
'\n'
>>> _ = filename.write_text('foo\r\n', newline='', encoding='utf-8')
>>> read_newlines(filename)
'\r\n'
>>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='', encoding='utf-8')
>>> read_newlines(filename)
('\r', '\n', '\r\n')
"""
with open(filename, encoding='utf-8') as fp:
fp.read(limit)
return fp.newlines

View file

@ -0,0 +1,25 @@
qwerty = "-=qwertyuiop[]asdfghjkl;'zxcvbnm,./_+QWERTYUIOP{}ASDFGHJKL:\"ZXCVBNM<>?"
dvorak = "[]',.pyfgcrl/=aoeuidhtns-;qjkxbmwvz{}\"<>PYFGCRL?+AOEUIDHTNS_:QJKXBMWVZ"
to_dvorak = str.maketrans(qwerty, dvorak)
to_qwerty = str.maketrans(dvorak, qwerty)
def translate(input, translation):
"""
>>> translate('dvorak', to_dvorak)
'ekrpat'
>>> translate('qwerty', to_qwerty)
'x,dokt'
"""
return input.translate(translation)
def _translate_stream(stream, translation):
"""
>>> import io
>>> _translate_stream(io.StringIO('foo'), to_dvorak)
urr
"""
print(translate(stream.read(), translation))

View file

@ -0,0 +1,33 @@
import autocommand
import inflect
from more_itertools import always_iterable
import jaraco.text
def report_newlines(filename):
r"""
Report the newlines in the indicated file.
>>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\n', newline='', encoding='utf-8')
>>> report_newlines(filename)
newline is '\n'
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\r\n', newline='', encoding='utf-8')
>>> report_newlines(filename)
newlines are ('\n', '\r\n')
"""
newlines = jaraco.text.read_newlines(filename)
count = len(tuple(always_iterable(newlines)))
engine = inflect.engine()
print(
engine.plural_noun("newline", count),
engine.plural_verb("is", count),
repr(newlines),
)
autocommand.autocommand(__name__)(report_newlines)

View file

@ -0,0 +1,21 @@
import sys
import autocommand
from jaraco.text import Stripper
def strip_prefix():
r"""
Strip any common prefix from stdin.
>>> import io, pytest
>>> getfixture('monkeypatch').setattr('sys.stdin', io.StringIO('abcdef\nabc123'))
>>> strip_prefix()
def
123
"""
sys.stdout.writelines(Stripper.strip_prefix(sys.stdin).lines)
autocommand.autocommand(__name__)(strip_prefix)

View file

@ -0,0 +1,6 @@
import sys
from . import layouts
__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_dvorak)

View file

@ -0,0 +1,6 @@
import sys
from . import layouts
__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_qwerty)

View file

@ -33,8 +33,6 @@ Internally ``MediaFile`` uses ``MediaField`` descriptors to access the
data from the tags. In turn ``MediaField`` uses a number of data from the tags. In turn ``MediaField`` uses a number of
``StorageStyle`` strategies to handle format specific logic. ``StorageStyle`` strategies to handle format specific logic.
""" """
from __future__ import division, absolute_import, print_function
import mutagen import mutagen
import mutagen.id3 import mutagen.id3
import mutagen.mp3 import mutagen.mp3
@ -48,18 +46,17 @@ import binascii
import codecs import codecs
import datetime import datetime
import enum import enum
import filetype
import functools import functools
import imghdr
import logging import logging
import math import math
import os import os
import re import re
import six
import struct import struct
import traceback import traceback
__version__ = '0.10.1' __version__ = '0.13.0'
__all__ = ['UnreadableFileError', 'FileTypeError', 'MediaFile'] __all__ = ['UnreadableFileError', 'FileTypeError', 'MediaFile']
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -81,8 +78,6 @@ TYPES = {
'wav': 'WAVE', 'wav': 'WAVE',
} }
PREFERRED_IMAGE_EXTENSIONS = {'jpeg': 'jpg'}
# Exceptions. # Exceptions.
@ -136,8 +131,8 @@ def mutagen_call(action, filename, func, *args, **kwargs):
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except mutagen.MutagenError as exc: except mutagen.MutagenError as exc:
log.debug(u'%s failed: %s', action, six.text_type(exc)) log.debug(u'%s failed: %s', action, str(exc))
raise UnreadableFileError(filename, six.text_type(exc)) raise UnreadableFileError(filename, str(exc))
except UnreadableFileError: except UnreadableFileError:
# Reraise our errors without changes. # Reraise our errors without changes.
# Used in case of decorating functions (e.g. by `loadfile`). # Used in case of decorating functions (e.g. by `loadfile`).
@ -202,8 +197,8 @@ def _safe_cast(out_type, val):
# Process any other type as a string. # Process any other type as a string.
if isinstance(val, bytes): if isinstance(val, bytes):
val = val.decode('utf-8', 'ignore') val = val.decode('utf-8', 'ignore')
elif not isinstance(val, six.string_types): elif not isinstance(val, str):
val = six.text_type(val) val = str(val)
# Get a number from the front of the string. # Get a number from the front of the string.
match = re.match(r'[\+-]?[0-9]+', val.strip()) match = re.match(r'[\+-]?[0-9]+', val.strip())
return int(match.group(0)) if match else 0 return int(match.group(0)) if match else 0
@ -215,13 +210,13 @@ def _safe_cast(out_type, val):
except ValueError: except ValueError:
return False return False
elif out_type == six.text_type: elif out_type == str:
if isinstance(val, bytes): if isinstance(val, bytes):
return val.decode('utf-8', 'ignore') return val.decode('utf-8', 'ignore')
elif isinstance(val, six.text_type): elif isinstance(val, str):
return val return val
else: else:
return six.text_type(val) return str(val)
elif out_type == float: elif out_type == float:
if isinstance(val, int) or isinstance(val, float): if isinstance(val, int) or isinstance(val, float):
@ -230,7 +225,7 @@ def _safe_cast(out_type, val):
if isinstance(val, bytes): if isinstance(val, bytes):
val = val.decode('utf-8', 'ignore') val = val.decode('utf-8', 'ignore')
else: else:
val = six.text_type(val) val = str(val)
match = re.match(r'[\+-]?([0-9]+\.?[0-9]*|[0-9]*\.[0-9]+)', match = re.match(r'[\+-]?([0-9]+\.?[0-9]*|[0-9]*\.[0-9]+)',
val.strip()) val.strip())
if match: if match:
@ -289,7 +284,7 @@ def _sc_decode(soundcheck):
""" """
# We decode binary data. If one of the formats gives us a text # We decode binary data. If one of the formats gives us a text
# string, interpret it as UTF-8. # string, interpret it as UTF-8.
if isinstance(soundcheck, six.text_type): if isinstance(soundcheck, str):
soundcheck = soundcheck.encode('utf-8') soundcheck = soundcheck.encode('utf-8')
# SoundCheck tags consist of 10 numbers, each represented by 8 # SoundCheck tags consist of 10 numbers, each represented by 8
@ -349,52 +344,15 @@ def _sc_encode(gain, peak):
# Cover art and other images. # Cover art and other images.
def _imghdr_what_wrapper(data):
"""A wrapper around imghdr.what to account for jpeg files that can only be
identified as such using their magic bytes
See #1545
See https://github.com/file/file/blob/master/magic/Magdir/jpeg#L12
"""
# imghdr.what returns none for jpegs with only the magic bytes, so
# _wider_test_jpeg is run in that case. It still returns None if it didn't
# match such a jpeg file.
return imghdr.what(None, h=data) or _wider_test_jpeg(data)
def _wider_test_jpeg(data):
"""Test for a jpeg file following the UNIX file implementation which
uses the magic bytes rather than just looking for the bytes that
represent 'JFIF' or 'EXIF' at a fixed position.
"""
if data[:2] == b'\xff\xd8':
return 'jpeg'
def image_mime_type(data): def image_mime_type(data):
"""Return the MIME type of the image data (a bytestring). """Return the MIME type of the image data (a bytestring).
""" """
# This checks for a jpeg file with only the magic bytes (unrecognized by return filetype.guess_mime(data)
# imghdr.what). imghdr.what returns none for that type of file, so
# _wider_test_jpeg is run in that case. It still returns None if it didn't
# match such a jpeg file.
kind = _imghdr_what_wrapper(data)
if kind in ['gif', 'jpeg', 'png', 'tiff', 'bmp']:
return 'image/{0}'.format(kind)
elif kind == 'pgm':
return 'image/x-portable-graymap'
elif kind == 'pbm':
return 'image/x-portable-bitmap'
elif kind == 'ppm':
return 'image/x-portable-pixmap'
elif kind == 'xbm':
return 'image/x-xbitmap'
else:
return 'image/x-{0}'.format(kind)
def image_extension(data): def image_extension(data):
ext = _imghdr_what_wrapper(data) return filetype.guess_extension(data)
return PREFERRED_IMAGE_EXTENSIONS.get(ext, ext)
class ImageType(enum.Enum): class ImageType(enum.Enum):
@ -437,7 +395,7 @@ class Image(object):
def __init__(self, data, desc=None, type=None): def __init__(self, data, desc=None, type=None):
assert isinstance(data, bytes) assert isinstance(data, bytes)
if desc is not None: if desc is not None:
assert isinstance(desc, six.text_type) assert isinstance(desc, str)
self.data = data self.data = data
self.desc = desc self.desc = desc
if isinstance(type, int): if isinstance(type, int):
@ -495,7 +453,7 @@ class StorageStyle(object):
"""List of mutagen classes the StorageStyle can handle. """List of mutagen classes the StorageStyle can handle.
""" """
def __init__(self, key, as_type=six.text_type, suffix=None, def __init__(self, key, as_type=str, suffix=None,
float_places=2, read_only=False): float_places=2, read_only=False):
"""Create a basic storage strategy. Parameters: """Create a basic storage strategy. Parameters:
@ -520,8 +478,8 @@ class StorageStyle(object):
self.read_only = read_only self.read_only = read_only
# Convert suffix to correct string type. # Convert suffix to correct string type.
if self.suffix and self.as_type is six.text_type \ if self.suffix and self.as_type is str \
and not isinstance(self.suffix, six.text_type): and not isinstance(self.suffix, str):
self.suffix = self.suffix.decode('utf-8') self.suffix = self.suffix.decode('utf-8')
# Getter. # Getter.
@ -544,7 +502,7 @@ class StorageStyle(object):
"""Given a raw value stored on a Mutagen object, decode and """Given a raw value stored on a Mutagen object, decode and
return the represented value. return the represented value.
""" """
if self.suffix and isinstance(mutagen_value, six.text_type) \ if self.suffix and isinstance(mutagen_value, str) \
and mutagen_value.endswith(self.suffix): and mutagen_value.endswith(self.suffix):
return mutagen_value[:-len(self.suffix)] return mutagen_value[:-len(self.suffix)]
else: else:
@ -566,17 +524,17 @@ class StorageStyle(object):
"""Convert the external Python value to a type that is suitable for """Convert the external Python value to a type that is suitable for
storing in a Mutagen file object. storing in a Mutagen file object.
""" """
if isinstance(value, float) and self.as_type is six.text_type: if isinstance(value, float) and self.as_type is str:
value = u'{0:.{1}f}'.format(value, self.float_places) value = u'{0:.{1}f}'.format(value, self.float_places)
value = self.as_type(value) value = self.as_type(value)
elif self.as_type is six.text_type: elif self.as_type is str:
if isinstance(value, bool): if isinstance(value, bool):
# Store bools as 1/0 instead of True/False. # Store bools as 1/0 instead of True/False.
value = six.text_type(int(bool(value))) value = str(int(bool(value)))
elif isinstance(value, bytes): elif isinstance(value, bytes):
value = value.decode('utf-8', 'ignore') value = value.decode('utf-8', 'ignore')
else: else:
value = six.text_type(value) value = str(value)
else: else:
value = self.as_type(value) value = self.as_type(value)
@ -600,8 +558,8 @@ class ListStorageStyle(StorageStyle):
object to each. object to each.
Subclasses may overwrite ``fetch`` and ``store``. ``fetch`` must Subclasses may overwrite ``fetch`` and ``store``. ``fetch`` must
return a (possibly empty) list and ``store`` receives a serialized return a (possibly empty) list or `None` if the tag does not exist.
list of values as the second argument. ``store`` receives a serialized list of values as the second argument.
The `serialize` and `deserialize` methods (from the base The `serialize` and `deserialize` methods (from the base
`StorageStyle`) are still called with individual values. This class `StorageStyle`) are still called with individual values. This class
@ -610,15 +568,23 @@ class ListStorageStyle(StorageStyle):
def get(self, mutagen_file): def get(self, mutagen_file):
"""Get the first value in the field's value list. """Get the first value in the field's value list.
""" """
values = self.get_list(mutagen_file)
if values is None:
return None
try: try:
return self.get_list(mutagen_file)[0] return values[0]
except IndexError: except IndexError:
return None return None
def get_list(self, mutagen_file): def get_list(self, mutagen_file):
"""Get a list of all values for the field using this style. """Get a list of all values for the field using this style.
""" """
return [self.deserialize(item) for item in self.fetch(mutagen_file)] raw_values = self.fetch(mutagen_file)
if raw_values is None:
return None
return [self.deserialize(item) for item in raw_values]
def fetch(self, mutagen_file): def fetch(self, mutagen_file):
"""Get the list of raw (serialized) values. """Get the list of raw (serialized) values.
@ -626,19 +592,27 @@ class ListStorageStyle(StorageStyle):
try: try:
return mutagen_file[self.key] return mutagen_file[self.key]
except KeyError: except KeyError:
return [] return None
def set(self, mutagen_file, value): def set(self, mutagen_file, value):
"""Set an individual value as the only value for the field using """Set an individual value as the only value for the field using
this style. this style.
""" """
self.set_list(mutagen_file, [value]) if value is None:
self.store(mutagen_file, None)
else:
self.set_list(mutagen_file, [value])
def set_list(self, mutagen_file, values): def set_list(self, mutagen_file, values):
"""Set all values for the field using this style. `values` """Set all values for the field using this style. `values`
should be an iterable. should be an iterable.
""" """
self.store(mutagen_file, [self.serialize(value) for value in values]) if values is None:
self.delete(mutagen_file)
else:
self.store(
mutagen_file, [self.serialize(value) for value in values]
)
def store(self, mutagen_file, values): def store(self, mutagen_file, values):
"""Set the list of all raw (serialized) values for this field. """Set the list of all raw (serialized) values for this field.
@ -686,7 +660,7 @@ class MP4StorageStyle(StorageStyle):
def serialize(self, value): def serialize(self, value):
value = super(MP4StorageStyle, self).serialize(value) value = super(MP4StorageStyle, self).serialize(value)
if self.key.startswith('----:') and isinstance(value, six.text_type): if self.key.startswith('----:') and isinstance(value, str):
value = value.encode('utf-8') value = value.encode('utf-8')
return value return value
@ -865,7 +839,7 @@ class MP3UFIDStorageStyle(MP3StorageStyle):
def store(self, mutagen_file, value): def store(self, mutagen_file, value):
# This field type stores text data as encoded data. # This field type stores text data as encoded data.
assert isinstance(value, six.text_type) assert isinstance(value, str)
value = value.encode('utf-8') value = value.encode('utf-8')
frames = mutagen_file.tags.getall(self.key) frames = mutagen_file.tags.getall(self.key)
@ -889,7 +863,7 @@ class MP3DescStorageStyle(MP3StorageStyle):
""" """
def __init__(self, desc=u'', key='TXXX', attr='text', multispec=True, def __init__(self, desc=u'', key='TXXX', attr='text', multispec=True,
**kwargs): **kwargs):
assert isinstance(desc, six.text_type) assert isinstance(desc, str)
self.description = desc self.description = desc
self.attr = attr self.attr = attr
self.multispec = multispec self.multispec = multispec
@ -978,7 +952,7 @@ class MP3SlashPackStorageStyle(MP3StorageStyle):
def _fetch_unpacked(self, mutagen_file): def _fetch_unpacked(self, mutagen_file):
data = self.fetch(mutagen_file) data = self.fetch(mutagen_file)
if data: if data:
items = six.text_type(data).split('/') items = str(data).split('/')
else: else:
items = [] items = []
packing_length = 2 packing_length = 2
@ -994,7 +968,7 @@ class MP3SlashPackStorageStyle(MP3StorageStyle):
items[0] = '' items[0] = ''
if items[1] is None: if items[1] is None:
items.pop() # Do not store last value items.pop() # Do not store last value
self.store(mutagen_file, '/'.join(map(six.text_type, items))) self.store(mutagen_file, '/'.join(map(str, items)))
def delete(self, mutagen_file): def delete(self, mutagen_file):
if self.pack_pos == 0: if self.pack_pos == 0:
@ -1261,7 +1235,7 @@ class MediaField(object):
getting this property. getting this property.
""" """
self.out_type = kwargs.get('out_type', six.text_type) self.out_type = kwargs.get('out_type', str)
self._styles = styles self._styles = styles
def styles(self, mutagen_file): def styles(self, mutagen_file):
@ -1301,7 +1275,7 @@ class MediaField(object):
return 0.0 return 0.0
elif self.out_type == bool: elif self.out_type == bool:
return False return False
elif self.out_type == six.text_type: elif self.out_type == str:
return u'' return u''
@ -1317,7 +1291,7 @@ class ListMediaField(MediaField):
values = style.get_list(mediafile.mgfile) values = style.get_list(mediafile.mgfile)
if values: if values:
return [_safe_cast(self.out_type, value) for value in values] return [_safe_cast(self.out_type, value) for value in values]
return [] return None
def __set__(self, mediafile, values): def __set__(self, mediafile, values):
for style in self.styles(mediafile.mgfile): for style in self.styles(mediafile.mgfile):
@ -1384,9 +1358,9 @@ class DateField(MediaField):
""" """
# Get the underlying data and split on hyphens and slashes. # Get the underlying data and split on hyphens and slashes.
datestring = super(DateField, self).__get__(mediafile, None) datestring = super(DateField, self).__get__(mediafile, None)
if isinstance(datestring, six.string_types): if isinstance(datestring, str):
datestring = re.sub(r'[Tt ].*$', '', six.text_type(datestring)) datestring = re.sub(r'[Tt ].*$', '', str(datestring))
items = re.split('[-/]', six.text_type(datestring)) items = re.split('[-/]', str(datestring))
else: else:
items = [] items = []
@ -1423,7 +1397,7 @@ class DateField(MediaField):
date.append(u'{0:02d}'.format(int(month))) date.append(u'{0:02d}'.format(int(month)))
if month and day: if month and day:
date.append(u'{0:02d}'.format(int(day))) date.append(u'{0:02d}'.format(int(day)))
date = map(six.text_type, date) date = map(str, date)
super(DateField, self).__set__(mediafile, u'-'.join(date)) super(DateField, self).__set__(mediafile, u'-'.join(date))
if hasattr(self, '_year_field'): if hasattr(self, '_year_field'):
@ -2071,6 +2045,7 @@ class MediaFile(object):
original_date = DateField( original_date = DateField(
MP3StorageStyle('TDOR'), MP3StorageStyle('TDOR'),
MP4StorageStyle('----:com.apple.iTunes:ORIGINAL YEAR'), MP4StorageStyle('----:com.apple.iTunes:ORIGINAL YEAR'),
MP4StorageStyle('----:com.apple.iTunes:ORIGINALDATE'),
StorageStyle('ORIGINALDATE'), StorageStyle('ORIGINALDATE'),
ASFStorageStyle('WM/OriginalReleaseYear')) ASFStorageStyle('WM/OriginalReleaseYear'))
@ -2085,12 +2060,36 @@ class MediaFile(object):
StorageStyle('ARTIST_CREDIT'), StorageStyle('ARTIST_CREDIT'),
ASFStorageStyle('beets/Artist Credit'), ASFStorageStyle('beets/Artist Credit'),
) )
artists_credit = ListMediaField(
MP3ListDescStorageStyle(desc=u'ARTISTS_CREDIT'),
MP4ListStorageStyle('----:com.apple.iTunes:ARTISTS_CREDIT'),
ListStorageStyle('ARTISTS_CREDIT'),
ASFStorageStyle('beets/ArtistsCredit'),
)
artists_sort = ListMediaField(
MP3ListDescStorageStyle(desc=u'ARTISTS_SORT'),
MP4ListStorageStyle('----:com.apple.iTunes:ARTISTS_SORT'),
ListStorageStyle('ARTISTS_SORT'),
ASFStorageStyle('beets/ArtistsSort'),
)
albumartist_credit = MediaField( albumartist_credit = MediaField(
MP3DescStorageStyle(u'Album Artist Credit'), MP3DescStorageStyle(u'Album Artist Credit'),
MP4StorageStyle('----:com.apple.iTunes:Album Artist Credit'), MP4StorageStyle('----:com.apple.iTunes:Album Artist Credit'),
StorageStyle('ALBUMARTIST_CREDIT'), StorageStyle('ALBUMARTIST_CREDIT'),
ASFStorageStyle('beets/Album Artist Credit'), ASFStorageStyle('beets/Album Artist Credit'),
) )
albumartists_credit = ListMediaField(
MP3ListDescStorageStyle(desc=u'ALBUMARTISTS_CREDIT'),
MP4ListStorageStyle('----:com.apple.iTunes:ALBUMARTISTS_CREDIT'),
ListStorageStyle('ALBUMARTISTS_CREDIT'),
ASFStorageStyle('beets/AlbumArtistsCredit'),
)
albumartists_sort = ListMediaField(
MP3ListDescStorageStyle(desc=u'ALBUMARTISTS_SORT'),
MP4ListStorageStyle('----:com.apple.iTunes:ALBUMARTISTS_SORT'),
ListStorageStyle('ALBUMARTISTS_SORT'),
ASFStorageStyle('beets/AlbumArtistsSort'),
)
# Legacy album art field # Legacy album art field
art = CoverArtField() art = CoverArtField()

View file

@ -1,21 +1,15 @@
# This file is dual licensed under the terms of the Apache License, Version # This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository # 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details. # for complete details.
from __future__ import absolute_import, division, print_function
__all__ = [
"__title__", "__summary__", "__uri__", "__version__", "__author__",
"__email__", "__license__", "__copyright__",
]
__title__ = "packaging" __title__ = "packaging"
__summary__ = "Core utilities for Python packages" __summary__ = "Core utilities for Python packages"
__uri__ = "https://github.com/pypa/packaging" __uri__ = "https://github.com/pypa/packaging"
__version__ = "16.8" __version__ = "23.2"
__author__ = "Donald Stufft and individual contributors" __author__ = "Donald Stufft and individual contributors"
__email__ = "donald@stufft.io" __email__ = "donald@stufft.io"
__license__ = "BSD or Apache License, Version 2.0" __license__ = "BSD-2-Clause or Apache-2.0"
__copyright__ = "Copyright 2014-2016 %s" % __author__ __copyright__ = "2014 %s" % __author__

View file

@ -0,0 +1,108 @@
"""
ELF file parser.
This provides a class ``ELFFile`` that parses an ELF executable in a similar
interface to ``ZipFile``. Only the read interface is implemented.
Based on: https://gist.github.com/lyssdod/f51579ae8d93c8657a5564aefc2ffbca
ELF header: https://refspecs.linuxfoundation.org/elf/gabi4+/ch4.eheader.html
"""
import enum
import os
import struct
from typing import IO, Optional, Tuple
class ELFInvalid(ValueError):
pass
class EIClass(enum.IntEnum):
C32 = 1
C64 = 2
class EIData(enum.IntEnum):
Lsb = 1
Msb = 2
class EMachine(enum.IntEnum):
I386 = 3
S390 = 22
Arm = 40
X8664 = 62
AArc64 = 183
class ELFFile:
"""
Representation of an ELF executable.
"""
def __init__(self, f: IO[bytes]) -> None:
self._f = f
try:
ident = self._read("16B")
except struct.error:
raise ELFInvalid("unable to parse identification")
magic = bytes(ident[:4])
if magic != b"\x7fELF":
raise ELFInvalid(f"invalid magic: {magic!r}")
self.capacity = ident[4] # Format for program header (bitness).
self.encoding = ident[5] # Data structure encoding (endianness).
try:
# e_fmt: Format for program header.
# p_fmt: Format for section header.
# p_idx: Indexes to find p_type, p_offset, and p_filesz.
e_fmt, self._p_fmt, self._p_idx = {
(1, 1): ("<HHIIIIIHHH", "<IIIIIIII", (0, 1, 4)), # 32-bit LSB.
(1, 2): (">HHIIIIIHHH", ">IIIIIIII", (0, 1, 4)), # 32-bit MSB.
(2, 1): ("<HHIQQQIHHH", "<IIQQQQQQ", (0, 2, 5)), # 64-bit LSB.
(2, 2): (">HHIQQQIHHH", ">IIQQQQQQ", (0, 2, 5)), # 64-bit MSB.
}[(self.capacity, self.encoding)]
except KeyError:
raise ELFInvalid(
f"unrecognized capacity ({self.capacity}) or "
f"encoding ({self.encoding})"
)
try:
(
_,
self.machine, # Architecture type.
_,
_,
self._e_phoff, # Offset of program header.
_,
self.flags, # Processor-specific flags.
_,
self._e_phentsize, # Size of section.
self._e_phnum, # Number of sections.
) = self._read(e_fmt)
except struct.error as e:
raise ELFInvalid("unable to parse machine and section information") from e
def _read(self, fmt: str) -> Tuple[int, ...]:
return struct.unpack(fmt, self._f.read(struct.calcsize(fmt)))
@property
def interpreter(self) -> Optional[str]:
"""
The path recorded in the ``PT_INTERP`` section header.
"""
for index in range(self._e_phnum):
self._f.seek(self._e_phoff + self._e_phentsize * index)
try:
data = self._read(self._p_fmt)
except struct.error:
continue
if data[self._p_idx[0]] != 3: # Not PT_INTERP.
continue
self._f.seek(data[self._p_idx[1]])
return os.fsdecode(self._f.read(data[self._p_idx[2]])).strip("\0")
return None

View file

@ -0,0 +1,252 @@
import collections
import contextlib
import functools
import os
import re
import sys
import warnings
from typing import Dict, Generator, Iterator, NamedTuple, Optional, Sequence, Tuple
from ._elffile import EIClass, EIData, ELFFile, EMachine
EF_ARM_ABIMASK = 0xFF000000
EF_ARM_ABI_VER5 = 0x05000000
EF_ARM_ABI_FLOAT_HARD = 0x00000400
# `os.PathLike` not a generic type until Python 3.9, so sticking with `str`
# as the type for `path` until then.
@contextlib.contextmanager
def _parse_elf(path: str) -> Generator[Optional[ELFFile], None, None]:
try:
with open(path, "rb") as f:
yield ELFFile(f)
except (OSError, TypeError, ValueError):
yield None
def _is_linux_armhf(executable: str) -> bool:
# hard-float ABI can be detected from the ELF header of the running
# process
# https://static.docs.arm.com/ihi0044/g/aaelf32.pdf
with _parse_elf(executable) as f:
return (
f is not None
and f.capacity == EIClass.C32
and f.encoding == EIData.Lsb
and f.machine == EMachine.Arm
and f.flags & EF_ARM_ABIMASK == EF_ARM_ABI_VER5
and f.flags & EF_ARM_ABI_FLOAT_HARD == EF_ARM_ABI_FLOAT_HARD
)
def _is_linux_i686(executable: str) -> bool:
with _parse_elf(executable) as f:
return (
f is not None
and f.capacity == EIClass.C32
and f.encoding == EIData.Lsb
and f.machine == EMachine.I386
)
def _have_compatible_abi(executable: str, archs: Sequence[str]) -> bool:
if "armv7l" in archs:
return _is_linux_armhf(executable)
if "i686" in archs:
return _is_linux_i686(executable)
allowed_archs = {"x86_64", "aarch64", "ppc64", "ppc64le", "s390x", "loongarch64"}
return any(arch in allowed_archs for arch in archs)
# If glibc ever changes its major version, we need to know what the last
# minor version was, so we can build the complete list of all versions.
# For now, guess what the highest minor version might be, assume it will
# be 50 for testing. Once this actually happens, update the dictionary
# with the actual value.
_LAST_GLIBC_MINOR: Dict[int, int] = collections.defaultdict(lambda: 50)
class _GLibCVersion(NamedTuple):
major: int
minor: int
def _glibc_version_string_confstr() -> Optional[str]:
"""
Primary implementation of glibc_version_string using os.confstr.
"""
# os.confstr is quite a bit faster than ctypes.DLL. It's also less likely
# to be broken or missing. This strategy is used in the standard library
# platform module.
# https://github.com/python/cpython/blob/fcf1d003bf4f0100c/Lib/platform.py#L175-L183
try:
# Should be a string like "glibc 2.17".
version_string: str = getattr(os, "confstr")("CS_GNU_LIBC_VERSION")
assert version_string is not None
_, version = version_string.rsplit()
except (AssertionError, AttributeError, OSError, ValueError):
# os.confstr() or CS_GNU_LIBC_VERSION not available (or a bad value)...
return None
return version
def _glibc_version_string_ctypes() -> Optional[str]:
"""
Fallback implementation of glibc_version_string using ctypes.
"""
try:
import ctypes
except ImportError:
return None
# ctypes.CDLL(None) internally calls dlopen(NULL), and as the dlopen
# manpage says, "If filename is NULL, then the returned handle is for the
# main program". This way we can let the linker do the work to figure out
# which libc our process is actually using.
#
# We must also handle the special case where the executable is not a
# dynamically linked executable. This can occur when using musl libc,
# for example. In this situation, dlopen() will error, leading to an
# OSError. Interestingly, at least in the case of musl, there is no
# errno set on the OSError. The single string argument used to construct
# OSError comes from libc itself and is therefore not portable to
# hard code here. In any case, failure to call dlopen() means we
# can proceed, so we bail on our attempt.
try:
process_namespace = ctypes.CDLL(None)
except OSError:
return None
try:
gnu_get_libc_version = process_namespace.gnu_get_libc_version
except AttributeError:
# Symbol doesn't exist -> therefore, we are not linked to
# glibc.
return None
# Call gnu_get_libc_version, which returns a string like "2.5"
gnu_get_libc_version.restype = ctypes.c_char_p
version_str: str = gnu_get_libc_version()
# py2 / py3 compatibility:
if not isinstance(version_str, str):
version_str = version_str.decode("ascii")
return version_str
def _glibc_version_string() -> Optional[str]:
"""Returns glibc version string, or None if not using glibc."""
return _glibc_version_string_confstr() or _glibc_version_string_ctypes()
def _parse_glibc_version(version_str: str) -> Tuple[int, int]:
"""Parse glibc version.
We use a regexp instead of str.split because we want to discard any
random junk that might come after the minor version -- this might happen
in patched/forked versions of glibc (e.g. Linaro's version of glibc
uses version strings like "2.20-2014.11"). See gh-3588.
"""
m = re.match(r"(?P<major>[0-9]+)\.(?P<minor>[0-9]+)", version_str)
if not m:
warnings.warn(
f"Expected glibc version with 2 components major.minor,"
f" got: {version_str}",
RuntimeWarning,
)
return -1, -1
return int(m.group("major")), int(m.group("minor"))
@functools.lru_cache()
def _get_glibc_version() -> Tuple[int, int]:
version_str = _glibc_version_string()
if version_str is None:
return (-1, -1)
return _parse_glibc_version(version_str)
# From PEP 513, PEP 600
def _is_compatible(arch: str, version: _GLibCVersion) -> bool:
sys_glibc = _get_glibc_version()
if sys_glibc < version:
return False
# Check for presence of _manylinux module.
try:
import _manylinux # noqa
except ImportError:
return True
if hasattr(_manylinux, "manylinux_compatible"):
result = _manylinux.manylinux_compatible(version[0], version[1], arch)
if result is not None:
return bool(result)
return True
if version == _GLibCVersion(2, 5):
if hasattr(_manylinux, "manylinux1_compatible"):
return bool(_manylinux.manylinux1_compatible)
if version == _GLibCVersion(2, 12):
if hasattr(_manylinux, "manylinux2010_compatible"):
return bool(_manylinux.manylinux2010_compatible)
if version == _GLibCVersion(2, 17):
if hasattr(_manylinux, "manylinux2014_compatible"):
return bool(_manylinux.manylinux2014_compatible)
return True
_LEGACY_MANYLINUX_MAP = {
# CentOS 7 w/ glibc 2.17 (PEP 599)
(2, 17): "manylinux2014",
# CentOS 6 w/ glibc 2.12 (PEP 571)
(2, 12): "manylinux2010",
# CentOS 5 w/ glibc 2.5 (PEP 513)
(2, 5): "manylinux1",
}
def platform_tags(archs: Sequence[str]) -> Iterator[str]:
"""Generate manylinux tags compatible to the current platform.
:param archs: Sequence of compatible architectures.
The first one shall be the closest to the actual architecture and be the part of
platform tag after the ``linux_`` prefix, e.g. ``x86_64``.
The ``linux_`` prefix is assumed as a prerequisite for the current platform to
be manylinux-compatible.
:returns: An iterator of compatible manylinux tags.
"""
if not _have_compatible_abi(sys.executable, archs):
return
# Oldest glibc to be supported regardless of architecture is (2, 17).
too_old_glibc2 = _GLibCVersion(2, 16)
if set(archs) & {"x86_64", "i686"}:
# On x86/i686 also oldest glibc to be supported is (2, 5).
too_old_glibc2 = _GLibCVersion(2, 4)
current_glibc = _GLibCVersion(*_get_glibc_version())
glibc_max_list = [current_glibc]
# We can assume compatibility across glibc major versions.
# https://sourceware.org/bugzilla/show_bug.cgi?id=24636
#
# Build a list of maximum glibc versions so that we can
# output the canonical list of all glibc from current_glibc
# down to too_old_glibc2, including all intermediary versions.
for glibc_major in range(current_glibc.major - 1, 1, -1):
glibc_minor = _LAST_GLIBC_MINOR[glibc_major]
glibc_max_list.append(_GLibCVersion(glibc_major, glibc_minor))
for arch in archs:
for glibc_max in glibc_max_list:
if glibc_max.major == too_old_glibc2.major:
min_minor = too_old_glibc2.minor
else:
# For other glibc major versions oldest supported is (x, 0).
min_minor = -1
for glibc_minor in range(glibc_max.minor, min_minor, -1):
glibc_version = _GLibCVersion(glibc_max.major, glibc_minor)
tag = "manylinux_{}_{}".format(*glibc_version)
if _is_compatible(arch, glibc_version):
yield f"{tag}_{arch}"
# Handle the legacy manylinux1, manylinux2010, manylinux2014 tags.
if glibc_version in _LEGACY_MANYLINUX_MAP:
legacy_tag = _LEGACY_MANYLINUX_MAP[glibc_version]
if _is_compatible(arch, glibc_version):
yield f"{legacy_tag}_{arch}"

View file

@ -0,0 +1,83 @@
"""PEP 656 support.
This module implements logic to detect if the currently running Python is
linked against musl, and what musl version is used.
"""
import functools
import re
import subprocess
import sys
from typing import Iterator, NamedTuple, Optional, Sequence
from ._elffile import ELFFile
class _MuslVersion(NamedTuple):
major: int
minor: int
def _parse_musl_version(output: str) -> Optional[_MuslVersion]:
lines = [n for n in (n.strip() for n in output.splitlines()) if n]
if len(lines) < 2 or lines[0][:4] != "musl":
return None
m = re.match(r"Version (\d+)\.(\d+)", lines[1])
if not m:
return None
return _MuslVersion(major=int(m.group(1)), minor=int(m.group(2)))
@functools.lru_cache()
def _get_musl_version(executable: str) -> Optional[_MuslVersion]:
"""Detect currently-running musl runtime version.
This is done by checking the specified executable's dynamic linking
information, and invoking the loader to parse its output for a version
string. If the loader is musl, the output would be something like::
musl libc (x86_64)
Version 1.2.2
Dynamic Program Loader
"""
try:
with open(executable, "rb") as f:
ld = ELFFile(f).interpreter
except (OSError, TypeError, ValueError):
return None
if ld is None or "musl" not in ld:
return None
proc = subprocess.run([ld], stderr=subprocess.PIPE, text=True)
return _parse_musl_version(proc.stderr)
def platform_tags(archs: Sequence[str]) -> Iterator[str]:
"""Generate musllinux tags compatible to the current platform.
:param archs: Sequence of compatible architectures.
The first one shall be the closest to the actual architecture and be the part of
platform tag after the ``linux_`` prefix, e.g. ``x86_64``.
The ``linux_`` prefix is assumed as a prerequisite for the current platform to
be musllinux-compatible.
:returns: An iterator of compatible musllinux tags.
"""
sys_musl = _get_musl_version(sys.executable)
if sys_musl is None: # Python not dynamically linked against musl.
return
for arch in archs:
for minor in range(sys_musl.minor, -1, -1):
yield f"musllinux_{sys_musl.major}_{minor}_{arch}"
if __name__ == "__main__": # pragma: no cover
import sysconfig
plat = sysconfig.get_platform()
assert plat.startswith("linux-"), "not linux"
print("plat:", plat)
print("musl:", _get_musl_version(sys.executable))
print("tags:", end=" ")
for t in platform_tags(re.sub(r"[.-]", "_", plat.split("-", 1)[-1])):
print(t, end="\n ")

View file

@ -0,0 +1,359 @@
"""Handwritten parser of dependency specifiers.
The docstring for each __parse_* function contains ENBF-inspired grammar representing
the implementation.
"""
import ast
from typing import Any, List, NamedTuple, Optional, Tuple, Union
from ._tokenizer import DEFAULT_RULES, Tokenizer
class Node:
def __init__(self, value: str) -> None:
self.value = value
def __str__(self) -> str:
return self.value
def __repr__(self) -> str:
return f"<{self.__class__.__name__}('{self}')>"
def serialize(self) -> str:
raise NotImplementedError
class Variable(Node):
def serialize(self) -> str:
return str(self)
class Value(Node):
def serialize(self) -> str:
return f'"{self}"'
class Op(Node):
def serialize(self) -> str:
return str(self)
MarkerVar = Union[Variable, Value]
MarkerItem = Tuple[MarkerVar, Op, MarkerVar]
# MarkerAtom = Union[MarkerItem, List["MarkerAtom"]]
# MarkerList = List[Union["MarkerList", MarkerAtom, str]]
# mypy does not support recursive type definition
# https://github.com/python/mypy/issues/731
MarkerAtom = Any
MarkerList = List[Any]
class ParsedRequirement(NamedTuple):
name: str
url: str
extras: List[str]
specifier: str
marker: Optional[MarkerList]
# --------------------------------------------------------------------------------------
# Recursive descent parser for dependency specifier
# --------------------------------------------------------------------------------------
def parse_requirement(source: str) -> ParsedRequirement:
return _parse_requirement(Tokenizer(source, rules=DEFAULT_RULES))
def _parse_requirement(tokenizer: Tokenizer) -> ParsedRequirement:
"""
requirement = WS? IDENTIFIER WS? extras WS? requirement_details
"""
tokenizer.consume("WS")
name_token = tokenizer.expect(
"IDENTIFIER", expected="package name at the start of dependency specifier"
)
name = name_token.text
tokenizer.consume("WS")
extras = _parse_extras(tokenizer)
tokenizer.consume("WS")
url, specifier, marker = _parse_requirement_details(tokenizer)
tokenizer.expect("END", expected="end of dependency specifier")
return ParsedRequirement(name, url, extras, specifier, marker)
def _parse_requirement_details(
tokenizer: Tokenizer,
) -> Tuple[str, str, Optional[MarkerList]]:
"""
requirement_details = AT URL (WS requirement_marker?)?
| specifier WS? (requirement_marker)?
"""
specifier = ""
url = ""
marker = None
if tokenizer.check("AT"):
tokenizer.read()
tokenizer.consume("WS")
url_start = tokenizer.position
url = tokenizer.expect("URL", expected="URL after @").text
if tokenizer.check("END", peek=True):
return (url, specifier, marker)
tokenizer.expect("WS", expected="whitespace after URL")
# The input might end after whitespace.
if tokenizer.check("END", peek=True):
return (url, specifier, marker)
marker = _parse_requirement_marker(
tokenizer, span_start=url_start, after="URL and whitespace"
)
else:
specifier_start = tokenizer.position
specifier = _parse_specifier(tokenizer)
tokenizer.consume("WS")
if tokenizer.check("END", peek=True):
return (url, specifier, marker)
marker = _parse_requirement_marker(
tokenizer,
span_start=specifier_start,
after=(
"version specifier"
if specifier
else "name and no valid version specifier"
),
)
return (url, specifier, marker)
def _parse_requirement_marker(
tokenizer: Tokenizer, *, span_start: int, after: str
) -> MarkerList:
"""
requirement_marker = SEMICOLON marker WS?
"""
if not tokenizer.check("SEMICOLON"):
tokenizer.raise_syntax_error(
f"Expected end or semicolon (after {after})",
span_start=span_start,
)
tokenizer.read()
marker = _parse_marker(tokenizer)
tokenizer.consume("WS")
return marker
def _parse_extras(tokenizer: Tokenizer) -> List[str]:
"""
extras = (LEFT_BRACKET wsp* extras_list? wsp* RIGHT_BRACKET)?
"""
if not tokenizer.check("LEFT_BRACKET", peek=True):
return []
with tokenizer.enclosing_tokens(
"LEFT_BRACKET",
"RIGHT_BRACKET",
around="extras",
):
tokenizer.consume("WS")
extras = _parse_extras_list(tokenizer)
tokenizer.consume("WS")
return extras
def _parse_extras_list(tokenizer: Tokenizer) -> List[str]:
"""
extras_list = identifier (wsp* ',' wsp* identifier)*
"""
extras: List[str] = []
if not tokenizer.check("IDENTIFIER"):
return extras
extras.append(tokenizer.read().text)
while True:
tokenizer.consume("WS")
if tokenizer.check("IDENTIFIER", peek=True):
tokenizer.raise_syntax_error("Expected comma between extra names")
elif not tokenizer.check("COMMA"):
break
tokenizer.read()
tokenizer.consume("WS")
extra_token = tokenizer.expect("IDENTIFIER", expected="extra name after comma")
extras.append(extra_token.text)
return extras
def _parse_specifier(tokenizer: Tokenizer) -> str:
"""
specifier = LEFT_PARENTHESIS WS? version_many WS? RIGHT_PARENTHESIS
| WS? version_many WS?
"""
with tokenizer.enclosing_tokens(
"LEFT_PARENTHESIS",
"RIGHT_PARENTHESIS",
around="version specifier",
):
tokenizer.consume("WS")
parsed_specifiers = _parse_version_many(tokenizer)
tokenizer.consume("WS")
return parsed_specifiers
def _parse_version_many(tokenizer: Tokenizer) -> str:
"""
version_many = (SPECIFIER (WS? COMMA WS? SPECIFIER)*)?
"""
parsed_specifiers = ""
while tokenizer.check("SPECIFIER"):
span_start = tokenizer.position
parsed_specifiers += tokenizer.read().text
if tokenizer.check("VERSION_PREFIX_TRAIL", peek=True):
tokenizer.raise_syntax_error(
".* suffix can only be used with `==` or `!=` operators",
span_start=span_start,
span_end=tokenizer.position + 1,
)
if tokenizer.check("VERSION_LOCAL_LABEL_TRAIL", peek=True):
tokenizer.raise_syntax_error(
"Local version label can only be used with `==` or `!=` operators",
span_start=span_start,
span_end=tokenizer.position,
)
tokenizer.consume("WS")
if not tokenizer.check("COMMA"):
break
parsed_specifiers += tokenizer.read().text
tokenizer.consume("WS")
return parsed_specifiers
# --------------------------------------------------------------------------------------
# Recursive descent parser for marker expression
# --------------------------------------------------------------------------------------
def parse_marker(source: str) -> MarkerList:
return _parse_full_marker(Tokenizer(source, rules=DEFAULT_RULES))
def _parse_full_marker(tokenizer: Tokenizer) -> MarkerList:
retval = _parse_marker(tokenizer)
tokenizer.expect("END", expected="end of marker expression")
return retval
def _parse_marker(tokenizer: Tokenizer) -> MarkerList:
"""
marker = marker_atom (BOOLOP marker_atom)+
"""
expression = [_parse_marker_atom(tokenizer)]
while tokenizer.check("BOOLOP"):
token = tokenizer.read()
expr_right = _parse_marker_atom(tokenizer)
expression.extend((token.text, expr_right))
return expression
def _parse_marker_atom(tokenizer: Tokenizer) -> MarkerAtom:
"""
marker_atom = WS? LEFT_PARENTHESIS WS? marker WS? RIGHT_PARENTHESIS WS?
| WS? marker_item WS?
"""
tokenizer.consume("WS")
if tokenizer.check("LEFT_PARENTHESIS", peek=True):
with tokenizer.enclosing_tokens(
"LEFT_PARENTHESIS",
"RIGHT_PARENTHESIS",
around="marker expression",
):
tokenizer.consume("WS")
marker: MarkerAtom = _parse_marker(tokenizer)
tokenizer.consume("WS")
else:
marker = _parse_marker_item(tokenizer)
tokenizer.consume("WS")
return marker
def _parse_marker_item(tokenizer: Tokenizer) -> MarkerItem:
"""
marker_item = WS? marker_var WS? marker_op WS? marker_var WS?
"""
tokenizer.consume("WS")
marker_var_left = _parse_marker_var(tokenizer)
tokenizer.consume("WS")
marker_op = _parse_marker_op(tokenizer)
tokenizer.consume("WS")
marker_var_right = _parse_marker_var(tokenizer)
tokenizer.consume("WS")
return (marker_var_left, marker_op, marker_var_right)
def _parse_marker_var(tokenizer: Tokenizer) -> MarkerVar:
"""
marker_var = VARIABLE | QUOTED_STRING
"""
if tokenizer.check("VARIABLE"):
return process_env_var(tokenizer.read().text.replace(".", "_"))
elif tokenizer.check("QUOTED_STRING"):
return process_python_str(tokenizer.read().text)
else:
tokenizer.raise_syntax_error(
message="Expected a marker variable or quoted string"
)
def process_env_var(env_var: str) -> Variable:
if (
env_var == "platform_python_implementation"
or env_var == "python_implementation"
):
return Variable("platform_python_implementation")
else:
return Variable(env_var)
def process_python_str(python_str: str) -> Value:
value = ast.literal_eval(python_str)
return Value(str(value))
def _parse_marker_op(tokenizer: Tokenizer) -> Op:
"""
marker_op = IN | NOT IN | OP
"""
if tokenizer.check("IN"):
tokenizer.read()
return Op("in")
elif tokenizer.check("NOT"):
tokenizer.read()
tokenizer.expect("WS", expected="whitespace after 'not'")
tokenizer.expect("IN", expected="'in' after 'not'")
return Op("not in")
elif tokenizer.check("OP"):
return Op(tokenizer.read().text)
else:
return tokenizer.raise_syntax_error(
"Expected marker operator, one of "
"<=, <, !=, ==, >=, >, ~=, ===, in, not in"
)

View file

@ -0,0 +1,61 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
class InfinityType:
def __repr__(self) -> str:
return "Infinity"
def __hash__(self) -> int:
return hash(repr(self))
def __lt__(self, other: object) -> bool:
return False
def __le__(self, other: object) -> bool:
return False
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)
def __gt__(self, other: object) -> bool:
return True
def __ge__(self, other: object) -> bool:
return True
def __neg__(self: object) -> "NegativeInfinityType":
return NegativeInfinity
Infinity = InfinityType()
class NegativeInfinityType:
def __repr__(self) -> str:
return "-Infinity"
def __hash__(self) -> int:
return hash(repr(self))
def __lt__(self, other: object) -> bool:
return True
def __le__(self, other: object) -> bool:
return True
def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__)
def __gt__(self, other: object) -> bool:
return False
def __ge__(self, other: object) -> bool:
return False
def __neg__(self: object) -> InfinityType:
return Infinity
NegativeInfinity = NegativeInfinityType()

View file

@ -0,0 +1,192 @@
import contextlib
import re
from dataclasses import dataclass
from typing import Dict, Iterator, NoReturn, Optional, Tuple, Union
from .specifiers import Specifier
@dataclass
class Token:
name: str
text: str
position: int
class ParserSyntaxError(Exception):
"""The provided source text could not be parsed correctly."""
def __init__(
self,
message: str,
*,
source: str,
span: Tuple[int, int],
) -> None:
self.span = span
self.message = message
self.source = source
super().__init__()
def __str__(self) -> str:
marker = " " * self.span[0] + "~" * (self.span[1] - self.span[0]) + "^"
return "\n ".join([self.message, self.source, marker])
DEFAULT_RULES: "Dict[str, Union[str, re.Pattern[str]]]" = {
"LEFT_PARENTHESIS": r"\(",
"RIGHT_PARENTHESIS": r"\)",
"LEFT_BRACKET": r"\[",
"RIGHT_BRACKET": r"\]",
"SEMICOLON": r";",
"COMMA": r",",
"QUOTED_STRING": re.compile(
r"""
(
('[^']*')
|
("[^"]*")
)
""",
re.VERBOSE,
),
"OP": r"(===|==|~=|!=|<=|>=|<|>)",
"BOOLOP": r"\b(or|and)\b",
"IN": r"\bin\b",
"NOT": r"\bnot\b",
"VARIABLE": re.compile(
r"""
\b(
python_version
|python_full_version
|os[._]name
|sys[._]platform
|platform_(release|system)
|platform[._](version|machine|python_implementation)
|python_implementation
|implementation_(name|version)
|extra
)\b
""",
re.VERBOSE,
),
"SPECIFIER": re.compile(
Specifier._operator_regex_str + Specifier._version_regex_str,
re.VERBOSE | re.IGNORECASE,
),
"AT": r"\@",
"URL": r"[^ \t]+",
"IDENTIFIER": r"\b[a-zA-Z0-9][a-zA-Z0-9._-]*\b",
"VERSION_PREFIX_TRAIL": r"\.\*",
"VERSION_LOCAL_LABEL_TRAIL": r"\+[a-z0-9]+(?:[-_\.][a-z0-9]+)*",
"WS": r"[ \t]+",
"END": r"$",
}
class Tokenizer:
"""Context-sensitive token parsing.
Provides methods to examine the input stream to check whether the next token
matches.
"""
def __init__(
self,
source: str,
*,
rules: "Dict[str, Union[str, re.Pattern[str]]]",
) -> None:
self.source = source
self.rules: Dict[str, re.Pattern[str]] = {
name: re.compile(pattern) for name, pattern in rules.items()
}
self.next_token: Optional[Token] = None
self.position = 0
def consume(self, name: str) -> None:
"""Move beyond provided token name, if at current position."""
if self.check(name):
self.read()
def check(self, name: str, *, peek: bool = False) -> bool:
"""Check whether the next token has the provided name.
By default, if the check succeeds, the token *must* be read before
another check. If `peek` is set to `True`, the token is not loaded and
would need to be checked again.
"""
assert (
self.next_token is None
), f"Cannot check for {name!r}, already have {self.next_token!r}"
assert name in self.rules, f"Unknown token name: {name!r}"
expression = self.rules[name]
match = expression.match(self.source, self.position)
if match is None:
return False
if not peek:
self.next_token = Token(name, match[0], self.position)
return True
def expect(self, name: str, *, expected: str) -> Token:
"""Expect a certain token name next, failing with a syntax error otherwise.
The token is *not* read.
"""
if not self.check(name):
raise self.raise_syntax_error(f"Expected {expected}")
return self.read()
def read(self) -> Token:
"""Consume the next token and return it."""
token = self.next_token
assert token is not None
self.position += len(token.text)
self.next_token = None
return token
def raise_syntax_error(
self,
message: str,
*,
span_start: Optional[int] = None,
span_end: Optional[int] = None,
) -> NoReturn:
"""Raise ParserSyntaxError at the given position."""
span = (
self.position if span_start is None else span_start,
self.position if span_end is None else span_end,
)
raise ParserSyntaxError(
message,
source=self.source,
span=span,
)
@contextlib.contextmanager
def enclosing_tokens(
self, open_token: str, close_token: str, *, around: str
) -> Iterator[None]:
if self.check(open_token):
open_position = self.position
self.read()
else:
open_position = None
yield
if open_position is None:
return
if not self.check(close_token):
self.raise_syntax_error(
f"Expected matching {close_token} for {open_token}, after {around}",
span_start=open_position,
)
self.read()

View file

@ -0,0 +1,252 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
import operator
import os
import platform
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from ._parser import (
MarkerAtom,
MarkerList,
Op,
Value,
Variable,
parse_marker as _parse_marker,
)
from ._tokenizer import ParserSyntaxError
from .specifiers import InvalidSpecifier, Specifier
from .utils import canonicalize_name
__all__ = [
"InvalidMarker",
"UndefinedComparison",
"UndefinedEnvironmentName",
"Marker",
"default_environment",
]
Operator = Callable[[str, str], bool]
class InvalidMarker(ValueError):
"""
An invalid marker was found, users should refer to PEP 508.
"""
class UndefinedComparison(ValueError):
"""
An invalid operation was attempted on a value that doesn't support it.
"""
class UndefinedEnvironmentName(ValueError):
"""
A name was attempted to be used that does not exist inside of the
environment.
"""
def _normalize_extra_values(results: Any) -> Any:
"""
Normalize extra values.
"""
if isinstance(results[0], tuple):
lhs, op, rhs = results[0]
if isinstance(lhs, Variable) and lhs.value == "extra":
normalized_extra = canonicalize_name(rhs.value)
rhs = Value(normalized_extra)
elif isinstance(rhs, Variable) and rhs.value == "extra":
normalized_extra = canonicalize_name(lhs.value)
lhs = Value(normalized_extra)
results[0] = lhs, op, rhs
return results
def _format_marker(
marker: Union[List[str], MarkerAtom, str], first: Optional[bool] = True
) -> str:
assert isinstance(marker, (list, tuple, str))
# Sometimes we have a structure like [[...]] which is a single item list
# where the single item is itself it's own list. In that case we want skip
# the rest of this function so that we don't get extraneous () on the
# outside.
if (
isinstance(marker, list)
and len(marker) == 1
and isinstance(marker[0], (list, tuple))
):
return _format_marker(marker[0])
if isinstance(marker, list):
inner = (_format_marker(m, first=False) for m in marker)
if first:
return " ".join(inner)
else:
return "(" + " ".join(inner) + ")"
elif isinstance(marker, tuple):
return " ".join([m.serialize() for m in marker])
else:
return marker
_operators: Dict[str, Operator] = {
"in": lambda lhs, rhs: lhs in rhs,
"not in": lambda lhs, rhs: lhs not in rhs,
"<": operator.lt,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
">=": operator.ge,
">": operator.gt,
}
def _eval_op(lhs: str, op: Op, rhs: str) -> bool:
try:
spec = Specifier("".join([op.serialize(), rhs]))
except InvalidSpecifier:
pass
else:
return spec.contains(lhs, prereleases=True)
oper: Optional[Operator] = _operators.get(op.serialize())
if oper is None:
raise UndefinedComparison(f"Undefined {op!r} on {lhs!r} and {rhs!r}.")
return oper(lhs, rhs)
def _normalize(*values: str, key: str) -> Tuple[str, ...]:
# PEP 685 Comparison of extra names for optional distribution dependencies
# https://peps.python.org/pep-0685/
# > When comparing extra names, tools MUST normalize the names being
# > compared using the semantics outlined in PEP 503 for names
if key == "extra":
return tuple(canonicalize_name(v) for v in values)
# other environment markers don't have such standards
return values
def _evaluate_markers(markers: MarkerList, environment: Dict[str, str]) -> bool:
groups: List[List[bool]] = [[]]
for marker in markers:
assert isinstance(marker, (list, tuple, str))
if isinstance(marker, list):
groups[-1].append(_evaluate_markers(marker, environment))
elif isinstance(marker, tuple):
lhs, op, rhs = marker
if isinstance(lhs, Variable):
environment_key = lhs.value
lhs_value = environment[environment_key]
rhs_value = rhs.value
else:
lhs_value = lhs.value
environment_key = rhs.value
rhs_value = environment[environment_key]
lhs_value, rhs_value = _normalize(lhs_value, rhs_value, key=environment_key)
groups[-1].append(_eval_op(lhs_value, op, rhs_value))
else:
assert marker in ["and", "or"]
if marker == "or":
groups.append([])
return any(all(item) for item in groups)
def format_full_version(info: "sys._version_info") -> str:
version = "{0.major}.{0.minor}.{0.micro}".format(info)
kind = info.releaselevel
if kind != "final":
version += kind[0] + str(info.serial)
return version
def default_environment() -> Dict[str, str]:
iver = format_full_version(sys.implementation.version)
implementation_name = sys.implementation.name
return {
"implementation_name": implementation_name,
"implementation_version": iver,
"os_name": os.name,
"platform_machine": platform.machine(),
"platform_release": platform.release(),
"platform_system": platform.system(),
"platform_version": platform.version(),
"python_full_version": platform.python_version(),
"platform_python_implementation": platform.python_implementation(),
"python_version": ".".join(platform.python_version_tuple()[:2]),
"sys_platform": sys.platform,
}
class Marker:
def __init__(self, marker: str) -> None:
# Note: We create a Marker object without calling this constructor in
# packaging.requirements.Requirement. If any additional logic is
# added here, make sure to mirror/adapt Requirement.
try:
self._markers = _normalize_extra_values(_parse_marker(marker))
# The attribute `_markers` can be described in terms of a recursive type:
# MarkerList = List[Union[Tuple[Node, ...], str, MarkerList]]
#
# For example, the following expression:
# python_version > "3.6" or (python_version == "3.6" and os_name == "unix")
#
# is parsed into:
# [
# (<Variable('python_version')>, <Op('>')>, <Value('3.6')>),
# 'and',
# [
# (<Variable('python_version')>, <Op('==')>, <Value('3.6')>),
# 'or',
# (<Variable('os_name')>, <Op('==')>, <Value('unix')>)
# ]
# ]
except ParserSyntaxError as e:
raise InvalidMarker(str(e)) from e
def __str__(self) -> str:
return _format_marker(self._markers)
def __repr__(self) -> str:
return f"<Marker('{self}')>"
def __hash__(self) -> int:
return hash((self.__class__.__name__, str(self)))
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Marker):
return NotImplemented
return str(self) == str(other)
def evaluate(self, environment: Optional[Dict[str, str]] = None) -> bool:
"""Evaluate a marker.
Return the boolean from evaluating the given marker against the
environment. environment is an optional argument to override all or
part of the determined environment.
The environment is determined from the current Python process.
"""
current_environment = default_environment()
current_environment["extra"] = ""
if environment is not None:
current_environment.update(environment)
# The API used to allow setting extra to None. We need to handle this
# case for backwards compatibility.
if current_environment["extra"] is None:
current_environment["extra"] = ""
return _evaluate_markers(self._markers, current_environment)

View file

@ -0,0 +1,822 @@
import email.feedparser
import email.header
import email.message
import email.parser
import email.policy
import sys
import typing
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from . import requirements, specifiers, utils, version as version_module
T = typing.TypeVar("T")
if sys.version_info[:2] >= (3, 8): # pragma: no cover
from typing import Literal, TypedDict
else: # pragma: no cover
if typing.TYPE_CHECKING:
from typing_extensions import Literal, TypedDict
else:
try:
from typing_extensions import Literal, TypedDict
except ImportError:
class Literal:
def __init_subclass__(*_args, **_kwargs):
pass
class TypedDict:
def __init_subclass__(*_args, **_kwargs):
pass
try:
ExceptionGroup = __builtins__.ExceptionGroup # type: ignore[attr-defined]
except AttributeError:
class ExceptionGroup(Exception): # type: ignore[no-redef] # noqa: N818
"""A minimal implementation of :external:exc:`ExceptionGroup` from Python 3.11.
If :external:exc:`ExceptionGroup` is already defined by Python itself,
that version is used instead.
"""
message: str
exceptions: List[Exception]
def __init__(self, message: str, exceptions: List[Exception]) -> None:
self.message = message
self.exceptions = exceptions
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.message!r}, {self.exceptions!r})"
class InvalidMetadata(ValueError):
"""A metadata field contains invalid data."""
field: str
"""The name of the field that contains invalid data."""
def __init__(self, field: str, message: str) -> None:
self.field = field
super().__init__(message)
# The RawMetadata class attempts to make as few assumptions about the underlying
# serialization formats as possible. The idea is that as long as a serialization
# formats offer some very basic primitives in *some* way then we can support
# serializing to and from that format.
class RawMetadata(TypedDict, total=False):
"""A dictionary of raw core metadata.
Each field in core metadata maps to a key of this dictionary (when data is
provided). The key is lower-case and underscores are used instead of dashes
compared to the equivalent core metadata field. Any core metadata field that
can be specified multiple times or can hold multiple values in a single
field have a key with a plural name. See :class:`Metadata` whose attributes
match the keys of this dictionary.
Core metadata fields that can be specified multiple times are stored as a
list or dict depending on which is appropriate for the field. Any fields
which hold multiple values in a single field are stored as a list.
"""
# Metadata 1.0 - PEP 241
metadata_version: str
name: str
version: str
platforms: List[str]
summary: str
description: str
keywords: List[str]
home_page: str
author: str
author_email: str
license: str
# Metadata 1.1 - PEP 314
supported_platforms: List[str]
download_url: str
classifiers: List[str]
requires: List[str]
provides: List[str]
obsoletes: List[str]
# Metadata 1.2 - PEP 345
maintainer: str
maintainer_email: str
requires_dist: List[str]
provides_dist: List[str]
obsoletes_dist: List[str]
requires_python: str
requires_external: List[str]
project_urls: Dict[str, str]
# Metadata 2.0
# PEP 426 attempted to completely revamp the metadata format
# but got stuck without ever being able to build consensus on
# it and ultimately ended up withdrawn.
#
# However, a number of tools had started emitting METADATA with
# `2.0` Metadata-Version, so for historical reasons, this version
# was skipped.
# Metadata 2.1 - PEP 566
description_content_type: str
provides_extra: List[str]
# Metadata 2.2 - PEP 643
dynamic: List[str]
# Metadata 2.3 - PEP 685
# No new fields were added in PEP 685, just some edge case were
# tightened up to provide better interoptability.
_STRING_FIELDS = {
"author",
"author_email",
"description",
"description_content_type",
"download_url",
"home_page",
"license",
"maintainer",
"maintainer_email",
"metadata_version",
"name",
"requires_python",
"summary",
"version",
}
_LIST_FIELDS = {
"classifiers",
"dynamic",
"obsoletes",
"obsoletes_dist",
"platforms",
"provides",
"provides_dist",
"provides_extra",
"requires",
"requires_dist",
"requires_external",
"supported_platforms",
}
_DICT_FIELDS = {
"project_urls",
}
def _parse_keywords(data: str) -> List[str]:
"""Split a string of comma-separate keyboards into a list of keywords."""
return [k.strip() for k in data.split(",")]
def _parse_project_urls(data: List[str]) -> Dict[str, str]:
"""Parse a list of label/URL string pairings separated by a comma."""
urls = {}
for pair in data:
# Our logic is slightly tricky here as we want to try and do
# *something* reasonable with malformed data.
#
# The main thing that we have to worry about, is data that does
# not have a ',' at all to split the label from the Value. There
# isn't a singular right answer here, and we will fail validation
# later on (if the caller is validating) so it doesn't *really*
# matter, but since the missing value has to be an empty str
# and our return value is dict[str, str], if we let the key
# be the missing value, then they'd have multiple '' values that
# overwrite each other in a accumulating dict.
#
# The other potentional issue is that it's possible to have the
# same label multiple times in the metadata, with no solid "right"
# answer with what to do in that case. As such, we'll do the only
# thing we can, which is treat the field as unparseable and add it
# to our list of unparsed fields.
parts = [p.strip() for p in pair.split(",", 1)]
parts.extend([""] * (max(0, 2 - len(parts)))) # Ensure 2 items
# TODO: The spec doesn't say anything about if the keys should be
# considered case sensitive or not... logically they should
# be case-preserving and case-insensitive, but doing that
# would open up more cases where we might have duplicate
# entries.
label, url = parts
if label in urls:
# The label already exists in our set of urls, so this field
# is unparseable, and we can just add the whole thing to our
# unparseable data and stop processing it.
raise KeyError("duplicate labels in project urls")
urls[label] = url
return urls
def _get_payload(msg: email.message.Message, source: Union[bytes, str]) -> str:
"""Get the body of the message."""
# If our source is a str, then our caller has managed encodings for us,
# and we don't need to deal with it.
if isinstance(source, str):
payload: str = msg.get_payload()
return payload
# If our source is a bytes, then we're managing the encoding and we need
# to deal with it.
else:
bpayload: bytes = msg.get_payload(decode=True)
try:
return bpayload.decode("utf8", "strict")
except UnicodeDecodeError:
raise ValueError("payload in an invalid encoding")
# The various parse_FORMAT functions here are intended to be as lenient as
# possible in their parsing, while still returning a correctly typed
# RawMetadata.
#
# To aid in this, we also generally want to do as little touching of the
# data as possible, except where there are possibly some historic holdovers
# that make valid data awkward to work with.
#
# While this is a lower level, intermediate format than our ``Metadata``
# class, some light touch ups can make a massive difference in usability.
# Map METADATA fields to RawMetadata.
_EMAIL_TO_RAW_MAPPING = {
"author": "author",
"author-email": "author_email",
"classifier": "classifiers",
"description": "description",
"description-content-type": "description_content_type",
"download-url": "download_url",
"dynamic": "dynamic",
"home-page": "home_page",
"keywords": "keywords",
"license": "license",
"maintainer": "maintainer",
"maintainer-email": "maintainer_email",
"metadata-version": "metadata_version",
"name": "name",
"obsoletes": "obsoletes",
"obsoletes-dist": "obsoletes_dist",
"platform": "platforms",
"project-url": "project_urls",
"provides": "provides",
"provides-dist": "provides_dist",
"provides-extra": "provides_extra",
"requires": "requires",
"requires-dist": "requires_dist",
"requires-external": "requires_external",
"requires-python": "requires_python",
"summary": "summary",
"supported-platform": "supported_platforms",
"version": "version",
}
_RAW_TO_EMAIL_MAPPING = {raw: email for email, raw in _EMAIL_TO_RAW_MAPPING.items()}
def parse_email(data: Union[bytes, str]) -> Tuple[RawMetadata, Dict[str, List[str]]]:
"""Parse a distribution's metadata stored as email headers (e.g. from ``METADATA``).
This function returns a two-item tuple of dicts. The first dict is of
recognized fields from the core metadata specification. Fields that can be
parsed and translated into Python's built-in types are converted
appropriately. All other fields are left as-is. Fields that are allowed to
appear multiple times are stored as lists.
The second dict contains all other fields from the metadata. This includes
any unrecognized fields. It also includes any fields which are expected to
be parsed into a built-in type but were not formatted appropriately. Finally,
any fields that are expected to appear only once but are repeated are
included in this dict.
"""
raw: Dict[str, Union[str, List[str], Dict[str, str]]] = {}
unparsed: Dict[str, List[str]] = {}
if isinstance(data, str):
parsed = email.parser.Parser(policy=email.policy.compat32).parsestr(data)
else:
parsed = email.parser.BytesParser(policy=email.policy.compat32).parsebytes(data)
# We have to wrap parsed.keys() in a set, because in the case of multiple
# values for a key (a list), the key will appear multiple times in the
# list of keys, but we're avoiding that by using get_all().
for name in frozenset(parsed.keys()):
# Header names in RFC are case insensitive, so we'll normalize to all
# lower case to make comparisons easier.
name = name.lower()
# We use get_all() here, even for fields that aren't multiple use,
# because otherwise someone could have e.g. two Name fields, and we
# would just silently ignore it rather than doing something about it.
headers = parsed.get_all(name) or []
# The way the email module works when parsing bytes is that it
# unconditionally decodes the bytes as ascii using the surrogateescape
# handler. When you pull that data back out (such as with get_all() ),
# it looks to see if the str has any surrogate escapes, and if it does
# it wraps it in a Header object instead of returning the string.
#
# As such, we'll look for those Header objects, and fix up the encoding.
value = []
# Flag if we have run into any issues processing the headers, thus
# signalling that the data belongs in 'unparsed'.
valid_encoding = True
for h in headers:
# It's unclear if this can return more types than just a Header or
# a str, so we'll just assert here to make sure.
assert isinstance(h, (email.header.Header, str))
# If it's a header object, we need to do our little dance to get
# the real data out of it. In cases where there is invalid data
# we're going to end up with mojibake, but there's no obvious, good
# way around that without reimplementing parts of the Header object
# ourselves.
#
# That should be fine since, if mojibacked happens, this key is
# going into the unparsed dict anyways.
if isinstance(h, email.header.Header):
# The Header object stores it's data as chunks, and each chunk
# can be independently encoded, so we'll need to check each
# of them.
chunks: List[Tuple[bytes, Optional[str]]] = []
for bin, encoding in email.header.decode_header(h):
try:
bin.decode("utf8", "strict")
except UnicodeDecodeError:
# Enable mojibake.
encoding = "latin1"
valid_encoding = False
else:
encoding = "utf8"
chunks.append((bin, encoding))
# Turn our chunks back into a Header object, then let that
# Header object do the right thing to turn them into a
# string for us.
value.append(str(email.header.make_header(chunks)))
# This is already a string, so just add it.
else:
value.append(h)
# We've processed all of our values to get them into a list of str,
# but we may have mojibake data, in which case this is an unparsed
# field.
if not valid_encoding:
unparsed[name] = value
continue
raw_name = _EMAIL_TO_RAW_MAPPING.get(name)
if raw_name is None:
# This is a bit of a weird situation, we've encountered a key that
# we don't know what it means, so we don't know whether it's meant
# to be a list or not.
#
# Since we can't really tell one way or another, we'll just leave it
# as a list, even though it may be a single item list, because that's
# what makes the most sense for email headers.
unparsed[name] = value
continue
# If this is one of our string fields, then we'll check to see if our
# value is a list of a single item. If it is then we'll assume that
# it was emitted as a single string, and unwrap the str from inside
# the list.
#
# If it's any other kind of data, then we haven't the faintest clue
# what we should parse it as, and we have to just add it to our list
# of unparsed stuff.
if raw_name in _STRING_FIELDS and len(value) == 1:
raw[raw_name] = value[0]
# If this is one of our list of string fields, then we can just assign
# the value, since email *only* has strings, and our get_all() call
# above ensures that this is a list.
elif raw_name in _LIST_FIELDS:
raw[raw_name] = value
# Special Case: Keywords
# The keywords field is implemented in the metadata spec as a str,
# but it conceptually is a list of strings, and is serialized using
# ", ".join(keywords), so we'll do some light data massaging to turn
# this into what it logically is.
elif raw_name == "keywords" and len(value) == 1:
raw[raw_name] = _parse_keywords(value[0])
# Special Case: Project-URL
# The project urls is implemented in the metadata spec as a list of
# specially-formatted strings that represent a key and a value, which
# is fundamentally a mapping, however the email format doesn't support
# mappings in a sane way, so it was crammed into a list of strings
# instead.
#
# We will do a little light data massaging to turn this into a map as
# it logically should be.
elif raw_name == "project_urls":
try:
raw[raw_name] = _parse_project_urls(value)
except KeyError:
unparsed[name] = value
# Nothing that we've done has managed to parse this, so it'll just
# throw it in our unparseable data and move on.
else:
unparsed[name] = value
# We need to support getting the Description from the message payload in
# addition to getting it from the the headers. This does mean, though, there
# is the possibility of it being set both ways, in which case we put both
# in 'unparsed' since we don't know which is right.
try:
payload = _get_payload(parsed, data)
except ValueError:
unparsed.setdefault("description", []).append(
parsed.get_payload(decode=isinstance(data, bytes))
)
else:
if payload:
# Check to see if we've already got a description, if so then both
# it, and this body move to unparseable.
if "description" in raw:
description_header = cast(str, raw.pop("description"))
unparsed.setdefault("description", []).extend(
[description_header, payload]
)
elif "description" in unparsed:
unparsed["description"].append(payload)
else:
raw["description"] = payload
# We need to cast our `raw` to a metadata, because a TypedDict only support
# literal key names, but we're computing our key names on purpose, but the
# way this function is implemented, our `TypedDict` can only have valid key
# names.
return cast(RawMetadata, raw), unparsed
_NOT_FOUND = object()
# Keep the two values in sync.
_VALID_METADATA_VERSIONS = ["1.0", "1.1", "1.2", "2.1", "2.2", "2.3"]
_MetadataVersion = Literal["1.0", "1.1", "1.2", "2.1", "2.2", "2.3"]
_REQUIRED_ATTRS = frozenset(["metadata_version", "name", "version"])
class _Validator(Generic[T]):
"""Validate a metadata field.
All _process_*() methods correspond to a core metadata field. The method is
called with the field's raw value. If the raw value is valid it is returned
in its "enriched" form (e.g. ``version.Version`` for the ``Version`` field).
If the raw value is invalid, :exc:`InvalidMetadata` is raised (with a cause
as appropriate).
"""
name: str
raw_name: str
added: _MetadataVersion
def __init__(
self,
*,
added: _MetadataVersion = "1.0",
) -> None:
self.added = added
def __set_name__(self, _owner: "Metadata", name: str) -> None:
self.name = name
self.raw_name = _RAW_TO_EMAIL_MAPPING[name]
def __get__(self, instance: "Metadata", _owner: Type["Metadata"]) -> T:
# With Python 3.8, the caching can be replaced with functools.cached_property().
# No need to check the cache as attribute lookup will resolve into the
# instance's __dict__ before __get__ is called.
cache = instance.__dict__
try:
value = instance._raw[self.name] # type: ignore[literal-required]
except KeyError:
if self.name in _STRING_FIELDS:
value = ""
elif self.name in _LIST_FIELDS:
value = []
elif self.name in _DICT_FIELDS:
value = {}
else: # pragma: no cover
assert False
try:
converter: Callable[[Any], T] = getattr(self, f"_process_{self.name}")
except AttributeError:
pass
else:
value = converter(value)
cache[self.name] = value
try:
del instance._raw[self.name] # type: ignore[misc]
except KeyError:
pass
return cast(T, value)
def _invalid_metadata(
self, msg: str, cause: Optional[Exception] = None
) -> InvalidMetadata:
exc = InvalidMetadata(
self.raw_name, msg.format_map({"field": repr(self.raw_name)})
)
exc.__cause__ = cause
return exc
def _process_metadata_version(self, value: str) -> _MetadataVersion:
# Implicitly makes Metadata-Version required.
if value not in _VALID_METADATA_VERSIONS:
raise self._invalid_metadata(f"{value!r} is not a valid metadata version")
return cast(_MetadataVersion, value)
def _process_name(self, value: str) -> str:
if not value:
raise self._invalid_metadata("{field} is a required field")
# Validate the name as a side-effect.
try:
utils.canonicalize_name(value, validate=True)
except utils.InvalidName as exc:
raise self._invalid_metadata(
f"{value!r} is invalid for {{field}}", cause=exc
)
else:
return value
def _process_version(self, value: str) -> version_module.Version:
if not value:
raise self._invalid_metadata("{field} is a required field")
try:
return version_module.parse(value)
except version_module.InvalidVersion as exc:
raise self._invalid_metadata(
f"{value!r} is invalid for {{field}}", cause=exc
)
def _process_summary(self, value: str) -> str:
"""Check the field contains no newlines."""
if "\n" in value:
raise self._invalid_metadata("{field} must be a single line")
return value
def _process_description_content_type(self, value: str) -> str:
content_types = {"text/plain", "text/x-rst", "text/markdown"}
message = email.message.EmailMessage()
message["content-type"] = value
content_type, parameters = (
# Defaults to `text/plain` if parsing failed.
message.get_content_type().lower(),
message["content-type"].params,
)
# Check if content-type is valid or defaulted to `text/plain` and thus was
# not parseable.
if content_type not in content_types or content_type not in value.lower():
raise self._invalid_metadata(
f"{{field}} must be one of {list(content_types)}, not {value!r}"
)
charset = parameters.get("charset", "UTF-8")
if charset != "UTF-8":
raise self._invalid_metadata(
f"{{field}} can only specify the UTF-8 charset, not {list(charset)}"
)
markdown_variants = {"GFM", "CommonMark"}
variant = parameters.get("variant", "GFM") # Use an acceptable default.
if content_type == "text/markdown" and variant not in markdown_variants:
raise self._invalid_metadata(
f"valid Markdown variants for {{field}} are {list(markdown_variants)}, "
f"not {variant!r}",
)
return value
def _process_dynamic(self, value: List[str]) -> List[str]:
for dynamic_field in map(str.lower, value):
if dynamic_field in {"name", "version", "metadata-version"}:
raise self._invalid_metadata(
f"{value!r} is not allowed as a dynamic field"
)
elif dynamic_field not in _EMAIL_TO_RAW_MAPPING:
raise self._invalid_metadata(f"{value!r} is not a valid dynamic field")
return list(map(str.lower, value))
def _process_provides_extra(
self,
value: List[str],
) -> List[utils.NormalizedName]:
normalized_names = []
try:
for name in value:
normalized_names.append(utils.canonicalize_name(name, validate=True))
except utils.InvalidName as exc:
raise self._invalid_metadata(
f"{name!r} is invalid for {{field}}", cause=exc
)
else:
return normalized_names
def _process_requires_python(self, value: str) -> specifiers.SpecifierSet:
try:
return specifiers.SpecifierSet(value)
except specifiers.InvalidSpecifier as exc:
raise self._invalid_metadata(
f"{value!r} is invalid for {{field}}", cause=exc
)
def _process_requires_dist(
self,
value: List[str],
) -> List[requirements.Requirement]:
reqs = []
try:
for req in value:
reqs.append(requirements.Requirement(req))
except requirements.InvalidRequirement as exc:
raise self._invalid_metadata(f"{req!r} is invalid for {{field}}", cause=exc)
else:
return reqs
class Metadata:
"""Representation of distribution metadata.
Compared to :class:`RawMetadata`, this class provides objects representing
metadata fields instead of only using built-in types. Any invalid metadata
will cause :exc:`InvalidMetadata` to be raised (with a
:py:attr:`~BaseException.__cause__` attribute as appropriate).
"""
_raw: RawMetadata
@classmethod
def from_raw(cls, data: RawMetadata, *, validate: bool = True) -> "Metadata":
"""Create an instance from :class:`RawMetadata`.
If *validate* is true, all metadata will be validated. All exceptions
related to validation will be gathered and raised as an :class:`ExceptionGroup`.
"""
ins = cls()
ins._raw = data.copy() # Mutations occur due to caching enriched values.
if validate:
exceptions: List[InvalidMetadata] = []
try:
metadata_version = ins.metadata_version
metadata_age = _VALID_METADATA_VERSIONS.index(metadata_version)
except InvalidMetadata as metadata_version_exc:
exceptions.append(metadata_version_exc)
metadata_version = None
# Make sure to check for the fields that are present, the required
# fields (so their absence can be reported).
fields_to_check = frozenset(ins._raw) | _REQUIRED_ATTRS
# Remove fields that have already been checked.
fields_to_check -= {"metadata_version"}
for key in fields_to_check:
try:
if metadata_version:
# Can't use getattr() as that triggers descriptor protocol which
# will fail due to no value for the instance argument.
try:
field_metadata_version = cls.__dict__[key].added
except KeyError:
exc = InvalidMetadata(key, f"unrecognized field: {key!r}")
exceptions.append(exc)
continue
field_age = _VALID_METADATA_VERSIONS.index(
field_metadata_version
)
if field_age > metadata_age:
field = _RAW_TO_EMAIL_MAPPING[key]
exc = InvalidMetadata(
field,
"{field} introduced in metadata version "
"{field_metadata_version}, not {metadata_version}",
)
exceptions.append(exc)
continue
getattr(ins, key)
except InvalidMetadata as exc:
exceptions.append(exc)
if exceptions:
raise ExceptionGroup("invalid metadata", exceptions)
return ins
@classmethod
def from_email(
cls, data: Union[bytes, str], *, validate: bool = True
) -> "Metadata":
"""Parse metadata from email headers.
If *validate* is true, the metadata will be validated. All exceptions
related to validation will be gathered and raised as an :class:`ExceptionGroup`.
"""
exceptions: list[InvalidMetadata] = []
raw, unparsed = parse_email(data)
if validate:
for unparsed_key in unparsed:
if unparsed_key in _EMAIL_TO_RAW_MAPPING:
message = f"{unparsed_key!r} has invalid data"
else:
message = f"unrecognized field: {unparsed_key!r}"
exceptions.append(InvalidMetadata(unparsed_key, message))
if exceptions:
raise ExceptionGroup("unparsed", exceptions)
try:
return cls.from_raw(raw, validate=validate)
except ExceptionGroup as exc_group:
exceptions.extend(exc_group.exceptions)
raise ExceptionGroup("invalid or unparsed metadata", exceptions) from None
metadata_version: _Validator[_MetadataVersion] = _Validator()
""":external:ref:`core-metadata-metadata-version`
(required; validated to be a valid metadata version)"""
name: _Validator[str] = _Validator()
""":external:ref:`core-metadata-name`
(required; validated using :func:`~packaging.utils.canonicalize_name` and its
*validate* parameter)"""
version: _Validator[version_module.Version] = _Validator()
""":external:ref:`core-metadata-version` (required)"""
dynamic: _Validator[List[str]] = _Validator(
added="2.2",
)
""":external:ref:`core-metadata-dynamic`
(validated against core metadata field names and lowercased)"""
platforms: _Validator[List[str]] = _Validator()
""":external:ref:`core-metadata-platform`"""
supported_platforms: _Validator[List[str]] = _Validator(added="1.1")
""":external:ref:`core-metadata-supported-platform`"""
summary: _Validator[str] = _Validator()
""":external:ref:`core-metadata-summary` (validated to contain no newlines)"""
description: _Validator[str] = _Validator() # TODO 2.1: can be in body
""":external:ref:`core-metadata-description`"""
description_content_type: _Validator[str] = _Validator(added="2.1")
""":external:ref:`core-metadata-description-content-type` (validated)"""
keywords: _Validator[List[str]] = _Validator()
""":external:ref:`core-metadata-keywords`"""
home_page: _Validator[str] = _Validator()
""":external:ref:`core-metadata-home-page`"""
download_url: _Validator[str] = _Validator(added="1.1")
""":external:ref:`core-metadata-download-url`"""
author: _Validator[str] = _Validator()
""":external:ref:`core-metadata-author`"""
author_email: _Validator[str] = _Validator()
""":external:ref:`core-metadata-author-email`"""
maintainer: _Validator[str] = _Validator(added="1.2")
""":external:ref:`core-metadata-maintainer`"""
maintainer_email: _Validator[str] = _Validator(added="1.2")
""":external:ref:`core-metadata-maintainer-email`"""
license: _Validator[str] = _Validator()
""":external:ref:`core-metadata-license`"""
classifiers: _Validator[List[str]] = _Validator(added="1.1")
""":external:ref:`core-metadata-classifier`"""
requires_dist: _Validator[List[requirements.Requirement]] = _Validator(added="1.2")
""":external:ref:`core-metadata-requires-dist`"""
requires_python: _Validator[specifiers.SpecifierSet] = _Validator(added="1.2")
""":external:ref:`core-metadata-requires-python`"""
# Because `Requires-External` allows for non-PEP 440 version specifiers, we
# don't do any processing on the values.
requires_external: _Validator[List[str]] = _Validator(added="1.2")
""":external:ref:`core-metadata-requires-external`"""
project_urls: _Validator[Dict[str, str]] = _Validator(added="1.2")
""":external:ref:`core-metadata-project-url`"""
# PEP 685 lets us raise an error if an extra doesn't pass `Name` validation
# regardless of metadata version.
provides_extra: _Validator[List[utils.NormalizedName]] = _Validator(
added="2.1",
)
""":external:ref:`core-metadata-provides-extra`"""
provides_dist: _Validator[List[str]] = _Validator(added="1.2")
""":external:ref:`core-metadata-provides-dist`"""
obsoletes_dist: _Validator[List[str]] = _Validator(added="1.2")
""":external:ref:`core-metadata-obsoletes-dist`"""
requires: _Validator[List[str]] = _Validator(added="1.1")
"""``Requires`` (deprecated)"""
provides: _Validator[List[str]] = _Validator(added="1.1")
"""``Provides`` (deprecated)"""
obsoletes: _Validator[List[str]] = _Validator(added="1.1")
"""``Obsoletes`` (deprecated)"""

View file

View file

@ -0,0 +1,90 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
from typing import Any, Iterator, Optional, Set
from ._parser import parse_requirement as _parse_requirement
from ._tokenizer import ParserSyntaxError
from .markers import Marker, _normalize_extra_values
from .specifiers import SpecifierSet
from .utils import canonicalize_name
class InvalidRequirement(ValueError):
"""
An invalid requirement was found, users should refer to PEP 508.
"""
class Requirement:
"""Parse a requirement.
Parse a given requirement string into its parts, such as name, specifier,
URL, and extras. Raises InvalidRequirement on a badly-formed requirement
string.
"""
# TODO: Can we test whether something is contained within a requirement?
# If so how do we do that? Do we need to test against the _name_ of
# the thing as well as the version? What about the markers?
# TODO: Can we normalize the name and extra name?
def __init__(self, requirement_string: str) -> None:
try:
parsed = _parse_requirement(requirement_string)
except ParserSyntaxError as e:
raise InvalidRequirement(str(e)) from e
self.name: str = parsed.name
self.url: Optional[str] = parsed.url or None
self.extras: Set[str] = set(parsed.extras if parsed.extras else [])
self.specifier: SpecifierSet = SpecifierSet(parsed.specifier)
self.marker: Optional[Marker] = None
if parsed.marker is not None:
self.marker = Marker.__new__(Marker)
self.marker._markers = _normalize_extra_values(parsed.marker)
def _iter_parts(self, name: str) -> Iterator[str]:
yield name
if self.extras:
formatted_extras = ",".join(sorted(self.extras))
yield f"[{formatted_extras}]"
if self.specifier:
yield str(self.specifier)
if self.url:
yield f"@ {self.url}"
if self.marker:
yield " "
if self.marker:
yield f"; {self.marker}"
def __str__(self) -> str:
return "".join(self._iter_parts(self.name))
def __repr__(self) -> str:
return f"<Requirement('{self}')>"
def __hash__(self) -> int:
return hash(
(
self.__class__.__name__,
*self._iter_parts(canonicalize_name(self.name)),
)
)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, Requirement):
return NotImplemented
return (
canonicalize_name(self.name) == canonicalize_name(other.name)
and self.extras == other.extras
and self.specifier == other.specifier
and self.url == other.url
and self.marker == other.marker
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,553 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
import logging
import platform
import struct
import subprocess
import sys
import sysconfig
from importlib.machinery import EXTENSION_SUFFIXES
from typing import (
Dict,
FrozenSet,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from . import _manylinux, _musllinux
logger = logging.getLogger(__name__)
PythonVersion = Sequence[int]
MacVersion = Tuple[int, int]
INTERPRETER_SHORT_NAMES: Dict[str, str] = {
"python": "py", # Generic.
"cpython": "cp",
"pypy": "pp",
"ironpython": "ip",
"jython": "jy",
}
_32_BIT_INTERPRETER = struct.calcsize("P") == 4
class Tag:
"""
A representation of the tag triple for a wheel.
Instances are considered immutable and thus are hashable. Equality checking
is also supported.
"""
__slots__ = ["_interpreter", "_abi", "_platform", "_hash"]
def __init__(self, interpreter: str, abi: str, platform: str) -> None:
self._interpreter = interpreter.lower()
self._abi = abi.lower()
self._platform = platform.lower()
# The __hash__ of every single element in a Set[Tag] will be evaluated each time
# that a set calls its `.disjoint()` method, which may be called hundreds of
# times when scanning a page of links for packages with tags matching that
# Set[Tag]. Pre-computing the value here produces significant speedups for
# downstream consumers.
self._hash = hash((self._interpreter, self._abi, self._platform))
@property
def interpreter(self) -> str:
return self._interpreter
@property
def abi(self) -> str:
return self._abi
@property
def platform(self) -> str:
return self._platform
def __eq__(self, other: object) -> bool:
if not isinstance(other, Tag):
return NotImplemented
return (
(self._hash == other._hash) # Short-circuit ASAP for perf reasons.
and (self._platform == other._platform)
and (self._abi == other._abi)
and (self._interpreter == other._interpreter)
)
def __hash__(self) -> int:
return self._hash
def __str__(self) -> str:
return f"{self._interpreter}-{self._abi}-{self._platform}"
def __repr__(self) -> str:
return f"<{self} @ {id(self)}>"
def parse_tag(tag: str) -> FrozenSet[Tag]:
"""
Parses the provided tag (e.g. `py3-none-any`) into a frozenset of Tag instances.
Returning a set is required due to the possibility that the tag is a
compressed tag set.
"""
tags = set()
interpreters, abis, platforms = tag.split("-")
for interpreter in interpreters.split("."):
for abi in abis.split("."):
for platform_ in platforms.split("."):
tags.add(Tag(interpreter, abi, platform_))
return frozenset(tags)
def _get_config_var(name: str, warn: bool = False) -> Union[int, str, None]:
value: Union[int, str, None] = sysconfig.get_config_var(name)
if value is None and warn:
logger.debug(
"Config variable '%s' is unset, Python ABI tag may be incorrect", name
)
return value
def _normalize_string(string: str) -> str:
return string.replace(".", "_").replace("-", "_").replace(" ", "_")
def _abi3_applies(python_version: PythonVersion) -> bool:
"""
Determine if the Python version supports abi3.
PEP 384 was first implemented in Python 3.2.
"""
return len(python_version) > 1 and tuple(python_version) >= (3, 2)
def _cpython_abis(py_version: PythonVersion, warn: bool = False) -> List[str]:
py_version = tuple(py_version) # To allow for version comparison.
abis = []
version = _version_nodot(py_version[:2])
debug = pymalloc = ucs4 = ""
with_debug = _get_config_var("Py_DEBUG", warn)
has_refcount = hasattr(sys, "gettotalrefcount")
# Windows doesn't set Py_DEBUG, so checking for support of debug-compiled
# extension modules is the best option.
# https://github.com/pypa/pip/issues/3383#issuecomment-173267692
has_ext = "_d.pyd" in EXTENSION_SUFFIXES
if with_debug or (with_debug is None and (has_refcount or has_ext)):
debug = "d"
if py_version < (3, 8):
with_pymalloc = _get_config_var("WITH_PYMALLOC", warn)
if with_pymalloc or with_pymalloc is None:
pymalloc = "m"
if py_version < (3, 3):
unicode_size = _get_config_var("Py_UNICODE_SIZE", warn)
if unicode_size == 4 or (
unicode_size is None and sys.maxunicode == 0x10FFFF
):
ucs4 = "u"
elif debug:
# Debug builds can also load "normal" extension modules.
# We can also assume no UCS-4 or pymalloc requirement.
abis.append(f"cp{version}")
abis.insert(
0,
"cp{version}{debug}{pymalloc}{ucs4}".format(
version=version, debug=debug, pymalloc=pymalloc, ucs4=ucs4
),
)
return abis
def cpython_tags(
python_version: Optional[PythonVersion] = None,
abis: Optional[Iterable[str]] = None,
platforms: Optional[Iterable[str]] = None,
*,
warn: bool = False,
) -> Iterator[Tag]:
"""
Yields the tags for a CPython interpreter.
The tags consist of:
- cp<python_version>-<abi>-<platform>
- cp<python_version>-abi3-<platform>
- cp<python_version>-none-<platform>
- cp<less than python_version>-abi3-<platform> # Older Python versions down to 3.2.
If python_version only specifies a major version then user-provided ABIs and
the 'none' ABItag will be used.
If 'abi3' or 'none' are specified in 'abis' then they will be yielded at
their normal position and not at the beginning.
"""
if not python_version:
python_version = sys.version_info[:2]
interpreter = f"cp{_version_nodot(python_version[:2])}"
if abis is None:
if len(python_version) > 1:
abis = _cpython_abis(python_version, warn)
else:
abis = []
abis = list(abis)
# 'abi3' and 'none' are explicitly handled later.
for explicit_abi in ("abi3", "none"):
try:
abis.remove(explicit_abi)
except ValueError:
pass
platforms = list(platforms or platform_tags())
for abi in abis:
for platform_ in platforms:
yield Tag(interpreter, abi, platform_)
if _abi3_applies(python_version):
yield from (Tag(interpreter, "abi3", platform_) for platform_ in platforms)
yield from (Tag(interpreter, "none", platform_) for platform_ in platforms)
if _abi3_applies(python_version):
for minor_version in range(python_version[1] - 1, 1, -1):
for platform_ in platforms:
interpreter = "cp{version}".format(
version=_version_nodot((python_version[0], minor_version))
)
yield Tag(interpreter, "abi3", platform_)
def _generic_abi() -> List[str]:
"""
Return the ABI tag based on EXT_SUFFIX.
"""
# The following are examples of `EXT_SUFFIX`.
# We want to keep the parts which are related to the ABI and remove the
# parts which are related to the platform:
# - linux: '.cpython-310-x86_64-linux-gnu.so' => cp310
# - mac: '.cpython-310-darwin.so' => cp310
# - win: '.cp310-win_amd64.pyd' => cp310
# - win: '.pyd' => cp37 (uses _cpython_abis())
# - pypy: '.pypy38-pp73-x86_64-linux-gnu.so' => pypy38_pp73
# - graalpy: '.graalpy-38-native-x86_64-darwin.dylib'
# => graalpy_38_native
ext_suffix = _get_config_var("EXT_SUFFIX", warn=True)
if not isinstance(ext_suffix, str) or ext_suffix[0] != ".":
raise SystemError("invalid sysconfig.get_config_var('EXT_SUFFIX')")
parts = ext_suffix.split(".")
if len(parts) < 3:
# CPython3.7 and earlier uses ".pyd" on Windows.
return _cpython_abis(sys.version_info[:2])
soabi = parts[1]
if soabi.startswith("cpython"):
# non-windows
abi = "cp" + soabi.split("-")[1]
elif soabi.startswith("cp"):
# windows
abi = soabi.split("-")[0]
elif soabi.startswith("pypy"):
abi = "-".join(soabi.split("-")[:2])
elif soabi.startswith("graalpy"):
abi = "-".join(soabi.split("-")[:3])
elif soabi:
# pyston, ironpython, others?
abi = soabi
else:
return []
return [_normalize_string(abi)]
def generic_tags(
interpreter: Optional[str] = None,
abis: Optional[Iterable[str]] = None,
platforms: Optional[Iterable[str]] = None,
*,
warn: bool = False,
) -> Iterator[Tag]:
"""
Yields the tags for a generic interpreter.
The tags consist of:
- <interpreter>-<abi>-<platform>
The "none" ABI will be added if it was not explicitly provided.
"""
if not interpreter:
interp_name = interpreter_name()
interp_version = interpreter_version(warn=warn)
interpreter = "".join([interp_name, interp_version])
if abis is None:
abis = _generic_abi()
else:
abis = list(abis)
platforms = list(platforms or platform_tags())
if "none" not in abis:
abis.append("none")
for abi in abis:
for platform_ in platforms:
yield Tag(interpreter, abi, platform_)
def _py_interpreter_range(py_version: PythonVersion) -> Iterator[str]:
"""
Yields Python versions in descending order.
After the latest version, the major-only version will be yielded, and then
all previous versions of that major version.
"""
if len(py_version) > 1:
yield f"py{_version_nodot(py_version[:2])}"
yield f"py{py_version[0]}"
if len(py_version) > 1:
for minor in range(py_version[1] - 1, -1, -1):
yield f"py{_version_nodot((py_version[0], minor))}"
def compatible_tags(
python_version: Optional[PythonVersion] = None,
interpreter: Optional[str] = None,
platforms: Optional[Iterable[str]] = None,
) -> Iterator[Tag]:
"""
Yields the sequence of tags that are compatible with a specific version of Python.
The tags consist of:
- py*-none-<platform>
- <interpreter>-none-any # ... if `interpreter` is provided.
- py*-none-any
"""
if not python_version:
python_version = sys.version_info[:2]
platforms = list(platforms or platform_tags())
for version in _py_interpreter_range(python_version):
for platform_ in platforms:
yield Tag(version, "none", platform_)
if interpreter:
yield Tag(interpreter, "none", "any")
for version in _py_interpreter_range(python_version):
yield Tag(version, "none", "any")
def _mac_arch(arch: str, is_32bit: bool = _32_BIT_INTERPRETER) -> str:
if not is_32bit:
return arch
if arch.startswith("ppc"):
return "ppc"
return "i386"
def _mac_binary_formats(version: MacVersion, cpu_arch: str) -> List[str]:
formats = [cpu_arch]
if cpu_arch == "x86_64":
if version < (10, 4):
return []
formats.extend(["intel", "fat64", "fat32"])
elif cpu_arch == "i386":
if version < (10, 4):
return []
formats.extend(["intel", "fat32", "fat"])
elif cpu_arch == "ppc64":
# TODO: Need to care about 32-bit PPC for ppc64 through 10.2?
if version > (10, 5) or version < (10, 4):
return []
formats.append("fat64")
elif cpu_arch == "ppc":
if version > (10, 6):
return []
formats.extend(["fat32", "fat"])
if cpu_arch in {"arm64", "x86_64"}:
formats.append("universal2")
if cpu_arch in {"x86_64", "i386", "ppc64", "ppc", "intel"}:
formats.append("universal")
return formats
def mac_platforms(
version: Optional[MacVersion] = None, arch: Optional[str] = None
) -> Iterator[str]:
"""
Yields the platform tags for a macOS system.
The `version` parameter is a two-item tuple specifying the macOS version to
generate platform tags for. The `arch` parameter is the CPU architecture to
generate platform tags for. Both parameters default to the appropriate value
for the current system.
"""
version_str, _, cpu_arch = platform.mac_ver()
if version is None:
version = cast("MacVersion", tuple(map(int, version_str.split(".")[:2])))
if version == (10, 16):
# When built against an older macOS SDK, Python will report macOS 10.16
# instead of the real version.
version_str = subprocess.run(
[
sys.executable,
"-sS",
"-c",
"import platform; print(platform.mac_ver()[0])",
],
check=True,
env={"SYSTEM_VERSION_COMPAT": "0"},
stdout=subprocess.PIPE,
text=True,
).stdout
version = cast("MacVersion", tuple(map(int, version_str.split(".")[:2])))
else:
version = version
if arch is None:
arch = _mac_arch(cpu_arch)
else:
arch = arch
if (10, 0) <= version and version < (11, 0):
# Prior to Mac OS 11, each yearly release of Mac OS bumped the
# "minor" version number. The major version was always 10.
for minor_version in range(version[1], -1, -1):
compat_version = 10, minor_version
binary_formats = _mac_binary_formats(compat_version, arch)
for binary_format in binary_formats:
yield "macosx_{major}_{minor}_{binary_format}".format(
major=10, minor=minor_version, binary_format=binary_format
)
if version >= (11, 0):
# Starting with Mac OS 11, each yearly release bumps the major version
# number. The minor versions are now the midyear updates.
for major_version in range(version[0], 10, -1):
compat_version = major_version, 0
binary_formats = _mac_binary_formats(compat_version, arch)
for binary_format in binary_formats:
yield "macosx_{major}_{minor}_{binary_format}".format(
major=major_version, minor=0, binary_format=binary_format
)
if version >= (11, 0):
# Mac OS 11 on x86_64 is compatible with binaries from previous releases.
# Arm64 support was introduced in 11.0, so no Arm binaries from previous
# releases exist.
#
# However, the "universal2" binary format can have a
# macOS version earlier than 11.0 when the x86_64 part of the binary supports
# that version of macOS.
if arch == "x86_64":
for minor_version in range(16, 3, -1):
compat_version = 10, minor_version
binary_formats = _mac_binary_formats(compat_version, arch)
for binary_format in binary_formats:
yield "macosx_{major}_{minor}_{binary_format}".format(
major=compat_version[0],
minor=compat_version[1],
binary_format=binary_format,
)
else:
for minor_version in range(16, 3, -1):
compat_version = 10, minor_version
binary_format = "universal2"
yield "macosx_{major}_{minor}_{binary_format}".format(
major=compat_version[0],
minor=compat_version[1],
binary_format=binary_format,
)
def _linux_platforms(is_32bit: bool = _32_BIT_INTERPRETER) -> Iterator[str]:
linux = _normalize_string(sysconfig.get_platform())
if not linux.startswith("linux_"):
# we should never be here, just yield the sysconfig one and return
yield linux
return
if is_32bit:
if linux == "linux_x86_64":
linux = "linux_i686"
elif linux == "linux_aarch64":
linux = "linux_armv8l"
_, arch = linux.split("_", 1)
archs = {"armv8l": ["armv8l", "armv7l"]}.get(arch, [arch])
yield from _manylinux.platform_tags(archs)
yield from _musllinux.platform_tags(archs)
for arch in archs:
yield f"linux_{arch}"
def _generic_platforms() -> Iterator[str]:
yield _normalize_string(sysconfig.get_platform())
def platform_tags() -> Iterator[str]:
"""
Provides the platform tags for this installation.
"""
if platform.system() == "Darwin":
return mac_platforms()
elif platform.system() == "Linux":
return _linux_platforms()
else:
return _generic_platforms()
def interpreter_name() -> str:
"""
Returns the name of the running interpreter.
Some implementations have a reserved, two-letter abbreviation which will
be returned when appropriate.
"""
name = sys.implementation.name
return INTERPRETER_SHORT_NAMES.get(name) or name
def interpreter_version(*, warn: bool = False) -> str:
"""
Returns the version of the running interpreter.
"""
version = _get_config_var("py_version_nodot", warn=warn)
if version:
version = str(version)
else:
version = _version_nodot(sys.version_info[:2])
return version
def _version_nodot(version: PythonVersion) -> str:
return "".join(map(str, version))
def sys_tags(*, warn: bool = False) -> Iterator[Tag]:
"""
Returns the sequence of tag triples for the running interpreter.
The order of the sequence corresponds to priority order for the
interpreter, from most to least important.
"""
interp_name = interpreter_name()
if interp_name == "cp":
yield from cpython_tags(warn=warn)
else:
yield from generic_tags()
if interp_name == "pp":
interp = "pp3"
elif interp_name == "cp":
interp = "cp" + interpreter_version(warn=warn)
else:
interp = None
yield from compatible_tags(interpreter=interp)

View file

@ -0,0 +1,172 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.
import re
from typing import FrozenSet, NewType, Tuple, Union, cast
from .tags import Tag, parse_tag
from .version import InvalidVersion, Version
BuildTag = Union[Tuple[()], Tuple[int, str]]
NormalizedName = NewType("NormalizedName", str)
class InvalidName(ValueError):
"""
An invalid distribution name; users should refer to the packaging user guide.
"""
class InvalidWheelFilename(ValueError):
"""
An invalid wheel filename was found, users should refer to PEP 427.
"""
class InvalidSdistFilename(ValueError):
"""
An invalid sdist filename was found, users should refer to the packaging user guide.
"""
# Core metadata spec for `Name`
_validate_regex = re.compile(
r"^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$", re.IGNORECASE
)
_canonicalize_regex = re.compile(r"[-_.]+")
_normalized_regex = re.compile(r"^([a-z0-9]|[a-z0-9]([a-z0-9-](?!--))*[a-z0-9])$")
# PEP 427: The build number must start with a digit.
_build_tag_regex = re.compile(r"(\d+)(.*)")
def canonicalize_name(name: str, *, validate: bool = False) -> NormalizedName:
if validate and not _validate_regex.match(name):
raise InvalidName(f"name is invalid: {name!r}")
# This is taken from PEP 503.
value = _canonicalize_regex.sub("-", name).lower()
return cast(NormalizedName, value)
def is_normalized_name(name: str) -> bool:
return _normalized_regex.match(name) is not None
def canonicalize_version(
version: Union[Version, str], *, strip_trailing_zero: bool = True
) -> str:
"""
This is very similar to Version.__str__, but has one subtle difference
with the way it handles the release segment.
"""
if isinstance(version, str):
try:
parsed = Version(version)
except InvalidVersion:
# Legacy versions cannot be normalized
return version
else:
parsed = version
parts = []
# Epoch
if parsed.epoch != 0:
parts.append(f"{parsed.epoch}!")
# Release segment
release_segment = ".".join(str(x) for x in parsed.release)
if strip_trailing_zero:
# NB: This strips trailing '.0's to normalize
release_segment = re.sub(r"(\.0)+$", "", release_segment)
parts.append(release_segment)
# Pre-release
if parsed.pre is not None:
parts.append("".join(str(x) for x in parsed.pre))
# Post-release
if parsed.post is not None:
parts.append(f".post{parsed.post}")
# Development release
if parsed.dev is not None:
parts.append(f".dev{parsed.dev}")
# Local version segment
if parsed.local is not None:
parts.append(f"+{parsed.local}")
return "".join(parts)
def parse_wheel_filename(
filename: str,
) -> Tuple[NormalizedName, Version, BuildTag, FrozenSet[Tag]]:
if not filename.endswith(".whl"):
raise InvalidWheelFilename(
f"Invalid wheel filename (extension must be '.whl'): {filename}"
)
filename = filename[:-4]
dashes = filename.count("-")
if dashes not in (4, 5):
raise InvalidWheelFilename(
f"Invalid wheel filename (wrong number of parts): {filename}"
)
parts = filename.split("-", dashes - 2)
name_part = parts[0]
# See PEP 427 for the rules on escaping the project name.
if "__" in name_part or re.match(r"^[\w\d._]*$", name_part, re.UNICODE) is None:
raise InvalidWheelFilename(f"Invalid project name: {filename}")
name = canonicalize_name(name_part)
try:
version = Version(parts[1])
except InvalidVersion as e:
raise InvalidWheelFilename(
f"Invalid wheel filename (invalid version): {filename}"
) from e
if dashes == 5:
build_part = parts[2]
build_match = _build_tag_regex.match(build_part)
if build_match is None:
raise InvalidWheelFilename(
f"Invalid build number: {build_part} in '{filename}'"
)
build = cast(BuildTag, (int(build_match.group(1)), build_match.group(2)))
else:
build = ()
tags = parse_tag(parts[-1])
return (name, version, build, tags)
def parse_sdist_filename(filename: str) -> Tuple[NormalizedName, Version]:
if filename.endswith(".tar.gz"):
file_stem = filename[: -len(".tar.gz")]
elif filename.endswith(".zip"):
file_stem = filename[: -len(".zip")]
else:
raise InvalidSdistFilename(
f"Invalid sdist filename (extension must be '.tar.gz' or '.zip'):"
f" {filename}"
)
# We are requiring a PEP 440 version, which cannot contain dashes,
# so we split on the last dash.
name_part, sep, version_part = file_stem.rpartition("-")
if not sep:
raise InvalidSdistFilename(f"Invalid sdist filename: {filename}")
name = canonicalize_name(name_part)
try:
version = Version(version_part)
except InvalidVersion as e:
raise InvalidSdistFilename(
f"Invalid sdist filename (invalid version): {filename}"
) from e
return (name, version)

Some files were not shown because too many files have changed in this diff Show more