Update subliminal to 2.0.5

Also updates:
- appdirs-1.4.3
- babelfish-0.5.5
- beautifulsoup4-4.6.3
- certifi-2018.11.29
- chardet-3.0.4
- click-7.0
- decorator-4.3.0
- dogpile.cache-0.7.1
- enzyme-0.4.1
- guessit-3.0.3
- idna-2.8
- pbr-5.1.1
- pysrt-1.1.1
- python-dateutil-2.7.5
- pytz-2018.7
- rarfile-3.0
- rebulk-1.0.0
- requests-2.21.0
- six-1.12.0
- stevedore-1.30.0
- urllib3-1.24.1
This commit is contained in:
Labrys of Knossos 2018-12-15 01:12:12 -05:00
commit f3fcb47427
761 changed files with 29015 additions and 1843 deletions

608
libs/appdirs.py Normal file
View file

@ -0,0 +1,608 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2005-2010 ActiveState Software Inc.
# Copyright (c) 2013 Eddy Petrișor
"""Utilities for determining application-specific dirs.
See <http://github.com/ActiveState/appdirs> for details and usage.
"""
# Dev Notes:
# - MSDN on where to store app data files:
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
__version_info__ = (1, 4, 3)
__version__ = '.'.join(map(str, __version_info__))
import sys
import os
PY3 = sys.version_info[0] == 3
if PY3:
unicode = str
if sys.platform.startswith('java'):
import platform
os_name = platform.java_ver()[3][0]
if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc.
system = 'win32'
elif os_name.startswith('Mac'): # "Mac OS X", etc.
system = 'darwin'
else: # "Linux", "SunOS", "FreeBSD", etc.
# Setting this to "linux2" is not ideal, but only Windows or Mac
# are actually checked for and the rest of the module expects
# *sys.platform* style strings.
system = 'linux2'
else:
system = sys.platform
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user data directories are:
Mac OS X: ~/Library/Application Support/<AppName>
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
That means, by default "~/.local/share/<AppName>".
"""
if system == "win32":
if appauthor is None:
appauthor = appname
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
path = os.path.normpath(_get_win_folder(const))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
elif system == 'darwin':
path = os.path.expanduser('~/Library/Application Support/')
if appname:
path = os.path.join(path, appname)
else:
path = os.getenv('XDG_DATA_HOME', os.path.expanduser("~/.local/share"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"multipath" is an optional parameter only applicable to *nix
which indicates that the entire list of data dirs should be
returned. By default, the first item from XDG_DATA_DIRS is
returned, or '/usr/local/share/<AppName>',
if XDG_DATA_DIRS is not set
Typical site data directories are:
Mac OS X: /Library/Application Support/<AppName>
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
For Unix, this is using the $XDG_DATA_DIRS[0] default.
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
"""
if system == "win32":
if appauthor is None:
appauthor = appname
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
elif system == 'darwin':
path = os.path.expanduser('/Library/Application Support')
if appname:
path = os.path.join(path, appname)
else:
# XDG default for $XDG_DATA_DIRS
# only first, if multipath is False
path = os.getenv('XDG_DATA_DIRS',
os.pathsep.join(['/usr/local/share', '/usr/share']))
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
if appname:
if version:
appname = os.path.join(appname, version)
pathlist = [os.sep.join([x, appname]) for x in pathlist]
if multipath:
path = os.pathsep.join(pathlist)
else:
path = pathlist[0]
return path
if appname and version:
path = os.path.join(path, version)
return path
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific config dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user config directories are:
Mac OS X: same as user_data_dir
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
Win *: same as user_data_dir
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
That means, by default "~/.config/<AppName>".
"""
if system in ["win32", "darwin"]:
path = user_data_dir(appname, appauthor, None, roaming)
else:
path = os.getenv('XDG_CONFIG_HOME', os.path.expanduser("~/.config"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"multipath" is an optional parameter only applicable to *nix
which indicates that the entire list of config dirs should be
returned. By default, the first item from XDG_CONFIG_DIRS is
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
Typical site config directories are:
Mac OS X: same as site_data_dir
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
$XDG_CONFIG_DIRS
Win *: same as site_data_dir
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
"""
if system in ["win32", "darwin"]:
path = site_data_dir(appname, appauthor)
if appname and version:
path = os.path.join(path, version)
else:
# XDG default for $XDG_CONFIG_DIRS
# only first, if multipath is False
path = os.getenv('XDG_CONFIG_DIRS', '/etc/xdg')
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
if appname:
if version:
appname = os.path.join(appname, version)
pathlist = [os.sep.join([x, appname]) for x in pathlist]
if multipath:
path = os.pathsep.join(pathlist)
else:
path = pathlist[0]
return path
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
r"""Return full path to the user-specific cache dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"opinion" (boolean) can be False to disable the appending of
"Cache" to the base app data dir for Windows. See
discussion below.
Typical user cache directories are:
Mac OS X: ~/Library/Caches/<AppName>
Unix: ~/.cache/<AppName> (XDG default)
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
On Windows the only suggestion in the MSDN docs is that local settings go in
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
app data dir (the default returned by `user_data_dir` above). Apps typically
put cache data somewhere *under* the given dir here. Some examples:
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
...\Acme\SuperApp\Cache\1.0
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
This can be disabled with the `opinion=False` option.
"""
if system == "win32":
if appauthor is None:
appauthor = appname
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
if opinion:
path = os.path.join(path, "Cache")
elif system == 'darwin':
path = os.path.expanduser('~/Library/Caches')
if appname:
path = os.path.join(path, appname)
else:
path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific state dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user state directories are:
Mac OS X: same as user_data_dir
Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
Win *: same as user_data_dir
For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
to extend the XDG spec and support $XDG_STATE_HOME.
That means, by default "~/.local/state/<AppName>".
"""
if system in ["win32", "darwin"]:
path = user_data_dir(appname, appauthor, None, roaming)
else:
path = os.getenv('XDG_STATE_HOME', os.path.expanduser("~/.local/state"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
r"""Return full path to the user-specific log dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"opinion" (boolean) can be False to disable the appending of
"Logs" to the base app data dir for Windows, and "log" to the
base cache dir for Unix. See discussion below.
Typical user log directories are:
Mac OS X: ~/Library/Logs/<AppName>
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
On Windows the only suggestion in the MSDN docs is that local settings
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
examples of what some windows apps use for a logs dir.)
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
value for Windows and appends "log" to the user cache dir for Unix.
This can be disabled with the `opinion=False` option.
"""
if system == "darwin":
path = os.path.join(
os.path.expanduser('~/Library/Logs'),
appname)
elif system == "win32":
path = user_data_dir(appname, appauthor, version)
version = False
if opinion:
path = os.path.join(path, "Logs")
else:
path = user_cache_dir(appname, appauthor, version)
version = False
if opinion:
path = os.path.join(path, "log")
if appname and version:
path = os.path.join(path, version)
return path
class AppDirs(object):
"""Convenience wrapper for getting application dirs."""
def __init__(self, appname=None, appauthor=None, version=None,
roaming=False, multipath=False):
self.appname = appname
self.appauthor = appauthor
self.version = version
self.roaming = roaming
self.multipath = multipath
@property
def user_data_dir(self):
return user_data_dir(self.appname, self.appauthor,
version=self.version, roaming=self.roaming)
@property
def site_data_dir(self):
return site_data_dir(self.appname, self.appauthor,
version=self.version, multipath=self.multipath)
@property
def user_config_dir(self):
return user_config_dir(self.appname, self.appauthor,
version=self.version, roaming=self.roaming)
@property
def site_config_dir(self):
return site_config_dir(self.appname, self.appauthor,
version=self.version, multipath=self.multipath)
@property
def user_cache_dir(self):
return user_cache_dir(self.appname, self.appauthor,
version=self.version)
@property
def user_state_dir(self):
return user_state_dir(self.appname, self.appauthor,
version=self.version)
@property
def user_log_dir(self):
return user_log_dir(self.appname, self.appauthor,
version=self.version)
#---- internal support stuff
def _get_win_folder_from_registry(csidl_name):
"""This is a fallback technique at best. I'm not sure if using the
registry for this guarantees us the correct answer for all CSIDL_*
names.
"""
if PY3:
import winreg as _winreg
else:
import _winreg
shell_folder_name = {
"CSIDL_APPDATA": "AppData",
"CSIDL_COMMON_APPDATA": "Common AppData",
"CSIDL_LOCAL_APPDATA": "Local AppData",
}[csidl_name]
key = _winreg.OpenKey(
_winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
return dir
def _get_win_folder_with_pywin32(csidl_name):
from win32com.shell import shellcon, shell
dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
# Try to make this a unicode path because SHGetFolderPath does
# not return unicode strings when there is unicode data in the
# path.
try:
dir = unicode(dir)
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
try:
import win32api
dir = win32api.GetShortPathName(dir)
except ImportError:
pass
except UnicodeError:
pass
return dir
def _get_win_folder_with_ctypes(csidl_name):
import ctypes
csidl_const = {
"CSIDL_APPDATA": 26,
"CSIDL_COMMON_APPDATA": 35,
"CSIDL_LOCAL_APPDATA": 28,
}[csidl_name]
buf = ctypes.create_unicode_buffer(1024)
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in buf:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf2 = ctypes.create_unicode_buffer(1024)
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
buf = buf2
return buf.value
def _get_win_folder_with_jna(csidl_name):
import array
from com.sun import jna
from com.sun.jna.platform import win32
buf_size = win32.WinDef.MAX_PATH * 2
buf = array.zeros('c', buf_size)
shell = win32.Shell32.INSTANCE
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf = array.zeros('c', buf_size)
kernel = win32.Kernel32.INSTANCE
if kernel.GetShortPathName(dir, buf, buf_size):
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
return dir
if system == "win32":
try:
import win32com.shell
_get_win_folder = _get_win_folder_with_pywin32
except ImportError:
try:
from ctypes import windll
_get_win_folder = _get_win_folder_with_ctypes
except ImportError:
try:
import com.sun.jna
_get_win_folder = _get_win_folder_with_jna
except ImportError:
_get_win_folder = _get_win_folder_from_registry
#---- self test code
if __name__ == "__main__":
appname = "MyApp"
appauthor = "MyCompany"
props = ("user_data_dir",
"user_config_dir",
"user_cache_dir",
"user_state_dir",
"user_log_dir",
"site_data_dir",
"site_config_dir")
print("-- app dirs %s --" % __version__)
print("-- app dirs (with optional 'version')")
dirs = AppDirs(appname, appauthor, version="1.0")
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))
print("\n-- app dirs (without optional 'version')")
dirs = AppDirs(appname, appauthor)
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))
print("\n-- app dirs (without optional 'appauthor')")
dirs = AppDirs(appname)
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))
print("\n-- app dirs (with disabled 'appauthor')")
dirs = AppDirs(appname, appauthor=False)
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))

Binary file not shown.

BIN
libs/bin/pbr.exe Normal file

Binary file not shown.

BIN
libs/bin/srt.exe Normal file

Binary file not shown.

BIN
libs/bin/subliminal.exe Normal file

Binary file not shown.

View file

@ -21,14 +21,15 @@ http://www.crummy.com/software/BeautifulSoup/bs4/doc/
# found in the LICENSE file.
__author__ = "Leonard Richardson (leonardr@segfault.org)"
__version__ = "4.5.1"
__copyright__ = "Copyright (c) 2004-2016 Leonard Richardson"
__version__ = "4.6.3"
__copyright__ = "Copyright (c) 2004-2018 Leonard Richardson"
__license__ = "MIT"
__all__ = ['BeautifulSoup']
import os
import re
import sys
import traceback
import warnings
@ -50,7 +51,7 @@ from .element import (
# The very first thing we do is give a useful error if someone is
# running this code under Python 3 without converting it.
'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'<>'You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).'
'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'!='You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).'
class BeautifulSoup(Tag):
"""
@ -74,7 +75,7 @@ class BeautifulSoup(Tag):
like HTML's <br> tag), call handle_starttag and then
handle_endtag.
"""
ROOT_TAG_NAME = u'[document]'
ROOT_TAG_NAME = '[document]'
# If the end-user gives no indication which tree builder they
# want, look for one with these features.
@ -82,14 +83,46 @@ class BeautifulSoup(Tag):
ASCII_SPACES = '\x20\x0a\x09\x0c\x0d'
NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nThe code that caused this warning is on line %(line_number)s of the file %(filename)s. To get rid of this warning, change code that looks like this:\n\n BeautifulSoup([your markup])\n\nto this:\n\n BeautifulSoup([your markup], \"%(parser)s\")\n"
NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nThe code that caused this warning is on line %(line_number)s of the file %(filename)s. To get rid of this warning, pass the additional argument 'features=\"%(parser)s\"' to the BeautifulSoup constructor.\n"
def __init__(self, markup="", features=None, builder=None,
parse_only=None, from_encoding=None, exclude_encodings=None,
**kwargs):
"""The Soup object is initialized as the 'root tag', and the
provided markup (which can be a string or a file-like object)
is fed into the underlying parser."""
"""Constructor.
:param markup: A string or a file-like object representing
markup to be parsed.
:param features: Desirable features of the parser to be used. This
may be the name of a specific parser ("lxml", "lxml-xml",
"html.parser", or "html5lib") or it may be the type of markup
to be used ("html", "html5", "xml"). It's recommended that you
name a specific parser, so that Beautiful Soup gives you the
same results across platforms and virtual environments.
:param builder: A specific TreeBuilder to use instead of looking one
up based on `features`. You shouldn't need to use this.
:param parse_only: A SoupStrainer. Only parts of the document
matching the SoupStrainer will be considered. This is useful
when parsing part of a document that would otherwise be too
large to fit into memory.
:param from_encoding: A string indicating the encoding of the
document to be parsed. Pass this in if Beautiful Soup is
guessing wrongly about the document's encoding.
:param exclude_encodings: A list of strings indicating
encodings known to be wrong. Pass this in if you don't know
the document's encoding but you know Beautiful Soup's guess is
wrong.
:param kwargs: For backwards compatibility purposes, the
constructor accepts certain keyword arguments used in
Beautiful Soup 3. None of these arguments do anything in
Beautiful Soup 4 and there's no need to actually pass keyword
arguments into the constructor.
"""
if 'convertEntities' in kwargs:
warnings.warn(
@ -142,18 +175,18 @@ class BeautifulSoup(Tag):
from_encoding = from_encoding or deprecated_argument(
"fromEncoding", "from_encoding")
if from_encoding and isinstance(markup, unicode):
if from_encoding and isinstance(markup, str):
warnings.warn("You provided Unicode markup but also provided a value for from_encoding. Your from_encoding will be ignored.")
from_encoding = None
if len(kwargs) > 0:
arg = kwargs.keys().pop()
arg = list(kwargs.keys()).pop()
raise TypeError(
"__init__() got an unexpected keyword argument '%s'" % arg)
if builder is None:
original_features = features
if isinstance(features, basestring):
if isinstance(features, str):
features = [features]
if features is None or len(features) == 0:
features = self.DEFAULT_BUILDER_FEATURES
@ -171,14 +204,35 @@ class BeautifulSoup(Tag):
else:
markup_type = "HTML"
caller = traceback.extract_stack()[0]
filename = caller[0]
line_number = caller[1]
warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % dict(
filename=filename,
line_number=line_number,
parser=builder.NAME,
markup_type=markup_type))
# This code adapted from warnings.py so that we get the same line
# of code as our warnings.warn() call gets, even if the answer is wrong
# (as it may be in a multithreading situation).
caller = None
try:
caller = sys._getframe(1)
except ValueError:
pass
if caller:
globals = caller.f_globals
line_number = caller.f_lineno
else:
globals = sys.__dict__
line_number= 1
filename = globals.get('__file__')
if filename:
fnl = filename.lower()
if fnl.endswith((".pyc", ".pyo")):
filename = filename[:-1]
if filename:
# If there is no filename at all, the user is most likely in a REPL,
# and the warning is not necessary.
values = dict(
filename=filename,
line_number=line_number,
parser=builder.NAME,
markup_type=markup_type
)
warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % values, stacklevel=2)
self.builder = builder
self.is_xml = builder.is_xml
@ -191,13 +245,13 @@ class BeautifulSoup(Tag):
markup = markup.read()
elif len(markup) <= 256 and (
(isinstance(markup, bytes) and not b'<' in markup)
or (isinstance(markup, unicode) and not u'<' in markup)
or (isinstance(markup, str) and not '<' in markup)
):
# Print out warnings for a couple beginner problems
# involving passing non-markup to Beautiful Soup.
# Beautiful Soup will still parse the input as markup,
# just in case that's what the user really wants.
if (isinstance(markup, unicode)
if (isinstance(markup, str)
and not os.path.supports_unicode_filenames):
possible_filename = markup.encode("utf8")
else:
@ -205,18 +259,18 @@ class BeautifulSoup(Tag):
is_file = False
try:
is_file = os.path.exists(possible_filename)
except Exception, e:
except Exception as e:
# This is almost certainly a problem involving
# characters not valid in filenames on this
# system. Just let it go.
pass
if is_file:
if isinstance(markup, unicode):
if isinstance(markup, str):
markup = markup.encode("utf8")
warnings.warn(
'"%s" looks like a filename, not markup. You should'
'probably open this file and pass the filehandle into'
'Beautiful Soup.' % markup)
' probably open this file and pass the filehandle into'
' Beautiful Soup.' % markup)
self._check_markup_is_url(markup)
for (self.markup, self.original_encoding, self.declared_html_encoding,
@ -263,9 +317,9 @@ class BeautifulSoup(Tag):
if isinstance(markup, bytes):
space = b' '
cant_start_with = (b"http:", b"https:")
elif isinstance(markup, unicode):
space = u' '
cant_start_with = (u"http:", u"https:")
elif isinstance(markup, str):
space = ' '
cant_start_with = ("http:", "https:")
else:
return
@ -302,9 +356,10 @@ class BeautifulSoup(Tag):
self.preserve_whitespace_tag_stack = []
self.pushTag(self)
def new_tag(self, name, namespace=None, nsprefix=None, **attrs):
def new_tag(self, name, namespace=None, nsprefix=None, attrs={}, **kwattrs):
"""Create a new tag associated with this soup."""
return Tag(None, self.builder, name, namespace, nsprefix, attrs)
kwattrs.update(attrs)
return Tag(None, self.builder, name, namespace, nsprefix, kwattrs)
def new_string(self, s, subclass=NavigableString):
"""Create a new NavigableString associated with this soup."""
@ -336,7 +391,7 @@ class BeautifulSoup(Tag):
def endData(self, containerClass=NavigableString):
if self.current_data:
current_data = u''.join(self.current_data)
current_data = ''.join(self.current_data)
# If whitespace is not preserved, and this string contains
# nothing but ASCII spaces, replace it with a single space
# or newline.
@ -490,9 +545,9 @@ class BeautifulSoup(Tag):
encoding_part = ''
if eventual_encoding != None:
encoding_part = ' encoding="%s"' % eventual_encoding
prefix = u'<?xml version="1.0"%s?>\n' % encoding_part
prefix = '<?xml version="1.0"%s?>\n' % encoding_part
else:
prefix = u''
prefix = ''
if not pretty_print:
indent_level = None
else:
@ -526,4 +581,4 @@ class FeatureNotFound(ValueError):
if __name__ == '__main__':
import sys
soup = BeautifulSoup(sys.stdin)
print soup.prettify()
print(soup.prettify())

View file

@ -160,13 +160,13 @@ class TreeBuilder(object):
universal = self.cdata_list_attributes.get('*', [])
tag_specific = self.cdata_list_attributes.get(
tag_name.lower(), None)
for attr in attrs.keys():
for attr in list(attrs.keys()):
if attr in universal or (tag_specific and attr in tag_specific):
# We have a "class"-type attribute whose string
# value is a whitespace-separated list of
# values. Split it into a list.
value = attrs[attr]
if isinstance(value, basestring):
if isinstance(value, str):
values = whitespace_re.split(value)
else:
# html5lib sometimes calls setAttributes twice
@ -232,8 +232,19 @@ class HTMLTreeBuilder(TreeBuilder):
"""
preserve_whitespace_tags = HTMLAwareEntitySubstitution.preserve_whitespace_tags
empty_element_tags = set(['br' , 'hr', 'input', 'img', 'meta',
'spacer', 'link', 'frame', 'base'])
empty_element_tags = set([
# These are from HTML5.
'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'keygen', 'link', 'menuitem', 'meta', 'param', 'source', 'track', 'wbr',
# These are from earlier versions of HTML and are removed in HTML5.
'basefont', 'bgsound', 'command', 'frame', 'image', 'isindex', 'nextid', 'spacer'
])
# The HTML standard defines these as block-level elements. Beautiful
# Soup does not treat these elements differently from other elements,
# but it may do so eventually, and this information is available if
# you need to use it.
block_elements = set(["address", "article", "aside", "blockquote", "canvas", "dd", "div", "dl", "dt", "fieldset", "figcaption", "figure", "footer", "form", "h1", "h2", "h3", "h4", "h5", "h6", "header", "hr", "li", "main", "nav", "noscript", "ol", "output", "p", "pre", "section", "table", "tfoot", "ul", "video"])
# The HTML standard defines these attributes as containing a
# space-separated list of values, not a single value. That is,

View file

@ -6,6 +6,7 @@ __all__ = [
]
import warnings
import re
from bs4.builder import (
PERMISSIVE,
HTML,
@ -17,7 +18,10 @@ from bs4.element import (
whitespace_re,
)
import html5lib
from html5lib.constants import namespaces
from html5lib.constants import (
namespaces,
prefixes,
)
from bs4.element import (
Comment,
Doctype,
@ -29,7 +33,7 @@ try:
# Pre-0.99999999
from html5lib.treebuilders import _base as treebuilder_base
new_html5lib = False
except ImportError, e:
except ImportError as e:
# 0.99999999 and up
from html5lib.treebuilders import base as treebuilder_base
new_html5lib = True
@ -60,7 +64,7 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
parser = html5lib.HTMLParser(tree=self.create_treebuilder)
extra_kwargs = dict()
if not isinstance(markup, unicode):
if not isinstance(markup, str):
if new_html5lib:
extra_kwargs['override_encoding'] = self.user_specified_encoding
else:
@ -68,13 +72,13 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
doc = parser.parse(markup, **extra_kwargs)
# Set the character encoding detected by the tokenizer.
if isinstance(markup, unicode):
if isinstance(markup, str):
# We need to special-case this because html5lib sets
# charEncoding to UTF-8 if it gets Unicode input.
doc.original_encoding = None
else:
original_encoding = parser.tokenizer.stream.charEncoding[0]
if not isinstance(original_encoding, basestring):
if not isinstance(original_encoding, str):
# In 0.99999999 and up, the encoding is an html5lib
# Encoding object. We want to use a string for compatibility
# with other tree builders.
@ -83,18 +87,22 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
def create_treebuilder(self, namespaceHTMLElements):
self.underlying_builder = TreeBuilderForHtml5lib(
self.soup, namespaceHTMLElements)
namespaceHTMLElements, self.soup)
return self.underlying_builder
def test_fragment_to_document(self, fragment):
"""See `TreeBuilder`."""
return u'<html><head></head><body>%s</body></html>' % fragment
return '<html><head></head><body>%s</body></html>' % fragment
class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
def __init__(self, soup, namespaceHTMLElements):
self.soup = soup
def __init__(self, namespaceHTMLElements, soup=None):
if soup:
self.soup = soup
else:
from bs4 import BeautifulSoup
self.soup = BeautifulSoup("", "html.parser")
super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements)
def documentClass(self):
@ -117,7 +125,8 @@ class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
return TextNode(Comment(data), self.soup)
def fragmentClass(self):
self.soup = BeautifulSoup("")
from bs4 import BeautifulSoup
self.soup = BeautifulSoup("", "html.parser")
self.soup.name = "[document_fragment]"
return Element(self.soup, self.soup, None)
@ -131,6 +140,56 @@ class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
def getFragment(self):
return treebuilder_base.TreeBuilder.getFragment(self).element
def testSerializer(self, element):
from bs4 import BeautifulSoup
rv = []
doctype_re = re.compile(r'^(.*?)(?: PUBLIC "(.*?)"(?: "(.*?)")?| SYSTEM "(.*?)")?$')
def serializeElement(element, indent=0):
if isinstance(element, BeautifulSoup):
pass
if isinstance(element, Doctype):
m = doctype_re.match(element)
if m:
name = m.group(1)
if m.lastindex > 1:
publicId = m.group(2) or ""
systemId = m.group(3) or m.group(4) or ""
rv.append("""|%s<!DOCTYPE %s "%s" "%s">""" %
(' ' * indent, name, publicId, systemId))
else:
rv.append("|%s<!DOCTYPE %s>" % (' ' * indent, name))
else:
rv.append("|%s<!DOCTYPE >" % (' ' * indent,))
elif isinstance(element, Comment):
rv.append("|%s<!-- %s -->" % (' ' * indent, element))
elif isinstance(element, NavigableString):
rv.append("|%s\"%s\"" % (' ' * indent, element))
else:
if element.namespace:
name = "%s %s" % (prefixes[element.namespace],
element.name)
else:
name = element.name
rv.append("|%s<%s>" % (' ' * indent, name))
if element.attrs:
attributes = []
for name, value in list(element.attrs.items()):
if isinstance(name, NamespacedAttribute):
name = "%s %s" % (prefixes[name.namespace], name.name)
if isinstance(value, list):
value = " ".join(value)
attributes.append((name, value))
for name, value in sorted(attributes):
rv.append('|%s%s="%s"' % (' ' * (indent + 2), name, value))
indent += 2
for child in element.children:
serializeElement(child, indent)
serializeElement(element, 0)
return "\n".join(rv)
class AttrList(object):
def __init__(self, element):
self.element = element
@ -170,7 +229,7 @@ class Element(treebuilder_base.Node):
def appendChild(self, node):
string_child = child = None
if isinstance(node, basestring):
if isinstance(node, str):
# Some other piece of code decided to pass in a string
# instead of creating a TextElement object to contain the
# string.
@ -182,10 +241,12 @@ class Element(treebuilder_base.Node):
child = node
elif node.element.__class__ == NavigableString:
string_child = child = node.element
node.parent = self
else:
child = node.element
node.parent = self
if not isinstance(child, basestring) and child.parent is not None:
if not isinstance(child, str) and child.parent is not None:
node.element.extract()
if (string_child and self.element.contents
@ -198,7 +259,7 @@ class Element(treebuilder_base.Node):
old_element.replace_with(new_element)
self.soup._most_recent_element = new_element
else:
if isinstance(node, basestring):
if isinstance(node, str):
# Create a brand new NavigableString from this string.
child = self.soup.new_string(node)
@ -221,6 +282,8 @@ class Element(treebuilder_base.Node):
most_recent_element=most_recent_element)
def getAttributes(self):
if isinstance(self.element, Comment):
return {}
return AttrList(self.element)
def setAttributes(self, attributes):
@ -236,7 +299,7 @@ class Element(treebuilder_base.Node):
self.soup.builder._replace_cdata_list_attribute_values(
self.name, attributes)
for name, value in attributes.items():
for name, value in list(attributes.items()):
self.element[name] = value
# The attributes may contain variables that need substitution.
@ -248,11 +311,11 @@ class Element(treebuilder_base.Node):
attributes = property(getAttributes, setAttributes)
def insertText(self, data, insertBefore=None):
text = TextNode(self.soup.new_string(data), self.soup)
if insertBefore:
text = TextNode(self.soup.new_string(data), self.soup)
self.insertBefore(data, insertBefore)
self.insertBefore(text, insertBefore)
else:
self.appendChild(data)
self.appendChild(text)
def insertBefore(self, node, refNode):
index = self.element.index(refNode.element)
@ -274,6 +337,7 @@ class Element(treebuilder_base.Node):
# print "MOVE", self.element.contents
# print "FROM", self.element
# print "TO", new_parent.element
element = self.element
new_parent_element = new_parent.element
# Determine what this tag's next_element will be once all the children
@ -292,7 +356,6 @@ class Element(treebuilder_base.Node):
new_parents_last_descendant_next_element = new_parent_element.next_element
to_append = element.contents
append_after = new_parent_element.contents
if len(to_append) > 0:
# Set the first child's previous_element and previous_sibling
# to elements within the new parent
@ -309,12 +372,19 @@ class Element(treebuilder_base.Node):
if new_parents_last_child:
new_parents_last_child.next_sibling = first_child
# Fix the last child's next_element and next_sibling
last_child = to_append[-1]
last_child.next_element = new_parents_last_descendant_next_element
# Find the very last element being moved. It is now the
# parent's last descendant. It has no .next_sibling and
# its .next_element is whatever the previous last
# descendant had.
last_childs_last_descendant = to_append[-1]._last_descendant(False, True)
last_childs_last_descendant.next_element = new_parents_last_descendant_next_element
if new_parents_last_descendant_next_element:
new_parents_last_descendant_next_element.previous_element = last_child
last_child.next_sibling = None
# TODO: This code has no test coverage and I'm not sure
# how to get html5lib to go through this path, but it's
# just the other side of the previous line.
new_parents_last_descendant_next_element.previous_element = last_childs_last_descendant
last_childs_last_descendant.next_sibling = None
for child in to_append:
child.parent = new_parent_element

View file

@ -1,3 +1,4 @@
# encoding: utf-8
"""Use the HTMLParser library to parse HTML files that aren't too bad."""
# Use of this source code is governed by a BSD-style license that can be
@ -7,11 +8,11 @@ __all__ = [
'HTMLParserTreeBuilder',
]
from HTMLParser import HTMLParser
from html.parser import HTMLParser
try:
from HTMLParser import HTMLParseError
except ImportError, e:
from html.parser import HTMLParseError
except ImportError as e:
# HTMLParseError is removed in Python 3.5. Since it can never be
# thrown in 3.5, we can just define our own class as a placeholder.
class HTMLParseError(Exception):
@ -52,7 +53,42 @@ from bs4.builder import (
HTMLPARSER = 'html.parser'
class BeautifulSoupHTMLParser(HTMLParser):
def handle_starttag(self, name, attrs):
def __init__(self, *args, **kwargs):
HTMLParser.__init__(self, *args, **kwargs)
# Keep a list of empty-element tags that were encountered
# without an explicit closing tag. If we encounter a closing tag
# of this type, we'll associate it with one of those entries.
#
# This isn't a stack because we don't care about the
# order. It's a list of closing tags we've already handled and
# will ignore, assuming they ever show up.
self.already_closed_empty_element = []
def error(self, msg):
"""In Python 3, HTMLParser subclasses must implement error(), although this
requirement doesn't appear to be documented.
In Python 2, HTMLParser implements error() as raising an exception.
In any event, this method is called only on very strange markup and our best strategy
is to pretend it didn't happen and keep going.
"""
warnings.warn(msg)
def handle_startendtag(self, name, attrs):
# This is only called when the markup looks like
# <tag/>.
# is_startend() tells handle_starttag not to close the tag
# just because its name matches a known empty-element tag. We
# know that this is an empty-element tag and we want to call
# handle_endtag ourselves.
tag = self.handle_starttag(name, attrs, handle_empty_element=False)
self.handle_endtag(name)
def handle_starttag(self, name, attrs, handle_empty_element=True):
# XXX namespace
attr_dict = {}
for key, value in attrs:
@ -62,10 +98,34 @@ class BeautifulSoupHTMLParser(HTMLParser):
value = ''
attr_dict[key] = value
attrvalue = '""'
self.soup.handle_starttag(name, None, None, attr_dict)
#print "START", name
tag = self.soup.handle_starttag(name, None, None, attr_dict)
if tag and tag.is_empty_element and handle_empty_element:
# Unlike other parsers, html.parser doesn't send separate end tag
# events for empty-element tags. (It's handled in
# handle_startendtag, but only if the original markup looked like
# <tag/>.)
#
# So we need to call handle_endtag() ourselves. Since we
# know the start event is identical to the end event, we
# don't want handle_endtag() to cross off any previous end
# events for tags of this name.
self.handle_endtag(name, check_already_closed=False)
def handle_endtag(self, name):
self.soup.handle_endtag(name)
# But we might encounter an explicit closing tag for this tag
# later on. If so, we want to ignore it.
self.already_closed_empty_element.append(name)
def handle_endtag(self, name, check_already_closed=True):
#print "END", name
if check_already_closed and name in self.already_closed_empty_element:
# This is a redundant end tag for an empty-element tag.
# We've already called handle_endtag() for it, so just
# check it off the list.
# print "ALREADY CLOSED", name
self.already_closed_empty_element.remove(name)
else:
self.soup.handle_endtag(name)
def handle_data(self, data):
self.soup.handle_data(data)
@ -81,11 +141,26 @@ class BeautifulSoupHTMLParser(HTMLParser):
else:
real_name = int(name)
try:
data = unichr(real_name)
except (ValueError, OverflowError), e:
data = u"\N{REPLACEMENT CHARACTER}"
data = None
if real_name < 256:
# HTML numeric entities are supposed to reference Unicode
# code points, but sometimes they reference code points in
# some other encoding (ahem, Windows-1252). E.g. &#147;
# instead of &#201; for LEFT DOUBLE QUOTATION MARK. This
# code tries to detect this situation and compensate.
for encoding in (self.soup.original_encoding, 'windows-1252'):
if not encoding:
continue
try:
data = bytearray([real_name]).decode(encoding)
except UnicodeDecodeError as e:
pass
if not data:
try:
data = chr(real_name)
except (ValueError, OverflowError) as e:
pass
data = data or "\N{REPLACEMENT CHARACTER}"
self.handle_data(data)
def handle_entityref(self, name):
@ -93,7 +168,12 @@ class BeautifulSoupHTMLParser(HTMLParser):
if character is not None:
data = character
else:
data = "&%s;" % name
# If this were XML, it would be ambiguous whether "&foo"
# was an character entity reference with a missing
# semicolon or the literal string "&foo". Since this is
# HTML, we have a complete list of all character entity references,
# and this one wasn't found, so assume it's the literal string "&foo".
data = "&%s" % name
self.handle_data(data)
def handle_comment(self, data):
@ -148,7 +228,7 @@ class HTMLParserTreeBuilder(HTMLTreeBuilder):
declared within markup, whether any characters had to be
replaced with REPLACEMENT CHARACTER).
"""
if isinstance(markup, unicode):
if isinstance(markup, str):
yield (markup, None, None, False)
return
@ -165,10 +245,12 @@ class HTMLParserTreeBuilder(HTMLTreeBuilder):
parser.soup = self.soup
try:
parser.feed(markup)
except HTMLParseError, e:
parser.close()
except HTMLParseError as e:
warnings.warn(RuntimeWarning(
"Python's built-in HTMLParser cannot parse the given document. This is not a bug in Beautiful Soup. The best solution is to install an external parser (lxml or html5lib), and use Beautiful Soup with that parser. See http://www.crummy.com/software/BeautifulSoup/bs4/doc/#installing-a-parser for help."))
raise e
parser.already_closed_empty_element = []
# Patch 3.2 versions of HTMLParser earlier than 3.2.3 to use some
# 3.2.3 code. This ensures they don't treat markup like <p></p> as a

View file

@ -5,9 +5,13 @@ __all__ = [
'LXMLTreeBuilder',
]
try:
from collections.abc import Callable # Python 3.6
except ImportError as e:
from collections import Callable
from io import BytesIO
from StringIO import StringIO
import collections
from io import StringIO
from lxml import etree
from bs4.element import (
Comment,
@ -58,7 +62,7 @@ class LXMLTreeBuilderForXML(TreeBuilder):
# Use the default parser.
parser = self.default_parser(encoding)
if isinstance(parser, collections.Callable):
if isinstance(parser, Callable):
# Instantiate the parser with default arguments
parser = parser(target=self, strip_cdata=False, encoding=encoding)
return parser
@ -101,12 +105,12 @@ class LXMLTreeBuilderForXML(TreeBuilder):
else:
self.processing_instruction_class = XMLProcessingInstruction
if isinstance(markup, unicode):
if isinstance(markup, str):
# We were given Unicode. Maybe lxml can parse Unicode on
# this system?
yield markup, None, document_declared_encoding, False
if isinstance(markup, unicode):
if isinstance(markup, str):
# No, apparently not. Convert the Unicode to UTF-8 and
# tell lxml to parse it as UTF-8.
yield (markup.encode("utf8"), "utf8",
@ -121,7 +125,7 @@ class LXMLTreeBuilderForXML(TreeBuilder):
def feed(self, markup):
if isinstance(markup, bytes):
markup = BytesIO(markup)
elif isinstance(markup, unicode):
elif isinstance(markup, str):
markup = StringIO(markup)
# Call feed() at least once, even if the markup is empty,
@ -136,7 +140,7 @@ class LXMLTreeBuilderForXML(TreeBuilder):
if len(data) != 0:
self.parser.feed(data)
self.parser.close()
except (UnicodeDecodeError, LookupError, etree.ParserError), e:
except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
raise ParserRejectedMarkup(str(e))
def close(self):
@ -147,19 +151,19 @@ class LXMLTreeBuilderForXML(TreeBuilder):
attrs = dict(attrs)
nsprefix = None
# Invert each namespace map as it comes in.
if len(self.nsmaps) > 1:
# There are no new namespaces for this tag, but
# non-default namespaces are in play, so we need a
# separate tag stack to know when they end.
self.nsmaps.append(None)
if len(nsmap) == 0 and len(self.nsmaps) > 1:
# There are no new namespaces for this tag, but
# non-default namespaces are in play, so we need a
# separate tag stack to know when they end.
self.nsmaps.append(None)
elif len(nsmap) > 0:
# A new namespace mapping has come into play.
inverted_nsmap = dict((value, key) for key, value in nsmap.items())
inverted_nsmap = dict((value, key) for key, value in list(nsmap.items()))
self.nsmaps.append(inverted_nsmap)
# Also treat the namespace mapping as a set of attributes on the
# tag, so we can recreate it later.
attrs = attrs.copy()
for prefix, namespace in nsmap.items():
for prefix, namespace in list(nsmap.items()):
attribute = NamespacedAttribute(
"xmlns", prefix, "http://www.w3.org/2000/xmlns/")
attrs[attribute] = namespace
@ -168,7 +172,7 @@ class LXMLTreeBuilderForXML(TreeBuilder):
# from lxml with namespaces attached to their names, and
# turn then into NamespacedAttribute objects.
new_attrs = {}
for attr, value in attrs.items():
for attr, value in list(attrs.items()):
namespace, attr = self._getNsTag(attr)
if namespace is None:
new_attrs[attr] = value
@ -228,7 +232,7 @@ class LXMLTreeBuilderForXML(TreeBuilder):
def test_fragment_to_document(self, fragment):
"""See `TreeBuilder`."""
return u'<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment
return '<?xml version="1.0" encoding="utf-8"?>\n%s' % fragment
class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
@ -249,10 +253,10 @@ class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
self.parser = self.parser_for(encoding)
self.parser.feed(markup)
self.parser.close()
except (UnicodeDecodeError, LookupError, etree.ParserError), e:
except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
raise ParserRejectedMarkup(str(e))
def test_fragment_to_document(self, fragment):
"""See `TreeBuilder`."""
return u'<html><body>%s</body></html>' % fragment
return '<html><body>%s</body></html>' % fragment

View file

@ -11,7 +11,7 @@ XML or HTML to reflect a new encoding; that's the tree builder's job.
__license__ = "MIT"
import codecs
from htmlentitydefs import codepoint2name
from html.entities import codepoint2name
import re
import logging
import string
@ -46,9 +46,9 @@ except ImportError:
pass
xml_encoding_re = re.compile(
'^<\?.*encoding=[\'"](.*?)[\'"].*\?>'.encode(), re.I)
'^<\\?.*encoding=[\'"](.*?)[\'"].*\\?>'.encode(), re.I)
html_meta_re = re.compile(
'<\s*meta[^>]+charset\s*=\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I)
'<\\s*meta[^>]+charset\\s*=\\s*["\']?([^>]*?)[ /;\'">]'.encode(), re.I)
class EntitySubstitution(object):
@ -59,7 +59,7 @@ class EntitySubstitution(object):
reverse_lookup = {}
characters_for_re = []
for codepoint, name in list(codepoint2name.items()):
character = unichr(codepoint)
character = chr(codepoint)
if codepoint != 34:
# There's no point in turning the quotation mark into
# &quot;, unless it happens within an attribute value, which
@ -82,7 +82,7 @@ class EntitySubstitution(object):
}
BARE_AMPERSAND_OR_BRACKET = re.compile("([<>]|"
"&(?!#\d+;|#x[0-9a-fA-F]+;|\w+;)"
"&(?!#\\d+;|#x[0-9a-fA-F]+;|\\w+;)"
")")
AMPERSAND_OR_BRACKET = re.compile("([<>&])")
@ -274,7 +274,7 @@ class EncodingDetector:
def strip_byte_order_mark(cls, data):
"""If a byte-order mark is present, strip it and return the encoding it implies."""
encoding = None
if isinstance(data, unicode):
if isinstance(data, str):
# Unicode data cannot have a byte-order mark.
return data, encoding
if (len(data) >= 4) and (data[:2] == b'\xfe\xff') \
@ -352,9 +352,9 @@ class UnicodeDammit:
markup, override_encodings, is_html, exclude_encodings)
# Short-circuit if the data is in Unicode to begin with.
if isinstance(markup, unicode) or markup == '':
if isinstance(markup, str) or markup == '':
self.markup = markup
self.unicode_markup = unicode(markup)
self.unicode_markup = str(markup)
self.original_encoding = None
return
@ -438,7 +438,7 @@ class UnicodeDammit:
def _to_unicode(self, data, encoding, errors="strict"):
'''Given a string and its encoding, decodes the string into Unicode.
%encoding is a string recognized by encodings.aliases'''
return unicode(data, encoding, errors)
return str(data, encoding, errors)
@property
def declared_html_encoding(self):
@ -736,7 +736,7 @@ class UnicodeDammit:
0xde : b'\xc3\x9e', # Þ
0xdf : b'\xc3\x9f', # ß
0xe0 : b'\xc3\xa0', # à
0xe1 : b'\xa1', # á
0xe1 : b'\xa1', # á
0xe2 : b'\xc3\xa2', # â
0xe3 : b'\xc3\xa3', # ã
0xe4 : b'\xc3\xa4', # ä

View file

@ -5,8 +5,8 @@
__license__ = "MIT"
import cProfile
from StringIO import StringIO
from HTMLParser import HTMLParser
from io import StringIO
from html.parser import HTMLParser
import bs4
from bs4 import BeautifulSoup, __version__
from bs4.builder import builder_registry
@ -22,8 +22,8 @@ import cProfile
def diagnose(data):
"""Diagnostic suite for isolating common problems."""
print "Diagnostic running on Beautiful Soup %s" % __version__
print "Python version %s" % sys.version
print("Diagnostic running on Beautiful Soup %s" % __version__)
print("Python version %s" % sys.version)
basic_parsers = ["html.parser", "html5lib", "lxml"]
for name in basic_parsers:
@ -32,16 +32,16 @@ def diagnose(data):
break
else:
basic_parsers.remove(name)
print (
print((
"I noticed that %s is not installed. Installing it may help." %
name)
name))
if 'lxml' in basic_parsers:
basic_parsers.append(["lxml", "xml"])
basic_parsers.append("lxml-xml")
try:
from lxml import etree
print "Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION))
except ImportError, e:
print("Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION)))
except ImportError as e:
print (
"lxml is not installed or couldn't be imported.")
@ -49,37 +49,43 @@ def diagnose(data):
if 'html5lib' in basic_parsers:
try:
import html5lib
print "Found html5lib version %s" % html5lib.__version__
except ImportError, e:
print("Found html5lib version %s" % html5lib.__version__)
except ImportError as e:
print (
"html5lib is not installed or couldn't be imported.")
if hasattr(data, 'read'):
data = data.read()
elif os.path.exists(data):
print '"%s" looks like a filename. Reading data from the file.' % data
with open(data) as fp:
data = fp.read()
elif data.startswith("http:") or data.startswith("https:"):
print '"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data
print "You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup."
print('"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data)
print("You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup.")
return
print
else:
try:
if os.path.exists(data):
print('"%s" looks like a filename. Reading data from the file.' % data)
with open(data) as fp:
data = fp.read()
except ValueError:
# This can happen on some platforms when the 'filename' is
# too long. Assume it's data and not a filename.
pass
print()
for parser in basic_parsers:
print "Trying to parse your markup with %s" % parser
print("Trying to parse your markup with %s" % parser)
success = False
try:
soup = BeautifulSoup(data, parser)
soup = BeautifulSoup(data, features=parser)
success = True
except Exception, e:
print "%s could not parse the markup." % parser
except Exception as e:
print("%s could not parse the markup." % parser)
traceback.print_exc()
if success:
print "Here's what %s did with the markup:" % parser
print soup.prettify()
print("Here's what %s did with the markup:" % parser)
print(soup.prettify())
print "-" * 80
print("-" * 80)
def lxml_trace(data, html=True, **kwargs):
"""Print out the lxml events that occur during parsing.
@ -89,7 +95,7 @@ def lxml_trace(data, html=True, **kwargs):
"""
from lxml import etree
for event, element in etree.iterparse(StringIO(data), html=html, **kwargs):
print("%s, %4s, %s" % (event, element.tag, element.text))
print(("%s, %4s, %s" % (event, element.tag, element.text)))
class AnnouncingParser(HTMLParser):
"""Announces HTMLParser parse events, without doing anything else."""
@ -171,9 +177,9 @@ def rdoc(num_elements=1000):
def benchmark_parsers(num_elements=100000):
"""Very basic head-to-head performance benchmark."""
print "Comparative parser benchmark on Beautiful Soup %s" % __version__
print("Comparative parser benchmark on Beautiful Soup %s" % __version__)
data = rdoc(num_elements)
print "Generated a large invalid HTML document (%d bytes)." % len(data)
print("Generated a large invalid HTML document (%d bytes)." % len(data))
for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]:
success = False
@ -182,24 +188,24 @@ def benchmark_parsers(num_elements=100000):
soup = BeautifulSoup(data, parser)
b = time.time()
success = True
except Exception, e:
print "%s could not parse the markup." % parser
except Exception as e:
print("%s could not parse the markup." % parser)
traceback.print_exc()
if success:
print "BS4+%s parsed the markup in %.2fs." % (parser, b-a)
print("BS4+%s parsed the markup in %.2fs." % (parser, b-a))
from lxml import etree
a = time.time()
etree.HTML(data)
b = time.time()
print "Raw lxml parsed the markup in %.2fs." % (b-a)
print("Raw lxml parsed the markup in %.2fs." % (b-a))
import html5lib
parser = html5lib.HTMLParser()
a = time.time()
parser.parse(data)
b = time.time()
print "Raw html5lib parsed the markup in %.2fs." % (b-a)
print("Raw html5lib parsed the markup in %.2fs." % (b-a))
def profile(num_elements=100000, parser="lxml"):

View file

@ -2,7 +2,10 @@
# found in the LICENSE file.
__license__ = "MIT"
import collections
try:
from collections.abc import Callable # Python 3.6
except ImportError as e:
from collections import Callable
import re
import shlex
import sys
@ -12,7 +15,7 @@ from bs4.dammit import EntitySubstitution
DEFAULT_OUTPUT_ENCODING = "utf-8"
PY3K = (sys.version_info[0] > 2)
whitespace_re = re.compile("\s+")
whitespace_re = re.compile(r"\s+")
def _alias(attr):
"""Alias one attribute name to another for backward compatibility"""
@ -26,22 +29,22 @@ def _alias(attr):
return alias
class NamespacedAttribute(unicode):
class NamespacedAttribute(str):
def __new__(cls, prefix, name, namespace=None):
if name is None:
obj = unicode.__new__(cls, prefix)
obj = str.__new__(cls, prefix)
elif prefix is None:
# Not really namespaced.
obj = unicode.__new__(cls, name)
obj = str.__new__(cls, name)
else:
obj = unicode.__new__(cls, prefix + ":" + name)
obj = str.__new__(cls, prefix + ":" + name)
obj.prefix = prefix
obj.name = name
obj.namespace = namespace
return obj
class AttributeValueWithCharsetSubstitution(unicode):
class AttributeValueWithCharsetSubstitution(str):
"""A stand-in object for a character encoding specified in HTML."""
class CharsetMetaAttributeValue(AttributeValueWithCharsetSubstitution):
@ -52,7 +55,7 @@ class CharsetMetaAttributeValue(AttributeValueWithCharsetSubstitution):
"""
def __new__(cls, original_value):
obj = unicode.__new__(cls, original_value)
obj = str.__new__(cls, original_value)
obj.original_value = original_value
return obj
@ -69,15 +72,15 @@ class ContentMetaAttributeValue(AttributeValueWithCharsetSubstitution):
The value of the 'content' attribute will be one of these objects.
"""
CHARSET_RE = re.compile("((^|;)\s*charset=)([^;]*)", re.M)
CHARSET_RE = re.compile(r"((^|;)\s*charset=)([^;]*)", re.M)
def __new__(cls, original_value):
match = cls.CHARSET_RE.search(original_value)
if match is None:
# No substitution necessary.
return unicode.__new__(unicode, original_value)
return str.__new__(str, original_value)
obj = unicode.__new__(cls, original_value)
obj = str.__new__(cls, original_value)
obj.original_value = original_value
return obj
@ -123,6 +126,41 @@ class HTMLAwareEntitySubstitution(EntitySubstitution):
return cls._substitute_if_appropriate(
ns, EntitySubstitution.substitute_xml)
class Formatter(object):
"""Contains information about how to format a parse tree."""
# By default, represent void elements as <tag/> rather than <tag>
void_element_close_prefix = '/'
def substitute_entities(self, *args, **kwargs):
"""Transform certain characters into named entities."""
raise NotImplementedError()
class HTMLFormatter(Formatter):
"""The default HTML formatter."""
def substitute(self, *args, **kwargs):
return HTMLAwareEntitySubstitution.substitute_html(*args, **kwargs)
class MinimalHTMLFormatter(Formatter):
"""A minimal HTML formatter."""
def substitute(self, *args, **kwargs):
return HTMLAwareEntitySubstitution.substitute_xml(*args, **kwargs)
class HTML5Formatter(HTMLFormatter):
"""An HTML formatter that omits the slash in a void tag."""
void_element_close_prefix = None
class XMLFormatter(Formatter):
"""Substitute only the essential XML entities."""
def substitute(self, *args, **kwargs):
return EntitySubstitution.substitute_xml(*args, **kwargs)
class HTMLXMLFormatter(Formatter):
"""Format XML using HTML rules."""
def substitute(self, *args, **kwargs):
return HTMLAwareEntitySubstitution.substitute_html(*args, **kwargs)
class PageElement(object):
"""Contains the navigational information for some part of the page
(either a tag or a piece of text)"""
@ -132,39 +170,48 @@ class PageElement(object):
#
# "html" - All Unicode characters with corresponding HTML entities
# are converted to those entities on output.
# "html5" - The same as "html", but empty void tags are represented as
# <tag> rather than <tag/>
# "minimal" - Bare ampersands and angle brackets are converted to
# XML entities: &amp; &lt; &gt;
# None - The null formatter. Unicode characters are never
# converted to entities. This is not recommended, but it's
# faster than "minimal".
# A function - This function will be called on every string that
# A callable function - it will be called on every string that needs to undergo entity substitution.
# A Formatter instance - Formatter.substitute(string) will be called on every string that
# needs to undergo entity substitution.
#
# In an HTML document, the default "html" and "minimal" functions
# will leave the contents of <script> and <style> tags alone. For
# an XML document, all tags will be given the same treatment.
# In an HTML document, the default "html", "html5", and "minimal"
# functions will leave the contents of <script> and <style> tags
# alone. For an XML document, all tags will be given the same
# treatment.
HTML_FORMATTERS = {
"html" : HTMLAwareEntitySubstitution.substitute_html,
"minimal" : HTMLAwareEntitySubstitution.substitute_xml,
"html" : HTMLFormatter(),
"html5" : HTML5Formatter(),
"minimal" : MinimalHTMLFormatter(),
None : None
}
XML_FORMATTERS = {
"html" : EntitySubstitution.substitute_html,
"minimal" : EntitySubstitution.substitute_xml,
"html" : HTMLXMLFormatter(),
"minimal" : XMLFormatter(),
None : None
}
def format_string(self, s, formatter='minimal'):
"""Format the given string using the given formatter."""
if not callable(formatter):
if isinstance(formatter, str):
formatter = self._formatter_for_name(formatter)
if formatter is None:
output = s
else:
output = formatter(s)
if callable(formatter):
# Backwards compatibility -- you used to pass in a formatting method.
output = formatter(s)
else:
output = formatter.substitute(s)
return output
@property
@ -194,11 +241,9 @@ class PageElement(object):
def _formatter_for_name(self, name):
"Look up a formatter function based on its name and the tree."
if self._is_xml:
return self.XML_FORMATTERS.get(
name, EntitySubstitution.substitute_xml)
return self.XML_FORMATTERS.get(name, XMLFormatter())
else:
return self.HTML_FORMATTERS.get(
name, HTMLAwareEntitySubstitution.substitute_xml)
return self.HTML_FORMATTERS.get(name, HTMLFormatter())
def setup(self, parent=None, previous_element=None, next_element=None,
previous_sibling=None, next_sibling=None):
@ -312,10 +357,18 @@ class PageElement(object):
raise ValueError("Cannot insert None into a tag.")
if new_child is self:
raise ValueError("Cannot insert a tag into itself.")
if (isinstance(new_child, basestring)
if (isinstance(new_child, str)
and not isinstance(new_child, NavigableString)):
new_child = NavigableString(new_child)
from bs4 import BeautifulSoup
if isinstance(new_child, BeautifulSoup):
# We don't want to end up with a situation where one BeautifulSoup
# object contains another. Insert the children one at a time.
for subchild in list(new_child.contents):
self.insert(position, subchild)
position += 1
return
position = min(position, len(self.contents))
if hasattr(new_child, 'parent') and new_child.parent is not None:
# We're 'inserting' an element that's already one
@ -533,11 +586,25 @@ class PageElement(object):
result = (element for element in generator
if isinstance(element, Tag))
return ResultSet(strainer, result)
elif isinstance(name, basestring):
elif isinstance(name, str):
# Optimization to find all tags with a given name.
if name.count(':') == 1:
# This is a name with a prefix. If this is a namespace-aware document,
# we need to match the local name against tag.name. If not,
# we need to match the fully-qualified name against tag.name.
prefix, local_name = name.split(':', 1)
else:
prefix = None
local_name = name
result = (element for element in generator
if isinstance(element, Tag)
and element.name == name)
and (
element.name == name
) or (
element.name == local_name
and (prefix is None or element.prefix == prefix)
)
)
return ResultSet(strainer, result)
results = ResultSet(strainer)
while True:
@ -684,7 +751,7 @@ class PageElement(object):
return self.parents
class NavigableString(unicode, PageElement):
class NavigableString(str, PageElement):
PREFIX = ''
SUFFIX = ''
@ -702,10 +769,10 @@ class NavigableString(unicode, PageElement):
passed in to the superclass's __new__ or the superclass won't know
how to handle non-ASCII characters.
"""
if isinstance(value, unicode):
u = unicode.__new__(cls, value)
if isinstance(value, str):
u = str.__new__(cls, value)
else:
u = unicode.__new__(cls, value, DEFAULT_OUTPUT_ENCODING)
u = str.__new__(cls, value, DEFAULT_OUTPUT_ENCODING)
u.setup()
return u
@ -716,7 +783,7 @@ class NavigableString(unicode, PageElement):
return type(self)(self)
def __getnewargs__(self):
return (unicode(self),)
return (str(self),)
def __getattr__(self, attr):
"""text.string gives you text. This is for backwards
@ -756,29 +823,29 @@ class PreformattedString(NavigableString):
class CData(PreformattedString):
PREFIX = u'<![CDATA['
SUFFIX = u']]>'
PREFIX = '<![CDATA['
SUFFIX = ']]>'
class ProcessingInstruction(PreformattedString):
"""A SGML processing instruction."""
PREFIX = u'<?'
SUFFIX = u'>'
PREFIX = '<?'
SUFFIX = '>'
class XMLProcessingInstruction(ProcessingInstruction):
"""An XML processing instruction."""
PREFIX = u'<?'
SUFFIX = u'?>'
PREFIX = '<?'
SUFFIX = '?>'
class Comment(PreformattedString):
PREFIX = u'<!--'
SUFFIX = u'-->'
PREFIX = '<!--'
SUFFIX = '-->'
class Declaration(PreformattedString):
PREFIX = u'<?'
SUFFIX = u'?>'
PREFIX = '<?'
SUFFIX = '?>'
class Doctype(PreformattedString):
@ -795,8 +862,8 @@ class Doctype(PreformattedString):
return Doctype(value)
PREFIX = u'<!DOCTYPE '
SUFFIX = u'>\n'
PREFIX = '<!DOCTYPE '
SUFFIX = '>\n'
class Tag(PageElement):
@ -863,7 +930,7 @@ class Tag(PageElement):
Its contents are a copy of the old Tag's contents.
"""
clone = type(self)(None, self.builder, self.name, self.namespace,
self.nsprefix, self.attrs, is_xml=self._is_xml)
self.prefix, self.attrs, is_xml=self._is_xml)
for attr in ('can_be_empty_element', 'hidden'):
setattr(clone, attr, getattr(self, attr))
for child in self.contents:
@ -935,7 +1002,7 @@ class Tag(PageElement):
for string in self._all_strings(True):
yield string
def get_text(self, separator=u"", strip=False,
def get_text(self, separator="", strip=False,
types=(NavigableString, CData)):
"""
Get all child strings, concatenated using the given separator.
@ -985,6 +1052,13 @@ class Tag(PageElement):
attribute."""
return self.attrs.get(key, default)
def get_attribute_list(self, key, default=None):
"""The same as get(), but always returns a list."""
value = self.get(key, default)
if not isinstance(value, list):
value = [value]
return value
def has_attr(self, key):
return key in self.attrs
@ -1007,7 +1081,7 @@ class Tag(PageElement):
def __contains__(self, x):
return x in self.contents
def __nonzero__(self):
def __bool__(self):
"A tag is non-None even if it has no contents."
return True
@ -1032,8 +1106,10 @@ class Tag(PageElement):
# BS3: soup.aTag -> "soup.find("a")
tag_name = tag[:-3]
warnings.warn(
'.%sTag is deprecated, use .find("%s") instead.' % (
tag_name, tag_name))
'.%(name)sTag is deprecated, use .find("%(name)s") instead. If you really were looking for a tag called %(name)sTag, use .find("%(name)sTag")' % dict(
name=tag_name
)
)
return self.find(tag_name)
# We special case contents to avoid recursion.
elif not tag.startswith("__") and not tag == "contents":
@ -1115,11 +1191,10 @@ class Tag(PageElement):
encoding.
"""
# First off, turn a string formatter into a function. This
# First off, turn a string formatter into a Formatter object. This
# will stop the lookup from happening over and over again.
if not callable(formatter):
if not isinstance(formatter, Formatter) and not callable(formatter):
formatter = self._formatter_for_name(formatter)
attrs = []
if self.attrs:
for key, val in sorted(self.attrs.items()):
@ -1128,8 +1203,8 @@ class Tag(PageElement):
else:
if isinstance(val, list) or isinstance(val, tuple):
val = ' '.join(val)
elif not isinstance(val, basestring):
val = unicode(val)
elif not isinstance(val, str):
val = str(val)
elif (
isinstance(val, AttributeValueWithCharsetSubstitution)
and eventual_encoding is not None):
@ -1137,7 +1212,7 @@ class Tag(PageElement):
text = self.format_string(val, formatter)
decoded = (
unicode(key) + '='
str(key) + '='
+ EntitySubstitution.quoted_attribute_value(text))
attrs.append(decoded)
close = ''
@ -1148,7 +1223,9 @@ class Tag(PageElement):
prefix = self.prefix + ":"
if self.is_empty_element:
close = '/'
close = ''
if isinstance(formatter, Formatter):
close = formatter.void_element_close_prefix or close
else:
closeTag = '</%s%s>' % (prefix, self.name)
@ -1219,9 +1296,9 @@ class Tag(PageElement):
:param formatter: The output formatter responsible for converting
entities to Unicode characters.
"""
# First off, turn a string formatter into a function. This
# First off, turn a string formatter into a Formatter object. This
# will stop the lookup from happening over and over again.
if not callable(formatter):
if not isinstance(formatter, Formatter) and not callable(formatter):
formatter = self._formatter_for_name(formatter)
pretty_print = (indent_level is not None)
@ -1334,15 +1411,29 @@ class Tag(PageElement):
# Handle grouping selectors if ',' exists, ie: p,a
if ',' in selector:
context = []
for partial_selector in selector.split(','):
partial_selector = partial_selector.strip()
selectors = [x.strip() for x in selector.split(",")]
# If a selector is mentioned multiple times we don't want
# to use it more than once.
used_selectors = set()
# We also don't want to select the same element more than once,
# if it's matched by multiple selectors.
selected_object_ids = set()
for partial_selector in selectors:
if partial_selector == '':
raise ValueError('Invalid group selection syntax: %s' % selector)
if partial_selector in used_selectors:
continue
used_selectors.add(partial_selector)
candidates = self.select(partial_selector, limit=limit)
for candidate in candidates:
if candidate not in context:
# This lets us distinguish between distinct tags that
# represent the same markup.
object_id = id(candidate)
if object_id not in selected_object_ids:
context.append(candidate)
selected_object_ids.add(object_id)
if limit and len(context) >= limit:
break
return context
@ -1354,7 +1445,7 @@ class Tag(PageElement):
'Final combinator "%s" is missing an argument.' % tokens[-1])
if self._select_debug:
print 'Running CSS selector "%s"' % selector
print('Running CSS selector "%s"' % selector)
for index, token in enumerate(tokens):
new_context = []
@ -1363,11 +1454,11 @@ class Tag(PageElement):
if tokens[index-1] in self._selector_combinators:
# This token was consumed by the previous combinator. Skip it.
if self._select_debug:
print ' Token was consumed by the previous combinator.'
print(' Token was consumed by the previous combinator.')
continue
if self._select_debug:
print ' Considering token "%s"' % token
print(' Considering token "%s"' % token)
recursive_candidate_generator = None
tag_name = None
@ -1404,7 +1495,7 @@ class Tag(PageElement):
if tag_name == '':
raise ValueError(
"A pseudo-class must be prefixed with a tag name.")
pseudo_attributes = re.match('([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo)
pseudo_attributes = re.match(r'([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo)
found = []
if pseudo_attributes is None:
pseudo_type = pseudo
@ -1474,14 +1565,14 @@ class Tag(PageElement):
next_token = tokens[index+1]
def recursive_select(tag):
if self._select_debug:
print ' Calling select("%s") recursively on %s %s' % (next_token, tag.name, tag.attrs)
print '-' * 40
print(' Calling select("%s") recursively on %s %s' % (next_token, tag.name, tag.attrs))
print('-' * 40)
for i in tag.select(next_token, recursive_candidate_generator):
if self._select_debug:
print '(Recursive select picked up candidate %s %s)' % (i.name, i.attrs)
print('(Recursive select picked up candidate %s %s)' % (i.name, i.attrs))
yield i
if self._select_debug:
print '-' * 40
print('-' * 40)
_use_candidate_generator = recursive_select
elif _candidate_generator is None:
# By default, a tag's candidates are all of its
@ -1492,7 +1583,7 @@ class Tag(PageElement):
check = "[any]"
else:
check = tag_name
print ' Default candidate generator, tag name="%s"' % check
print(' Default candidate generator, tag name="%s"' % check)
if self._select_debug:
# This is redundant with later code, but it stops
# a bunch of bogus tags from cluttering up the
@ -1513,8 +1604,8 @@ class Tag(PageElement):
count = 0
for tag in current_context:
if self._select_debug:
print " Running candidate generator on %s %s" % (
tag.name, repr(tag.attrs))
print(" Running candidate generator on %s %s" % (
tag.name, repr(tag.attrs)))
for candidate in _use_candidate_generator(tag):
if not isinstance(candidate, Tag):
continue
@ -1529,23 +1620,23 @@ class Tag(PageElement):
break
if checker is None or result:
if self._select_debug:
print " SUCCESS %s %s" % (candidate.name, repr(candidate.attrs))
print(" SUCCESS %s %s" % (candidate.name, repr(candidate.attrs)))
if id(candidate) not in new_context_ids:
# If a tag matches a selector more than once,
# don't include it in the context more than once.
new_context.append(candidate)
new_context_ids.add(id(candidate))
elif self._select_debug:
print " FAILURE %s %s" % (candidate.name, repr(candidate.attrs))
print(" FAILURE %s %s" % (candidate.name, repr(candidate.attrs)))
current_context = new_context
if limit and len(current_context) >= limit:
current_context = current_context[:limit]
if self._select_debug:
print "Final verdict:"
print("Final verdict:")
for i in current_context:
print " %s %s" % (i.name, i.attrs)
print(" %s %s" % (i.name, i.attrs))
return current_context
# Old names for backwards compatibility
@ -1589,7 +1680,7 @@ class SoupStrainer(object):
else:
attrs = kwargs
normalized_attrs = {}
for key, value in attrs.items():
for key, value in list(attrs.items()):
normalized_attrs[key] = self._normalize_search_value(value)
self.attrs = normalized_attrs
@ -1598,7 +1689,7 @@ class SoupStrainer(object):
def _normalize_search_value(self, value):
# Leave it alone if it's a Unicode string, a callable, a
# regular expression, a boolean, or None.
if (isinstance(value, unicode) or callable(value) or hasattr(value, 'match')
if (isinstance(value, str) or callable(value) or hasattr(value, 'match')
or isinstance(value, bool) or value is None):
return value
@ -1611,7 +1702,7 @@ class SoupStrainer(object):
new_value = []
for v in value:
if (hasattr(v, '__iter__') and not isinstance(v, bytes)
and not isinstance(v, unicode)):
and not isinstance(v, str)):
# This is almost certainly the user's mistake. In the
# interests of avoiding infinite loops, we'll let
# it through as-is rather than doing a recursive call.
@ -1623,7 +1714,7 @@ class SoupStrainer(object):
# Otherwise, convert it into a Unicode string.
# The unicode(str()) thing is so this will do the same thing on Python 2
# and Python 3.
return unicode(str(value))
return str(str(value))
def __str__(self):
if self.text:
@ -1638,7 +1729,7 @@ class SoupStrainer(object):
markup = markup_name
markup_attrs = markup
call_function_with_tag_data = (
isinstance(self.name, collections.Callable)
isinstance(self.name, Callable)
and not isinstance(markup_name, Tag))
if ((not self.name)
@ -1677,7 +1768,7 @@ class SoupStrainer(object):
found = None
# If given a list of items, scan it for a text element that
# matches.
if hasattr(markup, '__iter__') and not isinstance(markup, (Tag, basestring)):
if hasattr(markup, '__iter__') and not isinstance(markup, (Tag, str)):
for element in markup:
if isinstance(element, NavigableString) \
and self.search(element):
@ -1690,7 +1781,7 @@ class SoupStrainer(object):
found = self.search_tag(markup)
# If it's text, make sure the text matches.
elif isinstance(markup, NavigableString) or \
isinstance(markup, basestring):
isinstance(markup, str):
if not self.name and not self.attrs and self._matches(markup, self.text):
found = markup
else:
@ -1698,7 +1789,7 @@ class SoupStrainer(object):
"I don't know how to match against a %s" % markup.__class__)
return found
def _matches(self, markup, match_against):
def _matches(self, markup, match_against, already_tried=None):
# print u"Matching %s against %s" % (markup, match_against)
result = False
if isinstance(markup, list) or isinstance(markup, tuple):
@ -1718,11 +1809,12 @@ class SoupStrainer(object):
# True matches any non-None value.
return markup is not None
if isinstance(match_against, collections.Callable):
if isinstance(match_against, Callable):
return match_against(markup)
# Custom callables take the tag as an argument, but all
# other ways of matching match the tag name as a string.
original_markup = markup
if isinstance(markup, Tag):
markup = markup.name
@ -1733,18 +1825,51 @@ class SoupStrainer(object):
# None matches None, False, an empty string, an empty list, and so on.
return not match_against
if isinstance(match_against, unicode):
# Exact string match
return markup == match_against
if (hasattr(match_against, '__iter__')
and not isinstance(match_against, str)):
# We're asked to match against an iterable of items.
# The markup must be match at least one item in the
# iterable. We'll try each one in turn.
#
# To avoid infinite recursion we need to keep track of
# items we've already seen.
if not already_tried:
already_tried = set()
for item in match_against:
if item.__hash__:
key = item
else:
key = id(item)
if key in already_tried:
continue
else:
already_tried.add(key)
if self._matches(original_markup, item, already_tried):
return True
else:
return False
if hasattr(match_against, 'match'):
# Beyond this point we might need to run the test twice: once against
# the tag's name and once against its prefixed name.
match = False
if not match and isinstance(match_against, str):
# Exact string match
match = markup == match_against
if not match and hasattr(match_against, 'search'):
# Regexp match
return match_against.search(markup)
if hasattr(match_against, '__iter__'):
# The markup must be an exact match against something
# in the iterable.
return markup in match_against
if (not match
and isinstance(original_markup, Tag)
and original_markup.prefix):
# Try the whole thing again with the prefixed tag name.
return self._matches(
original_markup.prefix + ':' + original_markup.name, match_against
)
return match
class ResultSet(list):
@ -1753,3 +1878,8 @@ class ResultSet(list):
def __init__(self, source, result=()):
super(ResultSet, self).__init__(result)
self.source = source
def __getattr__(self, key):
raise AttributeError(
"ResultSet object has no attribute '%s'. You're probably treating a list of items like a single item. Did you call find_all() when you meant to call find()?" % key
)

View file

@ -1,3 +1,4 @@
# encoding: utf-8
"""Helper classes for tests."""
# Use of this source code is governed by a BSD-style license that can be
@ -69,6 +70,18 @@ class HTMLTreeBuilderSmokeTest(object):
markup in these tests, there's not much room for interpretation.
"""
def test_empty_element_tags(self):
"""Verify that all HTML4 and HTML5 empty element (aka void element) tags
are handled correctly.
"""
for name in [
'area', 'base', 'br', 'col', 'embed', 'hr', 'img', 'input', 'keygen', 'link', 'menuitem', 'meta', 'param', 'source', 'track', 'wbr',
'spacer', 'frame'
]:
soup = self.soup("")
new_tag = soup.new_tag(name)
self.assertEqual(True, new_tag.is_empty_element)
def test_pickle_and_unpickle_identity(self):
# Pickling a tree, then unpickling it, yields a tree identical
# to the original.
@ -138,12 +151,20 @@ class HTMLTreeBuilderSmokeTest(object):
soup.encode("utf-8").replace(b"\n", b""),
markup.replace(b"\n", b""))
def test_namespaced_html(self):
"""When a namespaced XML document is parsed as HTML it should
be treated as HTML with weird tag names.
"""
markup = b"""<ns1:foo>content</ns1:foo><ns1:foo/><ns2:foo/>"""
soup = self.soup(markup)
self.assertEqual(2, len(soup.find_all("ns1:foo")))
def test_processing_instruction(self):
# We test both Unicode and bytestring to verify that
# process_markup correctly sets processing_instruction_class
# even when the markup is already Unicode and there is no
# need to process anything.
markup = u"""<?PITarget PIContent?>"""
markup = """<?PITarget PIContent?>"""
soup = self.soup(markup)
self.assertEqual(markup, soup.decode())
@ -299,15 +320,35 @@ Hello, world!
def test_angle_brackets_in_attribute_values_are_escaped(self):
self.assertSoupEquals('<a b="<a>"></a>', '<a b="&lt;a&gt;"></a>')
def test_strings_resembling_character_entity_references(self):
# "&T" and "&p" look like incomplete character entities, but they are
# not.
self.assertSoupEquals(
"<p>&bull; AT&T is in the s&p 500</p>",
"<p>\u2022 AT&amp;T is in the s&amp;p 500</p>"
)
def test_entities_in_foreign_document_encoding(self):
# &#147; and &#148; are invalid numeric entities referencing
# Windows-1252 characters. &#45; references a character common
# to Windows-1252 and Unicode, and &#9731; references a
# character only found in Unicode.
#
# All of these entities should be converted to Unicode
# characters.
markup = "<p>&#147;Hello&#148; &#45;&#9731;</p>"
soup = self.soup(markup)
self.assertEqual("“Hello” -☃", soup.p.string)
def test_entities_in_attributes_converted_to_unicode(self):
expect = u'<p id="pi\N{LATIN SMALL LETTER N WITH TILDE}ata"></p>'
expect = '<p id="pi\N{LATIN SMALL LETTER N WITH TILDE}ata"></p>'
self.assertSoupEquals('<p id="pi&#241;ata"></p>', expect)
self.assertSoupEquals('<p id="pi&#xf1;ata"></p>', expect)
self.assertSoupEquals('<p id="pi&#Xf1;ata"></p>', expect)
self.assertSoupEquals('<p id="pi&ntilde;ata"></p>', expect)
def test_entities_in_text_converted_to_unicode(self):
expect = u'<p>pi\N{LATIN SMALL LETTER N WITH TILDE}ata</p>'
expect = '<p>pi\N{LATIN SMALL LETTER N WITH TILDE}ata</p>'
self.assertSoupEquals("<p>pi&#241;ata</p>", expect)
self.assertSoupEquals("<p>pi&#xf1;ata</p>", expect)
self.assertSoupEquals("<p>pi&#Xf1;ata</p>", expect)
@ -318,7 +359,7 @@ Hello, world!
'<p>I said "good day!"</p>')
def test_out_of_range_entity(self):
expect = u"\N{REPLACEMENT CHARACTER}"
expect = "\N{REPLACEMENT CHARACTER}"
self.assertSoupEquals("&#10000000000000;", expect)
self.assertSoupEquals("&#x10000000000000;", expect)
self.assertSoupEquals("&#1000000000;", expect)
@ -330,6 +371,13 @@ Hello, world!
self.assertEqual("p", soup.p.name)
self.assertConnectedness(soup)
def test_empty_element_tags(self):
"""Verify consistent handling of empty-element tags,
no matter how they come in through the markup.
"""
self.assertSoupEquals('<br/><br/><br/>', "<br/><br/><br/>")
self.assertSoupEquals('<br /><br /><br />', "<br/><br/><br/>")
def test_head_tag_between_head_and_body(self):
"Prevent recurrence of a bug in the html5lib treebuilder."
content = """<html><head></head>
@ -389,9 +437,9 @@ Hello, world!
# A seemingly innocuous document... but it's in Unicode! And
# it contains characters that can't be represented in the
# encoding found in the declaration! The horror!
markup = u'<html><head><meta encoding="euc-jp"></head><body>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</body>'
markup = '<html><head><meta encoding="euc-jp"></head><body>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</body>'
soup = self.soup(markup)
self.assertEqual(u'Sacr\xe9 bleu!', soup.body.string)
self.assertEqual('Sacr\xe9 bleu!', soup.body.string)
def test_soupstrainer(self):
"""Parsers should be able to work with SoupStrainers."""
@ -431,7 +479,7 @@ Hello, world!
# Both XML and HTML entities are converted to Unicode characters
# during parsing.
text = "<p>&lt;&lt;sacr&eacute;&#32;bleu!&gt;&gt;</p>"
expected = u"<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>"
expected = "<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>"
self.assertSoupEquals(text, expected)
def test_smart_quotes_converted_on_the_way_in(self):
@ -441,15 +489,15 @@ Hello, world!
soup = self.soup(quote)
self.assertEqual(
soup.p.string,
u"\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}")
"\N{LEFT SINGLE QUOTATION MARK}Foo\N{RIGHT SINGLE QUOTATION MARK}")
def test_non_breaking_spaces_converted_on_the_way_in(self):
soup = self.soup("<a>&nbsp;&nbsp;</a>")
self.assertEqual(soup.a.string, u"\N{NO-BREAK SPACE}" * 2)
self.assertEqual(soup.a.string, "\N{NO-BREAK SPACE}" * 2)
def test_entities_converted_on_the_way_out(self):
text = "<p>&lt;&lt;sacr&eacute;&#32;bleu!&gt;&gt;</p>"
expected = u"<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>".encode("utf-8")
expected = "<p>&lt;&lt;sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</p>".encode("utf-8")
soup = self.soup(text)
self.assertEqual(soup.p.encode("utf-8"), expected)
@ -458,7 +506,7 @@ Hello, world!
# easy-to-understand document.
# Here it is in Unicode. Note that it claims to be in ISO-Latin-1.
unicode_html = u'<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
unicode_html = '<html><head><meta content="text/html; charset=ISO-Latin-1" http-equiv="Content-type"/></head><body><p>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</p></body></html>'
# That's because we're going to encode it into ISO-Latin-1, and use
# that to test.
@ -605,6 +653,17 @@ class XMLTreeBuilderSmokeTest(object):
self.assertEqual(
soup.encode("utf-8"), markup)
def test_nested_namespaces(self):
doc = b"""<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">
<parent xmlns="http://ns1/">
<child xmlns="http://ns2/" xmlns:ns3="http://ns3/">
<grandchild ns3:attr="value" xmlns="http://ns4/"/>
</child>
</parent>"""
soup = self.soup(doc)
self.assertEqual(doc, soup.encode())
def test_formatter_processes_script_tag_for_xml_documents(self):
doc = """
<script type="text/javascript">
@ -618,15 +677,15 @@ class XMLTreeBuilderSmokeTest(object):
self.assertTrue(b"&lt; &lt; hey &gt; &gt;" in encoded)
def test_can_parse_unicode_document(self):
markup = u'<?xml version="1.0" encoding="euc-jp"><root>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</root>'
markup = '<?xml version="1.0" encoding="euc-jp"><root>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</root>'
soup = self.soup(markup)
self.assertEqual(u'Sacr\xe9 bleu!', soup.root.string)
self.assertEqual('Sacr\xe9 bleu!', soup.root.string)
def test_popping_namespaced_tag(self):
markup = '<rss xmlns:dc="foo"><dc:creator>b</dc:creator><dc:date>2012-07-02T20:33:42Z</dc:date><dc:rights>c</dc:rights><image>d</image></rss>'
soup = self.soup(markup)
self.assertEqual(
unicode(soup.rss), markup)
str(soup.rss), markup)
def test_docstring_includes_correct_encoding(self):
soup = self.soup("<root/>")
@ -657,17 +716,51 @@ class XMLTreeBuilderSmokeTest(object):
def test_closing_namespaced_tag(self):
markup = '<p xmlns:dc="http://purl.org/dc/elements/1.1/"><dc:date>20010504</dc:date></p>'
soup = self.soup(markup)
self.assertEqual(unicode(soup.p), markup)
self.assertEqual(str(soup.p), markup)
def test_namespaced_attributes(self):
markup = '<foo xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"><bar xsi:schemaLocation="http://www.example.com"/></foo>'
soup = self.soup(markup)
self.assertEqual(unicode(soup.foo), markup)
self.assertEqual(str(soup.foo), markup)
def test_namespaced_attributes_xml_namespace(self):
markup = '<foo xml:lang="fr">bar</foo>'
soup = self.soup(markup)
self.assertEqual(unicode(soup.foo), markup)
self.assertEqual(str(soup.foo), markup)
def test_find_by_prefixed_name(self):
doc = """<?xml version="1.0" encoding="utf-8"?>
<Document xmlns="http://example.com/ns0"
xmlns:ns1="http://example.com/ns1"
xmlns:ns2="http://example.com/ns2"
<ns1:tag>foo</ns1:tag>
<ns1:tag>bar</ns1:tag>
<ns2:tag key="value">baz</ns2:tag>
</Document>
"""
soup = self.soup(doc)
# There are three <tag> tags.
self.assertEqual(3, len(soup.find_all('tag')))
# But two of them are ns1:tag and one of them is ns2:tag.
self.assertEqual(2, len(soup.find_all('ns1:tag')))
self.assertEqual(1, len(soup.find_all('ns2:tag')))
self.assertEqual(1, len(soup.find_all('ns2:tag', key='value')))
self.assertEqual(3, len(soup.find_all(['ns1:tag', 'ns2:tag'])))
def test_copy_tag_preserves_namespace(self):
xml = """<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<w:document xmlns:w="http://example.com/ns0"/>"""
soup = self.soup(xml)
tag = soup.document
duplicate = copy.copy(tag)
# The two tags have the same namespace prefix.
self.assertEqual(tag.prefix, duplicate.prefix)
class HTML5TreeBuilderSmokeTest(HTMLTreeBuilderSmokeTest):
"""Smoke test for a tree builder that supports HTML5."""

View file

@ -5,7 +5,7 @@ import warnings
try:
from bs4.builder import HTML5TreeBuilder
HTML5LIB_PRESENT = True
except ImportError, e:
except ImportError as e:
HTML5LIB_PRESENT = False
from bs4.element import SoupStrainer
from bs4.testing import (
@ -74,14 +74,14 @@ class HTML5LibBuilderSmokeTest(SoupTest, HTML5TreeBuilderSmokeTest):
def test_reparented_markup(self):
markup = '<p><em>foo</p>\n<p>bar<a></a></em></p>'
soup = self.soup(markup)
self.assertEqual(u"<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p></body>", soup.body.decode())
self.assertEqual("<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p></body>", soup.body.decode())
self.assertEqual(2, len(soup.find_all('p')))
def test_reparented_markup_ends_with_whitespace(self):
markup = '<p><em>foo</p>\n<p>bar<a></a></em></p>\n'
soup = self.soup(markup)
self.assertEqual(u"<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p>\n</body>", soup.body.decode())
self.assertEqual("<body><p><em>foo</em></p><em>\n</em><p><em>bar<a></a></em></p>\n</body>", soup.body.decode())
self.assertEqual(2, len(soup.find_all('p')))
def test_reparented_markup_containing_identical_whitespace_nodes(self):
@ -95,6 +95,22 @@ class HTML5LibBuilderSmokeTest(SoupTest, HTML5TreeBuilderSmokeTest):
assert space1.next_element is tbody1
assert tbody2.next_element is space2
def test_reparented_markup_containing_children(self):
markup = '<div><a>aftermath<p><noscript>target</noscript>aftermath</a></p></div>'
soup = self.soup(markup)
noscript = soup.noscript
self.assertEqual("target", noscript.next_element)
target = soup.find(string='target')
# The 'aftermath' string was duplicated; we want the second one.
final_aftermath = soup.find_all(string='aftermath')[-1]
# The <noscript> tag was moved beneath a copy of the <a> tag,
# but the 'target' string within is still connected to the
# (second) 'aftermath' string.
self.assertEqual(final_aftermath, target.next_element)
self.assertEqual(target, final_aftermath.previous_element)
def test_processing_instruction(self):
"""Processing instructions become comments."""
markup = b"""<?PITarget PIContent?>"""
@ -107,3 +123,8 @@ class HTML5LibBuilderSmokeTest(SoupTest, HTML5TreeBuilderSmokeTest):
a1, a2 = soup.find_all('a')
self.assertEqual(a1, a2)
assert a1 is not a2
def test_foster_parenting(self):
markup = b"""<table><td></tbody>A"""
soup = self.soup(markup)
self.assertEqual("<body>A<table><tbody><tr><td></td></tr></tbody></table></body>", soup.body.decode())

View file

@ -5,6 +5,7 @@ from pdb import set_trace
import pickle
from bs4.testing import SoupTest, HTMLTreeBuilderSmokeTest
from bs4.builder import HTMLParserTreeBuilder
from bs4.builder._htmlparser import BeautifulSoupHTMLParser
class HTMLParserTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
@ -29,4 +30,20 @@ class HTMLParserTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
loaded = pickle.loads(dumped)
self.assertTrue(isinstance(loaded.builder, type(tree.builder)))
def test_redundant_empty_element_closing_tags(self):
self.assertSoupEquals('<br></br><br></br><br></br>', "<br/><br/><br/>")
self.assertSoupEquals('</br></br></br>', "")
def test_empty_element(self):
# This verifies that any buffered data present when the parser
# finishes working is handled.
self.assertSoupEquals("foo &# bar", "foo &amp;# bar")
class TestHTMLParserSubclass(SoupTest):
def test_error(self):
"""Verify that our HTMLParser subclass implements error() in a way
that doesn't cause a crash.
"""
parser = BeautifulSoupHTMLParser()
parser.error("don't crash")

View file

@ -7,7 +7,7 @@ try:
import lxml.etree
LXML_PRESENT = True
LXML_VERSION = lxml.etree.LXML_VERSION
except ImportError, e:
except ImportError as e:
LXML_PRESENT = False
LXML_VERSION = (0,)
@ -46,6 +46,12 @@ class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
self.assertSoupEquals(
"<p>foo&#1000000000;bar</p>", "<p>foobar</p>")
def test_entities_in_foreign_document_encoding(self):
# We can't implement this case correctly because by the time we
# hear about markup like "&#147;", it's been (incorrectly) converted into
# a string like u'\x93'
pass
# In lxml < 2.3.5, an empty doctype causes a segfault. Skip this
# test if an old version of lxml is installed.
@ -62,7 +68,7 @@ class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
# if one is installed.
with warnings.catch_warnings(record=True) as w:
soup = BeautifulStoneSoup("<b />")
self.assertEqual(u"<b/>", unicode(soup.b))
self.assertEqual("<b/>", str(soup.b))
self.assertTrue("BeautifulStoneSoup class is deprecated" in str(w[0].message))
@skipIf(

View file

@ -32,7 +32,7 @@ import warnings
try:
from bs4.builder import LXMLTreeBuilder, LXMLTreeBuilderForXML
LXML_PRESENT = True
except ImportError, e:
except ImportError as e:
LXML_PRESENT = False
PYTHON_3_PRE_3_2 = (sys.version_info[0] == 3 and sys.version_info < (3,2))
@ -40,17 +40,17 @@ PYTHON_3_PRE_3_2 = (sys.version_info[0] == 3 and sys.version_info < (3,2))
class TestConstructor(SoupTest):
def test_short_unicode_input(self):
data = u"<h1>éé</h1>"
data = "<h1>éé</h1>"
soup = self.soup(data)
self.assertEqual(u"éé", soup.h1.string)
self.assertEqual("éé", soup.h1.string)
def test_embedded_null(self):
data = u"<h1>foo\0bar</h1>"
data = "<h1>foo\0bar</h1>"
soup = self.soup(data)
self.assertEqual(u"foo\0bar", soup.h1.string)
self.assertEqual("foo\0bar", soup.h1.string)
def test_exclude_encodings(self):
utf8_data = u"Räksmörgås".encode("utf-8")
utf8_data = "Räksmörgås".encode("utf-8")
soup = self.soup(utf8_data, exclude_encodings=["utf-8"])
self.assertEqual("windows-1252", soup.original_encoding)
@ -129,7 +129,7 @@ class TestWarnings(SoupTest):
with warnings.catch_warnings(record=True) as warning_list:
# note - this url must differ from the bytes one otherwise
# python's warnings system swallows the second warning
soup = self.soup(u"http://www.crummyunicode.com/")
soup = self.soup("http://www.crummyunicode.com/")
self.assertTrue(any("looks like a URL" in str(w.message)
for w in warning_list))
@ -141,7 +141,7 @@ class TestWarnings(SoupTest):
def test_url_warning_with_unicode_and_space(self):
with warnings.catch_warnings(record=True) as warning_list:
soup = self.soup(u"http://www.crummyuncode.com/ is great")
soup = self.soup("http://www.crummyuncode.com/ is great")
self.assertFalse(any("looks like a URL" in str(w.message)
for w in warning_list))
@ -163,9 +163,9 @@ class TestEntitySubstitution(unittest.TestCase):
def test_simple_html_substitution(self):
# Unicode characters corresponding to named HTML entites
# are substituted, and no others.
s = u"foo\u2200\N{SNOWMAN}\u00f5bar"
s = "foo\u2200\N{SNOWMAN}\u00f5bar"
self.assertEqual(self.sub.substitute_html(s),
u"foo&forall;\N{SNOWMAN}&otilde;bar")
"foo&forall;\N{SNOWMAN}&otilde;bar")
def test_smart_quote_substitution(self):
# MS smart quotes are a common source of frustration, so we
@ -230,7 +230,7 @@ class TestEncodingConversion(SoupTest):
def setUp(self):
super(TestEncodingConversion, self).setUp()
self.unicode_data = u'<html><head><meta charset="utf-8"/></head><body><foo>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</foo></body></html>'
self.unicode_data = '<html><head><meta charset="utf-8"/></head><body><foo>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</foo></body></html>'
self.utf8_data = self.unicode_data.encode("utf-8")
# Just so you know what it looks like.
self.assertEqual(
@ -250,7 +250,7 @@ class TestEncodingConversion(SoupTest):
ascii = b"<foo>a</foo>"
soup_from_ascii = self.soup(ascii)
unicode_output = soup_from_ascii.decode()
self.assertTrue(isinstance(unicode_output, unicode))
self.assertTrue(isinstance(unicode_output, str))
self.assertEqual(unicode_output, self.document_for(ascii.decode()))
self.assertEqual(soup_from_ascii.original_encoding.lower(), "utf-8")
finally:
@ -262,7 +262,7 @@ class TestEncodingConversion(SoupTest):
# is not set.
soup_from_unicode = self.soup(self.unicode_data)
self.assertEqual(soup_from_unicode.decode(), self.unicode_data)
self.assertEqual(soup_from_unicode.foo.string, u'Sacr\xe9 bleu!')
self.assertEqual(soup_from_unicode.foo.string, 'Sacr\xe9 bleu!')
self.assertEqual(soup_from_unicode.original_encoding, None)
def test_utf8_in_unicode_out(self):
@ -270,7 +270,7 @@ class TestEncodingConversion(SoupTest):
# attribute is set.
soup_from_utf8 = self.soup(self.utf8_data)
self.assertEqual(soup_from_utf8.decode(), self.unicode_data)
self.assertEqual(soup_from_utf8.foo.string, u'Sacr\xe9 bleu!')
self.assertEqual(soup_from_utf8.foo.string, 'Sacr\xe9 bleu!')
def test_utf8_out(self):
# The internal data structures can be encoded as UTF-8.
@ -281,14 +281,14 @@ class TestEncodingConversion(SoupTest):
PYTHON_3_PRE_3_2,
"Bad HTMLParser detected; skipping test of non-ASCII characters in attribute name.")
def test_attribute_name_containing_unicode_characters(self):
markup = u'<div><a \N{SNOWMAN}="snowman"></a></div>'
markup = '<div><a \N{SNOWMAN}="snowman"></a></div>'
self.assertEqual(self.soup(markup).div.encode("utf8"), markup.encode("utf8"))
class TestUnicodeDammit(unittest.TestCase):
"""Standalone tests of UnicodeDammit."""
def test_unicode_input(self):
markup = u"I'm already Unicode! \N{SNOWMAN}"
markup = "I'm already Unicode! \N{SNOWMAN}"
dammit = UnicodeDammit(markup)
self.assertEqual(dammit.unicode_markup, markup)
@ -296,7 +296,7 @@ class TestUnicodeDammit(unittest.TestCase):
markup = b"<foo>\x91\x92\x93\x94</foo>"
dammit = UnicodeDammit(markup)
self.assertEqual(
dammit.unicode_markup, u"<foo>\u2018\u2019\u201c\u201d</foo>")
dammit.unicode_markup, "<foo>\u2018\u2019\u201c\u201d</foo>")
def test_smart_quotes_to_xml_entities(self):
markup = b"<foo>\x91\x92\x93\x94</foo>"
@ -320,14 +320,14 @@ class TestUnicodeDammit(unittest.TestCase):
utf8 = b"Sacr\xc3\xa9 bleu! \xe2\x98\x83"
dammit = UnicodeDammit(utf8)
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
self.assertEqual(dammit.unicode_markup, u'Sacr\xe9 bleu! \N{SNOWMAN}')
self.assertEqual(dammit.unicode_markup, 'Sacr\xe9 bleu! \N{SNOWMAN}')
def test_convert_hebrew(self):
hebrew = b"\xed\xe5\xec\xf9"
dammit = UnicodeDammit(hebrew, ["iso-8859-8"])
self.assertEqual(dammit.original_encoding.lower(), 'iso-8859-8')
self.assertEqual(dammit.unicode_markup, u'\u05dd\u05d5\u05dc\u05e9')
self.assertEqual(dammit.unicode_markup, '\u05dd\u05d5\u05dc\u05e9')
def test_dont_see_smart_quotes_where_there_are_none(self):
utf_8 = b"\343\202\261\343\203\274\343\202\277\343\202\244 Watch"
@ -336,19 +336,19 @@ class TestUnicodeDammit(unittest.TestCase):
self.assertEqual(dammit.unicode_markup.encode("utf-8"), utf_8)
def test_ignore_inappropriate_codecs(self):
utf8_data = u"Räksmörgås".encode("utf-8")
utf8_data = "Räksmörgås".encode("utf-8")
dammit = UnicodeDammit(utf8_data, ["iso-8859-8"])
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
def test_ignore_invalid_codecs(self):
utf8_data = u"Räksmörgås".encode("utf-8")
utf8_data = "Räksmörgås".encode("utf-8")
for bad_encoding in ['.utf8', '...', 'utF---16.!']:
dammit = UnicodeDammit(utf8_data, [bad_encoding])
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
def test_exclude_encodings(self):
# This is UTF-8.
utf8_data = u"Räksmörgås".encode("utf-8")
utf8_data = "Räksmörgås".encode("utf-8")
# But if we exclude UTF-8 from consideration, the guess is
# Windows-1252.
@ -364,7 +364,7 @@ class TestUnicodeDammit(unittest.TestCase):
detected = EncodingDetector(
b'<?xml version="1.0" encoding="UTF-\xdb" ?>')
encodings = list(detected.encodings)
assert u'utf-\N{REPLACEMENT CHARACTER}' in encodings
assert 'utf-\N{REPLACEMENT CHARACTER}' in encodings
def test_detect_html5_style_meta_tag(self):
@ -404,7 +404,7 @@ class TestUnicodeDammit(unittest.TestCase):
bs4.dammit.chardet_dammit = noop
dammit = UnicodeDammit(doc)
self.assertEqual(True, dammit.contains_replacement_characters)
self.assertTrue(u"\ufffd" in dammit.unicode_markup)
self.assertTrue("\ufffd" in dammit.unicode_markup)
soup = BeautifulSoup(doc, "html.parser")
self.assertTrue(soup.contains_replacement_characters)
@ -416,17 +416,17 @@ class TestUnicodeDammit(unittest.TestCase):
# A document written in UTF-16LE will have its byte order marker stripped.
data = b'\xff\xfe<\x00a\x00>\x00\xe1\x00\xe9\x00<\x00/\x00a\x00>\x00'
dammit = UnicodeDammit(data)
self.assertEqual(u"<a>áé</a>", dammit.unicode_markup)
self.assertEqual("<a>áé</a>", dammit.unicode_markup)
self.assertEqual("utf-16le", dammit.original_encoding)
def test_detwingle(self):
# Here's a UTF8 document.
utf8 = (u"\N{SNOWMAN}" * 3).encode("utf8")
utf8 = ("\N{SNOWMAN}" * 3).encode("utf8")
# Here's a Windows-1252 document.
windows_1252 = (
u"\N{LEFT DOUBLE QUOTATION MARK}Hi, I like Windows!"
u"\N{RIGHT DOUBLE QUOTATION MARK}").encode("windows_1252")
"\N{LEFT DOUBLE QUOTATION MARK}Hi, I like Windows!"
"\N{RIGHT DOUBLE QUOTATION MARK}").encode("windows_1252")
# Through some unholy alchemy, they've been stuck together.
doc = utf8 + windows_1252 + utf8
@ -441,7 +441,7 @@ class TestUnicodeDammit(unittest.TestCase):
fixed = UnicodeDammit.detwingle(doc)
self.assertEqual(
u"☃☃☃“Hi, I like Windows!”☃☃☃", fixed.decode("utf8"))
"☃☃☃“Hi, I like Windows!”☃☃☃", fixed.decode("utf8"))
def test_detwingle_ignores_multibyte_characters(self):
# Each of these characters has a UTF-8 representation ending
@ -449,9 +449,9 @@ class TestUnicodeDammit(unittest.TestCase):
# Windows-1252. But our code knows to skip over multibyte
# UTF-8 characters, so they'll survive the process unscathed.
for tricky_unicode_char in (
u"\N{LATIN SMALL LIGATURE OE}", # 2-byte char '\xc5\x93'
u"\N{LATIN SUBSCRIPT SMALL LETTER X}", # 3-byte char '\xe2\x82\x93'
u"\xf0\x90\x90\x93", # This is a CJK character, not sure which one.
"\N{LATIN SMALL LIGATURE OE}", # 2-byte char '\xc5\x93'
"\N{LATIN SUBSCRIPT SMALL LETTER X}", # 3-byte char '\xe2\x82\x93'
"\xf0\x90\x90\x93", # This is a CJK character, not sure which one.
):
input = tricky_unicode_char.encode("utf8")
self.assertTrue(input.endswith(b'\x93'))

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
"""Tests for Beautiful Soup's tree traversal methods.
@ -70,13 +71,13 @@ class TestFind(TreeTest):
self.assertEqual(soup.find("b").string, "2")
def test_unicode_text_find(self):
soup = self.soup(u'<h1>Räksmörgås</h1>')
self.assertEqual(soup.find(string=u'Räksmörgås'), u'Räksmörgås')
soup = self.soup('<h1>Räksmörgås</h1>')
self.assertEqual(soup.find(string='Räksmörgås'), 'Räksmörgås')
def test_unicode_attribute_find(self):
soup = self.soup(u'<h1 id="Räksmörgås">here it is</h1>')
soup = self.soup('<h1 id="Räksmörgås">here it is</h1>')
str(soup)
self.assertEqual("here it is", soup.find(id=u'Räksmörgås').text)
self.assertEqual("here it is", soup.find(id='Räksmörgås').text)
def test_find_everything(self):
@ -96,17 +97,17 @@ class TestFindAll(TreeTest):
"""You can search the tree for text nodes."""
soup = self.soup("<html>Foo<b>bar</b>\xbb</html>")
# Exact match.
self.assertEqual(soup.find_all(string="bar"), [u"bar"])
self.assertEqual(soup.find_all(text="bar"), [u"bar"])
self.assertEqual(soup.find_all(string="bar"), ["bar"])
self.assertEqual(soup.find_all(text="bar"), ["bar"])
# Match any of a number of strings.
self.assertEqual(
soup.find_all(text=["Foo", "bar"]), [u"Foo", u"bar"])
soup.find_all(text=["Foo", "bar"]), ["Foo", "bar"])
# Match a regular expression.
self.assertEqual(soup.find_all(text=re.compile('.*')),
[u"Foo", u"bar", u'\xbb'])
["Foo", "bar", '\xbb'])
# Match anything.
self.assertEqual(soup.find_all(text=True),
[u"Foo", u"bar", u'\xbb'])
["Foo", "bar", '\xbb'])
def test_find_all_limit(self):
"""You can limit the number of items returned by find_all."""
@ -234,6 +235,7 @@ class TestFindAllByName(TreeTest):
self.assertEqual('1', r3.string)
self.assertEqual('3', r4.string)
class TestFindAllByAttribute(TreeTest):
def test_find_all_by_attribute_name(self):
@ -248,8 +250,8 @@ class TestFindAllByAttribute(TreeTest):
["Matching a.", "Matching b."])
def test_find_all_by_utf8_attribute_value(self):
peace = u"םולש".encode("utf8")
data = u'<a title="םולש"></a>'.encode("utf8")
peace = "םולש".encode("utf8")
data = '<a title="םולש"></a>'.encode("utf8")
soup = self.soup(data)
self.assertEqual([soup.a], soup.find_all(title=peace))
self.assertEqual([soup.a], soup.find_all(title=peace.decode("utf8")))
@ -603,7 +605,7 @@ class SiblingTest(TreeTest):
</html>'''
# All that whitespace looks good but makes the tests more
# difficult. Get rid of it.
markup = re.compile("\n\s*").sub("", markup)
markup = re.compile(r"\n\s*").sub("", markup)
self.tree = self.soup(markup)
@ -701,10 +703,10 @@ class TestTagCreation(SoupTest):
"""Test the ability to create new tags."""
def test_new_tag(self):
soup = self.soup("")
new_tag = soup.new_tag("foo", bar="baz")
new_tag = soup.new_tag("foo", bar="baz", attrs={"name": "a name"})
self.assertTrue(isinstance(new_tag, Tag))
self.assertEqual("foo", new_tag.name)
self.assertEqual(dict(bar="baz"), new_tag.attrs)
self.assertEqual(dict(bar="baz", name="a name"), new_tag.attrs)
self.assertEqual(None, new_tag.parent)
def test_tag_inherits_self_closing_rules_from_builder(self):
@ -819,6 +821,26 @@ class TestTreeModification(SoupTest):
soup = self.soup(text)
self.assertRaises(ValueError, soup.a.insert, 0, soup.a)
def test_insert_beautifulsoup_object_inserts_children(self):
"""Inserting one BeautifulSoup object into another actually inserts all
of its children -- you'll never combine BeautifulSoup objects.
"""
soup = self.soup("<p>And now, a word:</p><p>And we're back.</p>")
text = "<p>p2</p><p>p3</p>"
to_insert = self.soup(text)
soup.insert(1, to_insert)
for i in soup.descendants:
assert not isinstance(i, BeautifulSoup)
p1, p2, p3, p4 = list(soup.children)
self.assertEqual("And now, a word:", p1.string)
self.assertEqual("p2", p2.string)
self.assertEqual("p3", p3.string)
self.assertEqual("And we're back.", p4.string)
def test_replace_with_maintains_next_element_throughout(self):
soup = self.soup('<p><a>one</a><b>three</b></p>')
a = soup.a
@ -1109,7 +1131,7 @@ class TestTreeModification(SoupTest):
<script>baz</script>
</html>""")
[soup.script.extract() for i in soup.find_all("script")]
self.assertEqual("<body>\n\n<a></a>\n</body>", unicode(soup.body))
self.assertEqual("<body>\n\n<a></a>\n</body>", str(soup.body))
def test_extract_works_when_element_is_surrounded_by_identical_strings(self):
@ -1184,7 +1206,7 @@ class TestElementObjects(SoupTest):
tag = soup.bTag
self.assertEqual(soup.b, tag)
self.assertEqual(
'.bTag is deprecated, use .find("b") instead.',
'.bTag is deprecated, use .find("b") instead. If you really were looking for a tag called bTag, use .find("bTag")',
str(w[0].message))
def test_has_attr(self):
@ -1284,6 +1306,10 @@ class TestCDAtaListAttributes(SoupTest):
soup = self.soup("<a class='foo\tbar'>")
self.assertEqual(b'<a class="foo bar"></a>', soup.a.encode())
def test_get_attribute_list(self):
soup = self.soup("<a id='abc def'>")
self.assertEqual(['abc def'], soup.a.get_attribute_list('id'))
def test_accept_charset(self):
soup = self.soup('<form accept-charset="ISO-8859-1 UTF-8">')
self.assertEqual(['ISO-8859-1', 'UTF-8'], soup.form['accept-charset'])
@ -1343,19 +1369,19 @@ class TestPersistence(SoupTest):
soup = BeautifulSoup(b'<p>&nbsp;</p>', 'html.parser')
encoding = soup.original_encoding
copy = soup.__copy__()
self.assertEqual(u"<p> </p>", unicode(copy))
self.assertEqual("<p> </p>", str(copy))
self.assertEqual(encoding, copy.original_encoding)
def test_unicode_pickle(self):
# A tree containing Unicode characters can be pickled.
html = u"<b>\N{SNOWMAN}</b>"
html = "<b>\N{SNOWMAN}</b>"
soup = self.soup(html)
dumped = pickle.dumps(soup, pickle.HIGHEST_PROTOCOL)
loaded = pickle.loads(dumped)
self.assertEqual(loaded.decode(), soup.decode())
def test_copy_navigablestring_is_not_attached_to_tree(self):
html = u"<b>Foo<a></a></b><b>Bar</b>"
html = "<b>Foo<a></a></b><b>Bar</b>"
soup = self.soup(html)
s1 = soup.find(string="Foo")
s2 = copy.copy(s1)
@ -1367,7 +1393,7 @@ class TestPersistence(SoupTest):
self.assertEqual(None, s2.previous_element)
def test_copy_navigablestring_subclass_has_same_type(self):
html = u"<b><!--Foo--></b>"
html = "<b><!--Foo--></b>"
soup = self.soup(html)
s1 = soup.string
s2 = copy.copy(s1)
@ -1375,19 +1401,19 @@ class TestPersistence(SoupTest):
self.assertTrue(isinstance(s2, Comment))
def test_copy_entire_soup(self):
html = u"<div><b>Foo<a></a></b><b>Bar</b></div>end"
html = "<div><b>Foo<a></a></b><b>Bar</b></div>end"
soup = self.soup(html)
soup_copy = copy.copy(soup)
self.assertEqual(soup, soup_copy)
def test_copy_tag_copies_contents(self):
html = u"<div><b>Foo<a></a></b><b>Bar</b></div>end"
html = "<div><b>Foo<a></a></b><b>Bar</b></div>end"
soup = self.soup(html)
div = soup.div
div_copy = copy.copy(div)
# The two tags look the same, and evaluate to equal.
self.assertEqual(unicode(div), unicode(div_copy))
self.assertEqual(str(div), str(div_copy))
self.assertEqual(div, div_copy)
# But they're not the same object.
@ -1403,67 +1429,75 @@ class TestPersistence(SoupTest):
class TestSubstitutions(SoupTest):
def test_default_formatter_is_minimal(self):
markup = u"<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
markup = "<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
soup = self.soup(markup)
decoded = soup.decode(formatter="minimal")
# The < is converted back into &lt; but the e-with-acute is left alone.
self.assertEqual(
decoded,
self.document_for(
u"<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"))
"<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"))
def test_formatter_html(self):
markup = u"<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
markup = "<br><b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
soup = self.soup(markup)
decoded = soup.decode(formatter="html")
self.assertEqual(
decoded,
self.document_for("<b>&lt;&lt;Sacr&eacute; bleu!&gt;&gt;</b>"))
self.document_for("<br/><b>&lt;&lt;Sacr&eacute; bleu!&gt;&gt;</b>"))
def test_formatter_html5(self):
markup = "<br><b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
soup = self.soup(markup)
decoded = soup.decode(formatter="html5")
self.assertEqual(
decoded,
self.document_for("<br><b>&lt;&lt;Sacr&eacute; bleu!&gt;&gt;</b>"))
def test_formatter_minimal(self):
markup = u"<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
markup = "<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
soup = self.soup(markup)
decoded = soup.decode(formatter="minimal")
# The < is converted back into &lt; but the e-with-acute is left alone.
self.assertEqual(
decoded,
self.document_for(
u"<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"))
"<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"))
def test_formatter_null(self):
markup = u"<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
markup = "<b>&lt;&lt;Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!&gt;&gt;</b>"
soup = self.soup(markup)
decoded = soup.decode(formatter=None)
# Neither the angle brackets nor the e-with-acute are converted.
# This is not valid HTML, but it's what the user wanted.
self.assertEqual(decoded,
self.document_for(u"<b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>"))
self.document_for("<b><<Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!>></b>"))
def test_formatter_custom(self):
markup = u"<b>&lt;foo&gt;</b><b>bar</b>"
markup = "<b>&lt;foo&gt;</b><b>bar</b><br/>"
soup = self.soup(markup)
decoded = soup.decode(formatter = lambda x: x.upper())
# Instead of normal entity conversion code, the custom
# callable is called on every string.
self.assertEqual(
decoded,
self.document_for(u"<b><FOO></b><b>BAR</b>"))
self.document_for("<b><FOO></b><b>BAR</b><br>"))
def test_formatter_is_run_on_attribute_values(self):
markup = u'<a href="http://a.com?a=b&c=é">e</a>'
markup = '<a href="http://a.com?a=b&c=é">e</a>'
soup = self.soup(markup)
a = soup.a
expect_minimal = u'<a href="http://a.com?a=b&amp;c=é">e</a>'
expect_minimal = '<a href="http://a.com?a=b&amp;c=é">e</a>'
self.assertEqual(expect_minimal, a.decode())
self.assertEqual(expect_minimal, a.decode(formatter="minimal"))
expect_html = u'<a href="http://a.com?a=b&amp;c=&eacute;">e</a>'
expect_html = '<a href="http://a.com?a=b&amp;c=&eacute;">e</a>'
self.assertEqual(expect_html, a.decode(formatter="html"))
self.assertEqual(markup, a.decode(formatter=None))
expect_upper = u'<a href="HTTP://A.COM?A=B&C=É">E</a>'
expect_upper = '<a href="HTTP://A.COM?A=B&C=É">E</a>'
self.assertEqual(expect_upper, a.decode(formatter=lambda x: x.upper()))
def test_formatter_skips_script_tag_for_html_documents(self):
@ -1489,24 +1523,24 @@ class TestSubstitutions(SoupTest):
# Everything outside the <pre> tag is reformatted, but everything
# inside is left alone.
self.assertEqual(
u'<div>\n foo\n <pre> \tbar\n \n </pre>\n baz\n</div>',
'<div>\n foo\n <pre> \tbar\n \n </pre>\n baz\n</div>',
soup.div.prettify())
def test_prettify_accepts_formatter(self):
def test_prettify_accepts_formatter_function(self):
soup = BeautifulSoup("<html><body>foo</body></html>", 'html.parser')
pretty = soup.prettify(formatter = lambda x: x.upper())
self.assertTrue("FOO" in pretty)
def test_prettify_outputs_unicode_by_default(self):
soup = self.soup("<a></a>")
self.assertEqual(unicode, type(soup.prettify()))
self.assertEqual(str, type(soup.prettify()))
def test_prettify_can_encode_data(self):
soup = self.soup("<a></a>")
self.assertEqual(bytes, type(soup.prettify("utf-8")))
def test_html_entity_substitution_off_by_default(self):
markup = u"<b>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</b>"
markup = "<b>Sacr\N{LATIN SMALL LETTER E WITH ACUTE} bleu!</b>"
soup = self.soup(markup)
encoded = soup.b.encode("utf-8")
self.assertEqual(encoded, markup.encode('utf-8'))
@ -1550,48 +1584,48 @@ class TestEncoding(SoupTest):
"""Test the ability to encode objects into strings."""
def test_unicode_string_can_be_encoded(self):
html = u"<b>\N{SNOWMAN}</b>"
html = "<b>\N{SNOWMAN}</b>"
soup = self.soup(html)
self.assertEqual(soup.b.string.encode("utf-8"),
u"\N{SNOWMAN}".encode("utf-8"))
"\N{SNOWMAN}".encode("utf-8"))
def test_tag_containing_unicode_string_can_be_encoded(self):
html = u"<b>\N{SNOWMAN}</b>"
html = "<b>\N{SNOWMAN}</b>"
soup = self.soup(html)
self.assertEqual(
soup.b.encode("utf-8"), html.encode("utf-8"))
def test_encoding_substitutes_unrecognized_characters_by_default(self):
html = u"<b>\N{SNOWMAN}</b>"
html = "<b>\N{SNOWMAN}</b>"
soup = self.soup(html)
self.assertEqual(soup.b.encode("ascii"), b"<b>&#9731;</b>")
def test_encoding_can_be_made_strict(self):
html = u"<b>\N{SNOWMAN}</b>"
html = "<b>\N{SNOWMAN}</b>"
soup = self.soup(html)
self.assertRaises(
UnicodeEncodeError, soup.encode, "ascii", errors="strict")
def test_decode_contents(self):
html = u"<b>\N{SNOWMAN}</b>"
html = "<b>\N{SNOWMAN}</b>"
soup = self.soup(html)
self.assertEqual(u"\N{SNOWMAN}", soup.b.decode_contents())
self.assertEqual("\N{SNOWMAN}", soup.b.decode_contents())
def test_encode_contents(self):
html = u"<b>\N{SNOWMAN}</b>"
html = "<b>\N{SNOWMAN}</b>"
soup = self.soup(html)
self.assertEqual(
u"\N{SNOWMAN}".encode("utf8"), soup.b.encode_contents(
"\N{SNOWMAN}".encode("utf8"), soup.b.encode_contents(
encoding="utf8"))
def test_deprecated_renderContents(self):
html = u"<b>\N{SNOWMAN}</b>"
html = "<b>\N{SNOWMAN}</b>"
soup = self.soup(html)
self.assertEqual(
u"\N{SNOWMAN}".encode("utf8"), soup.b.renderContents())
"\N{SNOWMAN}".encode("utf8"), soup.b.renderContents())
def test_repr(self):
html = u"<b>\N{SNOWMAN}</b>"
html = "<b>\N{SNOWMAN}</b>"
soup = self.soup(html)
if PY3K:
self.assertEqual(html, repr(soup))
@ -1714,7 +1748,7 @@ class TestSoupSelector(TreeTest):
els = self.soup.select('title')
self.assertEqual(len(els), 1)
self.assertEqual(els[0].name, 'title')
self.assertEqual(els[0].contents, [u'The title'])
self.assertEqual(els[0].contents, ['The title'])
def test_one_tag_many(self):
els = self.soup.select('div')
@ -1760,7 +1794,7 @@ class TestSoupSelector(TreeTest):
self.assertEqual(dashed[0]['id'], 'dash2')
def test_dashed_tag_text(self):
self.assertEqual(self.soup.select('body > custom-dashed-tag')[0].text, u'Hello there.')
self.assertEqual(self.soup.select('body > custom-dashed-tag')[0].text, 'Hello there.')
def test_select_dashed_matches_find_all(self):
self.assertEqual(self.soup.select('custom-dashed-tag'), self.soup.find_all('custom-dashed-tag'))
@ -1947,12 +1981,12 @@ class TestSoupSelector(TreeTest):
# Try to select first paragraph
els = self.soup.select('div#inner p:nth-of-type(1)')
self.assertEqual(len(els), 1)
self.assertEqual(els[0].string, u'Some text')
self.assertEqual(els[0].string, 'Some text')
# Try to select third paragraph
els = self.soup.select('div#inner p:nth-of-type(3)')
self.assertEqual(len(els), 1)
self.assertEqual(els[0].string, u'Another')
self.assertEqual(els[0].string, 'Another')
# Try to select (non-existent!) fourth paragraph
els = self.soup.select('div#inner p:nth-of-type(4)')
@ -1965,7 +1999,7 @@ class TestSoupSelector(TreeTest):
def test_nth_of_type_direct_descendant(self):
els = self.soup.select('div#inner > p:nth-of-type(1)')
self.assertEqual(len(els), 1)
self.assertEqual(els[0].string, u'Some text')
self.assertEqual(els[0].string, 'Some text')
def test_id_child_selector_nth_of_type(self):
self.assertSelects('#inner > p:nth-of-type(2)', ['p1'])
@ -2040,5 +2074,17 @@ class TestSoupSelector(TreeTest):
def test_multiple_select_nested(self):
self.assertSelects('body > div > x, y > z', ['xid', 'zidb'])
def test_select_duplicate_elements(self):
# When markup contains duplicate elements, a multiple select
# will find all of them.
markup = '<div class="c1"/><div class="c2"/><div class="c1"/>'
soup = BeautifulSoup(markup, 'html.parser')
selected = soup.select(".c1, .c2")
self.assertEqual(3, len(selected))
# Verify that find_all finds the same elements, though because
# of an implementation detail it finds them in a different
# order.
for element in soup.find_all(class_=['c1', 'c2']):
assert element in selected

97
libs/click/__init__.py Normal file
View file

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
"""
click
~~~~~
Click is a simple Python module inspired by the stdlib optparse to make
writing command line scripts fun. Unlike other modules, it's based
around a simple API that does not come with too much magic and is
composable.
:copyright: © 2014 by the Pallets team.
:license: BSD, see LICENSE.rst for more details.
"""
# Core classes
from .core import Context, BaseCommand, Command, MultiCommand, Group, \
CommandCollection, Parameter, Option, Argument
# Globals
from .globals import get_current_context
# Decorators
from .decorators import pass_context, pass_obj, make_pass_decorator, \
command, group, argument, option, confirmation_option, \
password_option, version_option, help_option
# Types
from .types import ParamType, File, Path, Choice, IntRange, Tuple, \
DateTime, STRING, INT, FLOAT, BOOL, UUID, UNPROCESSED, FloatRange
# Utilities
from .utils import echo, get_binary_stream, get_text_stream, open_file, \
format_filename, get_app_dir, get_os_args
# Terminal functions
from .termui import prompt, confirm, get_terminal_size, echo_via_pager, \
progressbar, clear, style, unstyle, secho, edit, launch, getchar, \
pause
# Exceptions
from .exceptions import ClickException, UsageError, BadParameter, \
FileError, Abort, NoSuchOption, BadOptionUsage, BadArgumentUsage, \
MissingParameter
# Formatting
from .formatting import HelpFormatter, wrap_text
# Parsing
from .parser import OptionParser
__all__ = [
# Core classes
'Context', 'BaseCommand', 'Command', 'MultiCommand', 'Group',
'CommandCollection', 'Parameter', 'Option', 'Argument',
# Globals
'get_current_context',
# Decorators
'pass_context', 'pass_obj', 'make_pass_decorator', 'command', 'group',
'argument', 'option', 'confirmation_option', 'password_option',
'version_option', 'help_option',
# Types
'ParamType', 'File', 'Path', 'Choice', 'IntRange', 'Tuple',
'DateTime', 'STRING', 'INT', 'FLOAT', 'BOOL', 'UUID', 'UNPROCESSED',
'FloatRange',
# Utilities
'echo', 'get_binary_stream', 'get_text_stream', 'open_file',
'format_filename', 'get_app_dir', 'get_os_args',
# Terminal functions
'prompt', 'confirm', 'get_terminal_size', 'echo_via_pager',
'progressbar', 'clear', 'style', 'unstyle', 'secho', 'edit', 'launch',
'getchar', 'pause',
# Exceptions
'ClickException', 'UsageError', 'BadParameter', 'FileError',
'Abort', 'NoSuchOption', 'BadOptionUsage', 'BadArgumentUsage',
'MissingParameter',
# Formatting
'HelpFormatter', 'wrap_text',
# Parsing
'OptionParser',
]
# Controls if click should emit the warning about the use of unicode
# literals.
disable_unicode_literals_warning = False
__version__ = '7.0'

293
libs/click/_bashcomplete.py Normal file
View file

@ -0,0 +1,293 @@
import copy
import os
import re
from .utils import echo
from .parser import split_arg_string
from .core import MultiCommand, Option, Argument
from .types import Choice
try:
from collections import abc
except ImportError:
import collections as abc
WORDBREAK = '='
# Note, only BASH version 4.4 and later have the nosort option.
COMPLETION_SCRIPT_BASH = '''
%(complete_func)s() {
local IFS=$'\n'
COMPREPLY=( $( env COMP_WORDS="${COMP_WORDS[*]}" \\
COMP_CWORD=$COMP_CWORD \\
%(autocomplete_var)s=complete $1 ) )
return 0
}
%(complete_func)setup() {
local COMPLETION_OPTIONS=""
local BASH_VERSION_ARR=(${BASH_VERSION//./ })
# Only BASH version 4.4 and later have the nosort option.
if [ ${BASH_VERSION_ARR[0]} -gt 4 ] || ([ ${BASH_VERSION_ARR[0]} -eq 4 ] && [ ${BASH_VERSION_ARR[1]} -ge 4 ]); then
COMPLETION_OPTIONS="-o nosort"
fi
complete $COMPLETION_OPTIONS -F %(complete_func)s %(script_names)s
}
%(complete_func)setup
'''
COMPLETION_SCRIPT_ZSH = '''
%(complete_func)s() {
local -a completions
local -a completions_with_descriptions
local -a response
response=("${(@f)$( env COMP_WORDS=\"${words[*]}\" \\
COMP_CWORD=$((CURRENT-1)) \\
%(autocomplete_var)s=\"complete_zsh\" \\
%(script_names)s )}")
for key descr in ${(kv)response}; do
if [[ "$descr" == "_" ]]; then
completions+=("$key")
else
completions_with_descriptions+=("$key":"$descr")
fi
done
if [ -n "$completions_with_descriptions" ]; then
_describe -V unsorted completions_with_descriptions -U -Q
fi
if [ -n "$completions" ]; then
compadd -U -V unsorted -Q -a completions
fi
compstate[insert]="automenu"
}
compdef %(complete_func)s %(script_names)s
'''
_invalid_ident_char_re = re.compile(r'[^a-zA-Z0-9_]')
def get_completion_script(prog_name, complete_var, shell):
cf_name = _invalid_ident_char_re.sub('', prog_name.replace('-', '_'))
script = COMPLETION_SCRIPT_ZSH if shell == 'zsh' else COMPLETION_SCRIPT_BASH
return (script % {
'complete_func': '_%s_completion' % cf_name,
'script_names': prog_name,
'autocomplete_var': complete_var,
}).strip() + ';'
def resolve_ctx(cli, prog_name, args):
"""
Parse into a hierarchy of contexts. Contexts are connected through the parent variable.
:param cli: command definition
:param prog_name: the program that is running
:param args: full list of args
:return: the final context/command parsed
"""
ctx = cli.make_context(prog_name, args, resilient_parsing=True)
args = ctx.protected_args + ctx.args
while args:
if isinstance(ctx.command, MultiCommand):
if not ctx.command.chain:
cmd_name, cmd, args = ctx.command.resolve_command(ctx, args)
if cmd is None:
return ctx
ctx = cmd.make_context(cmd_name, args, parent=ctx,
resilient_parsing=True)
args = ctx.protected_args + ctx.args
else:
# Walk chained subcommand contexts saving the last one.
while args:
cmd_name, cmd, args = ctx.command.resolve_command(ctx, args)
if cmd is None:
return ctx
sub_ctx = cmd.make_context(cmd_name, args, parent=ctx,
allow_extra_args=True,
allow_interspersed_args=False,
resilient_parsing=True)
args = sub_ctx.args
ctx = sub_ctx
args = sub_ctx.protected_args + sub_ctx.args
else:
break
return ctx
def start_of_option(param_str):
"""
:param param_str: param_str to check
:return: whether or not this is the start of an option declaration (i.e. starts "-" or "--")
"""
return param_str and param_str[:1] == '-'
def is_incomplete_option(all_args, cmd_param):
"""
:param all_args: the full original list of args supplied
:param cmd_param: the current command paramter
:return: whether or not the last option declaration (i.e. starts "-" or "--") is incomplete and
corresponds to this cmd_param. In other words whether this cmd_param option can still accept
values
"""
if not isinstance(cmd_param, Option):
return False
if cmd_param.is_flag:
return False
last_option = None
for index, arg_str in enumerate(reversed([arg for arg in all_args if arg != WORDBREAK])):
if index + 1 > cmd_param.nargs:
break
if start_of_option(arg_str):
last_option = arg_str
return True if last_option and last_option in cmd_param.opts else False
def is_incomplete_argument(current_params, cmd_param):
"""
:param current_params: the current params and values for this argument as already entered
:param cmd_param: the current command parameter
:return: whether or not the last argument is incomplete and corresponds to this cmd_param. In
other words whether or not the this cmd_param argument can still accept values
"""
if not isinstance(cmd_param, Argument):
return False
current_param_values = current_params[cmd_param.name]
if current_param_values is None:
return True
if cmd_param.nargs == -1:
return True
if isinstance(current_param_values, abc.Iterable) \
and cmd_param.nargs > 1 and len(current_param_values) < cmd_param.nargs:
return True
return False
def get_user_autocompletions(ctx, args, incomplete, cmd_param):
"""
:param ctx: context associated with the parsed command
:param args: full list of args
:param incomplete: the incomplete text to autocomplete
:param cmd_param: command definition
:return: all the possible user-specified completions for the param
"""
results = []
if isinstance(cmd_param.type, Choice):
# Choices don't support descriptions.
results = [(c, None)
for c in cmd_param.type.choices if str(c).startswith(incomplete)]
elif cmd_param.autocompletion is not None:
dynamic_completions = cmd_param.autocompletion(ctx=ctx,
args=args,
incomplete=incomplete)
results = [c if isinstance(c, tuple) else (c, None)
for c in dynamic_completions]
return results
def get_visible_commands_starting_with(ctx, starts_with):
"""
:param ctx: context associated with the parsed command
:starts_with: string that visible commands must start with.
:return: all visible (not hidden) commands that start with starts_with.
"""
for c in ctx.command.list_commands(ctx):
if c.startswith(starts_with):
command = ctx.command.get_command(ctx, c)
if not command.hidden:
yield command
def add_subcommand_completions(ctx, incomplete, completions_out):
# Add subcommand completions.
if isinstance(ctx.command, MultiCommand):
completions_out.extend(
[(c.name, c.get_short_help_str()) for c in get_visible_commands_starting_with(ctx, incomplete)])
# Walk up the context list and add any other completion possibilities from chained commands
while ctx.parent is not None:
ctx = ctx.parent
if isinstance(ctx.command, MultiCommand) and ctx.command.chain:
remaining_commands = [c for c in get_visible_commands_starting_with(ctx, incomplete)
if c.name not in ctx.protected_args]
completions_out.extend([(c.name, c.get_short_help_str()) for c in remaining_commands])
def get_choices(cli, prog_name, args, incomplete):
"""
:param cli: command definition
:param prog_name: the program that is running
:param args: full list of args
:param incomplete: the incomplete text to autocomplete
:return: all the possible completions for the incomplete
"""
all_args = copy.deepcopy(args)
ctx = resolve_ctx(cli, prog_name, args)
if ctx is None:
return []
# In newer versions of bash long opts with '='s are partitioned, but it's easier to parse
# without the '='
if start_of_option(incomplete) and WORDBREAK in incomplete:
partition_incomplete = incomplete.partition(WORDBREAK)
all_args.append(partition_incomplete[0])
incomplete = partition_incomplete[2]
elif incomplete == WORDBREAK:
incomplete = ''
completions = []
if start_of_option(incomplete):
# completions for partial options
for param in ctx.command.params:
if isinstance(param, Option) and not param.hidden:
param_opts = [param_opt for param_opt in param.opts +
param.secondary_opts if param_opt not in all_args or param.multiple]
completions.extend([(o, param.help) for o in param_opts if o.startswith(incomplete)])
return completions
# completion for option values from user supplied values
for param in ctx.command.params:
if is_incomplete_option(all_args, param):
return get_user_autocompletions(ctx, all_args, incomplete, param)
# completion for argument values from user supplied values
for param in ctx.command.params:
if is_incomplete_argument(ctx.params, param):
return get_user_autocompletions(ctx, all_args, incomplete, param)
add_subcommand_completions(ctx, incomplete, completions)
# Sort before returning so that proper ordering can be enforced in custom types.
return sorted(completions)
def do_complete(cli, prog_name, include_descriptions):
cwords = split_arg_string(os.environ['COMP_WORDS'])
cword = int(os.environ['COMP_CWORD'])
args = cwords[1:cword]
try:
incomplete = cwords[cword]
except IndexError:
incomplete = ''
for item in get_choices(cli, prog_name, args, incomplete):
echo(item[0])
if include_descriptions:
# ZSH has trouble dealing with empty array parameters when returned from commands, so use a well defined character '_' to indicate no description is present.
echo(item[1] if item[1] else '_')
return True
def bashcomplete(cli, prog_name, complete_var, complete_instr):
if complete_instr.startswith('source'):
shell = 'zsh' if complete_instr == 'source_zsh' else 'bash'
echo(get_completion_script(prog_name, complete_var, shell))
return True
elif complete_instr == 'complete' or complete_instr == 'complete_zsh':
return do_complete(cli, prog_name, complete_instr == 'complete_zsh')
return False

703
libs/click/_compat.py Normal file
View file

@ -0,0 +1,703 @@
import re
import io
import os
import sys
import codecs
from weakref import WeakKeyDictionary
PY2 = sys.version_info[0] == 2
CYGWIN = sys.platform.startswith('cygwin')
# Determine local App Engine environment, per Google's own suggestion
APP_ENGINE = ('APPENGINE_RUNTIME' in os.environ and
'Development/' in os.environ['SERVER_SOFTWARE'])
WIN = sys.platform.startswith('win') and not APP_ENGINE
DEFAULT_COLUMNS = 80
_ansi_re = re.compile(r'\033\[((?:\d|;)*)([a-zA-Z])')
def get_filesystem_encoding():
return sys.getfilesystemencoding() or sys.getdefaultencoding()
def _make_text_stream(stream, encoding, errors,
force_readable=False, force_writable=False):
if encoding is None:
encoding = get_best_encoding(stream)
if errors is None:
errors = 'replace'
return _NonClosingTextIOWrapper(stream, encoding, errors,
line_buffering=True,
force_readable=force_readable,
force_writable=force_writable)
def is_ascii_encoding(encoding):
"""Checks if a given encoding is ascii."""
try:
return codecs.lookup(encoding).name == 'ascii'
except LookupError:
return False
def get_best_encoding(stream):
"""Returns the default stream encoding if not found."""
rv = getattr(stream, 'encoding', None) or sys.getdefaultencoding()
if is_ascii_encoding(rv):
return 'utf-8'
return rv
class _NonClosingTextIOWrapper(io.TextIOWrapper):
def __init__(self, stream, encoding, errors,
force_readable=False, force_writable=False, **extra):
self._stream = stream = _FixupStream(stream, force_readable,
force_writable)
io.TextIOWrapper.__init__(self, stream, encoding, errors, **extra)
# The io module is a place where the Python 3 text behavior
# was forced upon Python 2, so we need to unbreak
# it to look like Python 2.
if PY2:
def write(self, x):
if isinstance(x, str) or is_bytes(x):
try:
self.flush()
except Exception:
pass
return self.buffer.write(str(x))
return io.TextIOWrapper.write(self, x)
def writelines(self, lines):
for line in lines:
self.write(line)
def __del__(self):
try:
self.detach()
except Exception:
pass
def isatty(self):
# https://bitbucket.org/pypy/pypy/issue/1803
return self._stream.isatty()
class _FixupStream(object):
"""The new io interface needs more from streams than streams
traditionally implement. As such, this fix-up code is necessary in
some circumstances.
The forcing of readable and writable flags are there because some tools
put badly patched objects on sys (one such offender are certain version
of jupyter notebook).
"""
def __init__(self, stream, force_readable=False, force_writable=False):
self._stream = stream
self._force_readable = force_readable
self._force_writable = force_writable
def __getattr__(self, name):
return getattr(self._stream, name)
def read1(self, size):
f = getattr(self._stream, 'read1', None)
if f is not None:
return f(size)
# We only dispatch to readline instead of read in Python 2 as we
# do not want cause problems with the different implementation
# of line buffering.
if PY2:
return self._stream.readline(size)
return self._stream.read(size)
def readable(self):
if self._force_readable:
return True
x = getattr(self._stream, 'readable', None)
if x is not None:
return x()
try:
self._stream.read(0)
except Exception:
return False
return True
def writable(self):
if self._force_writable:
return True
x = getattr(self._stream, 'writable', None)
if x is not None:
return x()
try:
self._stream.write('')
except Exception:
try:
self._stream.write(b'')
except Exception:
return False
return True
def seekable(self):
x = getattr(self._stream, 'seekable', None)
if x is not None:
return x()
try:
self._stream.seek(self._stream.tell())
except Exception:
return False
return True
if PY2:
text_type = unicode
bytes = str
raw_input = raw_input
string_types = (str, unicode)
int_types = (int, long)
iteritems = lambda x: x.iteritems()
range_type = xrange
def is_bytes(x):
return isinstance(x, (buffer, bytearray))
_identifier_re = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
# For Windows, we need to force stdout/stdin/stderr to binary if it's
# fetched for that. This obviously is not the most correct way to do
# it as it changes global state. Unfortunately, there does not seem to
# be a clear better way to do it as just reopening the file in binary
# mode does not change anything.
#
# An option would be to do what Python 3 does and to open the file as
# binary only, patch it back to the system, and then use a wrapper
# stream that converts newlines. It's not quite clear what's the
# correct option here.
#
# This code also lives in _winconsole for the fallback to the console
# emulation stream.
#
# There are also Windows environments where the `msvcrt` module is not
# available (which is why we use try-catch instead of the WIN variable
# here), such as the Google App Engine development server on Windows. In
# those cases there is just nothing we can do.
def set_binary_mode(f):
return f
try:
import msvcrt
except ImportError:
pass
else:
def set_binary_mode(f):
try:
fileno = f.fileno()
except Exception:
pass
else:
msvcrt.setmode(fileno, os.O_BINARY)
return f
try:
import fcntl
except ImportError:
pass
else:
def set_binary_mode(f):
try:
fileno = f.fileno()
except Exception:
pass
else:
flags = fcntl.fcntl(fileno, fcntl.F_GETFL)
fcntl.fcntl(fileno, fcntl.F_SETFL, flags & ~os.O_NONBLOCK)
return f
def isidentifier(x):
return _identifier_re.search(x) is not None
def get_binary_stdin():
return set_binary_mode(sys.stdin)
def get_binary_stdout():
_wrap_std_stream('stdout')
return set_binary_mode(sys.stdout)
def get_binary_stderr():
_wrap_std_stream('stderr')
return set_binary_mode(sys.stderr)
def get_text_stdin(encoding=None, errors=None):
rv = _get_windows_console_stream(sys.stdin, encoding, errors)
if rv is not None:
return rv
return _make_text_stream(sys.stdin, encoding, errors,
force_readable=True)
def get_text_stdout(encoding=None, errors=None):
_wrap_std_stream('stdout')
rv = _get_windows_console_stream(sys.stdout, encoding, errors)
if rv is not None:
return rv
return _make_text_stream(sys.stdout, encoding, errors,
force_writable=True)
def get_text_stderr(encoding=None, errors=None):
_wrap_std_stream('stderr')
rv = _get_windows_console_stream(sys.stderr, encoding, errors)
if rv is not None:
return rv
return _make_text_stream(sys.stderr, encoding, errors,
force_writable=True)
def filename_to_ui(value):
if isinstance(value, bytes):
value = value.decode(get_filesystem_encoding(), 'replace')
return value
else:
import io
text_type = str
raw_input = input
string_types = (str,)
int_types = (int,)
range_type = range
isidentifier = lambda x: x.isidentifier()
iteritems = lambda x: iter(x.items())
def is_bytes(x):
return isinstance(x, (bytes, memoryview, bytearray))
def _is_binary_reader(stream, default=False):
try:
return isinstance(stream.read(0), bytes)
except Exception:
return default
# This happens in some cases where the stream was already
# closed. In this case, we assume the default.
def _is_binary_writer(stream, default=False):
try:
stream.write(b'')
except Exception:
try:
stream.write('')
return False
except Exception:
pass
return default
return True
def _find_binary_reader(stream):
# We need to figure out if the given stream is already binary.
# This can happen because the official docs recommend detaching
# the streams to get binary streams. Some code might do this, so
# we need to deal with this case explicitly.
if _is_binary_reader(stream, False):
return stream
buf = getattr(stream, 'buffer', None)
# Same situation here; this time we assume that the buffer is
# actually binary in case it's closed.
if buf is not None and _is_binary_reader(buf, True):
return buf
def _find_binary_writer(stream):
# We need to figure out if the given stream is already binary.
# This can happen because the official docs recommend detatching
# the streams to get binary streams. Some code might do this, so
# we need to deal with this case explicitly.
if _is_binary_writer(stream, False):
return stream
buf = getattr(stream, 'buffer', None)
# Same situation here; this time we assume that the buffer is
# actually binary in case it's closed.
if buf is not None and _is_binary_writer(buf, True):
return buf
def _stream_is_misconfigured(stream):
"""A stream is misconfigured if its encoding is ASCII."""
# If the stream does not have an encoding set, we assume it's set
# to ASCII. This appears to happen in certain unittest
# environments. It's not quite clear what the correct behavior is
# but this at least will force Click to recover somehow.
return is_ascii_encoding(getattr(stream, 'encoding', None) or 'ascii')
def _is_compatible_text_stream(stream, encoding, errors):
stream_encoding = getattr(stream, 'encoding', None)
stream_errors = getattr(stream, 'errors', None)
# Perfect match.
if stream_encoding == encoding and stream_errors == errors:
return True
# Otherwise, it's only a compatible stream if we did not ask for
# an encoding.
if encoding is None:
return stream_encoding is not None
return False
def _force_correct_text_reader(text_reader, encoding, errors,
force_readable=False):
if _is_binary_reader(text_reader, False):
binary_reader = text_reader
else:
# If there is no target encoding set, we need to verify that the
# reader is not actually misconfigured.
if encoding is None and not _stream_is_misconfigured(text_reader):
return text_reader
if _is_compatible_text_stream(text_reader, encoding, errors):
return text_reader
# If the reader has no encoding, we try to find the underlying
# binary reader for it. If that fails because the environment is
# misconfigured, we silently go with the same reader because this
# is too common to happen. In that case, mojibake is better than
# exceptions.
binary_reader = _find_binary_reader(text_reader)
if binary_reader is None:
return text_reader
# At this point, we default the errors to replace instead of strict
# because nobody handles those errors anyways and at this point
# we're so fundamentally fucked that nothing can repair it.
if errors is None:
errors = 'replace'
return _make_text_stream(binary_reader, encoding, errors,
force_readable=force_readable)
def _force_correct_text_writer(text_writer, encoding, errors,
force_writable=False):
if _is_binary_writer(text_writer, False):
binary_writer = text_writer
else:
# If there is no target encoding set, we need to verify that the
# writer is not actually misconfigured.
if encoding is None and not _stream_is_misconfigured(text_writer):
return text_writer
if _is_compatible_text_stream(text_writer, encoding, errors):
return text_writer
# If the writer has no encoding, we try to find the underlying
# binary writer for it. If that fails because the environment is
# misconfigured, we silently go with the same writer because this
# is too common to happen. In that case, mojibake is better than
# exceptions.
binary_writer = _find_binary_writer(text_writer)
if binary_writer is None:
return text_writer
# At this point, we default the errors to replace instead of strict
# because nobody handles those errors anyways and at this point
# we're so fundamentally fucked that nothing can repair it.
if errors is None:
errors = 'replace'
return _make_text_stream(binary_writer, encoding, errors,
force_writable=force_writable)
def get_binary_stdin():
reader = _find_binary_reader(sys.stdin)
if reader is None:
raise RuntimeError('Was not able to determine binary '
'stream for sys.stdin.')
return reader
def get_binary_stdout():
writer = _find_binary_writer(sys.stdout)
if writer is None:
raise RuntimeError('Was not able to determine binary '
'stream for sys.stdout.')
return writer
def get_binary_stderr():
writer = _find_binary_writer(sys.stderr)
if writer is None:
raise RuntimeError('Was not able to determine binary '
'stream for sys.stderr.')
return writer
def get_text_stdin(encoding=None, errors=None):
rv = _get_windows_console_stream(sys.stdin, encoding, errors)
if rv is not None:
return rv
return _force_correct_text_reader(sys.stdin, encoding, errors,
force_readable=True)
def get_text_stdout(encoding=None, errors=None):
rv = _get_windows_console_stream(sys.stdout, encoding, errors)
if rv is not None:
return rv
return _force_correct_text_writer(sys.stdout, encoding, errors,
force_writable=True)
def get_text_stderr(encoding=None, errors=None):
rv = _get_windows_console_stream(sys.stderr, encoding, errors)
if rv is not None:
return rv
return _force_correct_text_writer(sys.stderr, encoding, errors,
force_writable=True)
def filename_to_ui(value):
if isinstance(value, bytes):
value = value.decode(get_filesystem_encoding(), 'replace')
else:
value = value.encode('utf-8', 'surrogateescape') \
.decode('utf-8', 'replace')
return value
def get_streerror(e, default=None):
if hasattr(e, 'strerror'):
msg = e.strerror
else:
if default is not None:
msg = default
else:
msg = str(e)
if isinstance(msg, bytes):
msg = msg.decode('utf-8', 'replace')
return msg
def open_stream(filename, mode='r', encoding=None, errors='strict',
atomic=False):
# Standard streams first. These are simple because they don't need
# special handling for the atomic flag. It's entirely ignored.
if filename == '-':
if any(m in mode for m in ['w', 'a', 'x']):
if 'b' in mode:
return get_binary_stdout(), False
return get_text_stdout(encoding=encoding, errors=errors), False
if 'b' in mode:
return get_binary_stdin(), False
return get_text_stdin(encoding=encoding, errors=errors), False
# Non-atomic writes directly go out through the regular open functions.
if not atomic:
if encoding is None:
return open(filename, mode), True
return io.open(filename, mode, encoding=encoding, errors=errors), True
# Some usability stuff for atomic writes
if 'a' in mode:
raise ValueError(
'Appending to an existing file is not supported, because that '
'would involve an expensive `copy`-operation to a temporary '
'file. Open the file in normal `w`-mode and copy explicitly '
'if that\'s what you\'re after.'
)
if 'x' in mode:
raise ValueError('Use the `overwrite`-parameter instead.')
if 'w' not in mode:
raise ValueError('Atomic writes only make sense with `w`-mode.')
# Atomic writes are more complicated. They work by opening a file
# as a proxy in the same folder and then using the fdopen
# functionality to wrap it in a Python file. Then we wrap it in an
# atomic file that moves the file over on close.
import tempfile
fd, tmp_filename = tempfile.mkstemp(dir=os.path.dirname(filename),
prefix='.__atomic-write')
if encoding is not None:
f = io.open(fd, mode, encoding=encoding, errors=errors)
else:
f = os.fdopen(fd, mode)
return _AtomicFile(f, tmp_filename, os.path.realpath(filename)), True
# Used in a destructor call, needs extra protection from interpreter cleanup.
if hasattr(os, 'replace'):
_replace = os.replace
_can_replace = True
else:
_replace = os.rename
_can_replace = not WIN
class _AtomicFile(object):
def __init__(self, f, tmp_filename, real_filename):
self._f = f
self._tmp_filename = tmp_filename
self._real_filename = real_filename
self.closed = False
@property
def name(self):
return self._real_filename
def close(self, delete=False):
if self.closed:
return
self._f.close()
if not _can_replace:
try:
os.remove(self._real_filename)
except OSError:
pass
_replace(self._tmp_filename, self._real_filename)
self.closed = True
def __getattr__(self, name):
return getattr(self._f, name)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
self.close(delete=exc_type is not None)
def __repr__(self):
return repr(self._f)
auto_wrap_for_ansi = None
colorama = None
get_winterm_size = None
def strip_ansi(value):
return _ansi_re.sub('', value)
def should_strip_ansi(stream=None, color=None):
if color is None:
if stream is None:
stream = sys.stdin
return not isatty(stream)
return not color
# If we're on Windows, we provide transparent integration through
# colorama. This will make ANSI colors through the echo function
# work automatically.
if WIN:
# Windows has a smaller terminal
DEFAULT_COLUMNS = 79
from ._winconsole import _get_windows_console_stream, _wrap_std_stream
def _get_argv_encoding():
import locale
return locale.getpreferredencoding()
if PY2:
def raw_input(prompt=''):
sys.stderr.flush()
if prompt:
stdout = _default_text_stdout()
stdout.write(prompt)
stdin = _default_text_stdin()
return stdin.readline().rstrip('\r\n')
try:
import colorama
except ImportError:
pass
else:
_ansi_stream_wrappers = WeakKeyDictionary()
def auto_wrap_for_ansi(stream, color=None):
"""This function wraps a stream so that calls through colorama
are issued to the win32 console API to recolor on demand. It
also ensures to reset the colors if a write call is interrupted
to not destroy the console afterwards.
"""
try:
cached = _ansi_stream_wrappers.get(stream)
except Exception:
cached = None
if cached is not None:
return cached
strip = should_strip_ansi(stream, color)
ansi_wrapper = colorama.AnsiToWin32(stream, strip=strip)
rv = ansi_wrapper.stream
_write = rv.write
def _safe_write(s):
try:
return _write(s)
except:
ansi_wrapper.reset_all()
raise
rv.write = _safe_write
try:
_ansi_stream_wrappers[stream] = rv
except Exception:
pass
return rv
def get_winterm_size():
win = colorama.win32.GetConsoleScreenBufferInfo(
colorama.win32.STDOUT).srWindow
return win.Right - win.Left, win.Bottom - win.Top
else:
def _get_argv_encoding():
return getattr(sys.stdin, 'encoding', None) or get_filesystem_encoding()
_get_windows_console_stream = lambda *x: None
_wrap_std_stream = lambda *x: None
def term_len(x):
return len(strip_ansi(x))
def isatty(stream):
try:
return stream.isatty()
except Exception:
return False
def _make_cached_stream_func(src_func, wrapper_func):
cache = WeakKeyDictionary()
def func():
stream = src_func()
try:
rv = cache.get(stream)
except Exception:
rv = None
if rv is not None:
return rv
rv = wrapper_func()
try:
stream = src_func() # In case wrapper_func() modified the stream
cache[stream] = rv
except Exception:
pass
return rv
return func
_default_text_stdin = _make_cached_stream_func(
lambda: sys.stdin, get_text_stdin)
_default_text_stdout = _make_cached_stream_func(
lambda: sys.stdout, get_text_stdout)
_default_text_stderr = _make_cached_stream_func(
lambda: sys.stderr, get_text_stderr)
binary_streams = {
'stdin': get_binary_stdin,
'stdout': get_binary_stdout,
'stderr': get_binary_stderr,
}
text_streams = {
'stdin': get_text_stdin,
'stdout': get_text_stdout,
'stderr': get_text_stderr,
}

621
libs/click/_termui_impl.py Normal file
View file

@ -0,0 +1,621 @@
# -*- coding: utf-8 -*-
"""
click._termui_impl
~~~~~~~~~~~~~~~~~~
This module contains implementations for the termui module. To keep the
import time of Click down, some infrequently used functionality is
placed in this module and only imported as needed.
:copyright: © 2014 by the Pallets team.
:license: BSD, see LICENSE.rst for more details.
"""
import os
import sys
import time
import math
import contextlib
from ._compat import _default_text_stdout, range_type, PY2, isatty, \
open_stream, strip_ansi, term_len, get_best_encoding, WIN, int_types, \
CYGWIN
from .utils import echo
from .exceptions import ClickException
if os.name == 'nt':
BEFORE_BAR = '\r'
AFTER_BAR = '\n'
else:
BEFORE_BAR = '\r\033[?25l'
AFTER_BAR = '\033[?25h\n'
def _length_hint(obj):
"""Returns the length hint of an object."""
try:
return len(obj)
except (AttributeError, TypeError):
try:
get_hint = type(obj).__length_hint__
except AttributeError:
return None
try:
hint = get_hint(obj)
except TypeError:
return None
if hint is NotImplemented or \
not isinstance(hint, int_types) or \
hint < 0:
return None
return hint
class ProgressBar(object):
def __init__(self, iterable, length=None, fill_char='#', empty_char=' ',
bar_template='%(bar)s', info_sep=' ', show_eta=True,
show_percent=None, show_pos=False, item_show_func=None,
label=None, file=None, color=None, width=30):
self.fill_char = fill_char
self.empty_char = empty_char
self.bar_template = bar_template
self.info_sep = info_sep
self.show_eta = show_eta
self.show_percent = show_percent
self.show_pos = show_pos
self.item_show_func = item_show_func
self.label = label or ''
if file is None:
file = _default_text_stdout()
self.file = file
self.color = color
self.width = width
self.autowidth = width == 0
if length is None:
length = _length_hint(iterable)
if iterable is None:
if length is None:
raise TypeError('iterable or length is required')
iterable = range_type(length)
self.iter = iter(iterable)
self.length = length
self.length_known = length is not None
self.pos = 0
self.avg = []
self.start = self.last_eta = time.time()
self.eta_known = False
self.finished = False
self.max_width = None
self.entered = False
self.current_item = None
self.is_hidden = not isatty(self.file)
self._last_line = None
self.short_limit = 0.5
def __enter__(self):
self.entered = True
self.render_progress()
return self
def __exit__(self, exc_type, exc_value, tb):
self.render_finish()
def __iter__(self):
if not self.entered:
raise RuntimeError('You need to use progress bars in a with block.')
self.render_progress()
return self.generator()
def is_fast(self):
return time.time() - self.start <= self.short_limit
def render_finish(self):
if self.is_hidden or self.is_fast():
return
self.file.write(AFTER_BAR)
self.file.flush()
@property
def pct(self):
if self.finished:
return 1.0
return min(self.pos / (float(self.length) or 1), 1.0)
@property
def time_per_iteration(self):
if not self.avg:
return 0.0
return sum(self.avg) / float(len(self.avg))
@property
def eta(self):
if self.length_known and not self.finished:
return self.time_per_iteration * (self.length - self.pos)
return 0.0
def format_eta(self):
if self.eta_known:
t = int(self.eta)
seconds = t % 60
t //= 60
minutes = t % 60
t //= 60
hours = t % 24
t //= 24
if t > 0:
days = t
return '%dd %02d:%02d:%02d' % (days, hours, minutes, seconds)
else:
return '%02d:%02d:%02d' % (hours, minutes, seconds)
return ''
def format_pos(self):
pos = str(self.pos)
if self.length_known:
pos += '/%s' % self.length
return pos
def format_pct(self):
return ('% 4d%%' % int(self.pct * 100))[1:]
def format_bar(self):
if self.length_known:
bar_length = int(self.pct * self.width)
bar = self.fill_char * bar_length
bar += self.empty_char * (self.width - bar_length)
elif self.finished:
bar = self.fill_char * self.width
else:
bar = list(self.empty_char * (self.width or 1))
if self.time_per_iteration != 0:
bar[int((math.cos(self.pos * self.time_per_iteration)
/ 2.0 + 0.5) * self.width)] = self.fill_char
bar = ''.join(bar)
return bar
def format_progress_line(self):
show_percent = self.show_percent
info_bits = []
if self.length_known and show_percent is None:
show_percent = not self.show_pos
if self.show_pos:
info_bits.append(self.format_pos())
if show_percent:
info_bits.append(self.format_pct())
if self.show_eta and self.eta_known and not self.finished:
info_bits.append(self.format_eta())
if self.item_show_func is not None:
item_info = self.item_show_func(self.current_item)
if item_info is not None:
info_bits.append(item_info)
return (self.bar_template % {
'label': self.label,
'bar': self.format_bar(),
'info': self.info_sep.join(info_bits)
}).rstrip()
def render_progress(self):
from .termui import get_terminal_size
if self.is_hidden:
return
buf = []
# Update width in case the terminal has been resized
if self.autowidth:
old_width = self.width
self.width = 0
clutter_length = term_len(self.format_progress_line())
new_width = max(0, get_terminal_size()[0] - clutter_length)
if new_width < old_width:
buf.append(BEFORE_BAR)
buf.append(' ' * self.max_width)
self.max_width = new_width
self.width = new_width
clear_width = self.width
if self.max_width is not None:
clear_width = self.max_width
buf.append(BEFORE_BAR)
line = self.format_progress_line()
line_len = term_len(line)
if self.max_width is None or self.max_width < line_len:
self.max_width = line_len
buf.append(line)
buf.append(' ' * (clear_width - line_len))
line = ''.join(buf)
# Render the line only if it changed.
if line != self._last_line and not self.is_fast():
self._last_line = line
echo(line, file=self.file, color=self.color, nl=False)
self.file.flush()
def make_step(self, n_steps):
self.pos += n_steps
if self.length_known and self.pos >= self.length:
self.finished = True
if (time.time() - self.last_eta) < 1.0:
return
self.last_eta = time.time()
# self.avg is a rolling list of length <= 7 of steps where steps are
# defined as time elapsed divided by the total progress through
# self.length.
if self.pos:
step = (time.time() - self.start) / self.pos
else:
step = time.time() - self.start
self.avg = self.avg[-6:] + [step]
self.eta_known = self.length_known
def update(self, n_steps):
self.make_step(n_steps)
self.render_progress()
def finish(self):
self.eta_known = 0
self.current_item = None
self.finished = True
def generator(self):
"""
Returns a generator which yields the items added to the bar during
construction, and updates the progress bar *after* the yielded block
returns.
"""
if not self.entered:
raise RuntimeError('You need to use progress bars in a with block.')
if self.is_hidden:
for rv in self.iter:
yield rv
else:
for rv in self.iter:
self.current_item = rv
yield rv
self.update(1)
self.finish()
self.render_progress()
def pager(generator, color=None):
"""Decide what method to use for paging through text."""
stdout = _default_text_stdout()
if not isatty(sys.stdin) or not isatty(stdout):
return _nullpager(stdout, generator, color)
pager_cmd = (os.environ.get('PAGER', None) or '').strip()
if pager_cmd:
if WIN:
return _tempfilepager(generator, pager_cmd, color)
return _pipepager(generator, pager_cmd, color)
if os.environ.get('TERM') in ('dumb', 'emacs'):
return _nullpager(stdout, generator, color)
if WIN or sys.platform.startswith('os2'):
return _tempfilepager(generator, 'more <', color)
if hasattr(os, 'system') and os.system('(less) 2>/dev/null') == 0:
return _pipepager(generator, 'less', color)
import tempfile
fd, filename = tempfile.mkstemp()
os.close(fd)
try:
if hasattr(os, 'system') and os.system('more "%s"' % filename) == 0:
return _pipepager(generator, 'more', color)
return _nullpager(stdout, generator, color)
finally:
os.unlink(filename)
def _pipepager(generator, cmd, color):
"""Page through text by feeding it to another program. Invoking a
pager through this might support colors.
"""
import subprocess
env = dict(os.environ)
# If we're piping to less we might support colors under the
# condition that
cmd_detail = cmd.rsplit('/', 1)[-1].split()
if color is None and cmd_detail[0] == 'less':
less_flags = os.environ.get('LESS', '') + ' '.join(cmd_detail[1:])
if not less_flags:
env['LESS'] = '-R'
color = True
elif 'r' in less_flags or 'R' in less_flags:
color = True
c = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE,
env=env)
encoding = get_best_encoding(c.stdin)
try:
for text in generator:
if not color:
text = strip_ansi(text)
c.stdin.write(text.encode(encoding, 'replace'))
except (IOError, KeyboardInterrupt):
pass
else:
c.stdin.close()
# Less doesn't respect ^C, but catches it for its own UI purposes (aborting
# search or other commands inside less).
#
# That means when the user hits ^C, the parent process (click) terminates,
# but less is still alive, paging the output and messing up the terminal.
#
# If the user wants to make the pager exit on ^C, they should set
# `LESS='-K'`. It's not our decision to make.
while True:
try:
c.wait()
except KeyboardInterrupt:
pass
else:
break
def _tempfilepager(generator, cmd, color):
"""Page through text by invoking a program on a temporary file."""
import tempfile
filename = tempfile.mktemp()
# TODO: This never terminates if the passed generator never terminates.
text = "".join(generator)
if not color:
text = strip_ansi(text)
encoding = get_best_encoding(sys.stdout)
with open_stream(filename, 'wb')[0] as f:
f.write(text.encode(encoding))
try:
os.system(cmd + ' "' + filename + '"')
finally:
os.unlink(filename)
def _nullpager(stream, generator, color):
"""Simply print unformatted text. This is the ultimate fallback."""
for text in generator:
if not color:
text = strip_ansi(text)
stream.write(text)
class Editor(object):
def __init__(self, editor=None, env=None, require_save=True,
extension='.txt'):
self.editor = editor
self.env = env
self.require_save = require_save
self.extension = extension
def get_editor(self):
if self.editor is not None:
return self.editor
for key in 'VISUAL', 'EDITOR':
rv = os.environ.get(key)
if rv:
return rv
if WIN:
return 'notepad'
for editor in 'vim', 'nano':
if os.system('which %s >/dev/null 2>&1' % editor) == 0:
return editor
return 'vi'
def edit_file(self, filename):
import subprocess
editor = self.get_editor()
if self.env:
environ = os.environ.copy()
environ.update(self.env)
else:
environ = None
try:
c = subprocess.Popen('%s "%s"' % (editor, filename),
env=environ, shell=True)
exit_code = c.wait()
if exit_code != 0:
raise ClickException('%s: Editing failed!' % editor)
except OSError as e:
raise ClickException('%s: Editing failed: %s' % (editor, e))
def edit(self, text):
import tempfile
text = text or ''
if text and not text.endswith('\n'):
text += '\n'
fd, name = tempfile.mkstemp(prefix='editor-', suffix=self.extension)
try:
if WIN:
encoding = 'utf-8-sig'
text = text.replace('\n', '\r\n')
else:
encoding = 'utf-8'
text = text.encode(encoding)
f = os.fdopen(fd, 'wb')
f.write(text)
f.close()
timestamp = os.path.getmtime(name)
self.edit_file(name)
if self.require_save \
and os.path.getmtime(name) == timestamp:
return None
f = open(name, 'rb')
try:
rv = f.read()
finally:
f.close()
return rv.decode('utf-8-sig').replace('\r\n', '\n')
finally:
os.unlink(name)
def open_url(url, wait=False, locate=False):
import subprocess
def _unquote_file(url):
try:
import urllib
except ImportError:
import urllib
if url.startswith('file://'):
url = urllib.unquote(url[7:])
return url
if sys.platform == 'darwin':
args = ['open']
if wait:
args.append('-W')
if locate:
args.append('-R')
args.append(_unquote_file(url))
null = open('/dev/null', 'w')
try:
return subprocess.Popen(args, stderr=null).wait()
finally:
null.close()
elif WIN:
if locate:
url = _unquote_file(url)
args = 'explorer /select,"%s"' % _unquote_file(
url.replace('"', ''))
else:
args = 'start %s "" "%s"' % (
wait and '/WAIT' or '', url.replace('"', ''))
return os.system(args)
elif CYGWIN:
if locate:
url = _unquote_file(url)
args = 'cygstart "%s"' % (os.path.dirname(url).replace('"', ''))
else:
args = 'cygstart %s "%s"' % (
wait and '-w' or '', url.replace('"', ''))
return os.system(args)
try:
if locate:
url = os.path.dirname(_unquote_file(url)) or '.'
else:
url = _unquote_file(url)
c = subprocess.Popen(['xdg-open', url])
if wait:
return c.wait()
return 0
except OSError:
if url.startswith(('http://', 'https://')) and not locate and not wait:
import webbrowser
webbrowser.open(url)
return 0
return 1
def _translate_ch_to_exc(ch):
if ch == u'\x03':
raise KeyboardInterrupt()
if ch == u'\x04' and not WIN: # Unix-like, Ctrl+D
raise EOFError()
if ch == u'\x1a' and WIN: # Windows, Ctrl+Z
raise EOFError()
if WIN:
import msvcrt
@contextlib.contextmanager
def raw_terminal():
yield
def getchar(echo):
# The function `getch` will return a bytes object corresponding to
# the pressed character. Since Windows 10 build 1803, it will also
# return \x00 when called a second time after pressing a regular key.
#
# `getwch` does not share this probably-bugged behavior. Moreover, it
# returns a Unicode object by default, which is what we want.
#
# Either of these functions will return \x00 or \xe0 to indicate
# a special key, and you need to call the same function again to get
# the "rest" of the code. The fun part is that \u00e0 is
# "latin small letter a with grave", so if you type that on a French
# keyboard, you _also_ get a \xe0.
# E.g., consider the Up arrow. This returns \xe0 and then \x48. The
# resulting Unicode string reads as "a with grave" + "capital H".
# This is indistinguishable from when the user actually types
# "a with grave" and then "capital H".
#
# When \xe0 is returned, we assume it's part of a special-key sequence
# and call `getwch` again, but that means that when the user types
# the \u00e0 character, `getchar` doesn't return until a second
# character is typed.
# The alternative is returning immediately, but that would mess up
# cross-platform handling of arrow keys and others that start with
# \xe0. Another option is using `getch`, but then we can't reliably
# read non-ASCII characters, because return values of `getch` are
# limited to the current 8-bit codepage.
#
# Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
# is doing the right thing in more situations than with `getch`.
if echo:
func = msvcrt.getwche
else:
func = msvcrt.getwch
rv = func()
if rv in (u'\x00', u'\xe0'):
# \x00 and \xe0 are control characters that indicate special key,
# see above.
rv += func()
_translate_ch_to_exc(rv)
return rv
else:
import tty
import termios
@contextlib.contextmanager
def raw_terminal():
if not isatty(sys.stdin):
f = open('/dev/tty')
fd = f.fileno()
else:
fd = sys.stdin.fileno()
f = None
try:
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(fd)
yield fd
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
sys.stdout.flush()
if f is not None:
f.close()
except termios.error:
pass
def getchar(echo):
with raw_terminal() as fd:
ch = os.read(fd, 32)
ch = ch.decode(get_best_encoding(sys.stdin), 'replace')
if echo and isatty(sys.stdout):
sys.stdout.write(ch)
_translate_ch_to_exc(ch)
return ch

38
libs/click/_textwrap.py Normal file
View file

@ -0,0 +1,38 @@
import textwrap
from contextlib import contextmanager
class TextWrapper(textwrap.TextWrapper):
def _handle_long_word(self, reversed_chunks, cur_line, cur_len, width):
space_left = max(width - cur_len, 1)
if self.break_long_words:
last = reversed_chunks[-1]
cut = last[:space_left]
res = last[space_left:]
cur_line.append(cut)
reversed_chunks[-1] = res
elif not cur_line:
cur_line.append(reversed_chunks.pop())
@contextmanager
def extra_indent(self, indent):
old_initial_indent = self.initial_indent
old_subsequent_indent = self.subsequent_indent
self.initial_indent += indent
self.subsequent_indent += indent
try:
yield
finally:
self.initial_indent = old_initial_indent
self.subsequent_indent = old_subsequent_indent
def indent_only(self, text):
rv = []
for idx, line in enumerate(text.splitlines()):
indent = self.initial_indent
if idx > 0:
indent = self.subsequent_indent
rv.append(indent + line)
return '\n'.join(rv)

125
libs/click/_unicodefun.py Normal file
View file

@ -0,0 +1,125 @@
import os
import sys
import codecs
from ._compat import PY2
# If someone wants to vendor click, we want to ensure the
# correct package is discovered. Ideally we could use a
# relative import here but unfortunately Python does not
# support that.
click = sys.modules[__name__.rsplit('.', 1)[0]]
def _find_unicode_literals_frame():
import __future__
if not hasattr(sys, '_getframe'): # not all Python implementations have it
return 0
frm = sys._getframe(1)
idx = 1
while frm is not None:
if frm.f_globals.get('__name__', '').startswith('click.'):
frm = frm.f_back
idx += 1
elif frm.f_code.co_flags & __future__.unicode_literals.compiler_flag:
return idx
else:
break
return 0
def _check_for_unicode_literals():
if not __debug__:
return
if not PY2 or click.disable_unicode_literals_warning:
return
bad_frame = _find_unicode_literals_frame()
if bad_frame <= 0:
return
from warnings import warn
warn(Warning('Click detected the use of the unicode_literals '
'__future__ import. This is heavily discouraged '
'because it can introduce subtle bugs in your '
'code. You should instead use explicit u"" literals '
'for your unicode strings. For more information see '
'https://click.palletsprojects.com/python3/'),
stacklevel=bad_frame)
def _verify_python3_env():
"""Ensures that the environment is good for unicode on Python 3."""
if PY2:
return
try:
import locale
fs_enc = codecs.lookup(locale.getpreferredencoding()).name
except Exception:
fs_enc = 'ascii'
if fs_enc != 'ascii':
return
extra = ''
if os.name == 'posix':
import subprocess
try:
rv = subprocess.Popen(['locale', '-a'], stdout=subprocess.PIPE,
stderr=subprocess.PIPE).communicate()[0]
except OSError:
rv = b''
good_locales = set()
has_c_utf8 = False
# Make sure we're operating on text here.
if isinstance(rv, bytes):
rv = rv.decode('ascii', 'replace')
for line in rv.splitlines():
locale = line.strip()
if locale.lower().endswith(('.utf-8', '.utf8')):
good_locales.add(locale)
if locale.lower() in ('c.utf8', 'c.utf-8'):
has_c_utf8 = True
extra += '\n\n'
if not good_locales:
extra += (
'Additional information: on this system no suitable UTF-8\n'
'locales were discovered. This most likely requires resolving\n'
'by reconfiguring the locale system.'
)
elif has_c_utf8:
extra += (
'This system supports the C.UTF-8 locale which is recommended.\n'
'You might be able to resolve your issue by exporting the\n'
'following environment variables:\n\n'
' export LC_ALL=C.UTF-8\n'
' export LANG=C.UTF-8'
)
else:
extra += (
'This system lists a couple of UTF-8 supporting locales that\n'
'you can pick from. The following suitable locales were\n'
'discovered: %s'
) % ', '.join(sorted(good_locales))
bad_locale = None
for locale in os.environ.get('LC_ALL'), os.environ.get('LANG'):
if locale and locale.lower().endswith(('.utf-8', '.utf8')):
bad_locale = locale
if locale is not None:
break
if bad_locale is not None:
extra += (
'\n\nClick discovered that you exported a UTF-8 locale\n'
'but the locale system could not pick up from it because\n'
'it does not exist. The exported locale is "%s" but it\n'
'is not supported'
) % bad_locale
raise RuntimeError(
'Click will abort further execution because Python 3 was'
' configured to use ASCII as encoding for the environment.'
' Consult https://click.palletsprojects.com/en/7.x/python3/ for'
' mitigation steps.' + extra
)

307
libs/click/_winconsole.py Normal file
View file

@ -0,0 +1,307 @@
# -*- coding: utf-8 -*-
# This module is based on the excellent work by Adam Bartoš who
# provided a lot of what went into the implementation here in
# the discussion to issue1602 in the Python bug tracker.
#
# There are some general differences in regards to how this works
# compared to the original patches as we do not need to patch
# the entire interpreter but just work in our little world of
# echo and prmopt.
import io
import os
import sys
import zlib
import time
import ctypes
import msvcrt
from ._compat import _NonClosingTextIOWrapper, text_type, PY2
from ctypes import byref, POINTER, c_int, c_char, c_char_p, \
c_void_p, py_object, c_ssize_t, c_ulong, windll, WINFUNCTYPE
try:
from ctypes import pythonapi
PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
PyBuffer_Release = pythonapi.PyBuffer_Release
except ImportError:
pythonapi = None
from ctypes.wintypes import LPWSTR, LPCWSTR
c_ssize_p = POINTER(c_ssize_t)
kernel32 = windll.kernel32
GetStdHandle = kernel32.GetStdHandle
ReadConsoleW = kernel32.ReadConsoleW
WriteConsoleW = kernel32.WriteConsoleW
GetLastError = kernel32.GetLastError
GetCommandLineW = WINFUNCTYPE(LPWSTR)(
('GetCommandLineW', windll.kernel32))
CommandLineToArgvW = WINFUNCTYPE(
POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
('CommandLineToArgvW', windll.shell32))
STDIN_HANDLE = GetStdHandle(-10)
STDOUT_HANDLE = GetStdHandle(-11)
STDERR_HANDLE = GetStdHandle(-12)
PyBUF_SIMPLE = 0
PyBUF_WRITABLE = 1
ERROR_SUCCESS = 0
ERROR_NOT_ENOUGH_MEMORY = 8
ERROR_OPERATION_ABORTED = 995
STDIN_FILENO = 0
STDOUT_FILENO = 1
STDERR_FILENO = 2
EOF = b'\x1a'
MAX_BYTES_WRITTEN = 32767
class Py_buffer(ctypes.Structure):
_fields_ = [
('buf', c_void_p),
('obj', py_object),
('len', c_ssize_t),
('itemsize', c_ssize_t),
('readonly', c_int),
('ndim', c_int),
('format', c_char_p),
('shape', c_ssize_p),
('strides', c_ssize_p),
('suboffsets', c_ssize_p),
('internal', c_void_p)
]
if PY2:
_fields_.insert(-1, ('smalltable', c_ssize_t * 2))
# On PyPy we cannot get buffers so our ability to operate here is
# serverly limited.
if pythonapi is None:
get_buffer = None
else:
def get_buffer(obj, writable=False):
buf = Py_buffer()
flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
PyObject_GetBuffer(py_object(obj), byref(buf), flags)
try:
buffer_type = c_char * buf.len
return buffer_type.from_address(buf.buf)
finally:
PyBuffer_Release(byref(buf))
class _WindowsConsoleRawIOBase(io.RawIOBase):
def __init__(self, handle):
self.handle = handle
def isatty(self):
io.RawIOBase.isatty(self)
return True
class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
def readable(self):
return True
def readinto(self, b):
bytes_to_be_read = len(b)
if not bytes_to_be_read:
return 0
elif bytes_to_be_read % 2:
raise ValueError('cannot read odd number of bytes from '
'UTF-16-LE encoded console')
buffer = get_buffer(b, writable=True)
code_units_to_be_read = bytes_to_be_read // 2
code_units_read = c_ulong()
rv = ReadConsoleW(self.handle, buffer, code_units_to_be_read,
byref(code_units_read), None)
if GetLastError() == ERROR_OPERATION_ABORTED:
# wait for KeyboardInterrupt
time.sleep(0.1)
if not rv:
raise OSError('Windows error: %s' % GetLastError())
if buffer[0] == EOF:
return 0
return 2 * code_units_read.value
class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
def writable(self):
return True
@staticmethod
def _get_error_message(errno):
if errno == ERROR_SUCCESS:
return 'ERROR_SUCCESS'
elif errno == ERROR_NOT_ENOUGH_MEMORY:
return 'ERROR_NOT_ENOUGH_MEMORY'
return 'Windows error %s' % errno
def write(self, b):
bytes_to_be_written = len(b)
buf = get_buffer(b)
code_units_to_be_written = min(bytes_to_be_written,
MAX_BYTES_WRITTEN) // 2
code_units_written = c_ulong()
WriteConsoleW(self.handle, buf, code_units_to_be_written,
byref(code_units_written), None)
bytes_written = 2 * code_units_written.value
if bytes_written == 0 and bytes_to_be_written > 0:
raise OSError(self._get_error_message(GetLastError()))
return bytes_written
class ConsoleStream(object):
def __init__(self, text_stream, byte_stream):
self._text_stream = text_stream
self.buffer = byte_stream
@property
def name(self):
return self.buffer.name
def write(self, x):
if isinstance(x, text_type):
return self._text_stream.write(x)
try:
self.flush()
except Exception:
pass
return self.buffer.write(x)
def writelines(self, lines):
for line in lines:
self.write(line)
def __getattr__(self, name):
return getattr(self._text_stream, name)
def isatty(self):
return self.buffer.isatty()
def __repr__(self):
return '<ConsoleStream name=%r encoding=%r>' % (
self.name,
self.encoding,
)
class WindowsChunkedWriter(object):
"""
Wraps a stream (such as stdout), acting as a transparent proxy for all
attribute access apart from method 'write()' which we wrap to write in
limited chunks due to a Windows limitation on binary console streams.
"""
def __init__(self, wrapped):
# double-underscore everything to prevent clashes with names of
# attributes on the wrapped stream object.
self.__wrapped = wrapped
def __getattr__(self, name):
return getattr(self.__wrapped, name)
def write(self, text):
total_to_write = len(text)
written = 0
while written < total_to_write:
to_write = min(total_to_write - written, MAX_BYTES_WRITTEN)
self.__wrapped.write(text[written:written+to_write])
written += to_write
_wrapped_std_streams = set()
def _wrap_std_stream(name):
# Python 2 & Windows 7 and below
if PY2 and sys.getwindowsversion()[:2] <= (6, 1) and name not in _wrapped_std_streams:
setattr(sys, name, WindowsChunkedWriter(getattr(sys, name)))
_wrapped_std_streams.add(name)
def _get_text_stdin(buffer_stream):
text_stream = _NonClosingTextIOWrapper(
io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
'utf-16-le', 'strict', line_buffering=True)
return ConsoleStream(text_stream, buffer_stream)
def _get_text_stdout(buffer_stream):
text_stream = _NonClosingTextIOWrapper(
io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
'utf-16-le', 'strict', line_buffering=True)
return ConsoleStream(text_stream, buffer_stream)
def _get_text_stderr(buffer_stream):
text_stream = _NonClosingTextIOWrapper(
io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
'utf-16-le', 'strict', line_buffering=True)
return ConsoleStream(text_stream, buffer_stream)
if PY2:
def _hash_py_argv():
return zlib.crc32('\x00'.join(sys.argv[1:]))
_initial_argv_hash = _hash_py_argv()
def _get_windows_argv():
argc = c_int(0)
argv_unicode = CommandLineToArgvW(GetCommandLineW(), byref(argc))
argv = [argv_unicode[i] for i in range(0, argc.value)]
if not hasattr(sys, 'frozen'):
argv = argv[1:]
while len(argv) > 0:
arg = argv[0]
if not arg.startswith('-') or arg == '-':
break
argv = argv[1:]
if arg.startswith(('-c', '-m')):
break
return argv[1:]
_stream_factories = {
0: _get_text_stdin,
1: _get_text_stdout,
2: _get_text_stderr,
}
def _get_windows_console_stream(f, encoding, errors):
if get_buffer is not None and \
encoding in ('utf-16-le', None) \
and errors in ('strict', None) and \
hasattr(f, 'isatty') and f.isatty():
func = _stream_factories.get(f.fileno())
if func is not None:
if not PY2:
f = getattr(f, 'buffer', None)
if f is None:
return None
else:
# If we are on Python 2 we need to set the stream that we
# deal with to binary mode as otherwise the exercise if a
# bit moot. The same problems apply as for
# get_binary_stdin and friends from _compat.
msvcrt.setmode(f.fileno(), os.O_BINARY)
return func(f)

1856
libs/click/core.py Normal file

File diff suppressed because it is too large Load diff

311
libs/click/decorators.py Normal file
View file

@ -0,0 +1,311 @@
import sys
import inspect
from functools import update_wrapper
from ._compat import iteritems
from ._unicodefun import _check_for_unicode_literals
from .utils import echo
from .globals import get_current_context
def pass_context(f):
"""Marks a callback as wanting to receive the current context
object as first argument.
"""
def new_func(*args, **kwargs):
return f(get_current_context(), *args, **kwargs)
return update_wrapper(new_func, f)
def pass_obj(f):
"""Similar to :func:`pass_context`, but only pass the object on the
context onwards (:attr:`Context.obj`). This is useful if that object
represents the state of a nested system.
"""
def new_func(*args, **kwargs):
return f(get_current_context().obj, *args, **kwargs)
return update_wrapper(new_func, f)
def make_pass_decorator(object_type, ensure=False):
"""Given an object type this creates a decorator that will work
similar to :func:`pass_obj` but instead of passing the object of the
current context, it will find the innermost context of type
:func:`object_type`.
This generates a decorator that works roughly like this::
from functools import update_wrapper
def decorator(f):
@pass_context
def new_func(ctx, *args, **kwargs):
obj = ctx.find_object(object_type)
return ctx.invoke(f, obj, *args, **kwargs)
return update_wrapper(new_func, f)
return decorator
:param object_type: the type of the object to pass.
:param ensure: if set to `True`, a new object will be created and
remembered on the context if it's not there yet.
"""
def decorator(f):
def new_func(*args, **kwargs):
ctx = get_current_context()
if ensure:
obj = ctx.ensure_object(object_type)
else:
obj = ctx.find_object(object_type)
if obj is None:
raise RuntimeError('Managed to invoke callback without a '
'context object of type %r existing'
% object_type.__name__)
return ctx.invoke(f, obj, *args, **kwargs)
return update_wrapper(new_func, f)
return decorator
def _make_command(f, name, attrs, cls):
if isinstance(f, Command):
raise TypeError('Attempted to convert a callback into a '
'command twice.')
try:
params = f.__click_params__
params.reverse()
del f.__click_params__
except AttributeError:
params = []
help = attrs.get('help')
if help is None:
help = inspect.getdoc(f)
if isinstance(help, bytes):
help = help.decode('utf-8')
else:
help = inspect.cleandoc(help)
attrs['help'] = help
_check_for_unicode_literals()
return cls(name=name or f.__name__.lower().replace('_', '-'),
callback=f, params=params, **attrs)
def command(name=None, cls=None, **attrs):
r"""Creates a new :class:`Command` and uses the decorated function as
callback. This will also automatically attach all decorated
:func:`option`\s and :func:`argument`\s as parameters to the command.
The name of the command defaults to the name of the function. If you
want to change that, you can pass the intended name as the first
argument.
All keyword arguments are forwarded to the underlying command class.
Once decorated the function turns into a :class:`Command` instance
that can be invoked as a command line utility or be attached to a
command :class:`Group`.
:param name: the name of the command. This defaults to the function
name with underscores replaced by dashes.
:param cls: the command class to instantiate. This defaults to
:class:`Command`.
"""
if cls is None:
cls = Command
def decorator(f):
cmd = _make_command(f, name, attrs, cls)
cmd.__doc__ = f.__doc__
return cmd
return decorator
def group(name=None, **attrs):
"""Creates a new :class:`Group` with a function as callback. This
works otherwise the same as :func:`command` just that the `cls`
parameter is set to :class:`Group`.
"""
attrs.setdefault('cls', Group)
return command(name, **attrs)
def _param_memo(f, param):
if isinstance(f, Command):
f.params.append(param)
else:
if not hasattr(f, '__click_params__'):
f.__click_params__ = []
f.__click_params__.append(param)
def argument(*param_decls, **attrs):
"""Attaches an argument to the command. All positional arguments are
passed as parameter declarations to :class:`Argument`; all keyword
arguments are forwarded unchanged (except ``cls``).
This is equivalent to creating an :class:`Argument` instance manually
and attaching it to the :attr:`Command.params` list.
:param cls: the argument class to instantiate. This defaults to
:class:`Argument`.
"""
def decorator(f):
ArgumentClass = attrs.pop('cls', Argument)
_param_memo(f, ArgumentClass(param_decls, **attrs))
return f
return decorator
def option(*param_decls, **attrs):
"""Attaches an option to the command. All positional arguments are
passed as parameter declarations to :class:`Option`; all keyword
arguments are forwarded unchanged (except ``cls``).
This is equivalent to creating an :class:`Option` instance manually
and attaching it to the :attr:`Command.params` list.
:param cls: the option class to instantiate. This defaults to
:class:`Option`.
"""
def decorator(f):
# Issue 926, copy attrs, so pre-defined options can re-use the same cls=
option_attrs = attrs.copy()
if 'help' in option_attrs:
option_attrs['help'] = inspect.cleandoc(option_attrs['help'])
OptionClass = option_attrs.pop('cls', Option)
_param_memo(f, OptionClass(param_decls, **option_attrs))
return f
return decorator
def confirmation_option(*param_decls, **attrs):
"""Shortcut for confirmation prompts that can be ignored by passing
``--yes`` as parameter.
This is equivalent to decorating a function with :func:`option` with
the following parameters::
def callback(ctx, param, value):
if not value:
ctx.abort()
@click.command()
@click.option('--yes', is_flag=True, callback=callback,
expose_value=False, prompt='Do you want to continue?')
def dropdb():
pass
"""
def decorator(f):
def callback(ctx, param, value):
if not value:
ctx.abort()
attrs.setdefault('is_flag', True)
attrs.setdefault('callback', callback)
attrs.setdefault('expose_value', False)
attrs.setdefault('prompt', 'Do you want to continue?')
attrs.setdefault('help', 'Confirm the action without prompting.')
return option(*(param_decls or ('--yes',)), **attrs)(f)
return decorator
def password_option(*param_decls, **attrs):
"""Shortcut for password prompts.
This is equivalent to decorating a function with :func:`option` with
the following parameters::
@click.command()
@click.option('--password', prompt=True, confirmation_prompt=True,
hide_input=True)
def changeadmin(password):
pass
"""
def decorator(f):
attrs.setdefault('prompt', True)
attrs.setdefault('confirmation_prompt', True)
attrs.setdefault('hide_input', True)
return option(*(param_decls or ('--password',)), **attrs)(f)
return decorator
def version_option(version=None, *param_decls, **attrs):
"""Adds a ``--version`` option which immediately ends the program
printing out the version number. This is implemented as an eager
option that prints the version and exits the program in the callback.
:param version: the version number to show. If not provided Click
attempts an auto discovery via setuptools.
:param prog_name: the name of the program (defaults to autodetection)
:param message: custom message to show instead of the default
(``'%(prog)s, version %(version)s'``)
:param others: everything else is forwarded to :func:`option`.
"""
if version is None:
if hasattr(sys, '_getframe'):
module = sys._getframe(1).f_globals.get('__name__')
else:
module = ''
def decorator(f):
prog_name = attrs.pop('prog_name', None)
message = attrs.pop('message', '%(prog)s, version %(version)s')
def callback(ctx, param, value):
if not value or ctx.resilient_parsing:
return
prog = prog_name
if prog is None:
prog = ctx.find_root().info_name
ver = version
if ver is None:
try:
import pkg_resources
except ImportError:
pass
else:
for dist in pkg_resources.working_set:
scripts = dist.get_entry_map().get('console_scripts') or {}
for script_name, entry_point in iteritems(scripts):
if entry_point.module_name == module:
ver = dist.version
break
if ver is None:
raise RuntimeError('Could not determine version')
echo(message % {
'prog': prog,
'version': ver,
}, color=ctx.color)
ctx.exit()
attrs.setdefault('is_flag', True)
attrs.setdefault('expose_value', False)
attrs.setdefault('is_eager', True)
attrs.setdefault('help', 'Show the version and exit.')
attrs['callback'] = callback
return option(*(param_decls or ('--version',)), **attrs)(f)
return decorator
def help_option(*param_decls, **attrs):
"""Adds a ``--help`` option which immediately ends the program
printing out the help page. This is usually unnecessary to add as
this is added by default to all commands unless suppressed.
Like :func:`version_option`, this is implemented as eager option that
prints in the callback and exits.
All arguments are forwarded to :func:`option`.
"""
def decorator(f):
def callback(ctx, param, value):
if value and not ctx.resilient_parsing:
echo(ctx.get_help(), color=ctx.color)
ctx.exit()
attrs.setdefault('is_flag', True)
attrs.setdefault('expose_value', False)
attrs.setdefault('help', 'Show this message and exit.')
attrs.setdefault('is_eager', True)
attrs['callback'] = callback
return option(*(param_decls or ('--help',)), **attrs)(f)
return decorator
# Circular dependencies between core and decorators
from .core import Command, Group, Argument, Option

235
libs/click/exceptions.py Normal file
View file

@ -0,0 +1,235 @@
from ._compat import PY2, filename_to_ui, get_text_stderr
from .utils import echo
def _join_param_hints(param_hint):
if isinstance(param_hint, (tuple, list)):
return ' / '.join('"%s"' % x for x in param_hint)
return param_hint
class ClickException(Exception):
"""An exception that Click can handle and show to the user."""
#: The exit code for this exception
exit_code = 1
def __init__(self, message):
ctor_msg = message
if PY2:
if ctor_msg is not None:
ctor_msg = ctor_msg.encode('utf-8')
Exception.__init__(self, ctor_msg)
self.message = message
def format_message(self):
return self.message
def __str__(self):
return self.message
if PY2:
__unicode__ = __str__
def __str__(self):
return self.message.encode('utf-8')
def show(self, file=None):
if file is None:
file = get_text_stderr()
echo('Error: %s' % self.format_message(), file=file)
class UsageError(ClickException):
"""An internal exception that signals a usage error. This typically
aborts any further handling.
:param message: the error message to display.
:param ctx: optionally the context that caused this error. Click will
fill in the context automatically in some situations.
"""
exit_code = 2
def __init__(self, message, ctx=None):
ClickException.__init__(self, message)
self.ctx = ctx
self.cmd = self.ctx and self.ctx.command or None
def show(self, file=None):
if file is None:
file = get_text_stderr()
color = None
hint = ''
if (self.cmd is not None and
self.cmd.get_help_option(self.ctx) is not None):
hint = ('Try "%s %s" for help.\n'
% (self.ctx.command_path, self.ctx.help_option_names[0]))
if self.ctx is not None:
color = self.ctx.color
echo(self.ctx.get_usage() + '\n%s' % hint, file=file, color=color)
echo('Error: %s' % self.format_message(), file=file, color=color)
class BadParameter(UsageError):
"""An exception that formats out a standardized error message for a
bad parameter. This is useful when thrown from a callback or type as
Click will attach contextual information to it (for instance, which
parameter it is).
.. versionadded:: 2.0
:param param: the parameter object that caused this error. This can
be left out, and Click will attach this info itself
if possible.
:param param_hint: a string that shows up as parameter name. This
can be used as alternative to `param` in cases
where custom validation should happen. If it is
a string it's used as such, if it's a list then
each item is quoted and separated.
"""
def __init__(self, message, ctx=None, param=None,
param_hint=None):
UsageError.__init__(self, message, ctx)
self.param = param
self.param_hint = param_hint
def format_message(self):
if self.param_hint is not None:
param_hint = self.param_hint
elif self.param is not None:
param_hint = self.param.get_error_hint(self.ctx)
else:
return 'Invalid value: %s' % self.message
param_hint = _join_param_hints(param_hint)
return 'Invalid value for %s: %s' % (param_hint, self.message)
class MissingParameter(BadParameter):
"""Raised if click required an option or argument but it was not
provided when invoking the script.
.. versionadded:: 4.0
:param param_type: a string that indicates the type of the parameter.
The default is to inherit the parameter type from
the given `param`. Valid values are ``'parameter'``,
``'option'`` or ``'argument'``.
"""
def __init__(self, message=None, ctx=None, param=None,
param_hint=None, param_type=None):
BadParameter.__init__(self, message, ctx, param, param_hint)
self.param_type = param_type
def format_message(self):
if self.param_hint is not None:
param_hint = self.param_hint
elif self.param is not None:
param_hint = self.param.get_error_hint(self.ctx)
else:
param_hint = None
param_hint = _join_param_hints(param_hint)
param_type = self.param_type
if param_type is None and self.param is not None:
param_type = self.param.param_type_name
msg = self.message
if self.param is not None:
msg_extra = self.param.type.get_missing_message(self.param)
if msg_extra:
if msg:
msg += '. ' + msg_extra
else:
msg = msg_extra
return 'Missing %s%s%s%s' % (
param_type,
param_hint and ' %s' % param_hint or '',
msg and '. ' or '.',
msg or '',
)
class NoSuchOption(UsageError):
"""Raised if click attempted to handle an option that does not
exist.
.. versionadded:: 4.0
"""
def __init__(self, option_name, message=None, possibilities=None,
ctx=None):
if message is None:
message = 'no such option: %s' % option_name
UsageError.__init__(self, message, ctx)
self.option_name = option_name
self.possibilities = possibilities
def format_message(self):
bits = [self.message]
if self.possibilities:
if len(self.possibilities) == 1:
bits.append('Did you mean %s?' % self.possibilities[0])
else:
possibilities = sorted(self.possibilities)
bits.append('(Possible options: %s)' % ', '.join(possibilities))
return ' '.join(bits)
class BadOptionUsage(UsageError):
"""Raised if an option is generally supplied but the use of the option
was incorrect. This is for instance raised if the number of arguments
for an option is not correct.
.. versionadded:: 4.0
:param option_name: the name of the option being used incorrectly.
"""
def __init__(self, option_name, message, ctx=None):
UsageError.__init__(self, message, ctx)
self.option_name = option_name
class BadArgumentUsage(UsageError):
"""Raised if an argument is generally supplied but the use of the argument
was incorrect. This is for instance raised if the number of values
for an argument is not correct.
.. versionadded:: 6.0
"""
def __init__(self, message, ctx=None):
UsageError.__init__(self, message, ctx)
class FileError(ClickException):
"""Raised if a file cannot be opened."""
def __init__(self, filename, hint=None):
ui_filename = filename_to_ui(filename)
if hint is None:
hint = 'unknown error'
ClickException.__init__(self, hint)
self.ui_filename = ui_filename
self.filename = filename
def format_message(self):
return 'Could not open file %s: %s' % (self.ui_filename, self.message)
class Abort(RuntimeError):
"""An internal signalling exception that signals Click to abort."""
class Exit(RuntimeError):
"""An exception that indicates that the application should exit with some
status code.
:param code: the status code to exit with.
"""
def __init__(self, code=0):
self.exit_code = code

256
libs/click/formatting.py Normal file
View file

@ -0,0 +1,256 @@
from contextlib import contextmanager
from .termui import get_terminal_size
from .parser import split_opt
from ._compat import term_len
# Can force a width. This is used by the test system
FORCED_WIDTH = None
def measure_table(rows):
widths = {}
for row in rows:
for idx, col in enumerate(row):
widths[idx] = max(widths.get(idx, 0), term_len(col))
return tuple(y for x, y in sorted(widths.items()))
def iter_rows(rows, col_count):
for row in rows:
row = tuple(row)
yield row + ('',) * (col_count - len(row))
def wrap_text(text, width=78, initial_indent='', subsequent_indent='',
preserve_paragraphs=False):
"""A helper function that intelligently wraps text. By default, it
assumes that it operates on a single paragraph of text but if the
`preserve_paragraphs` parameter is provided it will intelligently
handle paragraphs (defined by two empty lines).
If paragraphs are handled, a paragraph can be prefixed with an empty
line containing the ``\\b`` character (``\\x08``) to indicate that
no rewrapping should happen in that block.
:param text: the text that should be rewrapped.
:param width: the maximum width for the text.
:param initial_indent: the initial indent that should be placed on the
first line as a string.
:param subsequent_indent: the indent string that should be placed on
each consecutive line.
:param preserve_paragraphs: if this flag is set then the wrapping will
intelligently handle paragraphs.
"""
from ._textwrap import TextWrapper
text = text.expandtabs()
wrapper = TextWrapper(width, initial_indent=initial_indent,
subsequent_indent=subsequent_indent,
replace_whitespace=False)
if not preserve_paragraphs:
return wrapper.fill(text)
p = []
buf = []
indent = None
def _flush_par():
if not buf:
return
if buf[0].strip() == '\b':
p.append((indent or 0, True, '\n'.join(buf[1:])))
else:
p.append((indent or 0, False, ' '.join(buf)))
del buf[:]
for line in text.splitlines():
if not line:
_flush_par()
indent = None
else:
if indent is None:
orig_len = term_len(line)
line = line.lstrip()
indent = orig_len - term_len(line)
buf.append(line)
_flush_par()
rv = []
for indent, raw, text in p:
with wrapper.extra_indent(' ' * indent):
if raw:
rv.append(wrapper.indent_only(text))
else:
rv.append(wrapper.fill(text))
return '\n\n'.join(rv)
class HelpFormatter(object):
"""This class helps with formatting text-based help pages. It's
usually just needed for very special internal cases, but it's also
exposed so that developers can write their own fancy outputs.
At present, it always writes into memory.
:param indent_increment: the additional increment for each level.
:param width: the width for the text. This defaults to the terminal
width clamped to a maximum of 78.
"""
def __init__(self, indent_increment=2, width=None, max_width=None):
self.indent_increment = indent_increment
if max_width is None:
max_width = 80
if width is None:
width = FORCED_WIDTH
if width is None:
width = max(min(get_terminal_size()[0], max_width) - 2, 50)
self.width = width
self.current_indent = 0
self.buffer = []
def write(self, string):
"""Writes a unicode string into the internal buffer."""
self.buffer.append(string)
def indent(self):
"""Increases the indentation."""
self.current_indent += self.indent_increment
def dedent(self):
"""Decreases the indentation."""
self.current_indent -= self.indent_increment
def write_usage(self, prog, args='', prefix='Usage: '):
"""Writes a usage line into the buffer.
:param prog: the program name.
:param args: whitespace separated list of arguments.
:param prefix: the prefix for the first line.
"""
usage_prefix = '%*s%s ' % (self.current_indent, prefix, prog)
text_width = self.width - self.current_indent
if text_width >= (term_len(usage_prefix) + 20):
# The arguments will fit to the right of the prefix.
indent = ' ' * term_len(usage_prefix)
self.write(wrap_text(args, text_width,
initial_indent=usage_prefix,
subsequent_indent=indent))
else:
# The prefix is too long, put the arguments on the next line.
self.write(usage_prefix)
self.write('\n')
indent = ' ' * (max(self.current_indent, term_len(prefix)) + 4)
self.write(wrap_text(args, text_width,
initial_indent=indent,
subsequent_indent=indent))
self.write('\n')
def write_heading(self, heading):
"""Writes a heading into the buffer."""
self.write('%*s%s:\n' % (self.current_indent, '', heading))
def write_paragraph(self):
"""Writes a paragraph into the buffer."""
if self.buffer:
self.write('\n')
def write_text(self, text):
"""Writes re-indented text into the buffer. This rewraps and
preserves paragraphs.
"""
text_width = max(self.width - self.current_indent, 11)
indent = ' ' * self.current_indent
self.write(wrap_text(text, text_width,
initial_indent=indent,
subsequent_indent=indent,
preserve_paragraphs=True))
self.write('\n')
def write_dl(self, rows, col_max=30, col_spacing=2):
"""Writes a definition list into the buffer. This is how options
and commands are usually formatted.
:param rows: a list of two item tuples for the terms and values.
:param col_max: the maximum width of the first column.
:param col_spacing: the number of spaces between the first and
second column.
"""
rows = list(rows)
widths = measure_table(rows)
if len(widths) != 2:
raise TypeError('Expected two columns for definition list')
first_col = min(widths[0], col_max) + col_spacing
for first, second in iter_rows(rows, len(widths)):
self.write('%*s%s' % (self.current_indent, '', first))
if not second:
self.write('\n')
continue
if term_len(first) <= first_col - col_spacing:
self.write(' ' * (first_col - term_len(first)))
else:
self.write('\n')
self.write(' ' * (first_col + self.current_indent))
text_width = max(self.width - first_col - 2, 10)
lines = iter(wrap_text(second, text_width).splitlines())
if lines:
self.write(next(lines) + '\n')
for line in lines:
self.write('%*s%s\n' % (
first_col + self.current_indent, '', line))
else:
self.write('\n')
@contextmanager
def section(self, name):
"""Helpful context manager that writes a paragraph, a heading,
and the indents.
:param name: the section name that is written as heading.
"""
self.write_paragraph()
self.write_heading(name)
self.indent()
try:
yield
finally:
self.dedent()
@contextmanager
def indentation(self):
"""A context manager that increases the indentation."""
self.indent()
try:
yield
finally:
self.dedent()
def getvalue(self):
"""Returns the buffer contents."""
return ''.join(self.buffer)
def join_options(options):
"""Given a list of option strings this joins them in the most appropriate
way and returns them in the form ``(formatted_string,
any_prefix_is_slash)`` where the second item in the tuple is a flag that
indicates if any of the option prefixes was a slash.
"""
rv = []
any_prefix_is_slash = False
for opt in options:
prefix = split_opt(opt)[0]
if prefix == '/':
any_prefix_is_slash = True
rv.append((len(prefix), opt))
rv.sort(key=lambda x: x[0])
rv = ', '.join(x[1] for x in rv)
return rv, any_prefix_is_slash

48
libs/click/globals.py Normal file
View file

@ -0,0 +1,48 @@
from threading import local
_local = local()
def get_current_context(silent=False):
"""Returns the current click context. This can be used as a way to
access the current context object from anywhere. This is a more implicit
alternative to the :func:`pass_context` decorator. This function is
primarily useful for helpers such as :func:`echo` which might be
interested in changing its behavior based on the current context.
To push the current context, :meth:`Context.scope` can be used.
.. versionadded:: 5.0
:param silent: is set to `True` the return value is `None` if no context
is available. The default behavior is to raise a
:exc:`RuntimeError`.
"""
try:
return getattr(_local, 'stack')[-1]
except (AttributeError, IndexError):
if not silent:
raise RuntimeError('There is no active click context.')
def push_context(ctx):
"""Pushes a new context to the current stack."""
_local.__dict__.setdefault('stack', []).append(ctx)
def pop_context():
"""Removes the top level from the stack."""
_local.stack.pop()
def resolve_color_default(color=None):
""""Internal helper to get the default value of the color flag. If a
value is passed it's returned unchanged, otherwise it's looked up from
the current context.
"""
if color is not None:
return color
ctx = get_current_context(silent=True)
if ctx is not None:
return ctx.color

427
libs/click/parser.py Normal file
View file

@ -0,0 +1,427 @@
# -*- coding: utf-8 -*-
"""
click.parser
~~~~~~~~~~~~
This module started out as largely a copy paste from the stdlib's
optparse module with the features removed that we do not need from
optparse because we implement them in Click on a higher level (for
instance type handling, help formatting and a lot more).
The plan is to remove more and more from here over time.
The reason this is a different module and not optparse from the stdlib
is that there are differences in 2.x and 3.x about the error messages
generated and optparse in the stdlib uses gettext for no good reason
and might cause us issues.
"""
import re
from collections import deque
from .exceptions import UsageError, NoSuchOption, BadOptionUsage, \
BadArgumentUsage
def _unpack_args(args, nargs_spec):
"""Given an iterable of arguments and an iterable of nargs specifications,
it returns a tuple with all the unpacked arguments at the first index
and all remaining arguments as the second.
The nargs specification is the number of arguments that should be consumed
or `-1` to indicate that this position should eat up all the remainders.
Missing items are filled with `None`.
"""
args = deque(args)
nargs_spec = deque(nargs_spec)
rv = []
spos = None
def _fetch(c):
try:
if spos is None:
return c.popleft()
else:
return c.pop()
except IndexError:
return None
while nargs_spec:
nargs = _fetch(nargs_spec)
if nargs == 1:
rv.append(_fetch(args))
elif nargs > 1:
x = [_fetch(args) for _ in range(nargs)]
# If we're reversed, we're pulling in the arguments in reverse,
# so we need to turn them around.
if spos is not None:
x.reverse()
rv.append(tuple(x))
elif nargs < 0:
if spos is not None:
raise TypeError('Cannot have two nargs < 0')
spos = len(rv)
rv.append(None)
# spos is the position of the wildcard (star). If it's not `None`,
# we fill it with the remainder.
if spos is not None:
rv[spos] = tuple(args)
args = []
rv[spos + 1:] = reversed(rv[spos + 1:])
return tuple(rv), list(args)
def _error_opt_args(nargs, opt):
if nargs == 1:
raise BadOptionUsage(opt, '%s option requires an argument' % opt)
raise BadOptionUsage(opt, '%s option requires %d arguments' % (opt, nargs))
def split_opt(opt):
first = opt[:1]
if first.isalnum():
return '', opt
if opt[1:2] == first:
return opt[:2], opt[2:]
return first, opt[1:]
def normalize_opt(opt, ctx):
if ctx is None or ctx.token_normalize_func is None:
return opt
prefix, opt = split_opt(opt)
return prefix + ctx.token_normalize_func(opt)
def split_arg_string(string):
"""Given an argument string this attempts to split it into small parts."""
rv = []
for match in re.finditer(r"('([^'\\]*(?:\\.[^'\\]*)*)'"
r'|"([^"\\]*(?:\\.[^"\\]*)*)"'
r'|\S+)\s*', string, re.S):
arg = match.group().strip()
if arg[:1] == arg[-1:] and arg[:1] in '"\'':
arg = arg[1:-1].encode('ascii', 'backslashreplace') \
.decode('unicode-escape')
try:
arg = type(string)(arg)
except UnicodeError:
pass
rv.append(arg)
return rv
class Option(object):
def __init__(self, opts, dest, action=None, nargs=1, const=None, obj=None):
self._short_opts = []
self._long_opts = []
self.prefixes = set()
for opt in opts:
prefix, value = split_opt(opt)
if not prefix:
raise ValueError('Invalid start character for option (%s)'
% opt)
self.prefixes.add(prefix[0])
if len(prefix) == 1 and len(value) == 1:
self._short_opts.append(opt)
else:
self._long_opts.append(opt)
self.prefixes.add(prefix)
if action is None:
action = 'store'
self.dest = dest
self.action = action
self.nargs = nargs
self.const = const
self.obj = obj
@property
def takes_value(self):
return self.action in ('store', 'append')
def process(self, value, state):
if self.action == 'store':
state.opts[self.dest] = value
elif self.action == 'store_const':
state.opts[self.dest] = self.const
elif self.action == 'append':
state.opts.setdefault(self.dest, []).append(value)
elif self.action == 'append_const':
state.opts.setdefault(self.dest, []).append(self.const)
elif self.action == 'count':
state.opts[self.dest] = state.opts.get(self.dest, 0) + 1
else:
raise ValueError('unknown action %r' % self.action)
state.order.append(self.obj)
class Argument(object):
def __init__(self, dest, nargs=1, obj=None):
self.dest = dest
self.nargs = nargs
self.obj = obj
def process(self, value, state):
if self.nargs > 1:
holes = sum(1 for x in value if x is None)
if holes == len(value):
value = None
elif holes != 0:
raise BadArgumentUsage('argument %s takes %d values'
% (self.dest, self.nargs))
state.opts[self.dest] = value
state.order.append(self.obj)
class ParsingState(object):
def __init__(self, rargs):
self.opts = {}
self.largs = []
self.rargs = rargs
self.order = []
class OptionParser(object):
"""The option parser is an internal class that is ultimately used to
parse options and arguments. It's modelled after optparse and brings
a similar but vastly simplified API. It should generally not be used
directly as the high level Click classes wrap it for you.
It's not nearly as extensible as optparse or argparse as it does not
implement features that are implemented on a higher level (such as
types or defaults).
:param ctx: optionally the :class:`~click.Context` where this parser
should go with.
"""
def __init__(self, ctx=None):
#: The :class:`~click.Context` for this parser. This might be
#: `None` for some advanced use cases.
self.ctx = ctx
#: This controls how the parser deals with interspersed arguments.
#: If this is set to `False`, the parser will stop on the first
#: non-option. Click uses this to implement nested subcommands
#: safely.
self.allow_interspersed_args = True
#: This tells the parser how to deal with unknown options. By
#: default it will error out (which is sensible), but there is a
#: second mode where it will ignore it and continue processing
#: after shifting all the unknown options into the resulting args.
self.ignore_unknown_options = False
if ctx is not None:
self.allow_interspersed_args = ctx.allow_interspersed_args
self.ignore_unknown_options = ctx.ignore_unknown_options
self._short_opt = {}
self._long_opt = {}
self._opt_prefixes = set(['-', '--'])
self._args = []
def add_option(self, opts, dest, action=None, nargs=1, const=None,
obj=None):
"""Adds a new option named `dest` to the parser. The destination
is not inferred (unlike with optparse) and needs to be explicitly
provided. Action can be any of ``store``, ``store_const``,
``append``, ``appnd_const`` or ``count``.
The `obj` can be used to identify the option in the order list
that is returned from the parser.
"""
if obj is None:
obj = dest
opts = [normalize_opt(opt, self.ctx) for opt in opts]
option = Option(opts, dest, action=action, nargs=nargs,
const=const, obj=obj)
self._opt_prefixes.update(option.prefixes)
for opt in option._short_opts:
self._short_opt[opt] = option
for opt in option._long_opts:
self._long_opt[opt] = option
def add_argument(self, dest, nargs=1, obj=None):
"""Adds a positional argument named `dest` to the parser.
The `obj` can be used to identify the option in the order list
that is returned from the parser.
"""
if obj is None:
obj = dest
self._args.append(Argument(dest=dest, nargs=nargs, obj=obj))
def parse_args(self, args):
"""Parses positional arguments and returns ``(values, args, order)``
for the parsed options and arguments as well as the leftover
arguments if there are any. The order is a list of objects as they
appear on the command line. If arguments appear multiple times they
will be memorized multiple times as well.
"""
state = ParsingState(args)
try:
self._process_args_for_options(state)
self._process_args_for_args(state)
except UsageError:
if self.ctx is None or not self.ctx.resilient_parsing:
raise
return state.opts, state.largs, state.order
def _process_args_for_args(self, state):
pargs, args = _unpack_args(state.largs + state.rargs,
[x.nargs for x in self._args])
for idx, arg in enumerate(self._args):
arg.process(pargs[idx], state)
state.largs = args
state.rargs = []
def _process_args_for_options(self, state):
while state.rargs:
arg = state.rargs.pop(0)
arglen = len(arg)
# Double dashes always handled explicitly regardless of what
# prefixes are valid.
if arg == '--':
return
elif arg[:1] in self._opt_prefixes and arglen > 1:
self._process_opts(arg, state)
elif self.allow_interspersed_args:
state.largs.append(arg)
else:
state.rargs.insert(0, arg)
return
# Say this is the original argument list:
# [arg0, arg1, ..., arg(i-1), arg(i), arg(i+1), ..., arg(N-1)]
# ^
# (we are about to process arg(i)).
#
# Then rargs is [arg(i), ..., arg(N-1)] and largs is a *subset* of
# [arg0, ..., arg(i-1)] (any options and their arguments will have
# been removed from largs).
#
# The while loop will usually consume 1 or more arguments per pass.
# If it consumes 1 (eg. arg is an option that takes no arguments),
# then after _process_arg() is done the situation is:
#
# largs = subset of [arg0, ..., arg(i)]
# rargs = [arg(i+1), ..., arg(N-1)]
#
# If allow_interspersed_args is false, largs will always be
# *empty* -- still a subset of [arg0, ..., arg(i-1)], but
# not a very interesting subset!
def _match_long_opt(self, opt, explicit_value, state):
if opt not in self._long_opt:
possibilities = [word for word in self._long_opt
if word.startswith(opt)]
raise NoSuchOption(opt, possibilities=possibilities, ctx=self.ctx)
option = self._long_opt[opt]
if option.takes_value:
# At this point it's safe to modify rargs by injecting the
# explicit value, because no exception is raised in this
# branch. This means that the inserted value will be fully
# consumed.
if explicit_value is not None:
state.rargs.insert(0, explicit_value)
nargs = option.nargs
if len(state.rargs) < nargs:
_error_opt_args(nargs, opt)
elif nargs == 1:
value = state.rargs.pop(0)
else:
value = tuple(state.rargs[:nargs])
del state.rargs[:nargs]
elif explicit_value is not None:
raise BadOptionUsage(opt, '%s option does not take a value' % opt)
else:
value = None
option.process(value, state)
def _match_short_opt(self, arg, state):
stop = False
i = 1
prefix = arg[0]
unknown_options = []
for ch in arg[1:]:
opt = normalize_opt(prefix + ch, self.ctx)
option = self._short_opt.get(opt)
i += 1
if not option:
if self.ignore_unknown_options:
unknown_options.append(ch)
continue
raise NoSuchOption(opt, ctx=self.ctx)
if option.takes_value:
# Any characters left in arg? Pretend they're the
# next arg, and stop consuming characters of arg.
if i < len(arg):
state.rargs.insert(0, arg[i:])
stop = True
nargs = option.nargs
if len(state.rargs) < nargs:
_error_opt_args(nargs, opt)
elif nargs == 1:
value = state.rargs.pop(0)
else:
value = tuple(state.rargs[:nargs])
del state.rargs[:nargs]
else:
value = None
option.process(value, state)
if stop:
break
# If we got any unknown options we re-combinate the string of the
# remaining options and re-attach the prefix, then report that
# to the state as new larg. This way there is basic combinatorics
# that can be achieved while still ignoring unknown arguments.
if self.ignore_unknown_options and unknown_options:
state.largs.append(prefix + ''.join(unknown_options))
def _process_opts(self, arg, state):
explicit_value = None
# Long option handling happens in two parts. The first part is
# supporting explicitly attached values. In any case, we will try
# to long match the option first.
if '=' in arg:
long_opt, explicit_value = arg.split('=', 1)
else:
long_opt = arg
norm_long_opt = normalize_opt(long_opt, self.ctx)
# At this point we will match the (assumed) long option through
# the long option matching code. Note that this allows options
# like "-foo" to be matched as long options.
try:
self._match_long_opt(norm_long_opt, explicit_value, state)
except NoSuchOption:
# At this point the long option matching failed, and we need
# to try with short options. However there is a special rule
# which says, that if we have a two character options prefix
# (applies to "--foo" for instance), we do not dispatch to the
# short option code and will instead raise the no option
# error.
if arg[:2] not in self._opt_prefixes:
return self._match_short_opt(arg, state)
if not self.ignore_unknown_options:
raise
state.largs.append(arg)

606
libs/click/termui.py Normal file
View file

@ -0,0 +1,606 @@
import os
import sys
import struct
import inspect
import itertools
from ._compat import raw_input, text_type, string_types, \
isatty, strip_ansi, get_winterm_size, DEFAULT_COLUMNS, WIN
from .utils import echo
from .exceptions import Abort, UsageError
from .types import convert_type, Choice, Path
from .globals import resolve_color_default
# The prompt functions to use. The doc tools currently override these
# functions to customize how they work.
visible_prompt_func = raw_input
_ansi_colors = {
'black': 30,
'red': 31,
'green': 32,
'yellow': 33,
'blue': 34,
'magenta': 35,
'cyan': 36,
'white': 37,
'reset': 39,
'bright_black': 90,
'bright_red': 91,
'bright_green': 92,
'bright_yellow': 93,
'bright_blue': 94,
'bright_magenta': 95,
'bright_cyan': 96,
'bright_white': 97,
}
_ansi_reset_all = '\033[0m'
def hidden_prompt_func(prompt):
import getpass
return getpass.getpass(prompt)
def _build_prompt(text, suffix, show_default=False, default=None, show_choices=True, type=None):
prompt = text
if type is not None and show_choices and isinstance(type, Choice):
prompt += ' (' + ", ".join(map(str, type.choices)) + ')'
if default is not None and show_default:
prompt = '%s [%s]' % (prompt, default)
return prompt + suffix
def prompt(text, default=None, hide_input=False, confirmation_prompt=False,
type=None, value_proc=None, prompt_suffix=': ', show_default=True,
err=False, show_choices=True):
"""Prompts a user for input. This is a convenience function that can
be used to prompt a user for input later.
If the user aborts the input by sending a interrupt signal, this
function will catch it and raise a :exc:`Abort` exception.
.. versionadded:: 7.0
Added the show_choices parameter.
.. versionadded:: 6.0
Added unicode support for cmd.exe on Windows.
.. versionadded:: 4.0
Added the `err` parameter.
:param text: the text to show for the prompt.
:param default: the default value to use if no input happens. If this
is not given it will prompt until it's aborted.
:param hide_input: if this is set to true then the input value will
be hidden.
:param confirmation_prompt: asks for confirmation for the value.
:param type: the type to use to check the value against.
:param value_proc: if this parameter is provided it's a function that
is invoked instead of the type conversion to
convert a value.
:param prompt_suffix: a suffix that should be added to the prompt.
:param show_default: shows or hides the default value in the prompt.
:param err: if set to true the file defaults to ``stderr`` instead of
``stdout``, the same as with echo.
:param show_choices: Show or hide choices if the passed type is a Choice.
For example if type is a Choice of either day or week,
show_choices is true and text is "Group by" then the
prompt will be "Group by (day, week): ".
"""
result = None
def prompt_func(text):
f = hide_input and hidden_prompt_func or visible_prompt_func
try:
# Write the prompt separately so that we get nice
# coloring through colorama on Windows
echo(text, nl=False, err=err)
return f('')
except (KeyboardInterrupt, EOFError):
# getpass doesn't print a newline if the user aborts input with ^C.
# Allegedly this behavior is inherited from getpass(3).
# A doc bug has been filed at https://bugs.python.org/issue24711
if hide_input:
echo(None, err=err)
raise Abort()
if value_proc is None:
value_proc = convert_type(type, default)
prompt = _build_prompt(text, prompt_suffix, show_default, default, show_choices, type)
while 1:
while 1:
value = prompt_func(prompt)
if value:
break
elif default is not None:
if isinstance(value_proc, Path):
# validate Path default value(exists, dir_okay etc.)
value = default
break
return default
try:
result = value_proc(value)
except UsageError as e:
echo('Error: %s' % e.message, err=err)
continue
if not confirmation_prompt:
return result
while 1:
value2 = prompt_func('Repeat for confirmation: ')
if value2:
break
if value == value2:
return result
echo('Error: the two entered values do not match', err=err)
def confirm(text, default=False, abort=False, prompt_suffix=': ',
show_default=True, err=False):
"""Prompts for confirmation (yes/no question).
If the user aborts the input by sending a interrupt signal this
function will catch it and raise a :exc:`Abort` exception.
.. versionadded:: 4.0
Added the `err` parameter.
:param text: the question to ask.
:param default: the default for the prompt.
:param abort: if this is set to `True` a negative answer aborts the
exception by raising :exc:`Abort`.
:param prompt_suffix: a suffix that should be added to the prompt.
:param show_default: shows or hides the default value in the prompt.
:param err: if set to true the file defaults to ``stderr`` instead of
``stdout``, the same as with echo.
"""
prompt = _build_prompt(text, prompt_suffix, show_default,
default and 'Y/n' or 'y/N')
while 1:
try:
# Write the prompt separately so that we get nice
# coloring through colorama on Windows
echo(prompt, nl=False, err=err)
value = visible_prompt_func('').lower().strip()
except (KeyboardInterrupt, EOFError):
raise Abort()
if value in ('y', 'yes'):
rv = True
elif value in ('n', 'no'):
rv = False
elif value == '':
rv = default
else:
echo('Error: invalid input', err=err)
continue
break
if abort and not rv:
raise Abort()
return rv
def get_terminal_size():
"""Returns the current size of the terminal as tuple in the form
``(width, height)`` in columns and rows.
"""
# If shutil has get_terminal_size() (Python 3.3 and later) use that
if sys.version_info >= (3, 3):
import shutil
shutil_get_terminal_size = getattr(shutil, 'get_terminal_size', None)
if shutil_get_terminal_size:
sz = shutil_get_terminal_size()
return sz.columns, sz.lines
# We provide a sensible default for get_winterm_size() when being invoked
# inside a subprocess. Without this, it would not provide a useful input.
if get_winterm_size is not None:
size = get_winterm_size()
if size == (0, 0):
return (79, 24)
else:
return size
def ioctl_gwinsz(fd):
try:
import fcntl
import termios
cr = struct.unpack(
'hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234'))
except Exception:
return
return cr
cr = ioctl_gwinsz(0) or ioctl_gwinsz(1) or ioctl_gwinsz(2)
if not cr:
try:
fd = os.open(os.ctermid(), os.O_RDONLY)
try:
cr = ioctl_gwinsz(fd)
finally:
os.close(fd)
except Exception:
pass
if not cr or not cr[0] or not cr[1]:
cr = (os.environ.get('LINES', 25),
os.environ.get('COLUMNS', DEFAULT_COLUMNS))
return int(cr[1]), int(cr[0])
def echo_via_pager(text_or_generator, color=None):
"""This function takes a text and shows it via an environment specific
pager on stdout.
.. versionchanged:: 3.0
Added the `color` flag.
:param text_or_generator: the text to page, or alternatively, a
generator emitting the text to page.
:param color: controls if the pager supports ANSI colors or not. The
default is autodetection.
"""
color = resolve_color_default(color)
if inspect.isgeneratorfunction(text_or_generator):
i = text_or_generator()
elif isinstance(text_or_generator, string_types):
i = [text_or_generator]
else:
i = iter(text_or_generator)
# convert every element of i to a text type if necessary
text_generator = (el if isinstance(el, string_types) else text_type(el)
for el in i)
from ._termui_impl import pager
return pager(itertools.chain(text_generator, "\n"), color)
def progressbar(iterable=None, length=None, label=None, show_eta=True,
show_percent=None, show_pos=False,
item_show_func=None, fill_char='#', empty_char='-',
bar_template='%(label)s [%(bar)s] %(info)s',
info_sep=' ', width=36, file=None, color=None):
"""This function creates an iterable context manager that can be used
to iterate over something while showing a progress bar. It will
either iterate over the `iterable` or `length` items (that are counted
up). While iteration happens, this function will print a rendered
progress bar to the given `file` (defaults to stdout) and will attempt
to calculate remaining time and more. By default, this progress bar
will not be rendered if the file is not a terminal.
The context manager creates the progress bar. When the context
manager is entered the progress bar is already displayed. With every
iteration over the progress bar, the iterable passed to the bar is
advanced and the bar is updated. When the context manager exits,
a newline is printed and the progress bar is finalized on screen.
No printing must happen or the progress bar will be unintentionally
destroyed.
Example usage::
with progressbar(items) as bar:
for item in bar:
do_something_with(item)
Alternatively, if no iterable is specified, one can manually update the
progress bar through the `update()` method instead of directly
iterating over the progress bar. The update method accepts the number
of steps to increment the bar with::
with progressbar(length=chunks.total_bytes) as bar:
for chunk in chunks:
process_chunk(chunk)
bar.update(chunks.bytes)
.. versionadded:: 2.0
.. versionadded:: 4.0
Added the `color` parameter. Added a `update` method to the
progressbar object.
:param iterable: an iterable to iterate over. If not provided the length
is required.
:param length: the number of items to iterate over. By default the
progressbar will attempt to ask the iterator about its
length, which might or might not work. If an iterable is
also provided this parameter can be used to override the
length. If an iterable is not provided the progress bar
will iterate over a range of that length.
:param label: the label to show next to the progress bar.
:param show_eta: enables or disables the estimated time display. This is
automatically disabled if the length cannot be
determined.
:param show_percent: enables or disables the percentage display. The
default is `True` if the iterable has a length or
`False` if not.
:param show_pos: enables or disables the absolute position display. The
default is `False`.
:param item_show_func: a function called with the current item which
can return a string to show the current item
next to the progress bar. Note that the current
item can be `None`!
:param fill_char: the character to use to show the filled part of the
progress bar.
:param empty_char: the character to use to show the non-filled part of
the progress bar.
:param bar_template: the format string to use as template for the bar.
The parameters in it are ``label`` for the label,
``bar`` for the progress bar and ``info`` for the
info section.
:param info_sep: the separator between multiple info items (eta etc.)
:param width: the width of the progress bar in characters, 0 means full
terminal width
:param file: the file to write to. If this is not a terminal then
only the label is printed.
:param color: controls if the terminal supports ANSI colors or not. The
default is autodetection. This is only needed if ANSI
codes are included anywhere in the progress bar output
which is not the case by default.
"""
from ._termui_impl import ProgressBar
color = resolve_color_default(color)
return ProgressBar(iterable=iterable, length=length, show_eta=show_eta,
show_percent=show_percent, show_pos=show_pos,
item_show_func=item_show_func, fill_char=fill_char,
empty_char=empty_char, bar_template=bar_template,
info_sep=info_sep, file=file, label=label,
width=width, color=color)
def clear():
"""Clears the terminal screen. This will have the effect of clearing
the whole visible space of the terminal and moving the cursor to the
top left. This does not do anything if not connected to a terminal.
.. versionadded:: 2.0
"""
if not isatty(sys.stdout):
return
# If we're on Windows and we don't have colorama available, then we
# clear the screen by shelling out. Otherwise we can use an escape
# sequence.
if WIN:
os.system('cls')
else:
sys.stdout.write('\033[2J\033[1;1H')
def style(text, fg=None, bg=None, bold=None, dim=None, underline=None,
blink=None, reverse=None, reset=True):
"""Styles a text with ANSI styles and returns the new string. By
default the styling is self contained which means that at the end
of the string a reset code is issued. This can be prevented by
passing ``reset=False``.
Examples::
click.echo(click.style('Hello World!', fg='green'))
click.echo(click.style('ATTENTION!', blink=True))
click.echo(click.style('Some things', reverse=True, fg='cyan'))
Supported color names:
* ``black`` (might be a gray)
* ``red``
* ``green``
* ``yellow`` (might be an orange)
* ``blue``
* ``magenta``
* ``cyan``
* ``white`` (might be light gray)
* ``bright_black``
* ``bright_red``
* ``bright_green``
* ``bright_yellow``
* ``bright_blue``
* ``bright_magenta``
* ``bright_cyan``
* ``bright_white``
* ``reset`` (reset the color code only)
.. versionadded:: 2.0
.. versionadded:: 7.0
Added support for bright colors.
:param text: the string to style with ansi codes.
:param fg: if provided this will become the foreground color.
:param bg: if provided this will become the background color.
:param bold: if provided this will enable or disable bold mode.
:param dim: if provided this will enable or disable dim mode. This is
badly supported.
:param underline: if provided this will enable or disable underline.
:param blink: if provided this will enable or disable blinking.
:param reverse: if provided this will enable or disable inverse
rendering (foreground becomes background and the
other way round).
:param reset: by default a reset-all code is added at the end of the
string which means that styles do not carry over. This
can be disabled to compose styles.
"""
bits = []
if fg:
try:
bits.append('\033[%dm' % (_ansi_colors[fg]))
except KeyError:
raise TypeError('Unknown color %r' % fg)
if bg:
try:
bits.append('\033[%dm' % (_ansi_colors[bg] + 10))
except KeyError:
raise TypeError('Unknown color %r' % bg)
if bold is not None:
bits.append('\033[%dm' % (1 if bold else 22))
if dim is not None:
bits.append('\033[%dm' % (2 if dim else 22))
if underline is not None:
bits.append('\033[%dm' % (4 if underline else 24))
if blink is not None:
bits.append('\033[%dm' % (5 if blink else 25))
if reverse is not None:
bits.append('\033[%dm' % (7 if reverse else 27))
bits.append(text)
if reset:
bits.append(_ansi_reset_all)
return ''.join(bits)
def unstyle(text):
"""Removes ANSI styling information from a string. Usually it's not
necessary to use this function as Click's echo function will
automatically remove styling if necessary.
.. versionadded:: 2.0
:param text: the text to remove style information from.
"""
return strip_ansi(text)
def secho(message=None, file=None, nl=True, err=False, color=None, **styles):
"""This function combines :func:`echo` and :func:`style` into one
call. As such the following two calls are the same::
click.secho('Hello World!', fg='green')
click.echo(click.style('Hello World!', fg='green'))
All keyword arguments are forwarded to the underlying functions
depending on which one they go with.
.. versionadded:: 2.0
"""
if message is not None:
message = style(message, **styles)
return echo(message, file=file, nl=nl, err=err, color=color)
def edit(text=None, editor=None, env=None, require_save=True,
extension='.txt', filename=None):
r"""Edits the given text in the defined editor. If an editor is given
(should be the full path to the executable but the regular operating
system search path is used for finding the executable) it overrides
the detected editor. Optionally, some environment variables can be
used. If the editor is closed without changes, `None` is returned. In
case a file is edited directly the return value is always `None` and
`require_save` and `extension` are ignored.
If the editor cannot be opened a :exc:`UsageError` is raised.
Note for Windows: to simplify cross-platform usage, the newlines are
automatically converted from POSIX to Windows and vice versa. As such,
the message here will have ``\n`` as newline markers.
:param text: the text to edit.
:param editor: optionally the editor to use. Defaults to automatic
detection.
:param env: environment variables to forward to the editor.
:param require_save: if this is true, then not saving in the editor
will make the return value become `None`.
:param extension: the extension to tell the editor about. This defaults
to `.txt` but changing this might change syntax
highlighting.
:param filename: if provided it will edit this file instead of the
provided text contents. It will not use a temporary
file as an indirection in that case.
"""
from ._termui_impl import Editor
editor = Editor(editor=editor, env=env, require_save=require_save,
extension=extension)
if filename is None:
return editor.edit(text)
editor.edit_file(filename)
def launch(url, wait=False, locate=False):
"""This function launches the given URL (or filename) in the default
viewer application for this file type. If this is an executable, it
might launch the executable in a new session. The return value is
the exit code of the launched application. Usually, ``0`` indicates
success.
Examples::
click.launch('https://click.palletsprojects.com/')
click.launch('/my/downloaded/file', locate=True)
.. versionadded:: 2.0
:param url: URL or filename of the thing to launch.
:param wait: waits for the program to stop.
:param locate: if this is set to `True` then instead of launching the
application associated with the URL it will attempt to
launch a file manager with the file located. This
might have weird effects if the URL does not point to
the filesystem.
"""
from ._termui_impl import open_url
return open_url(url, wait=wait, locate=locate)
# If this is provided, getchar() calls into this instead. This is used
# for unittesting purposes.
_getchar = None
def getchar(echo=False):
"""Fetches a single character from the terminal and returns it. This
will always return a unicode character and under certain rare
circumstances this might return more than one character. The
situations which more than one character is returned is when for
whatever reason multiple characters end up in the terminal buffer or
standard input was not actually a terminal.
Note that this will always read from the terminal, even if something
is piped into the standard input.
Note for Windows: in rare cases when typing non-ASCII characters, this
function might wait for a second character and then return both at once.
This is because certain Unicode characters look like special-key markers.
.. versionadded:: 2.0
:param echo: if set to `True`, the character read will also show up on
the terminal. The default is to not show it.
"""
f = _getchar
if f is None:
from ._termui_impl import getchar as f
return f(echo)
def raw_terminal():
from ._termui_impl import raw_terminal as f
return f()
def pause(info='Press any key to continue ...', err=False):
"""This command stops execution and waits for the user to press any
key to continue. This is similar to the Windows batch "pause"
command. If the program is not run through a terminal, this command
will instead do nothing.
.. versionadded:: 2.0
.. versionadded:: 4.0
Added the `err` parameter.
:param info: the info string to print before pausing.
:param err: if set to message goes to ``stderr`` instead of
``stdout``, the same as with echo.
"""
if not isatty(sys.stdin) or not isatty(sys.stdout):
return
try:
if info:
echo(info, nl=False, err=err)
try:
getchar()
except (KeyboardInterrupt, EOFError):
pass
finally:
if info:
echo(err=err)

374
libs/click/testing.py Normal file
View file

@ -0,0 +1,374 @@
import os
import sys
import shutil
import tempfile
import contextlib
import shlex
from ._compat import iteritems, PY2, string_types
# If someone wants to vendor click, we want to ensure the
# correct package is discovered. Ideally we could use a
# relative import here but unfortunately Python does not
# support that.
clickpkg = sys.modules[__name__.rsplit('.', 1)[0]]
if PY2:
from cStringIO import StringIO
else:
import io
from ._compat import _find_binary_reader
class EchoingStdin(object):
def __init__(self, input, output):
self._input = input
self._output = output
def __getattr__(self, x):
return getattr(self._input, x)
def _echo(self, rv):
self._output.write(rv)
return rv
def read(self, n=-1):
return self._echo(self._input.read(n))
def readline(self, n=-1):
return self._echo(self._input.readline(n))
def readlines(self):
return [self._echo(x) for x in self._input.readlines()]
def __iter__(self):
return iter(self._echo(x) for x in self._input)
def __repr__(self):
return repr(self._input)
def make_input_stream(input, charset):
# Is already an input stream.
if hasattr(input, 'read'):
if PY2:
return input
rv = _find_binary_reader(input)
if rv is not None:
return rv
raise TypeError('Could not find binary reader for input stream.')
if input is None:
input = b''
elif not isinstance(input, bytes):
input = input.encode(charset)
if PY2:
return StringIO(input)
return io.BytesIO(input)
class Result(object):
"""Holds the captured result of an invoked CLI script."""
def __init__(self, runner, stdout_bytes, stderr_bytes, exit_code,
exception, exc_info=None):
#: The runner that created the result
self.runner = runner
#: The standard output as bytes.
self.stdout_bytes = stdout_bytes
#: The standard error as bytes, or False(y) if not available
self.stderr_bytes = stderr_bytes
#: The exit code as integer.
self.exit_code = exit_code
#: The exception that happened if one did.
self.exception = exception
#: The traceback
self.exc_info = exc_info
@property
def output(self):
"""The (standard) output as unicode string."""
return self.stdout
@property
def stdout(self):
"""The standard output as unicode string."""
return self.stdout_bytes.decode(self.runner.charset, 'replace') \
.replace('\r\n', '\n')
@property
def stderr(self):
"""The standard error as unicode string."""
if not self.stderr_bytes:
raise ValueError("stderr not separately captured")
return self.stderr_bytes.decode(self.runner.charset, 'replace') \
.replace('\r\n', '\n')
def __repr__(self):
return '<%s %s>' % (
type(self).__name__,
self.exception and repr(self.exception) or 'okay',
)
class CliRunner(object):
"""The CLI runner provides functionality to invoke a Click command line
script for unittesting purposes in a isolated environment. This only
works in single-threaded systems without any concurrency as it changes the
global interpreter state.
:param charset: the character set for the input and output data. This is
UTF-8 by default and should not be changed currently as
the reporting to Click only works in Python 2 properly.
:param env: a dictionary with environment variables for overriding.
:param echo_stdin: if this is set to `True`, then reading from stdin writes
to stdout. This is useful for showing examples in
some circumstances. Note that regular prompts
will automatically echo the input.
:param mix_stderr: if this is set to `False`, then stdout and stderr are
preserved as independent streams. This is useful for
Unix-philosophy apps that have predictable stdout and
noisy stderr, such that each may be measured
independently
"""
def __init__(self, charset=None, env=None, echo_stdin=False,
mix_stderr=True):
if charset is None:
charset = 'utf-8'
self.charset = charset
self.env = env or {}
self.echo_stdin = echo_stdin
self.mix_stderr = mix_stderr
def get_default_prog_name(self, cli):
"""Given a command object it will return the default program name
for it. The default is the `name` attribute or ``"root"`` if not
set.
"""
return cli.name or 'root'
def make_env(self, overrides=None):
"""Returns the environment overrides for invoking a script."""
rv = dict(self.env)
if overrides:
rv.update(overrides)
return rv
@contextlib.contextmanager
def isolation(self, input=None, env=None, color=False):
"""A context manager that sets up the isolation for invoking of a
command line tool. This sets up stdin with the given input data
and `os.environ` with the overrides from the given dictionary.
This also rebinds some internals in Click to be mocked (like the
prompt functionality).
This is automatically done in the :meth:`invoke` method.
.. versionadded:: 4.0
The ``color`` parameter was added.
:param input: the input stream to put into sys.stdin.
:param env: the environment overrides as dictionary.
:param color: whether the output should contain color codes. The
application can still override this explicitly.
"""
input = make_input_stream(input, self.charset)
old_stdin = sys.stdin
old_stdout = sys.stdout
old_stderr = sys.stderr
old_forced_width = clickpkg.formatting.FORCED_WIDTH
clickpkg.formatting.FORCED_WIDTH = 80
env = self.make_env(env)
if PY2:
bytes_output = StringIO()
if self.echo_stdin:
input = EchoingStdin(input, bytes_output)
sys.stdout = bytes_output
if not self.mix_stderr:
bytes_error = StringIO()
sys.stderr = bytes_error
else:
bytes_output = io.BytesIO()
if self.echo_stdin:
input = EchoingStdin(input, bytes_output)
input = io.TextIOWrapper(input, encoding=self.charset)
sys.stdout = io.TextIOWrapper(
bytes_output, encoding=self.charset)
if not self.mix_stderr:
bytes_error = io.BytesIO()
sys.stderr = io.TextIOWrapper(
bytes_error, encoding=self.charset)
if self.mix_stderr:
sys.stderr = sys.stdout
sys.stdin = input
def visible_input(prompt=None):
sys.stdout.write(prompt or '')
val = input.readline().rstrip('\r\n')
sys.stdout.write(val + '\n')
sys.stdout.flush()
return val
def hidden_input(prompt=None):
sys.stdout.write((prompt or '') + '\n')
sys.stdout.flush()
return input.readline().rstrip('\r\n')
def _getchar(echo):
char = sys.stdin.read(1)
if echo:
sys.stdout.write(char)
sys.stdout.flush()
return char
default_color = color
def should_strip_ansi(stream=None, color=None):
if color is None:
return not default_color
return not color
old_visible_prompt_func = clickpkg.termui.visible_prompt_func
old_hidden_prompt_func = clickpkg.termui.hidden_prompt_func
old__getchar_func = clickpkg.termui._getchar
old_should_strip_ansi = clickpkg.utils.should_strip_ansi
clickpkg.termui.visible_prompt_func = visible_input
clickpkg.termui.hidden_prompt_func = hidden_input
clickpkg.termui._getchar = _getchar
clickpkg.utils.should_strip_ansi = should_strip_ansi
old_env = {}
try:
for key, value in iteritems(env):
old_env[key] = os.environ.get(key)
if value is None:
try:
del os.environ[key]
except Exception:
pass
else:
os.environ[key] = value
yield (bytes_output, not self.mix_stderr and bytes_error)
finally:
for key, value in iteritems(old_env):
if value is None:
try:
del os.environ[key]
except Exception:
pass
else:
os.environ[key] = value
sys.stdout = old_stdout
sys.stderr = old_stderr
sys.stdin = old_stdin
clickpkg.termui.visible_prompt_func = old_visible_prompt_func
clickpkg.termui.hidden_prompt_func = old_hidden_prompt_func
clickpkg.termui._getchar = old__getchar_func
clickpkg.utils.should_strip_ansi = old_should_strip_ansi
clickpkg.formatting.FORCED_WIDTH = old_forced_width
def invoke(self, cli, args=None, input=None, env=None,
catch_exceptions=True, color=False, mix_stderr=False, **extra):
"""Invokes a command in an isolated environment. The arguments are
forwarded directly to the command line script, the `extra` keyword
arguments are passed to the :meth:`~clickpkg.Command.main` function of
the command.
This returns a :class:`Result` object.
.. versionadded:: 3.0
The ``catch_exceptions`` parameter was added.
.. versionchanged:: 3.0
The result object now has an `exc_info` attribute with the
traceback if available.
.. versionadded:: 4.0
The ``color`` parameter was added.
:param cli: the command to invoke
:param args: the arguments to invoke. It may be given as an iterable
or a string. When given as string it will be interpreted
as a Unix shell command. More details at
:func:`shlex.split`.
:param input: the input data for `sys.stdin`.
:param env: the environment overrides.
:param catch_exceptions: Whether to catch any other exceptions than
``SystemExit``.
:param extra: the keyword arguments to pass to :meth:`main`.
:param color: whether the output should contain color codes. The
application can still override this explicitly.
"""
exc_info = None
with self.isolation(input=input, env=env, color=color) as outstreams:
exception = None
exit_code = 0
if isinstance(args, string_types):
args = shlex.split(args)
try:
prog_name = extra.pop("prog_name")
except KeyError:
prog_name = self.get_default_prog_name(cli)
try:
cli.main(args=args or (), prog_name=prog_name, **extra)
except SystemExit as e:
exc_info = sys.exc_info()
exit_code = e.code
if exit_code is None:
exit_code = 0
if exit_code != 0:
exception = e
if not isinstance(exit_code, int):
sys.stdout.write(str(exit_code))
sys.stdout.write('\n')
exit_code = 1
except Exception as e:
if not catch_exceptions:
raise
exception = e
exit_code = 1
exc_info = sys.exc_info()
finally:
sys.stdout.flush()
stdout = outstreams[0].getvalue()
stderr = outstreams[1] and outstreams[1].getvalue()
return Result(runner=self,
stdout_bytes=stdout,
stderr_bytes=stderr,
exit_code=exit_code,
exception=exception,
exc_info=exc_info)
@contextlib.contextmanager
def isolated_filesystem(self):
"""A context manager that creates a temporary folder and changes
the current working directory to it for isolated filesystem tests.
"""
cwd = os.getcwd()
t = tempfile.mkdtemp()
os.chdir(t)
try:
yield t
finally:
os.chdir(cwd)
try:
shutil.rmtree(t)
except (OSError, IOError):
pass

668
libs/click/types.py Normal file
View file

@ -0,0 +1,668 @@
import os
import stat
from datetime import datetime
from ._compat import open_stream, text_type, filename_to_ui, \
get_filesystem_encoding, get_streerror, _get_argv_encoding, PY2
from .exceptions import BadParameter
from .utils import safecall, LazyFile
class ParamType(object):
"""Helper for converting values through types. The following is
necessary for a valid type:
* it needs a name
* it needs to pass through None unchanged
* it needs to convert from a string
* it needs to convert its result type through unchanged
(eg: needs to be idempotent)
* it needs to be able to deal with param and context being `None`.
This can be the case when the object is used with prompt
inputs.
"""
is_composite = False
#: the descriptive name of this type
name = None
#: if a list of this type is expected and the value is pulled from a
#: string environment variable, this is what splits it up. `None`
#: means any whitespace. For all parameters the general rule is that
#: whitespace splits them up. The exception are paths and files which
#: are split by ``os.path.pathsep`` by default (":" on Unix and ";" on
#: Windows).
envvar_list_splitter = None
def __call__(self, value, param=None, ctx=None):
if value is not None:
return self.convert(value, param, ctx)
def get_metavar(self, param):
"""Returns the metavar default for this param if it provides one."""
def get_missing_message(self, param):
"""Optionally might return extra information about a missing
parameter.
.. versionadded:: 2.0
"""
def convert(self, value, param, ctx):
"""Converts the value. This is not invoked for values that are
`None` (the missing value).
"""
return value
def split_envvar_value(self, rv):
"""Given a value from an environment variable this splits it up
into small chunks depending on the defined envvar list splitter.
If the splitter is set to `None`, which means that whitespace splits,
then leading and trailing whitespace is ignored. Otherwise, leading
and trailing splitters usually lead to empty items being included.
"""
return (rv or '').split(self.envvar_list_splitter)
def fail(self, message, param=None, ctx=None):
"""Helper method to fail with an invalid value message."""
raise BadParameter(message, ctx=ctx, param=param)
class CompositeParamType(ParamType):
is_composite = True
@property
def arity(self):
raise NotImplementedError()
class FuncParamType(ParamType):
def __init__(self, func):
self.name = func.__name__
self.func = func
def convert(self, value, param, ctx):
try:
return self.func(value)
except ValueError:
try:
value = text_type(value)
except UnicodeError:
value = str(value).decode('utf-8', 'replace')
self.fail(value, param, ctx)
class UnprocessedParamType(ParamType):
name = 'text'
def convert(self, value, param, ctx):
return value
def __repr__(self):
return 'UNPROCESSED'
class StringParamType(ParamType):
name = 'text'
def convert(self, value, param, ctx):
if isinstance(value, bytes):
enc = _get_argv_encoding()
try:
value = value.decode(enc)
except UnicodeError:
fs_enc = get_filesystem_encoding()
if fs_enc != enc:
try:
value = value.decode(fs_enc)
except UnicodeError:
value = value.decode('utf-8', 'replace')
return value
return value
def __repr__(self):
return 'STRING'
class Choice(ParamType):
"""The choice type allows a value to be checked against a fixed set
of supported values. All of these values have to be strings.
You should only pass a list or tuple of choices. Other iterables
(like generators) may lead to surprising results.
See :ref:`choice-opts` for an example.
:param case_sensitive: Set to false to make choices case
insensitive. Defaults to true.
"""
name = 'choice'
def __init__(self, choices, case_sensitive=True):
self.choices = choices
self.case_sensitive = case_sensitive
def get_metavar(self, param):
return '[%s]' % '|'.join(self.choices)
def get_missing_message(self, param):
return 'Choose from:\n\t%s.' % ',\n\t'.join(self.choices)
def convert(self, value, param, ctx):
# Exact match
if value in self.choices:
return value
# Match through normalization and case sensitivity
# first do token_normalize_func, then lowercase
# preserve original `value` to produce an accurate message in
# `self.fail`
normed_value = value
normed_choices = self.choices
if ctx is not None and \
ctx.token_normalize_func is not None:
normed_value = ctx.token_normalize_func(value)
normed_choices = [ctx.token_normalize_func(choice) for choice in
self.choices]
if not self.case_sensitive:
normed_value = normed_value.lower()
normed_choices = [choice.lower() for choice in normed_choices]
if normed_value in normed_choices:
return normed_value
self.fail('invalid choice: %s. (choose from %s)' %
(value, ', '.join(self.choices)), param, ctx)
def __repr__(self):
return 'Choice(%r)' % list(self.choices)
class DateTime(ParamType):
"""The DateTime type converts date strings into `datetime` objects.
The format strings which are checked are configurable, but default to some
common (non-timezone aware) ISO 8601 formats.
When specifying *DateTime* formats, you should only pass a list or a tuple.
Other iterables, like generators, may lead to surprising results.
The format strings are processed using ``datetime.strptime``, and this
consequently defines the format strings which are allowed.
Parsing is tried using each format, in order, and the first format which
parses successfully is used.
:param formats: A list or tuple of date format strings, in the order in
which they should be tried. Defaults to
``'%Y-%m-%d'``, ``'%Y-%m-%dT%H:%M:%S'``,
``'%Y-%m-%d %H:%M:%S'``.
"""
name = 'datetime'
def __init__(self, formats=None):
self.formats = formats or [
'%Y-%m-%d',
'%Y-%m-%dT%H:%M:%S',
'%Y-%m-%d %H:%M:%S'
]
def get_metavar(self, param):
return '[{}]'.format('|'.join(self.formats))
def _try_to_convert_date(self, value, format):
try:
return datetime.strptime(value, format)
except ValueError:
return None
def convert(self, value, param, ctx):
# Exact match
for format in self.formats:
dtime = self._try_to_convert_date(value, format)
if dtime:
return dtime
self.fail(
'invalid datetime format: {}. (choose from {})'.format(
value, ', '.join(self.formats)))
def __repr__(self):
return 'DateTime'
class IntParamType(ParamType):
name = 'integer'
def convert(self, value, param, ctx):
try:
return int(value)
except (ValueError, UnicodeError):
self.fail('%s is not a valid integer' % value, param, ctx)
def __repr__(self):
return 'INT'
class IntRange(IntParamType):
"""A parameter that works similar to :data:`click.INT` but restricts
the value to fit into a range. The default behavior is to fail if the
value falls outside the range, but it can also be silently clamped
between the two edges.
See :ref:`ranges` for an example.
"""
name = 'integer range'
def __init__(self, min=None, max=None, clamp=False):
self.min = min
self.max = max
self.clamp = clamp
def convert(self, value, param, ctx):
rv = IntParamType.convert(self, value, param, ctx)
if self.clamp:
if self.min is not None and rv < self.min:
return self.min
if self.max is not None and rv > self.max:
return self.max
if self.min is not None and rv < self.min or \
self.max is not None and rv > self.max:
if self.min is None:
self.fail('%s is bigger than the maximum valid value '
'%s.' % (rv, self.max), param, ctx)
elif self.max is None:
self.fail('%s is smaller than the minimum valid value '
'%s.' % (rv, self.min), param, ctx)
else:
self.fail('%s is not in the valid range of %s to %s.'
% (rv, self.min, self.max), param, ctx)
return rv
def __repr__(self):
return 'IntRange(%r, %r)' % (self.min, self.max)
class FloatParamType(ParamType):
name = 'float'
def convert(self, value, param, ctx):
try:
return float(value)
except (UnicodeError, ValueError):
self.fail('%s is not a valid floating point value' %
value, param, ctx)
def __repr__(self):
return 'FLOAT'
class FloatRange(FloatParamType):
"""A parameter that works similar to :data:`click.FLOAT` but restricts
the value to fit into a range. The default behavior is to fail if the
value falls outside the range, but it can also be silently clamped
between the two edges.
See :ref:`ranges` for an example.
"""
name = 'float range'
def __init__(self, min=None, max=None, clamp=False):
self.min = min
self.max = max
self.clamp = clamp
def convert(self, value, param, ctx):
rv = FloatParamType.convert(self, value, param, ctx)
if self.clamp:
if self.min is not None and rv < self.min:
return self.min
if self.max is not None and rv > self.max:
return self.max
if self.min is not None and rv < self.min or \
self.max is not None and rv > self.max:
if self.min is None:
self.fail('%s is bigger than the maximum valid value '
'%s.' % (rv, self.max), param, ctx)
elif self.max is None:
self.fail('%s is smaller than the minimum valid value '
'%s.' % (rv, self.min), param, ctx)
else:
self.fail('%s is not in the valid range of %s to %s.'
% (rv, self.min, self.max), param, ctx)
return rv
def __repr__(self):
return 'FloatRange(%r, %r)' % (self.min, self.max)
class BoolParamType(ParamType):
name = 'boolean'
def convert(self, value, param, ctx):
if isinstance(value, bool):
return bool(value)
value = value.lower()
if value in ('true', 't', '1', 'yes', 'y'):
return True
elif value in ('false', 'f', '0', 'no', 'n'):
return False
self.fail('%s is not a valid boolean' % value, param, ctx)
def __repr__(self):
return 'BOOL'
class UUIDParameterType(ParamType):
name = 'uuid'
def convert(self, value, param, ctx):
import uuid
try:
if PY2 and isinstance(value, text_type):
value = value.encode('ascii')
return uuid.UUID(value)
except (UnicodeError, ValueError):
self.fail('%s is not a valid UUID value' % value, param, ctx)
def __repr__(self):
return 'UUID'
class File(ParamType):
"""Declares a parameter to be a file for reading or writing. The file
is automatically closed once the context tears down (after the command
finished working).
Files can be opened for reading or writing. The special value ``-``
indicates stdin or stdout depending on the mode.
By default, the file is opened for reading text data, but it can also be
opened in binary mode or for writing. The encoding parameter can be used
to force a specific encoding.
The `lazy` flag controls if the file should be opened immediately or upon
first IO. The default is to be non-lazy for standard input and output
streams as well as files opened for reading, `lazy` otherwise. When opening a
file lazily for reading, it is still opened temporarily for validation, but
will not be held open until first IO. lazy is mainly useful when opening
for writing to avoid creating the file until it is needed.
Starting with Click 2.0, files can also be opened atomically in which
case all writes go into a separate file in the same folder and upon
completion the file will be moved over to the original location. This
is useful if a file regularly read by other users is modified.
See :ref:`file-args` for more information.
"""
name = 'filename'
envvar_list_splitter = os.path.pathsep
def __init__(self, mode='r', encoding=None, errors='strict', lazy=None,
atomic=False):
self.mode = mode
self.encoding = encoding
self.errors = errors
self.lazy = lazy
self.atomic = atomic
def resolve_lazy_flag(self, value):
if self.lazy is not None:
return self.lazy
if value == '-':
return False
elif 'w' in self.mode:
return True
return False
def convert(self, value, param, ctx):
try:
if hasattr(value, 'read') or hasattr(value, 'write'):
return value
lazy = self.resolve_lazy_flag(value)
if lazy:
f = LazyFile(value, self.mode, self.encoding, self.errors,
atomic=self.atomic)
if ctx is not None:
ctx.call_on_close(f.close_intelligently)
return f
f, should_close = open_stream(value, self.mode,
self.encoding, self.errors,
atomic=self.atomic)
# If a context is provided, we automatically close the file
# at the end of the context execution (or flush out). If a
# context does not exist, it's the caller's responsibility to
# properly close the file. This for instance happens when the
# type is used with prompts.
if ctx is not None:
if should_close:
ctx.call_on_close(safecall(f.close))
else:
ctx.call_on_close(safecall(f.flush))
return f
except (IOError, OSError) as e:
self.fail('Could not open file: %s: %s' % (
filename_to_ui(value),
get_streerror(e),
), param, ctx)
class Path(ParamType):
"""The path type is similar to the :class:`File` type but it performs
different checks. First of all, instead of returning an open file
handle it returns just the filename. Secondly, it can perform various
basic checks about what the file or directory should be.
.. versionchanged:: 6.0
`allow_dash` was added.
:param exists: if set to true, the file or directory needs to exist for
this value to be valid. If this is not required and a
file does indeed not exist, then all further checks are
silently skipped.
:param file_okay: controls if a file is a possible value.
:param dir_okay: controls if a directory is a possible value.
:param writable: if true, a writable check is performed.
:param readable: if true, a readable check is performed.
:param resolve_path: if this is true, then the path is fully resolved
before the value is passed onwards. This means
that it's absolute and symlinks are resolved. It
will not expand a tilde-prefix, as this is
supposed to be done by the shell only.
:param allow_dash: If this is set to `True`, a single dash to indicate
standard streams is permitted.
:param path_type: optionally a string type that should be used to
represent the path. The default is `None` which
means the return value will be either bytes or
unicode depending on what makes most sense given the
input data Click deals with.
"""
envvar_list_splitter = os.path.pathsep
def __init__(self, exists=False, file_okay=True, dir_okay=True,
writable=False, readable=True, resolve_path=False,
allow_dash=False, path_type=None):
self.exists = exists
self.file_okay = file_okay
self.dir_okay = dir_okay
self.writable = writable
self.readable = readable
self.resolve_path = resolve_path
self.allow_dash = allow_dash
self.type = path_type
if self.file_okay and not self.dir_okay:
self.name = 'file'
self.path_type = 'File'
elif self.dir_okay and not self.file_okay:
self.name = 'directory'
self.path_type = 'Directory'
else:
self.name = 'path'
self.path_type = 'Path'
def coerce_path_result(self, rv):
if self.type is not None and not isinstance(rv, self.type):
if self.type is text_type:
rv = rv.decode(get_filesystem_encoding())
else:
rv = rv.encode(get_filesystem_encoding())
return rv
def convert(self, value, param, ctx):
rv = value
is_dash = self.file_okay and self.allow_dash and rv in (b'-', '-')
if not is_dash:
if self.resolve_path:
rv = os.path.realpath(rv)
try:
st = os.stat(rv)
except OSError:
if not self.exists:
return self.coerce_path_result(rv)
self.fail('%s "%s" does not exist.' % (
self.path_type,
filename_to_ui(value)
), param, ctx)
if not self.file_okay and stat.S_ISREG(st.st_mode):
self.fail('%s "%s" is a file.' % (
self.path_type,
filename_to_ui(value)
), param, ctx)
if not self.dir_okay and stat.S_ISDIR(st.st_mode):
self.fail('%s "%s" is a directory.' % (
self.path_type,
filename_to_ui(value)
), param, ctx)
if self.writable and not os.access(value, os.W_OK):
self.fail('%s "%s" is not writable.' % (
self.path_type,
filename_to_ui(value)
), param, ctx)
if self.readable and not os.access(value, os.R_OK):
self.fail('%s "%s" is not readable.' % (
self.path_type,
filename_to_ui(value)
), param, ctx)
return self.coerce_path_result(rv)
class Tuple(CompositeParamType):
"""The default behavior of Click is to apply a type on a value directly.
This works well in most cases, except for when `nargs` is set to a fixed
count and different types should be used for different items. In this
case the :class:`Tuple` type can be used. This type can only be used
if `nargs` is set to a fixed number.
For more information see :ref:`tuple-type`.
This can be selected by using a Python tuple literal as a type.
:param types: a list of types that should be used for the tuple items.
"""
def __init__(self, types):
self.types = [convert_type(ty) for ty in types]
@property
def name(self):
return "<" + " ".join(ty.name for ty in self.types) + ">"
@property
def arity(self):
return len(self.types)
def convert(self, value, param, ctx):
if len(value) != len(self.types):
raise TypeError('It would appear that nargs is set to conflict '
'with the composite type arity.')
return tuple(ty(x, param, ctx) for ty, x in zip(self.types, value))
def convert_type(ty, default=None):
"""Converts a callable or python ty into the most appropriate param
ty.
"""
guessed_type = False
if ty is None and default is not None:
if isinstance(default, tuple):
ty = tuple(map(type, default))
else:
ty = type(default)
guessed_type = True
if isinstance(ty, tuple):
return Tuple(ty)
if isinstance(ty, ParamType):
return ty
if ty is text_type or ty is str or ty is None:
return STRING
if ty is int:
return INT
# Booleans are only okay if not guessed. This is done because for
# flags the default value is actually a bit of a lie in that it
# indicates which of the flags is the one we want. See get_default()
# for more information.
if ty is bool and not guessed_type:
return BOOL
if ty is float:
return FLOAT
if guessed_type:
return STRING
# Catch a common mistake
if __debug__:
try:
if issubclass(ty, ParamType):
raise AssertionError('Attempted to use an uninstantiated '
'parameter type (%s).' % ty)
except TypeError:
pass
return FuncParamType(ty)
#: A dummy parameter type that just does nothing. From a user's
#: perspective this appears to just be the same as `STRING` but internally
#: no string conversion takes place. This is necessary to achieve the
#: same bytes/unicode behavior on Python 2/3 in situations where you want
#: to not convert argument types. This is usually useful when working
#: with file paths as they can appear in bytes and unicode.
#:
#: For path related uses the :class:`Path` type is a better choice but
#: there are situations where an unprocessed type is useful which is why
#: it is is provided.
#:
#: .. versionadded:: 4.0
UNPROCESSED = UnprocessedParamType()
#: A unicode string parameter type which is the implicit default. This
#: can also be selected by using ``str`` as type.
STRING = StringParamType()
#: An integer parameter. This can also be selected by using ``int`` as
#: type.
INT = IntParamType()
#: A floating point value parameter. This can also be selected by using
#: ``float`` as type.
FLOAT = FloatParamType()
#: A boolean parameter. This is the default for boolean flags. This can
#: also be selected by using ``bool`` as a type.
BOOL = BoolParamType()
#: A UUID parameter.
UUID = UUIDParameterType()

440
libs/click/utils.py Normal file
View file

@ -0,0 +1,440 @@
import os
import sys
from .globals import resolve_color_default
from ._compat import text_type, open_stream, get_filesystem_encoding, \
get_streerror, string_types, PY2, binary_streams, text_streams, \
filename_to_ui, auto_wrap_for_ansi, strip_ansi, should_strip_ansi, \
_default_text_stdout, _default_text_stderr, is_bytes, WIN
if not PY2:
from ._compat import _find_binary_writer
elif WIN:
from ._winconsole import _get_windows_argv, \
_hash_py_argv, _initial_argv_hash
echo_native_types = string_types + (bytes, bytearray)
def _posixify(name):
return '-'.join(name.split()).lower()
def safecall(func):
"""Wraps a function so that it swallows exceptions."""
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception:
pass
return wrapper
def make_str(value):
"""Converts a value into a valid string."""
if isinstance(value, bytes):
try:
return value.decode(get_filesystem_encoding())
except UnicodeError:
return value.decode('utf-8', 'replace')
return text_type(value)
def make_default_short_help(help, max_length=45):
"""Return a condensed version of help string."""
words = help.split()
total_length = 0
result = []
done = False
for word in words:
if word[-1:] == '.':
done = True
new_length = result and 1 + len(word) or len(word)
if total_length + new_length > max_length:
result.append('...')
done = True
else:
if result:
result.append(' ')
result.append(word)
if done:
break
total_length += new_length
return ''.join(result)
class LazyFile(object):
"""A lazy file works like a regular file but it does not fully open
the file but it does perform some basic checks early to see if the
filename parameter does make sense. This is useful for safely opening
files for writing.
"""
def __init__(self, filename, mode='r', encoding=None, errors='strict',
atomic=False):
self.name = filename
self.mode = mode
self.encoding = encoding
self.errors = errors
self.atomic = atomic
if filename == '-':
self._f, self.should_close = open_stream(filename, mode,
encoding, errors)
else:
if 'r' in mode:
# Open and close the file in case we're opening it for
# reading so that we can catch at least some errors in
# some cases early.
open(filename, mode).close()
self._f = None
self.should_close = True
def __getattr__(self, name):
return getattr(self.open(), name)
def __repr__(self):
if self._f is not None:
return repr(self._f)
return '<unopened file %r %s>' % (self.name, self.mode)
def open(self):
"""Opens the file if it's not yet open. This call might fail with
a :exc:`FileError`. Not handling this error will produce an error
that Click shows.
"""
if self._f is not None:
return self._f
try:
rv, self.should_close = open_stream(self.name, self.mode,
self.encoding,
self.errors,
atomic=self.atomic)
except (IOError, OSError) as e:
from .exceptions import FileError
raise FileError(self.name, hint=get_streerror(e))
self._f = rv
return rv
def close(self):
"""Closes the underlying file, no matter what."""
if self._f is not None:
self._f.close()
def close_intelligently(self):
"""This function only closes the file if it was opened by the lazy
file wrapper. For instance this will never close stdin.
"""
if self.should_close:
self.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
self.close_intelligently()
def __iter__(self):
self.open()
return iter(self._f)
class KeepOpenFile(object):
def __init__(self, file):
self._file = file
def __getattr__(self, name):
return getattr(self._file, name)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
pass
def __repr__(self):
return repr(self._file)
def __iter__(self):
return iter(self._file)
def echo(message=None, file=None, nl=True, err=False, color=None):
"""Prints a message plus a newline to the given file or stdout. On
first sight, this looks like the print function, but it has improved
support for handling Unicode and binary data that does not fail no
matter how badly configured the system is.
Primarily it means that you can print binary data as well as Unicode
data on both 2.x and 3.x to the given file in the most appropriate way
possible. This is a very carefree function in that it will try its
best to not fail. As of Click 6.0 this includes support for unicode
output on the Windows console.
In addition to that, if `colorama`_ is installed, the echo function will
also support clever handling of ANSI codes. Essentially it will then
do the following:
- add transparent handling of ANSI color codes on Windows.
- hide ANSI codes automatically if the destination file is not a
terminal.
.. _colorama: https://pypi.org/project/colorama/
.. versionchanged:: 6.0
As of Click 6.0 the echo function will properly support unicode
output on the windows console. Not that click does not modify
the interpreter in any way which means that `sys.stdout` or the
print statement or function will still not provide unicode support.
.. versionchanged:: 2.0
Starting with version 2.0 of Click, the echo function will work
with colorama if it's installed.
.. versionadded:: 3.0
The `err` parameter was added.
.. versionchanged:: 4.0
Added the `color` flag.
:param message: the message to print
:param file: the file to write to (defaults to ``stdout``)
:param err: if set to true the file defaults to ``stderr`` instead of
``stdout``. This is faster and easier than calling
:func:`get_text_stderr` yourself.
:param nl: if set to `True` (the default) a newline is printed afterwards.
:param color: controls if the terminal supports ANSI colors or not. The
default is autodetection.
"""
if file is None:
if err:
file = _default_text_stderr()
else:
file = _default_text_stdout()
# Convert non bytes/text into the native string type.
if message is not None and not isinstance(message, echo_native_types):
message = text_type(message)
if nl:
message = message or u''
if isinstance(message, text_type):
message += u'\n'
else:
message += b'\n'
# If there is a message, and we're in Python 3, and the value looks
# like bytes, we manually need to find the binary stream and write the
# message in there. This is done separately so that most stream
# types will work as you would expect. Eg: you can write to StringIO
# for other cases.
if message and not PY2 and is_bytes(message):
binary_file = _find_binary_writer(file)
if binary_file is not None:
file.flush()
binary_file.write(message)
binary_file.flush()
return
# ANSI-style support. If there is no message or we are dealing with
# bytes nothing is happening. If we are connected to a file we want
# to strip colors. If we are on windows we either wrap the stream
# to strip the color or we use the colorama support to translate the
# ansi codes to API calls.
if message and not is_bytes(message):
color = resolve_color_default(color)
if should_strip_ansi(file, color):
message = strip_ansi(message)
elif WIN:
if auto_wrap_for_ansi is not None:
file = auto_wrap_for_ansi(file)
elif not color:
message = strip_ansi(message)
if message:
file.write(message)
file.flush()
def get_binary_stream(name):
"""Returns a system stream for byte processing. This essentially
returns the stream from the sys module with the given name but it
solves some compatibility issues between different Python versions.
Primarily this function is necessary for getting binary streams on
Python 3.
:param name: the name of the stream to open. Valid names are ``'stdin'``,
``'stdout'`` and ``'stderr'``
"""
opener = binary_streams.get(name)
if opener is None:
raise TypeError('Unknown standard stream %r' % name)
return opener()
def get_text_stream(name, encoding=None, errors='strict'):
"""Returns a system stream for text processing. This usually returns
a wrapped stream around a binary stream returned from
:func:`get_binary_stream` but it also can take shortcuts on Python 3
for already correctly configured streams.
:param name: the name of the stream to open. Valid names are ``'stdin'``,
``'stdout'`` and ``'stderr'``
:param encoding: overrides the detected default encoding.
:param errors: overrides the default error mode.
"""
opener = text_streams.get(name)
if opener is None:
raise TypeError('Unknown standard stream %r' % name)
return opener(encoding, errors)
def open_file(filename, mode='r', encoding=None, errors='strict',
lazy=False, atomic=False):
"""This is similar to how the :class:`File` works but for manual
usage. Files are opened non lazy by default. This can open regular
files as well as stdin/stdout if ``'-'`` is passed.
If stdin/stdout is returned the stream is wrapped so that the context
manager will not close the stream accidentally. This makes it possible
to always use the function like this without having to worry to
accidentally close a standard stream::
with open_file(filename) as f:
...
.. versionadded:: 3.0
:param filename: the name of the file to open (or ``'-'`` for stdin/stdout).
:param mode: the mode in which to open the file.
:param encoding: the encoding to use.
:param errors: the error handling for this file.
:param lazy: can be flipped to true to open the file lazily.
:param atomic: in atomic mode writes go into a temporary file and it's
moved on close.
"""
if lazy:
return LazyFile(filename, mode, encoding, errors, atomic=atomic)
f, should_close = open_stream(filename, mode, encoding, errors,
atomic=atomic)
if not should_close:
f = KeepOpenFile(f)
return f
def get_os_args():
"""This returns the argument part of sys.argv in the most appropriate
form for processing. What this means is that this return value is in
a format that works for Click to process but does not necessarily
correspond well to what's actually standard for the interpreter.
On most environments the return value is ``sys.argv[:1]`` unchanged.
However if you are on Windows and running Python 2 the return value
will actually be a list of unicode strings instead because the
default behavior on that platform otherwise will not be able to
carry all possible values that sys.argv can have.
.. versionadded:: 6.0
"""
# We can only extract the unicode argv if sys.argv has not been
# changed since the startup of the application.
if PY2 and WIN and _initial_argv_hash == _hash_py_argv():
return _get_windows_argv()
return sys.argv[1:]
def format_filename(filename, shorten=False):
"""Formats a filename for user display. The main purpose of this
function is to ensure that the filename can be displayed at all. This
will decode the filename to unicode if necessary in a way that it will
not fail. Optionally, it can shorten the filename to not include the
full path to the filename.
:param filename: formats a filename for UI display. This will also convert
the filename into unicode without failing.
:param shorten: this optionally shortens the filename to strip of the
path that leads up to it.
"""
if shorten:
filename = os.path.basename(filename)
return filename_to_ui(filename)
def get_app_dir(app_name, roaming=True, force_posix=False):
r"""Returns the config folder for the application. The default behavior
is to return whatever is most appropriate for the operating system.
To give you an idea, for an app called ``"Foo Bar"``, something like
the following folders could be returned:
Mac OS X:
``~/Library/Application Support/Foo Bar``
Mac OS X (POSIX):
``~/.foo-bar``
Unix:
``~/.config/foo-bar``
Unix (POSIX):
``~/.foo-bar``
Win XP (roaming):
``C:\Documents and Settings\<user>\Local Settings\Application Data\Foo Bar``
Win XP (not roaming):
``C:\Documents and Settings\<user>\Application Data\Foo Bar``
Win 7 (roaming):
``C:\Users\<user>\AppData\Roaming\Foo Bar``
Win 7 (not roaming):
``C:\Users\<user>\AppData\Local\Foo Bar``
.. versionadded:: 2.0
:param app_name: the application name. This should be properly capitalized
and can contain whitespace.
:param roaming: controls if the folder should be roaming or not on Windows.
Has no affect otherwise.
:param force_posix: if this is set to `True` then on any POSIX system the
folder will be stored in the home folder with a leading
dot instead of the XDG config home or darwin's
application support folder.
"""
if WIN:
key = roaming and 'APPDATA' or 'LOCALAPPDATA'
folder = os.environ.get(key)
if folder is None:
folder = os.path.expanduser('~')
return os.path.join(folder, app_name)
if force_posix:
return os.path.join(os.path.expanduser('~/.' + _posixify(app_name)))
if sys.platform == 'darwin':
return os.path.join(os.path.expanduser(
'~/Library/Application Support'), app_name)
return os.path.join(
os.environ.get('XDG_CONFIG_HOME', os.path.expanduser('~/.config')),
_posixify(app_name))
class PacifyFlushWrapper(object):
"""This wrapper is used to catch and suppress BrokenPipeErrors resulting
from ``.flush()`` being called on broken pipe during the shutdown/final-GC
of the Python interpreter. Notably ``.flush()`` is always called on
``sys.stdout`` and ``sys.stderr``. So as to have minimal impact on any
other cleanup code, and the case where the underlying file is not a broken
pipe, all calls and attributes are proxied.
"""
def __init__(self, wrapped):
self.wrapped = wrapped
def flush(self):
try:
self.wrapped.flush()
except IOError as e:
import errno
if e.errno != errno.EPIPE:
raise
def __getattr__(self, attr):
return getattr(self.wrapped, attr)

432
libs/decorator.py Normal file
View file

@ -0,0 +1,432 @@
# ######################### LICENSE ############################ #
# Copyright (c) 2005-2018, Michele Simionato
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# Redistributions in bytecode form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in
# the documentation and/or other materials provided with the
# distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
"""
Decorator module, see http://pypi.python.org/pypi/decorator
for the documentation.
"""
from __future__ import print_function
import re
import sys
import inspect
import operator
import itertools
import collections
__version__ = '4.3.0'
if sys.version >= '3':
from inspect import getfullargspec
def get_init(cls):
return cls.__init__
else:
FullArgSpec = collections.namedtuple(
'FullArgSpec', 'args varargs varkw defaults '
'kwonlyargs kwonlydefaults annotations')
def getfullargspec(f):
"A quick and dirty replacement for getfullargspec for Python 2.X"
return FullArgSpec._make(inspect.getargspec(f) + ([], None, {}))
def get_init(cls):
return cls.__init__.__func__
try:
iscoroutinefunction = inspect.iscoroutinefunction
except AttributeError:
# let's assume there are no coroutine functions in old Python
def iscoroutinefunction(f):
return False
DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(')
# basic functionality
class FunctionMaker(object):
"""
An object with the ability to create functions with a given signature.
It has attributes name, doc, module, signature, defaults, dict and
methods update and make.
"""
# Atomic get-and-increment provided by the GIL
_compile_count = itertools.count()
# make pylint happy
args = varargs = varkw = defaults = kwonlyargs = kwonlydefaults = ()
def __init__(self, func=None, name=None, signature=None,
defaults=None, doc=None, module=None, funcdict=None):
self.shortsignature = signature
if func:
# func can be a class or a callable, but not an instance method
self.name = func.__name__
if self.name == '<lambda>': # small hack for lambda functions
self.name = '_lambda_'
self.doc = func.__doc__
self.module = func.__module__
if inspect.isfunction(func):
argspec = getfullargspec(func)
self.annotations = getattr(func, '__annotations__', {})
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
'kwonlydefaults'):
setattr(self, a, getattr(argspec, a))
for i, arg in enumerate(self.args):
setattr(self, 'arg%d' % i, arg)
allargs = list(self.args)
allshortargs = list(self.args)
if self.varargs:
allargs.append('*' + self.varargs)
allshortargs.append('*' + self.varargs)
elif self.kwonlyargs:
allargs.append('*') # single star syntax
for a in self.kwonlyargs:
allargs.append('%s=None' % a)
allshortargs.append('%s=%s' % (a, a))
if self.varkw:
allargs.append('**' + self.varkw)
allshortargs.append('**' + self.varkw)
self.signature = ', '.join(allargs)
self.shortsignature = ', '.join(allshortargs)
self.dict = func.__dict__.copy()
# func=None happens when decorating a caller
if name:
self.name = name
if signature is not None:
self.signature = signature
if defaults:
self.defaults = defaults
if doc:
self.doc = doc
if module:
self.module = module
if funcdict:
self.dict = funcdict
# check existence required attributes
assert hasattr(self, 'name')
if not hasattr(self, 'signature'):
raise TypeError('You are decorating a non function: %s' % func)
def update(self, func, **kw):
"Update the signature of func with the data in self"
func.__name__ = self.name
func.__doc__ = getattr(self, 'doc', None)
func.__dict__ = getattr(self, 'dict', {})
func.__defaults__ = self.defaults
func.__kwdefaults__ = self.kwonlydefaults or None
func.__annotations__ = getattr(self, 'annotations', None)
try:
frame = sys._getframe(3)
except AttributeError: # for IronPython and similar implementations
callermodule = '?'
else:
callermodule = frame.f_globals.get('__name__', '?')
func.__module__ = getattr(self, 'module', callermodule)
func.__dict__.update(kw)
def make(self, src_templ, evaldict=None, addsource=False, **attrs):
"Make a new function from a given template and update the signature"
src = src_templ % vars(self) # expand name and signature
evaldict = evaldict or {}
mo = DEF.search(src)
if mo is None:
raise SyntaxError('not a valid function template\n%s' % src)
name = mo.group(1) # extract the function name
names = set([name] + [arg.strip(' *') for arg in
self.shortsignature.split(',')])
for n in names:
if n in ('_func_', '_call_'):
raise NameError('%s is overridden in\n%s' % (n, src))
if not src.endswith('\n'): # add a newline for old Pythons
src += '\n'
# Ensure each generated function has a unique filename for profilers
# (such as cProfile) that depend on the tuple of (<filename>,
# <definition line>, <function name>) being unique.
filename = '<decorator-gen-%d>' % (next(self._compile_count),)
try:
code = compile(src, filename, 'single')
exec(code, evaldict)
except Exception:
print('Error in generated code:', file=sys.stderr)
print(src, file=sys.stderr)
raise
func = evaldict[name]
if addsource:
attrs['__source__'] = src
self.update(func, **attrs)
return func
@classmethod
def create(cls, obj, body, evaldict, defaults=None,
doc=None, module=None, addsource=True, **attrs):
"""
Create a function from the strings name, signature and body.
evaldict is the evaluation dictionary. If addsource is true an
attribute __source__ is added to the result. The attributes attrs
are added, if any.
"""
if isinstance(obj, str): # "name(signature)"
name, rest = obj.strip().split('(', 1)
signature = rest[:-1] # strip a right parens
func = None
else: # a function
name = None
signature = None
func = obj
self = cls(func, name, signature, defaults, doc, module)
ibody = '\n'.join(' ' + line for line in body.splitlines())
caller = evaldict.get('_call_') # when called from `decorate`
if caller and iscoroutinefunction(caller):
body = ('async def %(name)s(%(signature)s):\n' + ibody).replace(
'return', 'return await')
else:
body = 'def %(name)s(%(signature)s):\n' + ibody
return self.make(body, evaldict, addsource, **attrs)
def decorate(func, caller, extras=()):
"""
decorate(func, caller) decorates a function using a caller.
"""
evaldict = dict(_call_=caller, _func_=func)
es = ''
for i, extra in enumerate(extras):
ex = '_e%d_' % i
evaldict[ex] = extra
es += ex + ', '
fun = FunctionMaker.create(
func, "return _call_(_func_, %s%%(shortsignature)s)" % es,
evaldict, __wrapped__=func)
if hasattr(func, '__qualname__'):
fun.__qualname__ = func.__qualname__
return fun
def decorator(caller, _func=None):
"""decorator(caller) converts a caller function into a decorator"""
if _func is not None: # return a decorated function
# this is obsolete behavior; you should use decorate instead
return decorate(_func, caller)
# else return a decorator function
defaultargs, defaults = '', ()
if inspect.isclass(caller):
name = caller.__name__.lower()
doc = 'decorator(%s) converts functions/generators into ' \
'factories of %s objects' % (caller.__name__, caller.__name__)
elif inspect.isfunction(caller):
if caller.__name__ == '<lambda>':
name = '_lambda_'
else:
name = caller.__name__
doc = caller.__doc__
nargs = caller.__code__.co_argcount
ndefs = len(caller.__defaults__ or ())
defaultargs = ', '.join(caller.__code__.co_varnames[nargs-ndefs:nargs])
if defaultargs:
defaultargs += ','
defaults = caller.__defaults__
else: # assume caller is an object with a __call__ method
name = caller.__class__.__name__.lower()
doc = caller.__call__.__doc__
evaldict = dict(_call=caller, _decorate_=decorate)
dec = FunctionMaker.create(
'%s(%s func)' % (name, defaultargs),
'if func is None: return lambda func: _decorate_(func, _call, (%s))\n'
'return _decorate_(func, _call, (%s))' % (defaultargs, defaultargs),
evaldict, doc=doc, module=caller.__module__, __wrapped__=caller)
if defaults:
dec.__defaults__ = defaults + (None,)
return dec
# ####################### contextmanager ####################### #
try: # Python >= 3.2
from contextlib import _GeneratorContextManager
except ImportError: # Python >= 2.5
from contextlib import GeneratorContextManager as _GeneratorContextManager
class ContextManager(_GeneratorContextManager):
def __call__(self, func):
"""Context manager decorator"""
return FunctionMaker.create(
func, "with _self_: return _func_(%(shortsignature)s)",
dict(_self_=self, _func_=func), __wrapped__=func)
init = getfullargspec(_GeneratorContextManager.__init__)
n_args = len(init.args)
if n_args == 2 and not init.varargs: # (self, genobj) Python 2.7
def __init__(self, g, *a, **k):
return _GeneratorContextManager.__init__(self, g(*a, **k))
ContextManager.__init__ = __init__
elif n_args == 2 and init.varargs: # (self, gen, *a, **k) Python 3.4
pass
elif n_args == 4: # (self, gen, args, kwds) Python 3.5
def __init__(self, g, *a, **k):
return _GeneratorContextManager.__init__(self, g, a, k)
ContextManager.__init__ = __init__
_contextmanager = decorator(ContextManager)
def contextmanager(func):
# Enable Pylint config: contextmanager-decorators=decorator.contextmanager
return _contextmanager(func)
# ############################ dispatch_on ############################ #
def append(a, vancestors):
"""
Append ``a`` to the list of the virtual ancestors, unless it is already
included.
"""
add = True
for j, va in enumerate(vancestors):
if issubclass(va, a):
add = False
break
if issubclass(a, va):
vancestors[j] = a
add = False
if add:
vancestors.append(a)
# inspired from simplegeneric by P.J. Eby and functools.singledispatch
def dispatch_on(*dispatch_args):
"""
Factory of decorators turning a function into a generic function
dispatching on the given arguments.
"""
assert dispatch_args, 'No dispatch args passed'
dispatch_str = '(%s,)' % ', '.join(dispatch_args)
def check(arguments, wrong=operator.ne, msg=''):
"""Make sure one passes the expected number of arguments"""
if wrong(len(arguments), len(dispatch_args)):
raise TypeError('Expected %d arguments, got %d%s' %
(len(dispatch_args), len(arguments), msg))
def gen_func_dec(func):
"""Decorator turning a function into a generic function"""
# first check the dispatch arguments
argset = set(getfullargspec(func).args)
if not set(dispatch_args) <= argset:
raise NameError('Unknown dispatch arguments %s' % dispatch_str)
typemap = {}
def vancestors(*types):
"""
Get a list of sets of virtual ancestors for the given types
"""
check(types)
ras = [[] for _ in range(len(dispatch_args))]
for types_ in typemap:
for t, type_, ra in zip(types, types_, ras):
if issubclass(t, type_) and type_ not in t.mro():
append(type_, ra)
return [set(ra) for ra in ras]
def ancestors(*types):
"""
Get a list of virtual MROs, one for each type
"""
check(types)
lists = []
for t, vas in zip(types, vancestors(*types)):
n_vas = len(vas)
if n_vas > 1:
raise RuntimeError(
'Ambiguous dispatch for %s: %s' % (t, vas))
elif n_vas == 1:
va, = vas
mro = type('t', (t, va), {}).mro()[1:]
else:
mro = t.mro()
lists.append(mro[:-1]) # discard t and object
return lists
def register(*types):
"""
Decorator to register an implementation for the given types
"""
check(types)
def dec(f):
check(getfullargspec(f).args, operator.lt, ' in ' + f.__name__)
typemap[types] = f
return f
return dec
def dispatch_info(*types):
"""
An utility to introspect the dispatch algorithm
"""
check(types)
lst = []
for anc in itertools.product(*ancestors(*types)):
lst.append(tuple(a.__name__ for a in anc))
return lst
def _dispatch(dispatch_args, *args, **kw):
types = tuple(type(arg) for arg in dispatch_args)
try: # fast path
f = typemap[types]
except KeyError:
pass
else:
return f(*args, **kw)
combinations = itertools.product(*ancestors(*types))
next(combinations) # the first one has been already tried
for types_ in combinations:
f = typemap.get(types_)
if f is not None:
return f(*args, **kw)
# else call the default implementation
return func(*args, **kw)
return FunctionMaker.create(
func, 'return _f_(%s, %%(shortsignature)s)' % dispatch_str,
dict(_f_=_dispatch), register=register, default=func,
typemap=typemap, vancestors=vancestors, ancestors=ancestors,
dispatch_info=dispatch_info, __wrapped__=func)
gen_func_dec.__name__ = 'dispatch_on' + dispatch_str
return gen_func_dec

View file

@ -1,6 +1,4 @@
# See http://peak.telecommunity.com/DevCenter/setuptools#namespace-packages
try:
__import__('pkg_resources').declare_namespace(__name__)
except ImportError:
from pkgutil import extend_path
__path__ = extend_path(__path__, __name__)
__version__ = '0.7.1'
from .lock import Lock # noqa
from .lock import NeedRegenerationException # noqa

View file

@ -1,3 +1,4 @@
__version__ = '0.5.4'
from .region import CacheRegion, register_backend, make_region # noqa
from .region import CacheRegion, register_backend, make_region
# backwards compat
from .. import __version__ # noqa

View file

@ -1,5 +1,5 @@
import operator
from .compat import py3k
from ..util.compat import py3k
class NoValue(object):
@ -13,17 +13,26 @@ class NoValue(object):
def payload(self):
return self
def __repr__(self):
"""Ensure __repr__ is a consistent value in case NoValue is used to
fill another cache key.
"""
return '<dogpile.cache.api.NoValue object>'
if py3k:
def __bool__(self): #pragma NO COVERAGE
def __bool__(self): # pragma NO COVERAGE
return False
else:
def __nonzero__(self): #pragma NO COVERAGE
def __nonzero__(self): # pragma NO COVERAGE
return False
NO_VALUE = NoValue()
"""Value returned from ``get()`` that describes
a key not present."""
class CachedValue(tuple):
"""Represent a value stored in the cache.
@ -47,6 +56,7 @@ class CachedValue(tuple):
def __reduce__(self):
return CachedValue, (self.payload, self.metadata)
class CacheBackend(object):
"""Base class for backend implementations."""
@ -58,7 +68,7 @@ class CacheBackend(object):
"""
def __init__(self, arguments): #pragma NO COVERAGE
def __init__(self, arguments): # pragma NO COVERAGE
"""Construct a new :class:`.CacheBackend`.
Subclasses should override this to
@ -74,12 +84,15 @@ class CacheBackend(object):
def from_config_dict(cls, config_dict, prefix):
prefix_len = len(prefix)
return cls(
dict(
(key[prefix_len:], config_dict[key])
for key in config_dict
if key.startswith(prefix)
)
dict(
(key[prefix_len:], config_dict[key])
for key in config_dict
if key.startswith(prefix)
)
)
def has_lock_timeout(self):
return False
def get_mutex(self, key):
"""Return an optional mutexing object for the given key.
@ -114,7 +127,7 @@ class CacheBackend(object):
"""
return None
def get(self, key): #pragma NO COVERAGE
def get(self, key): # pragma NO COVERAGE
"""Retrieve a value from the cache.
The returned value should be an instance of
@ -124,7 +137,7 @@ class CacheBackend(object):
"""
raise NotImplementedError()
def get_multi(self, keys): #pragma NO COVERAGE
def get_multi(self, keys): # pragma NO COVERAGE
"""Retrieve multiple values from the cache.
The returned value should be a list, corresponding
@ -135,7 +148,7 @@ class CacheBackend(object):
"""
raise NotImplementedError()
def set(self, key, value): #pragma NO COVERAGE
def set(self, key, value): # pragma NO COVERAGE
"""Set a value in the cache.
The key will be whatever was passed
@ -147,21 +160,30 @@ class CacheBackend(object):
"""
raise NotImplementedError()
def set_multi(self, mapping): #pragma NO COVERAGE
def set_multi(self, mapping): # pragma NO COVERAGE
"""Set multiple values in the cache.
The key will be whatever was passed
``mapping`` is a dict in which
the key will be whatever was passed
to the registry, processed by the
"key mangling" function, if any.
The value will always be an instance
of :class:`.CachedValue`.
When implementing a new :class:`.CacheBackend` or cutomizing via
:class:`.ProxyBackend`, be aware that when this method is invoked by
:meth:`.Region.get_or_create_multi`, the ``mapping`` values are the
same ones returned to the upstream caller. If the subclass alters the
values in any way, it must not do so 'in-place' on the ``mapping`` dict
-- that will have the undesirable effect of modifying the returned
values as well.
.. versionadded:: 0.5.0
"""
raise NotImplementedError()
def delete(self, key): #pragma NO COVERAGE
def delete(self, key): # pragma NO COVERAGE
"""Delete a value from the cache.
The key will be whatever was passed
@ -175,7 +197,7 @@ class CacheBackend(object):
"""
raise NotImplementedError()
def delete_multi(self, keys): #pragma NO COVERAGE
def delete_multi(self, keys): # pragma NO COVERAGE
"""Delete multiple values from the cache.
The key will be whatever was passed

View file

@ -1,10 +1,22 @@
from dogpile.cache.region import register_backend
register_backend("dogpile.cache.null", "dogpile.cache.backends.null", "NullBackend")
register_backend("dogpile.cache.dbm", "dogpile.cache.backends.file", "DBMBackend")
register_backend("dogpile.cache.pylibmc", "dogpile.cache.backends.memcached", "PylibmcBackend")
register_backend("dogpile.cache.bmemcached", "dogpile.cache.backends.memcached", "BMemcachedBackend")
register_backend("dogpile.cache.memcached", "dogpile.cache.backends.memcached", "MemcachedBackend")
register_backend("dogpile.cache.memory", "dogpile.cache.backends.memory", "MemoryBackend")
register_backend("dogpile.cache.memory_pickle", "dogpile.cache.backends.memory", "MemoryPickleBackend")
register_backend("dogpile.cache.redis", "dogpile.cache.backends.redis", "RedisBackend")
register_backend(
"dogpile.cache.null", "dogpile.cache.backends.null", "NullBackend")
register_backend(
"dogpile.cache.dbm", "dogpile.cache.backends.file", "DBMBackend")
register_backend(
"dogpile.cache.pylibmc", "dogpile.cache.backends.memcached",
"PylibmcBackend")
register_backend(
"dogpile.cache.bmemcached", "dogpile.cache.backends.memcached",
"BMemcachedBackend")
register_backend(
"dogpile.cache.memcached", "dogpile.cache.backends.memcached",
"MemcachedBackend")
register_backend(
"dogpile.cache.memory", "dogpile.cache.backends.memory", "MemoryBackend")
register_backend(
"dogpile.cache.memory_pickle", "dogpile.cache.backends.memory",
"MemoryPickleBackend")
register_backend(
"dogpile.cache.redis", "dogpile.cache.backends.redis", "RedisBackend")

View file

@ -7,14 +7,15 @@ Provides backends that deal with local filesystem access.
"""
from __future__ import with_statement
from dogpile.cache.api import CacheBackend, NO_VALUE
from ..api import CacheBackend, NO_VALUE
from contextlib import contextmanager
from dogpile.cache import compat
from dogpile.cache import util
from ...util import compat
from ... import util
import os
__all__ = 'DBMBackend', 'FileLock', 'AbstractFileLock'
class DBMBackend(CacheBackend):
"""A file-backend using a dbm file to store keys.
@ -135,19 +136,19 @@ class DBMBackend(CacheBackend):
"""
def __init__(self, arguments):
self.filename = os.path.abspath(
os.path.normpath(arguments['filename'])
)
os.path.normpath(arguments['filename'])
)
dir_, filename = os.path.split(self.filename)
self.lock_factory = arguments.get("lock_factory", FileLock)
self._rw_lock = self._init_lock(
arguments.get('rw_lockfile'),
".rw.lock", dir_, filename)
arguments.get('rw_lockfile'),
".rw.lock", dir_, filename)
self._dogpile_lock = self._init_lock(
arguments.get('dogpile_lockfile'),
".dogpile.lock",
dir_, filename,
util.KeyReentrantMutex.factory)
arguments.get('dogpile_lockfile'),
".dogpile.lock",
dir_, filename,
util.KeyReentrantMutex.factory)
# TODO: make this configurable
if compat.py3k:
@ -162,9 +163,9 @@ class DBMBackend(CacheBackend):
lock = self.lock_factory(os.path.join(basedir, basefile + suffix))
elif argument is not False:
lock = self.lock_factory(
os.path.abspath(
os.path.normpath(argument)
))
os.path.abspath(
os.path.normpath(argument)
))
else:
return None
if wrapper:
@ -209,8 +210,9 @@ class DBMBackend(CacheBackend):
@contextmanager
def _dbm_file(self, write):
with self._use_rw_lock(write):
dbm = self.dbmmodule.open(self.filename,
"w" if write else "r")
dbm = self.dbmmodule.open(
self.filename,
"w" if write else "r")
yield dbm
dbm.close()
@ -233,12 +235,14 @@ class DBMBackend(CacheBackend):
def set(self, key, value):
with self._dbm_file(True) as dbm:
dbm[key] = compat.pickle.dumps(value)
dbm[key] = compat.pickle.dumps(value,
compat.pickle.HIGHEST_PROTOCOL)
def set_multi(self, mapping):
with self._dbm_file(True) as dbm:
for key,value in mapping.items():
dbm[key] = compat.pickle.dumps(value)
for key, value in mapping.items():
dbm[key] = compat.pickle.dumps(value,
compat.pickle.HIGHEST_PROTOCOL)
def delete(self, key):
with self._dbm_file(True) as dbm:
@ -255,6 +259,7 @@ class DBMBackend(CacheBackend):
except KeyError:
pass
class AbstractFileLock(object):
"""Coordinate read/write access to a file.
@ -275,10 +280,10 @@ class AbstractFileLock(object):
file.
Note that multithreaded environments must provide a thread-safe
version of this lock. The recommended approach for file-descriptor-based
locks is to use a Python ``threading.local()`` so that a unique file descriptor
is held per thread. See the source code of :class:`.FileLock` for an
implementation example.
version of this lock. The recommended approach for file-
descriptor-based locks is to use a Python ``threading.local()`` so
that a unique file descriptor is held per thread. See the source
code of :class:`.FileLock` for an implementation example.
"""
@ -377,6 +382,7 @@ class AbstractFileLock(object):
"""
raise NotImplementedError()
class FileLock(AbstractFileLock):
"""Use lockfiles to coordinate read/write access to a file.

View file

@ -6,15 +6,16 @@ Provides backends for talking to `memcached <http://memcached.org>`_.
"""
from dogpile.cache.api import CacheBackend, NO_VALUE
from dogpile.cache import compat
from dogpile.cache import util
from ..api import CacheBackend, NO_VALUE
from ...util import compat
from ... import util
import random
import time
__all__ = 'GenericMemcachedBackend', 'MemcachedBackend',\
'PylibmcBackend', 'BMemcachedBackend', 'MemcachedLock'
class MemcachedLock(object):
"""Simple distributed lock using memcached.
@ -23,20 +24,21 @@ class MemcachedLock(object):
"""
def __init__(self, client_fn, key):
def __init__(self, client_fn, key, timeout=0):
self.client_fn = client_fn
self.key = "_lock" + key
self.timeout = timeout
def acquire(self, wait=True):
client = self.client_fn()
i = 0
while True:
if client.add(self.key, 1):
if client.add(self.key, 1, self.timeout):
return True
elif not wait:
return False
else:
sleep_time = (((i+1)*random.random()) + 2**i) / 2.5
sleep_time = (((i + 1) * random.random()) + 2 ** i) / 2.5
time.sleep(sleep_time)
if i < 15:
i += 1
@ -45,6 +47,7 @@ class MemcachedLock(object):
client = self.client_fn()
client.delete(self.key)
class GenericMemcachedBackend(CacheBackend):
"""Base class for memcached backends.
@ -60,6 +63,12 @@ class GenericMemcachedBackend(CacheBackend):
processes will be talking to the same memcached instance.
When left at False, dogpile will coordinate on a regular
threading mutex.
:param lock_timeout: integer, number of seconds after acquiring a lock that
memcached should expire it. This argument is only valid when
``distributed_lock`` is ``True``.
.. versionadded:: 0.5.7
:param memcached_expire_time: integer, when present will
be passed as the ``time`` parameter to ``pylibmc.Client.set``.
This is used to set the memcached expiry time for a value.
@ -104,9 +113,13 @@ class GenericMemcachedBackend(CacheBackend):
# automatically.
self.url = util.to_list(arguments['url'])
self.distributed_lock = arguments.get('distributed_lock', False)
self.lock_timeout = arguments.get('lock_timeout', 0)
self.memcached_expire_time = arguments.get(
'memcached_expire_time', 0)
def has_lock_timeout(self):
return self.lock_timeout != 0
def _imports(self):
"""client library imports go here."""
raise NotImplementedError()
@ -118,6 +131,7 @@ class GenericMemcachedBackend(CacheBackend):
@util.memoized_property
def _clients(self):
backend = self
class ClientPool(compat.threading.local):
def __init__(self):
self.memcached = backend._create_client()
@ -138,7 +152,8 @@ class GenericMemcachedBackend(CacheBackend):
def get_mutex(self, key):
if self.distributed_lock:
return MemcachedLock(lambda: self.client, key)
return MemcachedLock(lambda: self.client, key,
timeout=self.lock_timeout)
else:
return None
@ -157,13 +172,15 @@ class GenericMemcachedBackend(CacheBackend):
]
def set(self, key, value):
self.client.set(key,
self.client.set(
key,
value,
**self.set_arguments
)
def set_multi(self, mapping):
self.client.set_multi(mapping,
self.client.set_multi(
mapping,
**self.set_arguments
)
@ -173,6 +190,7 @@ class GenericMemcachedBackend(CacheBackend):
def delete_multi(self, keys):
self.client.delete_multi(keys)
class MemcacheArgs(object):
"""Mixin which provides support for the 'time' argument to set(),
'min_compress_len' to other methods.
@ -183,13 +201,15 @@ class MemcacheArgs(object):
self.set_arguments = {}
if "memcached_expire_time" in arguments:
self.set_arguments["time"] =\
arguments["memcached_expire_time"]
self.set_arguments["time"] = arguments["memcached_expire_time"]
if "min_compress_len" in arguments:
self.set_arguments["min_compress_len"] =\
arguments["min_compress_len"]
self.set_arguments["min_compress_len"] = \
arguments["min_compress_len"]
super(MemcacheArgs, self).__init__(arguments)
pylibmc = None
class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend):
"""A backend for the
`pylibmc <http://sendapatch.se/projects/pylibmc/index.html>`_
@ -229,19 +249,24 @@ class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend):
self.behaviors = arguments.get('behaviors', {})
super(PylibmcBackend, self).__init__(arguments)
def _imports(self):
global pylibmc
import pylibmc
import pylibmc # noqa
def _create_client(self):
return pylibmc.Client(self.url,
return pylibmc.Client(
self.url,
binary=self.binary,
behaviors=self.behaviors
)
memcache = None
class MemcachedBackend(MemcacheArgs, GenericMemcachedBackend):
"""A backend using the standard `Python-memcached <http://www.tummy.com/Community/software/python-memcached/>`_
"""A backend using the standard
`Python-memcached <http://www.tummy.com/Community/software/\
python-memcached/>`_
library.
Example::
@ -259,14 +284,19 @@ class MemcachedBackend(MemcacheArgs, GenericMemcachedBackend):
"""
def _imports(self):
global memcache
import memcache
import memcache # noqa
def _create_client(self):
return memcache.Client(self.url)
bmemcached = None
class BMemcachedBackend(GenericMemcachedBackend):
"""A backend for the
`python-binary-memcached <https://github.com/jaysonsantos/python-binary-memcached>`_
`python-binary-memcached <https://github.com/jaysonsantos/\
python-binary-memcached>`_
memcached client.
This is a pure Python memcached client which
@ -312,16 +342,18 @@ class BMemcachedBackend(GenericMemcachedBackend):
"""
def add(self, key, value):
def add(self, key, value, timeout=0):
try:
return super(RepairBMemcachedAPI, self).add(key, value)
return super(RepairBMemcachedAPI, self).add(
key, value, timeout)
except ValueError:
return False
self.Client = RepairBMemcachedAPI
def _create_client(self):
return self.Client(self.url,
return self.Client(
self.url,
username=self.username,
password=self.password
)

View file

@ -10,8 +10,9 @@ places the value as given into the dictionary.
"""
from dogpile.cache.api import CacheBackend, NO_VALUE
from dogpile.cache.compat import pickle
from ..api import CacheBackend, NO_VALUE
from ...util.compat import pickle
class MemoryBackend(CacheBackend):
"""A backend that uses a plain dictionary.
@ -58,14 +59,15 @@ class MemoryBackend(CacheBackend):
return value
def get_multi(self, keys):
ret = [self._cache.get(key, NO_VALUE)
ret = [
self._cache.get(key, NO_VALUE)
for key in keys]
if self.pickle_values:
ret = [
pickle.loads(value)
if value is not NO_VALUE else value
for value in ret
]
pickle.loads(value)
if value is not NO_VALUE else value
for value in ret
]
return ret
def set(self, key, value):

View file

@ -10,15 +10,15 @@ caching for a region that is otherwise used normally.
"""
from dogpile.cache.api import CacheBackend, NO_VALUE
from ..api import CacheBackend, NO_VALUE
__all__ = ['NullBackend']
class NullLock(object):
def acquire(self):
pass
def acquire(self, wait=True):
return True
def release(self):
pass

View file

@ -7,8 +7,8 @@ Provides backends for talking to `Redis <http://redis.io>`_.
"""
from __future__ import absolute_import
from dogpile.cache.api import CacheBackend, NO_VALUE
from dogpile.cache.compat import pickle, u
from ..api import CacheBackend, NO_VALUE
from ...util.compat import pickle, u
redis = None
@ -30,7 +30,7 @@ class RedisBackend(CacheBackend):
'port': 6379,
'db': 0,
'redis_expiration_time': 60*60*2, # 2 hours
'distributed_lock':True
'distributed_lock': True
}
)
@ -91,6 +91,7 @@ class RedisBackend(CacheBackend):
"""
def __init__(self, arguments):
arguments = arguments.copy()
self._imports()
self.url = arguments.pop('url', None)
self.host = arguments.pop('host', 'localhost')
@ -110,7 +111,7 @@ class RedisBackend(CacheBackend):
def _imports(self):
# defer imports until backend is used
global redis
import redis
import redis # noqa
def _create_client(self):
if self.connection_pool is not None:
@ -133,7 +134,6 @@ class RedisBackend(CacheBackend):
)
return redis.StrictRedis(**args)
def get_mutex(self, key):
if self.distributed_lock:
return self.client.lock(u('_lock{0}').format(key),
@ -148,9 +148,12 @@ class RedisBackend(CacheBackend):
return pickle.loads(value)
def get_multi(self, keys):
if not keys:
return []
values = self.client.mget(keys)
return [pickle.loads(v) if v is not None else NO_VALUE
for v in values]
return [
pickle.loads(v) if v is not None else NO_VALUE
for v in values]
def set(self, key, value):
if self.redis_expiration_time:
@ -178,4 +181,3 @@ class RedisBackend(CacheBackend):
def delete_multi(self, keys):
self.client.delete(*keys)

View file

@ -15,3 +15,11 @@ class RegionNotConfigured(DogpileCacheException):
class ValidationError(DogpileCacheException):
"""Error validating a value or option."""
class PluginNotFound(DogpileCacheException):
"""The specified plugin could not be found.
.. versionadded:: 0.6.4
"""

View file

@ -2,7 +2,8 @@
Mako Integration
----------------
dogpile.cache includes a `Mako <http://www.makotemplates.org>`_ plugin that replaces `Beaker <http://beaker.groovie.org>`_
dogpile.cache includes a `Mako <http://www.makotemplates.org>`_ plugin
that replaces `Beaker <http://beaker.groovie.org>`_
as the cache backend.
Setup a Mako template lookup using the "dogpile.cache" cache implementation
and a region dictionary::
@ -31,9 +32,9 @@ and a region dictionary::
}
)
To use the above configuration in a template, use the ``cached=True`` argument on any
Mako tag which accepts it, in conjunction with the name of the desired region
as the ``cache_region`` argument::
To use the above configuration in a template, use the ``cached=True``
argument on any Mako tag which accepts it, in conjunction with the
name of the desired region as the ``cache_region`` argument::
<%def name="mysection()" cached="True" cache_region="memcached">
some content that's cached
@ -43,6 +44,7 @@ as the ``cache_region`` argument::
"""
from mako.cache import CacheImpl
class MakoPlugin(CacheImpl):
"""A Mako ``CacheImpl`` which talks to dogpile.cache."""
@ -70,8 +72,9 @@ class MakoPlugin(CacheImpl):
def get_and_replace(self, key, creation_function, **kw):
expiration_time = kw.pop("timeout", None)
return self._get_region(**kw).get_or_create(key, creation_function,
expiration_time=expiration_time)
return self._get_region(**kw).get_or_create(
key, creation_function,
expiration_time=expiration_time)
def get_or_create(self, key, creation_function, **kw):
return self.get_and_replace(key, creation_function, **kw)

View file

@ -12,6 +12,7 @@ base backend.
from .api import CacheBackend
class ProxyBackend(CacheBackend):
"""A decorator class for altering the functionality of backends.
@ -62,7 +63,9 @@ class ProxyBackend(CacheBackend):
Return an object that be used as a backend by a :class:`.CacheRegion`
object.
'''
assert(isinstance(backend, CacheBackend) or isinstance(backend, ProxyBackend))
assert(
isinstance(backend, CacheBackend) or
isinstance(backend, ProxyBackend))
self.proxied = backend
return self
@ -82,12 +85,11 @@ class ProxyBackend(CacheBackend):
def get_multi(self, keys):
return self.proxied.get_multi(keys)
def set_multi(self, keys):
self.proxied.set_multi(keys)
def set_multi(self, mapping):
self.proxied.set_multi(mapping)
def delete_multi(self, keys):
self.proxied.delete_multi(keys)
def get_mutex(self, key):
return self.proxied.get_mutex(key)

File diff suppressed because it is too large Load diff

View file

@ -1,57 +1,6 @@
from hashlib import sha1
import inspect
import re
import collections
from . import compat
def coerce_string_conf(d):
result = {}
for k, v in d.items():
if not isinstance(v, compat.string_types):
result[k] = v
continue
v = v.strip()
if re.match(r'^[-+]?\d+$', v):
result[k] = int(v)
elif re.match(r'^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?$', v):
result[k] = float(v)
elif v.lower() in ('false', 'true'):
result[k] = v.lower() == 'true'
elif v == 'None':
result[k] = None
else:
result[k] = v
return result
class PluginLoader(object):
def __init__(self, group):
self.group = group
self.impls = {}
def load(self, name):
if name in self.impls:
return self.impls[name]()
else: # pragma NO COVERAGE
import pkg_resources
for impl in pkg_resources.iter_entry_points(
self.group,
name):
self.impls[name] = impl.load
return impl.load()
else:
raise Exception(
"Can't load plugin %s %s" %
(self.group, name))
def register(self, name, modulepath, objname):
def load():
mod = __import__(modulepath)
for token in modulepath.split(".")[1:]:
mod = getattr(mod, token)
return getattr(mod, objname)
self.impls[name] = load
from ..util import compat
from ..util import langhelpers
def function_key_generator(namespace, fn, to_str=compat.string_type):
@ -62,8 +11,14 @@ def function_key_generator(namespace, fn, to_str=compat.string_type):
This is used by :meth:`.CacheRegion.cache_on_arguments`
to generate a cache key from a decorated function.
It can be replaced using the ``function_key_generator``
argument passed to :func:`.make_region`.
An alternate function may be used by specifying
the :paramref:`.CacheRegion.function_key_generator` argument
for :class:`.CacheRegion`.
.. seealso::
:func:`.kwarg_function_key_generator` - similar function that also
takes keyword arguments into account
"""
@ -72,19 +27,21 @@ def function_key_generator(namespace, fn, to_str=compat.string_type):
else:
namespace = '%s:%s|%s' % (fn.__module__, fn.__name__, namespace)
args = inspect.getargspec(fn)
args = compat.inspect_getargspec(fn)
has_self = args[0] and args[0][0] in ('self', 'cls')
def generate_key(*args, **kw):
if kw:
raise ValueError(
"dogpile.cache's default key creation "
"function does not accept keyword arguments.")
"dogpile.cache's default key creation "
"function does not accept keyword arguments.")
if has_self:
args = args[1:]
return namespace + "|" + " ".join(map(to_str, args))
return generate_key
def function_multi_key_generator(namespace, fn, to_str=compat.string_type):
if namespace is None:
@ -92,23 +49,80 @@ def function_multi_key_generator(namespace, fn, to_str=compat.string_type):
else:
namespace = '%s:%s|%s' % (fn.__module__, fn.__name__, namespace)
args = inspect.getargspec(fn)
args = compat.inspect_getargspec(fn)
has_self = args[0] and args[0][0] in ('self', 'cls')
def generate_keys(*args, **kw):
if kw:
raise ValueError(
"dogpile.cache's default key creation "
"function does not accept keyword arguments.")
"dogpile.cache's default key creation "
"function does not accept keyword arguments.")
if has_self:
args = args[1:]
return [namespace + "|" + key for key in map(to_str, args)]
return generate_keys
def kwarg_function_key_generator(namespace, fn, to_str=compat.string_type):
"""Return a function that generates a string
key, based on a given function as well as
arguments to the returned function itself.
For kwargs passed in, we will build a dict of
all argname (key) argvalue (values) including
default args from the argspec and then
alphabetize the list before generating the
key.
.. versionadded:: 0.6.2
.. seealso::
:func:`.function_key_generator` - default key generation function
"""
if namespace is None:
namespace = '%s:%s' % (fn.__module__, fn.__name__)
else:
namespace = '%s:%s|%s' % (fn.__module__, fn.__name__, namespace)
argspec = compat.inspect_getargspec(fn)
default_list = list(argspec.defaults or [])
# Reverse the list, as we want to compare the argspec by negative index,
# meaning default_list[0] should be args[-1], which works well with
# enumerate()
default_list.reverse()
# use idx*-1 to create the correct right-lookup index.
args_with_defaults = dict((argspec.args[(idx*-1)], default)
for idx, default in enumerate(default_list, 1))
if argspec.args and argspec.args[0] in ('self', 'cls'):
arg_index_start = 1
else:
arg_index_start = 0
def generate_key(*args, **kwargs):
as_kwargs = dict(
[(argspec.args[idx], arg)
for idx, arg in enumerate(args[arg_index_start:],
arg_index_start)])
as_kwargs.update(kwargs)
for arg, val in args_with_defaults.items():
if arg not in as_kwargs:
as_kwargs[arg] = val
argument_values = [as_kwargs[key]
for key in sorted(as_kwargs.keys())]
return namespace + '|' + " ".join(map(to_str, argument_values))
return generate_key
def sha1_mangle_key(key):
"""a SHA1 key mangler."""
return sha1(key).hexdigest()
def length_conditional_mangler(length, mangler):
"""a key mangler that mangles if the length of the key is
past a certain threshold.
@ -121,69 +135,11 @@ def length_conditional_mangler(length, mangler):
return key
return mangle
class memoized_property(object):
"""A read-only @property that is only evaluated once."""
def __init__(self, fget, doc=None):
self.fget = fget
self.__doc__ = doc or fget.__doc__
self.__name__ = fget.__name__
# in the 0.6 release these functions were moved to the dogpile.util namespace.
# They are linked here to maintain compatibility with older versions.
def __get__(self, obj, cls):
if obj is None:
return self
obj.__dict__[self.__name__] = result = self.fget(obj)
return result
def to_list(x, default=None):
"""Coerce to a list."""
if x is None:
return default
if not isinstance(x, (list, tuple)):
return [x]
else:
return x
class KeyReentrantMutex(object):
def __init__(self, key, mutex, keys):
self.key = key
self.mutex = mutex
self.keys = keys
@classmethod
def factory(cls, mutex):
# this collection holds zero or one
# thread idents as the key; a set of
# keynames held as the value.
keystore = collections.defaultdict(set)
def fac(key):
return KeyReentrantMutex(key, mutex, keystore)
return fac
def acquire(self, wait=True):
current_thread = compat.threading.current_thread().ident
keys = self.keys.get(current_thread)
if keys is not None and \
self.key not in keys:
# current lockholder, new key. add it in
keys.add(self.key)
return True
elif self.mutex.acquire(wait=wait):
# after acquire, create new set and add our key
self.keys[current_thread].add(self.key)
return True
else:
return False
def release(self):
current_thread = compat.threading.current_thread().ident
keys = self.keys.get(current_thread)
assert keys is not None, "this thread didn't do the acquire"
assert self.key in keys, "No acquire held for key '%s'" % self.key
keys.remove(self.key)
if not keys:
# when list of keys empty, remove
# the thread ident and unlock.
del self.keys[current_thread]
self.mutex.release()
coerce_string_conf = langhelpers.coerce_string_conf
KeyReentrantMutex = langhelpers.KeyReentrantMutex
memoized_property = langhelpers.memoized_property
PluginLoader = langhelpers.PluginLoader
to_list = langhelpers.to_list

17
libs/dogpile/core.py Normal file
View file

@ -0,0 +1,17 @@
"""Compatibility namespace for those using dogpile.core.
As of dogpile.cache 0.6.0, dogpile.core as a separate package
is no longer used by dogpile.cache.
Note that this namespace will not take effect if an actual
dogpile.core installation is present.
"""
from .util import nameregistry # noqa
from .util import readwrite_lock # noqa
from .util.readwrite_lock import ReadWriteMutex # noqa
from .util.nameregistry import NameRegistry # noqa
from .lock import Lock # noqa
from .lock import NeedRegenerationException # noqa
from . import __version__ # noqa

View file

@ -1,11 +0,0 @@
from .dogpile import NeedRegenerationException, Lock
from .nameregistry import NameRegistry
from .readwrite_lock import ReadWriteMutex
from .legacy import Dogpile, SyncReaderDogpile
__all__ = [
'Dogpile', 'SyncReaderDogpile', 'NeedRegenerationException',
'NameRegistry', 'ReadWriteMutex', 'Lock']
__version__ = '0.4.1'

View file

@ -1,154 +0,0 @@
from __future__ import with_statement
from .util import threading
from .readwrite_lock import ReadWriteMutex
from .dogpile import Lock
import time
import contextlib
class Dogpile(object):
"""Dogpile lock class.
.. deprecated:: 0.4.0
The :class:`.Lock` object specifies the full
API of the :class:`.Dogpile` object in a single way,
rather than providing multiple modes of usage which
don't necessarily work in the majority of cases.
:class:`.Dogpile` is now a wrapper around the :class:`.Lock` object
which provides dogpile.core's original usage pattern.
This usage pattern began as something simple, but was
not of general use in real-world caching environments without
several extra complicating factors; the :class:`.Lock`
object presents the "real-world" API more succinctly,
and also fixes a cross-process concurrency issue.
:param expiretime: Expiration time in seconds. Set to
``None`` for never expires.
:param init: if True, set the 'createdtime' to the
current time.
:param lock: a mutex object that provides
``acquire()`` and ``release()`` methods.
"""
def __init__(self, expiretime, init=False, lock=None):
"""Construct a new :class:`.Dogpile`.
"""
if lock:
self.dogpilelock = lock
else:
self.dogpilelock = threading.Lock()
self.expiretime = expiretime
if init:
self.createdtime = time.time()
createdtime = -1
"""The last known 'creation time' of the value,
stored as an epoch (i.e. from ``time.time()``).
If the value here is -1, it is assumed the value
should recreate immediately.
"""
def acquire(self, creator,
value_fn=None,
value_and_created_fn=None):
"""Acquire the lock, returning a context manager.
:param creator: Creation function, used if this thread
is chosen to create a new value.
:param value_fn: Optional function that returns
the value from some datasource. Will be returned
if regeneration is not needed.
:param value_and_created_fn: Like value_fn, but returns a tuple
of (value, createdtime). The returned createdtime
will replace the "createdtime" value on this dogpile
lock. This option removes the need for the dogpile lock
itself to remain persistent across usages; another
dogpile can come along later and pick up where the
previous one left off.
"""
if value_and_created_fn is None:
if value_fn is None:
def value_and_created_fn():
return None, self.createdtime
else:
def value_and_created_fn():
return value_fn(), self.createdtime
def creator_wrapper():
value = creator()
self.createdtime = time.time()
return value, self.createdtime
else:
def creator_wrapper():
value = creator()
self.createdtime = time.time()
return value
return Lock(
self.dogpilelock,
creator_wrapper,
value_and_created_fn,
self.expiretime
)
@property
def is_expired(self):
"""Return true if the expiration time is reached, or no
value is available."""
return not self.has_value or \
(
self.expiretime is not None and
time.time() - self.createdtime > self.expiretime
)
@property
def has_value(self):
"""Return true if the creation function has proceeded
at least once."""
return self.createdtime > 0
class SyncReaderDogpile(Dogpile):
"""Provide a read-write lock function on top of the :class:`.Dogpile`
class.
.. deprecated:: 0.4.0
The :class:`.ReadWriteMutex` object can be used directly.
"""
def __init__(self, *args, **kw):
super(SyncReaderDogpile, self).__init__(*args, **kw)
self.readwritelock = ReadWriteMutex()
@contextlib.contextmanager
def acquire_write_lock(self):
"""Return the "write" lock context manager.
This will provide a section that is mutexed against
all readers/writers for the dogpile-maintained value.
"""
self.readwritelock.acquire_write_lock()
try:
yield
finally:
self.readwritelock.release_write_lock()
@contextlib.contextmanager
def acquire(self, *arg, **kw):
with super(SyncReaderDogpile, self).acquire(*arg, **kw) as value:
self.readwritelock.acquire_read_lock()
try:
yield value
finally:
self.readwritelock.release_read_lock()

View file

@ -1,8 +0,0 @@
import sys
py3k = sys.version_info >= (3, 0)
try:
import threading
except ImportError:
import dummy_threading as threading

View file

@ -3,6 +3,7 @@ import logging
log = logging.getLogger(__name__)
class NeedRegenerationException(Exception):
"""An exception that when raised in the 'with' block,
forces the 'has_value' flag to False and incurs a
@ -12,6 +13,7 @@ class NeedRegenerationException(Exception):
NOT_REGENERATED = object()
class Lock(object):
"""Dogpile lock class.
@ -21,11 +23,6 @@ class Lock(object):
continue to return the previous version
of that value.
.. versionadded:: 0.4.0
The :class:`.Lock` class was added as a single-use object
representing the dogpile API without dependence on
any shared state between multiple instances.
:param mutex: A mutex object that provides ``acquire()``
and ``release()`` methods.
:param creator: Callable which returns a tuple of the form
@ -52,17 +49,16 @@ class Lock(object):
this to be used to defer invocation of the creator callable until some
later time.
.. versionadded:: 0.4.1 added the async_creator argument.
"""
def __init__(self,
mutex,
creator,
value_and_created_fn,
expiretime,
async_creator=None,
):
def __init__(
self,
mutex,
creator,
value_and_created_fn,
expiretime,
async_creator=None,
):
self.mutex = mutex
self.creator = creator
self.value_and_created_fn = value_and_created_fn
@ -73,11 +69,10 @@ class Lock(object):
"""Return true if the expiration time is reached, or no
value is available."""
return not self._has_value(createdtime) or \
(
self.expiretime is not None and
time.time() - createdtime > self.expiretime
)
return not self._has_value(createdtime) or (
self.expiretime is not None and
time.time() - createdtime > self.expiretime
)
def _has_value(self, createdtime):
"""Return true if the creation function has proceeded
@ -95,68 +90,100 @@ class Lock(object):
value = NOT_REGENERATED
createdtime = -1
generated = self._enter_create(createdtime)
generated = self._enter_create(value, createdtime)
if generated is not NOT_REGENERATED:
generated, createdtime = generated
return generated
elif value is NOT_REGENERATED:
# we called upon the creator, and it said that it
# didn't regenerate. this typically means another
# thread is running the creation function, and that the
# cache should still have a value. However,
# we don't have a value at all, which is unusual since we just
# checked for it, so check again (TODO: is this a real codepath?)
try:
value, createdtime = value_fn()
return value
except NeedRegenerationException:
raise Exception("Generation function should "
"have just been called by a concurrent "
"thread.")
raise Exception(
"Generation function should "
"have just been called by a concurrent "
"thread.")
else:
return value
def _enter_create(self, createdtime):
def _enter_create(self, value, createdtime):
if not self._is_expired(createdtime):
return NOT_REGENERATED
async = False
_async = False
if self._has_value(createdtime):
has_value = True
if not self.mutex.acquire(False):
log.debug("creation function in progress "
"elsewhere, returning")
log.debug(
"creation function in progress "
"elsewhere, returning")
return NOT_REGENERATED
else:
has_value = False
log.debug("no value, waiting for create lock")
self.mutex.acquire()
try:
log.debug("value creation lock %r acquired" % self.mutex)
# see if someone created the value already
try:
value, createdtime = self.value_and_created_fn()
except NeedRegenerationException:
pass
else:
if not self._is_expired(createdtime):
log.debug("value already present")
return value, createdtime
elif self.async_creator:
log.debug("Passing creation lock to async runner")
self.async_creator(self.mutex)
async = True
return value, createdtime
if not has_value:
# we entered without a value, or at least with "creationtime ==
# 0". Run the "getter" function again, to see if another
# thread has already generated the value while we waited on the
# mutex, or if the caller is otherwise telling us there is a
# value already which allows us to use async regeneration. (the
# latter is used by the multi-key routine).
try:
value, createdtime = self.value_and_created_fn()
except NeedRegenerationException:
# nope, nobody created the value, we're it.
# we must create it right now
pass
else:
has_value = True
# caller is telling us there is a value and that we can
# use async creation if it is expired.
if not self._is_expired(createdtime):
# it's not expired, return it
log.debug("Concurrent thread created the value")
return value, createdtime
log.debug("Calling creation function")
created = self.creator()
return created
# otherwise it's expired, call creator again
if has_value and self.async_creator:
# we have a value we can return, safe to use async_creator
log.debug("Passing creation lock to async runner")
# so...run it!
self.async_creator(self.mutex)
_async = True
# and return the expired value for now
return value, createdtime
# it's expired, and it's our turn to create it synchronously, *or*,
# there's no value at all, and we have to create it synchronously
log.debug(
"Calling creation function for %s value",
"not-yet-present" if not has_value else
"previously expired"
)
return self.creator()
finally:
if not async:
if not _async:
self.mutex.release()
log.debug("Released creation lock")
def __enter__(self):
return self._enter()
def __exit__(self, type, value, traceback):
pass

View file

@ -0,0 +1,4 @@
from .nameregistry import NameRegistry # noqa
from .readwrite_lock import ReadWriteMutex # noqa
from .langhelpers import PluginLoader, memoized_property, \
coerce_string_conf, to_list, KeyReentrantMutex # noqa

View file

@ -1,6 +1,5 @@
import sys
py2k = sys.version_info < (3, 0)
py3k = sys.version_info >= (3, 0)
py32 = sys.version_info >= (3, 2)
@ -11,10 +10,10 @@ win32 = sys.platform.startswith('win')
try:
import threading
except ImportError:
import dummy_threading as threading
import dummy_threading as threading # noqa
if py3k: # pragma: no cover
if py3k: # pragma: no cover
string_types = str,
text_type = str
string_type = str
@ -45,24 +44,44 @@ else:
def ue(s):
return unicode(s, "unicode_escape")
import ConfigParser as configparser
import StringIO as io
import ConfigParser as configparser # noqa
import StringIO as io # noqa
callable = callable # noqa
import thread # noqa
callable = callable
import thread
if py3k:
import collections
ArgSpec = collections.namedtuple(
"ArgSpec",
["args", "varargs", "keywords", "defaults"])
from inspect import getfullargspec as inspect_getfullargspec
def inspect_getargspec(func):
return ArgSpec(
*inspect_getfullargspec(func)[0:4]
)
else:
from inspect import getargspec as inspect_getargspec # noqa
if py3k or jython:
import pickle
else:
import cPickle as pickle
import cPickle as pickle # noqa
if py3k:
def read_config_file(config, fileobj):
return config.read_file(fileobj)
else:
def read_config_file(config, fileobj):
return config.readfp(fileobj)
def timedelta_total_seconds(td):
if py27:
return td.total_seconds()
else:
return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 1e6) / 1e6
return (td.microseconds + (
td.seconds + td.days * 24 * 3600) * 1e6) / 1e6

View file

@ -0,0 +1,123 @@
import re
import collections
from . import compat
def coerce_string_conf(d):
result = {}
for k, v in d.items():
if not isinstance(v, compat.string_types):
result[k] = v
continue
v = v.strip()
if re.match(r'^[-+]?\d+$', v):
result[k] = int(v)
elif re.match(r'^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?$', v):
result[k] = float(v)
elif v.lower() in ('false', 'true'):
result[k] = v.lower() == 'true'
elif v == 'None':
result[k] = None
else:
result[k] = v
return result
class PluginLoader(object):
def __init__(self, group):
self.group = group
self.impls = {}
def load(self, name):
if name in self.impls:
return self.impls[name]()
else: # pragma NO COVERAGE
import pkg_resources
for impl in pkg_resources.iter_entry_points(
self.group, name):
self.impls[name] = impl.load
return impl.load()
else:
raise self.NotFound(
"Can't load plugin %s %s" % (self.group, name)
)
def register(self, name, modulepath, objname):
def load():
mod = __import__(modulepath, fromlist=[objname])
return getattr(mod, objname)
self.impls[name] = load
class NotFound(Exception):
"""The specified plugin could not be found."""
class memoized_property(object):
"""A read-only @property that is only evaluated once."""
def __init__(self, fget, doc=None):
self.fget = fget
self.__doc__ = doc or fget.__doc__
self.__name__ = fget.__name__
def __get__(self, obj, cls):
if obj is None:
return self
obj.__dict__[self.__name__] = result = self.fget(obj)
return result
def to_list(x, default=None):
"""Coerce to a list."""
if x is None:
return default
if not isinstance(x, (list, tuple)):
return [x]
else:
return x
class KeyReentrantMutex(object):
def __init__(self, key, mutex, keys):
self.key = key
self.mutex = mutex
self.keys = keys
@classmethod
def factory(cls, mutex):
# this collection holds zero or one
# thread idents as the key; a set of
# keynames held as the value.
keystore = collections.defaultdict(set)
def fac(key):
return KeyReentrantMutex(key, mutex, keystore)
return fac
def acquire(self, wait=True):
current_thread = compat.threading.current_thread().ident
keys = self.keys.get(current_thread)
if keys is not None and \
self.key not in keys:
# current lockholder, new key. add it in
keys.add(self.key)
return True
elif self.mutex.acquire(wait=wait):
# after acquire, create new set and add our key
self.keys[current_thread].add(self.key)
return True
else:
return False
def release(self):
current_thread = compat.threading.current_thread().ident
keys = self.keys.get(current_thread)
assert keys is not None, "this thread didn't do the acquire"
assert self.key in keys, "No acquire held for key '%s'" % self.key
keys.remove(self.key)
if not keys:
# when list of keys empty, remove
# the thread ident and unlock.
del self.keys[current_thread]
self.mutex.release()

View file

@ -1,6 +1,7 @@
from .util import threading
from .compat import threading
import weakref
class NameRegistry(object):
"""Generates and return an object, keeping it as a
singleton for a certain identifier for as long as its
@ -49,7 +50,7 @@ class NameRegistry(object):
self.creator = creator
def get(self, identifier, *args, **kw):
"""Get and possibly create the value.
r"""Get and possibly create the value.
:param identifier: Hash key for the value.
If the creation function is called, this identifier
@ -74,10 +75,12 @@ class NameRegistry(object):
if identifier in self._values:
return self._values[identifier]
else:
self._values[identifier] = value = self.creator(identifier, *args, **kw)
self._values[identifier] = value = self.creator(
identifier, *args, **kw)
return value
except KeyError:
self._values[identifier] = value = self.creator(identifier, *args, **kw)
self._values[identifier] = value = self.creator(
identifier, *args, **kw)
return value
finally:
self._mutex.release()

View file

@ -1,11 +1,13 @@
from .util import threading
from .compat import threading
import logging
log = logging.getLogger(__name__)
class LockError(Exception):
pass
class ReadWriteMutex(object):
"""A mutex which allows multiple readers, single writer.
@ -21,7 +23,7 @@ class ReadWriteMutex(object):
def __init__(self):
# counts how many asynchronous methods are executing
self.async = 0
self.async_ = 0
# pointer to thread that is the current sync operation
self.current_sync_operation = None
@ -29,7 +31,7 @@ class ReadWriteMutex(object):
# condition object to lock on
self.condition = threading.Condition(threading.Lock())
def acquire_read_lock(self, wait = True):
def acquire_read_lock(self, wait=True):
"""Acquire the 'read' lock."""
self.condition.acquire()
try:
@ -43,7 +45,7 @@ class ReadWriteMutex(object):
if self.current_sync_operation is not None:
return False
self.async += 1
self.async_ += 1
log.debug("%s acquired read lock", self)
finally:
self.condition.release()
@ -55,23 +57,23 @@ class ReadWriteMutex(object):
"""Release the 'read' lock."""
self.condition.acquire()
try:
self.async -= 1
self.async_ -= 1
# check if we are the last asynchronous reader thread
# out the door.
if self.async == 0:
if self.async_ == 0:
# yes. so if a sync operation is waiting, notifyAll to wake
# it up
if self.current_sync_operation is not None:
self.condition.notifyAll()
elif self.async < 0:
elif self.async_ < 0:
raise LockError("Synchronizer error - too many "
"release_read_locks called")
log.debug("%s released read lock", self)
finally:
self.condition.release()
def acquire_write_lock(self, wait = True):
def acquire_write_lock(self, wait=True):
"""Acquire the 'write' lock."""
self.condition.acquire()
try:
@ -94,7 +96,7 @@ class ReadWriteMutex(object):
self.current_sync_operation = threading.currentThread()
# now wait again for asyncs to finish
if self.async > 0:
if self.async_ > 0:
if wait:
# wait
self.condition.wait()

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
__title__ = 'enzyme'
__version__ = '0.4.2'
__version__ = '0.4.1'
__author__ = 'Antoine Bertin'
__license__ = 'Apache 2.0'
__copyright__ = 'Copyright 2013 Antoine Bertin'

View file

@ -21,8 +21,8 @@ def setUpModule():
class MKVTestCase(unittest.TestCase):
def test_test1(self):
with io.open(os.path.join(TEST_DIR, 'test1.mkv'), 'rb') as stream:
mkv = MKV(stream)
stream = io.open(os.path.join(TEST_DIR, 'test1.mkv'), 'rb')
mkv = MKV(stream)
# info
self.assertTrue(mkv.info.title is None)
self.assertTrue(mkv.info.duration == timedelta(minutes=1, seconds=27, milliseconds=336))
@ -90,8 +90,8 @@ class MKVTestCase(unittest.TestCase):
self.assertTrue(mkv.tags[0].simpletags[2].binary is None)
def test_test2(self):
with io.open(os.path.join(TEST_DIR, 'test2.mkv'), 'rb') as stream:
mkv = MKV(stream)
stream = io.open(os.path.join(TEST_DIR, 'test2.mkv'), 'rb')
mkv = MKV(stream)
# info
self.assertTrue(mkv.info.title is None)
self.assertTrue(mkv.info.duration == timedelta(seconds=47, milliseconds=509))
@ -159,8 +159,8 @@ class MKVTestCase(unittest.TestCase):
self.assertTrue(mkv.tags[0].simpletags[2].binary is None)
def test_test3(self):
with io.open(os.path.join(TEST_DIR, 'test3.mkv'), 'rb') as stream:
mkv = MKV(stream)
stream = io.open(os.path.join(TEST_DIR, 'test3.mkv'), 'rb')
mkv = MKV(stream)
# info
self.assertTrue(mkv.info.title is None)
self.assertTrue(mkv.info.duration == timedelta(seconds=49, milliseconds=64))
@ -228,8 +228,8 @@ class MKVTestCase(unittest.TestCase):
self.assertTrue(mkv.tags[0].simpletags[2].binary is None)
def test_test5(self):
with io.open(os.path.join(TEST_DIR, 'test5.mkv'), 'rb') as stream:
mkv = MKV(stream)
stream = io.open(os.path.join(TEST_DIR, 'test5.mkv'), 'rb')
mkv = MKV(stream)
# info
self.assertTrue(mkv.info.title is None)
self.assertTrue(mkv.info.duration == timedelta(seconds=46, milliseconds=665))
@ -391,8 +391,8 @@ class MKVTestCase(unittest.TestCase):
self.assertTrue(mkv.tags[0].simpletags[2].binary is None)
def test_test6(self):
with io.open(os.path.join(TEST_DIR, 'test6.mkv'), 'rb') as stream:
mkv = MKV(stream)
stream = io.open(os.path.join(TEST_DIR, 'test6.mkv'), 'rb')
mkv = MKV(stream)
# info
self.assertTrue(mkv.info.title is None)
self.assertTrue(mkv.info.duration == timedelta(seconds=87, milliseconds=336))
@ -460,8 +460,8 @@ class MKVTestCase(unittest.TestCase):
self.assertTrue(mkv.tags[0].simpletags[2].binary is None)
def test_test7(self):
with io.open(os.path.join(TEST_DIR, 'test7.mkv'), 'rb') as stream:
mkv = MKV(stream)
stream = io.open(os.path.join(TEST_DIR, 'test7.mkv'), 'rb')
mkv = MKV(stream)
# info
self.assertTrue(mkv.info.title is None)
self.assertTrue(mkv.info.duration == timedelta(seconds=37, milliseconds=43))
@ -529,8 +529,8 @@ class MKVTestCase(unittest.TestCase):
self.assertTrue(mkv.tags[0].simpletags[2].binary is None)
def test_test8(self):
with io.open(os.path.join(TEST_DIR, 'test8.mkv'), 'rb') as stream:
mkv = MKV(stream)
stream = io.open(os.path.join(TEST_DIR, 'test8.mkv'), 'rb')
mkv = MKV(stream)
# info
self.assertTrue(mkv.info.title is None)
self.assertTrue(mkv.info.duration == timedelta(seconds=47, milliseconds=341))

0
libs/pbr/__init__.py Normal file
View file

292
libs/pbr/builddoc.py Normal file
View file

@ -0,0 +1,292 @@
# Copyright 2011 OpenStack Foundation
# Copyright 2012-2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from distutils import log
import fnmatch
import os
import sys
try:
import cStringIO
except ImportError:
import io as cStringIO
try:
import sphinx
# NOTE(dhellmann): Newer versions of Sphinx have moved the apidoc
# module into sphinx.ext and the API is slightly different (the
# function expects sys.argv[1:] instead of sys.argv[:]. So, figure
# out where we can import it from and set a flag so we can invoke
# it properly. See this change in sphinx for details:
# https://github.com/sphinx-doc/sphinx/commit/87630c8ae8bff8c0e23187676e6343d8903003a6
try:
from sphinx.ext import apidoc
apidoc_use_padding = False
except ImportError:
from sphinx import apidoc
apidoc_use_padding = True
from sphinx import application
from sphinx import setup_command
except Exception as e:
# NOTE(dhellmann): During the installation of docutils, setuptools
# tries to import pbr code to find the egg_info.writer hooks. That
# imports this module, which imports sphinx, which imports
# docutils, which is being installed. Because docutils uses 2to3
# to convert its code during installation under python 3, the
# import fails, but it fails with an error other than ImportError
# (today it's a NameError on StandardError, an exception base
# class). Convert the exception type here so it can be caught in
# packaging.py where we try to determine if we can import and use
# sphinx by importing this module. See bug #1403510 for details.
raise ImportError(str(e))
from pbr import git
from pbr import options
from pbr import version
_deprecated_options = ['autodoc_tree_index_modules', 'autodoc_index_modules',
'autodoc_tree_excludes', 'autodoc_exclude_modules']
_deprecated_envs = ['AUTODOC_TREE_INDEX_MODULES', 'AUTODOC_INDEX_MODULES']
_rst_template = """%(heading)s
%(underline)s
.. automodule:: %(module)s
:members:
:undoc-members:
:show-inheritance:
"""
def _find_modules(arg, dirname, files):
for filename in files:
if filename.endswith('.py') and filename != '__init__.py':
arg["%s.%s" % (dirname.replace('/', '.'),
filename[:-3])] = True
class LocalBuildDoc(setup_command.BuildDoc):
builders = ['html']
command_name = 'build_sphinx'
sphinx_initialized = False
def _get_source_dir(self):
option_dict = self.distribution.get_option_dict('build_sphinx')
pbr_option_dict = self.distribution.get_option_dict('pbr')
_, api_doc_dir = pbr_option_dict.get('api_doc_dir', (None, 'api'))
if 'source_dir' in option_dict:
source_dir = os.path.join(option_dict['source_dir'][1],
api_doc_dir)
else:
source_dir = 'doc/source/' + api_doc_dir
if not os.path.exists(source_dir):
os.makedirs(source_dir)
return source_dir
def generate_autoindex(self, excluded_modules=None):
log.info("[pbr] Autodocumenting from %s"
% os.path.abspath(os.curdir))
modules = {}
source_dir = self._get_source_dir()
for pkg in self.distribution.packages:
if '.' not in pkg:
for dirpath, dirnames, files in os.walk(pkg):
_find_modules(modules, dirpath, files)
def include(module):
return not any(fnmatch.fnmatch(module, pat)
for pat in excluded_modules)
module_list = sorted(mod for mod in modules.keys() if include(mod))
autoindex_filename = os.path.join(source_dir, 'autoindex.rst')
with open(autoindex_filename, 'w') as autoindex:
autoindex.write(""".. toctree::
:maxdepth: 1
""")
for module in module_list:
output_filename = os.path.join(source_dir,
"%s.rst" % module)
heading = "The :mod:`%s` Module" % module
underline = "=" * len(heading)
values = dict(module=module, heading=heading,
underline=underline)
log.info("[pbr] Generating %s"
% output_filename)
with open(output_filename, 'w') as output_file:
output_file.write(_rst_template % values)
autoindex.write(" %s.rst\n" % module)
def _sphinx_tree(self):
source_dir = self._get_source_dir()
cmd = ['-H', 'Modules', '-o', source_dir, '.']
if apidoc_use_padding:
cmd.insert(0, 'apidoc')
apidoc.main(cmd + self.autodoc_tree_excludes)
def _sphinx_run(self):
if not self.verbose:
status_stream = cStringIO.StringIO()
else:
status_stream = sys.stdout
confoverrides = {}
if self.project:
confoverrides['project'] = self.project
if self.version:
confoverrides['version'] = self.version
if self.release:
confoverrides['release'] = self.release
if self.today:
confoverrides['today'] = self.today
if self.sphinx_initialized:
confoverrides['suppress_warnings'] = [
'app.add_directive', 'app.add_role',
'app.add_generic_role', 'app.add_node',
'image.nonlocal_uri',
]
app = application.Sphinx(
self.source_dir, self.config_dir,
self.builder_target_dir, self.doctree_dir,
self.builder, confoverrides, status_stream,
freshenv=self.fresh_env, warningiserror=self.warning_is_error)
self.sphinx_initialized = True
try:
app.build(force_all=self.all_files)
except Exception as err:
from docutils import utils
if isinstance(err, utils.SystemMessage):
sys.stder.write('reST markup error:\n')
sys.stderr.write(err.args[0].encode('ascii',
'backslashreplace'))
sys.stderr.write('\n')
else:
raise
if self.link_index:
src = app.config.master_doc + app.builder.out_suffix
dst = app.builder.get_outfilename('index')
os.symlink(src, dst)
def run(self):
option_dict = self.distribution.get_option_dict('pbr')
# TODO(stephenfin): Remove this (and the entire file) when 5.0 is
# released
warn_opts = set(option_dict.keys()).intersection(_deprecated_options)
warn_env = list(filter(lambda x: os.getenv(x), _deprecated_envs))
if warn_opts or warn_env:
msg = ('The autodoc and autodoc_tree features are deprecated in '
'4.2 and will be removed in a future release. You should '
'use the sphinxcontrib-apidoc Sphinx extension instead. '
'Refer to the pbr documentation for more information.')
if warn_opts:
msg += ' Deprecated options: %s' % list(warn_opts)
if warn_env:
msg += ' Deprecated environment variables: %s' % warn_env
log.warn(msg)
if git._git_is_installed():
git.write_git_changelog(option_dict=option_dict)
git.generate_authors(option_dict=option_dict)
tree_index = options.get_boolean_option(option_dict,
'autodoc_tree_index_modules',
'AUTODOC_TREE_INDEX_MODULES')
auto_index = options.get_boolean_option(option_dict,
'autodoc_index_modules',
'AUTODOC_INDEX_MODULES')
if not os.getenv('SPHINX_DEBUG'):
# NOTE(afazekas): These options can be used together,
# but they do a very similar thing in a different way
if tree_index:
self._sphinx_tree()
if auto_index:
self.generate_autoindex(
set(option_dict.get(
"autodoc_exclude_modules",
[None, ""])[1].split()))
self.finalize_options()
is_multibuilder_sphinx = version.SemanticVersion.from_pip_string(
sphinx.__version__) >= version.SemanticVersion(1, 6)
# TODO(stephenfin): Remove support for Sphinx < 1.6 in 4.0
if not is_multibuilder_sphinx:
log.warn('[pbr] Support for Sphinx < 1.6 will be dropped in '
'pbr 4.0. Upgrade to Sphinx 1.6+')
# TODO(stephenfin): Remove this at the next MAJOR version bump
if self.builders != ['html']:
log.warn("[pbr] Sphinx 1.6 added native support for "
"specifying multiple builders in the "
"'[sphinx_build] builder' configuration option, "
"found in 'setup.cfg'. As a result, the "
"'[sphinx_build] builders' option has been "
"deprecated and will be removed in pbr 4.0. Migrate "
"to the 'builder' configuration option.")
if is_multibuilder_sphinx:
self.builder = self.builders
if is_multibuilder_sphinx:
# Sphinx >= 1.6
return setup_command.BuildDoc.run(self)
# Sphinx < 1.6
for builder in self.builders:
self.builder = builder
self.finalize_options()
self._sphinx_run()
def initialize_options(self):
# Not a new style class, super keyword does not work.
setup_command.BuildDoc.initialize_options(self)
# NOTE(dstanek): exclude setup.py from the autodoc tree index
# builds because all projects will have an issue with it
self.autodoc_tree_excludes = ['setup.py']
def finalize_options(self):
from pbr import util
# Not a new style class, super keyword does not work.
setup_command.BuildDoc.finalize_options(self)
# Handle builder option from command line - override cfg
option_dict = self.distribution.get_option_dict('build_sphinx')
if 'command line' in option_dict.get('builder', [[]])[0]:
self.builders = option_dict['builder'][1]
# Allow builders to be configurable - as a comma separated list.
if not isinstance(self.builders, list) and self.builders:
self.builders = self.builders.split(',')
self.project = self.distribution.get_name()
self.version = self.distribution.get_version()
self.release = self.distribution.get_version()
# NOTE(dstanek): check for autodoc tree exclusion overrides
# in the setup.cfg
opt = 'autodoc_tree_excludes'
option_dict = self.distribution.get_option_dict('pbr')
if opt in option_dict:
self.autodoc_tree_excludes = util.split_multiline(
option_dict[opt][1])
# handle Sphinx < 1.5.0
if not hasattr(self, 'warning_is_error'):
self.warning_is_error = False

0
libs/pbr/cmd/__init__.py Normal file
View file

112
libs/pbr/cmd/main.py Normal file
View file

@ -0,0 +1,112 @@
# Copyright 2014 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import argparse
import json
import sys
import pkg_resources
import pbr.version
def _get_metadata(package_name):
try:
return json.loads(
pkg_resources.get_distribution(
package_name).get_metadata('pbr.json'))
except pkg_resources.DistributionNotFound:
raise Exception('Package {0} not installed'.format(package_name))
except Exception:
return None
def get_sha(args):
sha = _get_info(args.name)['sha']
if sha:
print(sha)
def get_info(args):
print("{name}\t{version}\t{released}\t{sha}".format(
**_get_info(args.name)))
def _get_info(name):
metadata = _get_metadata(name)
version = pkg_resources.get_distribution(name).version
if metadata:
if metadata['is_release']:
released = 'released'
else:
released = 'pre-release'
sha = metadata['git_version']
else:
version_parts = version.split('.')
if version_parts[-1].startswith('g'):
sha = version_parts[-1][1:]
released = 'pre-release'
else:
sha = ""
released = "released"
for part in version_parts:
if not part.isdigit():
released = "pre-release"
return dict(name=name, version=version, sha=sha, released=released)
def freeze(args):
sorted_dists = sorted(pkg_resources.working_set,
key=lambda dist: dist.project_name.lower())
for dist in sorted_dists:
info = _get_info(dist.project_name)
output = "{name}=={version}".format(**info)
if info['sha']:
output += " # git sha {sha}".format(**info)
print(output)
def main():
parser = argparse.ArgumentParser(
description='pbr: Python Build Reasonableness')
parser.add_argument(
'-v', '--version', action='version',
version=str(pbr.version.VersionInfo('pbr')))
subparsers = parser.add_subparsers(
title='commands', description='valid commands', help='additional help')
cmd_sha = subparsers.add_parser('sha', help='print sha of package')
cmd_sha.set_defaults(func=get_sha)
cmd_sha.add_argument('name', help='package to print sha of')
cmd_info = subparsers.add_parser(
'info', help='print version info for package')
cmd_info.set_defaults(func=get_info)
cmd_info.add_argument('name', help='package to print info of')
cmd_freeze = subparsers.add_parser(
'freeze', help='print version info for all installed packages')
cmd_freeze.set_defaults(func=freeze)
args = parser.parse_args()
try:
args.func(args)
except Exception as e:
print(e)
if __name__ == '__main__':
sys.exit(main())

145
libs/pbr/core.py Normal file
View file

@ -0,0 +1,145 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright (C) 2013 Association of Universities for Research in Astronomy
# (AURA)
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# 3. The name of AURA and its representatives may not be used to
# endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY AURA ``AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL AURA BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
import logging
import os
import sys
import warnings
from distutils import errors
from pbr import util
if sys.version_info[0] == 3:
string_type = str
integer_types = (int,)
else:
string_type = basestring # noqa
integer_types = (int, long) # noqa
def pbr(dist, attr, value):
"""Implements the actual pbr setup() keyword.
When used, this should be the only keyword in your setup() aside from
`setup_requires`.
If given as a string, the value of pbr is assumed to be the relative path
to the setup.cfg file to use. Otherwise, if it evaluates to true, it
simply assumes that pbr should be used, and the default 'setup.cfg' is
used.
This works by reading the setup.cfg file, parsing out the supported
metadata and command options, and using them to rebuild the
`DistributionMetadata` object and set the newly added command options.
The reason for doing things this way is that a custom `Distribution` class
will not play nicely with setup_requires; however, this implementation may
not work well with distributions that do use a `Distribution` subclass.
"""
if not value:
return
if isinstance(value, string_type):
path = os.path.abspath(value)
else:
path = os.path.abspath('setup.cfg')
if not os.path.exists(path):
raise errors.DistutilsFileError(
'The setup.cfg file %s does not exist.' % path)
# Converts the setup.cfg file to setup() arguments
try:
attrs = util.cfg_to_args(path, dist.script_args)
except Exception:
e = sys.exc_info()[1]
# NB: This will output to the console if no explicit logging has
# been setup - but thats fine, this is a fatal distutils error, so
# being pretty isn't the #1 goal.. being diagnosable is.
logging.exception('Error parsing')
raise errors.DistutilsSetupError(
'Error parsing %s: %s: %s' % (path, e.__class__.__name__, e))
# There are some metadata fields that are only supported by
# setuptools and not distutils, and hence are not in
# dist.metadata. We are OK to write these in. For gory details
# see
# https://github.com/pypa/setuptools/pull/1343
_DISTUTILS_UNSUPPORTED_METADATA = (
'long_description_content_type', 'project_urls', 'provides_extras'
)
# Repeat some of the Distribution initialization code with the newly
# provided attrs
if attrs:
# Skips 'options' and 'licence' support which are rarely used; may
# add back in later if demanded
for key, val in attrs.items():
if hasattr(dist.metadata, 'set_' + key):
getattr(dist.metadata, 'set_' + key)(val)
elif hasattr(dist.metadata, key):
setattr(dist.metadata, key, val)
elif hasattr(dist, key):
setattr(dist, key, val)
elif key in _DISTUTILS_UNSUPPORTED_METADATA:
setattr(dist.metadata, key, val)
else:
msg = 'Unknown distribution option: %s' % repr(key)
warnings.warn(msg)
# Re-finalize the underlying Distribution
try:
super(dist.__class__, dist).finalize_options()
except TypeError:
# If dist is not declared as a new-style class (with object as
# a subclass) then super() will not work on it. This is the case
# for Python 2. In that case, fall back to doing this the ugly way
dist.__class__.__bases__[-1].finalize_options(dist)
# This bit comes out of distribute/setuptools
if isinstance(dist.metadata.version, integer_types + (float,)):
# Some people apparently take "version number" too literally :)
dist.metadata.version = str(dist.metadata.version)

35
libs/pbr/extra_files.py Normal file
View file

@ -0,0 +1,35 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils import errors
import os
_extra_files = []
def get_extra_files():
global _extra_files
return _extra_files
def set_extra_files(extra_files):
# Let's do a sanity check
for filename in extra_files:
if not os.path.exists(filename):
raise errors.DistutilsFileError(
'%s from the extra_files option in setup.cfg does not '
'exist' % filename)
global _extra_files
_extra_files[:] = extra_files[:]

29
libs/pbr/find_package.py Normal file
View file

@ -0,0 +1,29 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import setuptools
def smart_find_packages(package_list):
"""Run find_packages the way we intend."""
packages = []
for pkg in package_list.strip().split("\n"):
pkg_path = pkg.replace('.', os.path.sep)
packages.append(pkg)
packages.extend(['%s.%s' % (pkg, f)
for f in setuptools.find_packages(pkg_path)])
return "\n".join(set(packages))

331
libs/pbr/git.py Normal file
View file

@ -0,0 +1,331 @@
# Copyright 2011 OpenStack Foundation
# Copyright 2012-2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import unicode_literals
import distutils.errors
from distutils import log
import errno
import io
import os
import re
import subprocess
import time
import pkg_resources
from pbr import options
from pbr import version
def _run_shell_command(cmd, throw_on_error=False, buffer=True, env=None):
if buffer:
out_location = subprocess.PIPE
err_location = subprocess.PIPE
else:
out_location = None
err_location = None
newenv = os.environ.copy()
if env:
newenv.update(env)
output = subprocess.Popen(cmd,
stdout=out_location,
stderr=err_location,
env=newenv)
out = output.communicate()
if output.returncode and throw_on_error:
raise distutils.errors.DistutilsError(
"%s returned %d" % (cmd, output.returncode))
if len(out) == 0 or not out[0] or not out[0].strip():
return ''
# Since we don't control the history, and forcing users to rebase arbitrary
# history to fix utf8 issues is harsh, decode with replace.
return out[0].strip().decode('utf-8', 'replace')
def _run_git_command(cmd, git_dir, **kwargs):
if not isinstance(cmd, (list, tuple)):
cmd = [cmd]
return _run_shell_command(
['git', '--git-dir=%s' % git_dir] + cmd, **kwargs)
def _get_git_directory():
try:
return _run_shell_command(['git', 'rev-parse', '--git-dir'])
except OSError as e:
if e.errno == errno.ENOENT:
# git not installed.
return ''
raise
def _git_is_installed():
try:
# We cannot use 'which git' as it may not be available
# in some distributions, So just try 'git --version'
# to see if we run into trouble
_run_shell_command(['git', '--version'])
except OSError:
return False
return True
def _get_highest_tag(tags):
"""Find the highest tag from a list.
Pass in a list of tag strings and this will return the highest
(latest) as sorted by the pkg_resources version parser.
"""
return max(tags, key=pkg_resources.parse_version)
def _find_git_files(dirname='', git_dir=None):
"""Behave like a file finder entrypoint plugin.
We don't actually use the entrypoints system for this because it runs
at absurd times. We only want to do this when we are building an sdist.
"""
file_list = []
if git_dir is None:
git_dir = _run_git_functions()
if git_dir:
log.info("[pbr] In git context, generating filelist from git")
file_list = _run_git_command(['ls-files', '-z'], git_dir)
# Users can fix utf8 issues locally with a single commit, so we are
# strict here.
file_list = file_list.split(b'\x00'.decode('utf-8'))
return [f for f in file_list if f]
def _get_raw_tag_info(git_dir):
describe = _run_git_command(['describe', '--always'], git_dir)
if "-" in describe:
return describe.rsplit("-", 2)[-2]
if "." in describe:
return 0
return None
def get_is_release(git_dir):
return _get_raw_tag_info(git_dir) == 0
def _run_git_functions():
git_dir = None
if _git_is_installed():
git_dir = _get_git_directory()
return git_dir or None
def get_git_short_sha(git_dir=None):
"""Return the short sha for this repo, if it exists."""
if not git_dir:
git_dir = _run_git_functions()
if git_dir:
return _run_git_command(
['log', '-n1', '--pretty=format:%h'], git_dir)
return None
def _clean_changelog_message(msg):
"""Cleans any instances of invalid sphinx wording.
This escapes/removes any instances of invalid characters
that can be interpreted by sphinx as a warning or error
when translating the Changelog into an HTML file for
documentation building within projects.
* Escapes '_' which is interpreted as a link
* Escapes '*' which is interpreted as a new line
* Escapes '`' which is interpreted as a literal
"""
msg = msg.replace('*', '\*')
msg = msg.replace('_', '\_')
msg = msg.replace('`', '\`')
return msg
def _iter_changelog(changelog):
"""Convert a oneline log iterator to formatted strings.
:param changelog: An iterator of one line log entries like
that given by _iter_log_oneline.
:return: An iterator over (release, formatted changelog) tuples.
"""
first_line = True
current_release = None
yield current_release, "CHANGES\n=======\n\n"
for hash, tags, msg in changelog:
if tags:
current_release = _get_highest_tag(tags)
underline = len(current_release) * '-'
if not first_line:
yield current_release, '\n'
yield current_release, (
"%(tag)s\n%(underline)s\n\n" %
dict(tag=current_release, underline=underline))
if not msg.startswith("Merge "):
if msg.endswith("."):
msg = msg[:-1]
msg = _clean_changelog_message(msg)
yield current_release, "* %(msg)s\n" % dict(msg=msg)
first_line = False
def _iter_log_oneline(git_dir=None):
"""Iterate over --oneline log entries if possible.
This parses the output into a structured form but does not apply
presentation logic to the output - making it suitable for different
uses.
:return: An iterator of (hash, tags_set, 1st_line) tuples, or None if
changelog generation is disabled / not available.
"""
if git_dir is None:
git_dir = _get_git_directory()
if not git_dir:
return []
return _iter_log_inner(git_dir)
def _is_valid_version(candidate):
try:
version.SemanticVersion.from_pip_string(candidate)
return True
except ValueError:
return False
def _iter_log_inner(git_dir):
"""Iterate over --oneline log entries.
This parses the output intro a structured form but does not apply
presentation logic to the output - making it suitable for different
uses.
:return: An iterator of (hash, tags_set, 1st_line) tuples.
"""
log.info('[pbr] Generating ChangeLog')
log_cmd = ['log', '--decorate=full', '--format=%h%x00%s%x00%d']
changelog = _run_git_command(log_cmd, git_dir)
for line in changelog.split('\n'):
line_parts = line.split('\x00')
if len(line_parts) != 3:
continue
sha, msg, refname = line_parts
tags = set()
# refname can be:
# <empty>
# HEAD, tag: refs/tags/1.4.0, refs/remotes/origin/master, \
# refs/heads/master
# refs/tags/1.3.4
if "refs/tags/" in refname:
refname = refname.strip()[1:-1] # remove wrapping ()'s
# If we start with "tag: refs/tags/1.2b1, tag: refs/tags/1.2"
# The first split gives us "['', '1.2b1, tag:', '1.2']"
# Which is why we do the second split below on the comma
for tag_string in refname.split("refs/tags/")[1:]:
# git tag does not allow : or " " in tag names, so we split
# on ", " which is the separator between elements
candidate = tag_string.split(", ")[0]
if _is_valid_version(candidate):
tags.add(candidate)
yield sha, tags, msg
def write_git_changelog(git_dir=None, dest_dir=os.path.curdir,
option_dict=None, changelog=None):
"""Write a changelog based on the git changelog."""
start = time.time()
if not option_dict:
option_dict = {}
should_skip = options.get_boolean_option(option_dict, 'skip_changelog',
'SKIP_WRITE_GIT_CHANGELOG')
if should_skip:
return
if not changelog:
changelog = _iter_log_oneline(git_dir=git_dir)
if changelog:
changelog = _iter_changelog(changelog)
if not changelog:
return
new_changelog = os.path.join(dest_dir, 'ChangeLog')
# If there's already a ChangeLog and it's not writable, just use it
if (os.path.exists(new_changelog)
and not os.access(new_changelog, os.W_OK)):
log.info('[pbr] ChangeLog not written (file already'
' exists and it is not writeable)')
return
log.info('[pbr] Writing ChangeLog')
with io.open(new_changelog, "w", encoding="utf-8") as changelog_file:
for release, content in changelog:
changelog_file.write(content)
stop = time.time()
log.info('[pbr] ChangeLog complete (%0.1fs)' % (stop - start))
def generate_authors(git_dir=None, dest_dir='.', option_dict=dict()):
"""Create AUTHORS file using git commits."""
should_skip = options.get_boolean_option(option_dict, 'skip_authors',
'SKIP_GENERATE_AUTHORS')
if should_skip:
return
start = time.time()
old_authors = os.path.join(dest_dir, 'AUTHORS.in')
new_authors = os.path.join(dest_dir, 'AUTHORS')
# If there's already an AUTHORS file and it's not writable, just use it
if (os.path.exists(new_authors)
and not os.access(new_authors, os.W_OK)):
return
log.info('[pbr] Generating AUTHORS')
ignore_emails = '((jenkins|zuul)@review|infra@lists|jenkins@openstack)'
if git_dir is None:
git_dir = _get_git_directory()
if git_dir:
authors = []
# don't include jenkins email address in AUTHORS file
git_log_cmd = ['log', '--format=%aN <%aE>']
authors += _run_git_command(git_log_cmd, git_dir).split('\n')
authors = [a for a in authors if not re.search(ignore_emails, a)]
# get all co-authors from commit messages
co_authors_out = _run_git_command('log', git_dir)
co_authors = re.findall('Co-authored-by:.+', co_authors_out,
re.MULTILINE)
co_authors = [signed.split(":", 1)[1].strip()
for signed in co_authors if signed]
authors += co_authors
authors = sorted(set(authors))
with open(new_authors, 'wb') as new_authors_fh:
if os.path.exists(old_authors):
with open(old_authors, "rb") as old_authors_fh:
new_authors_fh.write(old_authors_fh.read())
new_authors_fh.write(('\n'.join(authors) + '\n')
.encode('utf-8'))
stop = time.time()
log.info('[pbr] AUTHORS complete (%0.1fs)' % (stop - start))

View file

@ -0,0 +1,28 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from pbr.hooks import backwards
from pbr.hooks import commands
from pbr.hooks import files
from pbr.hooks import metadata
def setup_hook(config):
"""Filter config parsed from a setup.cfg to inject our defaults."""
metadata_config = metadata.MetadataConfig(config)
metadata_config.run()
backwards.BackwardsCompatConfig(config).run()
commands.CommandsConfig(config).run()
files.FilesConfig(config, metadata_config.get_name()).run()

View file

@ -0,0 +1,33 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from pbr.hooks import base
from pbr import packaging
class BackwardsCompatConfig(base.BaseConfig):
section = 'backwards_compat'
def hook(self):
self.config['include_package_data'] = 'True'
packaging.append_text_list(
self.config, 'dependency_links',
packaging.parse_dependency_links())
packaging.append_text_list(
self.config, 'tests_require',
packaging.parse_requirements(
packaging.TEST_REQUIREMENTS_FILES,
strip_markers=True))

34
libs/pbr/hooks/base.py Normal file
View file

@ -0,0 +1,34 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
class BaseConfig(object):
section = None
def __init__(self, config):
self._global_config = config
self.config = self._global_config.get(self.section, dict())
self.pbr_config = config.get('pbr', dict())
def run(self):
self.hook()
self.save()
def hook(self):
pass
def save(self):
self._global_config[self.section] = self.config

View file

@ -0,0 +1,66 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
from setuptools.command import easy_install
from pbr.hooks import base
from pbr import options
from pbr import packaging
class CommandsConfig(base.BaseConfig):
section = 'global'
def __init__(self, config):
super(CommandsConfig, self).__init__(config)
self.commands = self.config.get('commands', "")
def save(self):
self.config['commands'] = self.commands
super(CommandsConfig, self).save()
def add_command(self, command):
self.commands = "%s\n%s" % (self.commands, command)
def hook(self):
self.add_command('pbr.packaging.LocalEggInfo')
self.add_command('pbr.packaging.LocalSDist')
self.add_command('pbr.packaging.LocalInstallScripts')
self.add_command('pbr.packaging.LocalDevelop')
self.add_command('pbr.packaging.LocalRPMVersion')
self.add_command('pbr.packaging.LocalDebVersion')
if os.name != 'nt':
easy_install.get_script_args = packaging.override_get_script_args
if packaging.have_sphinx():
self.add_command('pbr.builddoc.LocalBuildDoc')
if os.path.exists('.testr.conf') and packaging.have_testr():
# There is a .testr.conf file. We want to use it.
self.add_command('pbr.packaging.TestrTest')
elif self.config.get('nosetests', False) and packaging.have_nose():
# We seem to still have nose configured
self.add_command('pbr.packaging.NoseTest')
use_egg = options.get_boolean_option(
self.pbr_config, 'use-egg', 'PBR_USE_EGG')
# We always want non-egg install unless explicitly requested
if 'manpages' in self.pbr_config or not use_egg:
self.add_command('pbr.packaging.LocalInstall')
else:
self.add_command('pbr.packaging.InstallWithGit')

103
libs/pbr/hooks/files.py Normal file
View file

@ -0,0 +1,103 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import sys
from pbr import find_package
from pbr.hooks import base
def get_manpath():
manpath = 'share/man'
if os.path.exists(os.path.join(sys.prefix, 'man')):
# This works around a bug with install where it expects every node
# in the relative data directory to be an actual directory, since at
# least Debian derivatives (and probably other platforms as well)
# like to symlink Unixish /usr/local/man to /usr/local/share/man.
manpath = 'man'
return manpath
def get_man_section(section):
return os.path.join(get_manpath(), 'man%s' % section)
class FilesConfig(base.BaseConfig):
section = 'files'
def __init__(self, config, name):
super(FilesConfig, self).__init__(config)
self.name = name
self.data_files = self.config.get('data_files', '')
def save(self):
self.config['data_files'] = self.data_files
super(FilesConfig, self).save()
def expand_globs(self):
finished = []
for line in self.data_files.split("\n"):
if line.rstrip().endswith('*') and '=' in line:
(target, source_glob) = line.split('=')
source_prefix = source_glob.strip()[:-1]
target = target.strip()
if not target.endswith(os.path.sep):
target += os.path.sep
for (dirpath, dirnames, fnames) in os.walk(source_prefix):
finished.append(
"%s = " % dirpath.replace(source_prefix, target))
finished.extend(
[" %s" % os.path.join(dirpath, f) for f in fnames])
else:
finished.append(line)
self.data_files = "\n".join(finished)
def add_man_path(self, man_path):
self.data_files = "%s\n%s =" % (self.data_files, man_path)
def add_man_page(self, man_page):
self.data_files = "%s\n %s" % (self.data_files, man_page)
def get_man_sections(self):
man_sections = dict()
manpages = self.pbr_config['manpages']
for manpage in manpages.split():
section_number = manpage.strip()[-1]
section = man_sections.get(section_number, list())
section.append(manpage.strip())
man_sections[section_number] = section
return man_sections
def hook(self):
packages = self.config.get('packages', self.name).strip()
expanded = []
for pkg in packages.split("\n"):
if os.path.isdir(pkg.strip()):
expanded.append(find_package.smart_find_packages(pkg.strip()))
self.config['packages'] = "\n".join(expanded)
self.expand_globs()
if 'manpages' in self.pbr_config:
man_sections = self.get_man_sections()
for (section, pages) in man_sections.items():
manpath = get_man_section(section)
self.add_man_path(manpath)
for page in pages:
self.add_man_page(page)

View file

@ -0,0 +1,32 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from pbr.hooks import base
from pbr import packaging
class MetadataConfig(base.BaseConfig):
section = 'metadata'
def hook(self):
self.config['version'] = packaging.get_version(
self.config['name'], self.config.get('version', None))
packaging.append_text_list(
self.config, 'requires_dist',
packaging.parse_requirements())
def get_name(self):
return self.config['name']

53
libs/pbr/options.py Normal file
View file

@ -0,0 +1,53 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright (C) 2013 Association of Universities for Research in Astronomy
# (AURA)
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# 3. The name of AURA and its representatives may not be used to
# endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY AURA ``AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL AURA BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
# DAMAGE.
import os
TRUE_VALUES = ('true', '1', 'yes')
def get_boolean_option(option_dict, option_name, env_name):
return ((option_name in option_dict
and option_dict[option_name][1].lower() in TRUE_VALUES) or
str(os.getenv(env_name)).lower() in TRUE_VALUES)

855
libs/pbr/packaging.py Normal file
View file

@ -0,0 +1,855 @@
# Copyright 2011 OpenStack Foundation
# Copyright 2012-2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Utilities with minimum-depends for use in setup.py
"""
from __future__ import unicode_literals
from distutils.command import install as du_install
from distutils import log
import email
import email.errors
import os
import re
import sys
import warnings
import pkg_resources
import setuptools
from setuptools.command import develop
from setuptools.command import easy_install
from setuptools.command import egg_info
from setuptools.command import install
from setuptools.command import install_scripts
from setuptools.command import sdist
from pbr import extra_files
from pbr import git
from pbr import options
import pbr.pbr_json
from pbr import testr_command
from pbr import version
REQUIREMENTS_FILES = ('requirements.txt', 'tools/pip-requires')
PY_REQUIREMENTS_FILES = [x % sys.version_info[0] for x in (
'requirements-py%d.txt', 'tools/pip-requires-py%d')]
TEST_REQUIREMENTS_FILES = ('test-requirements.txt', 'tools/test-requires')
def get_requirements_files():
files = os.environ.get("PBR_REQUIREMENTS_FILES")
if files:
return tuple(f.strip() for f in files.split(','))
# Returns a list composed of:
# - REQUIREMENTS_FILES with -py2 or -py3 in the name
# (e.g. requirements-py3.txt)
# - REQUIREMENTS_FILES
return PY_REQUIREMENTS_FILES + list(REQUIREMENTS_FILES)
def append_text_list(config, key, text_list):
"""Append a \n separated list to possibly existing value."""
new_value = []
current_value = config.get(key, "")
if current_value:
new_value.append(current_value)
new_value.extend(text_list)
config[key] = '\n'.join(new_value)
def _any_existing(file_list):
return [f for f in file_list if os.path.exists(f)]
# Get requirements from the first file that exists
def get_reqs_from_files(requirements_files):
existing = _any_existing(requirements_files)
# TODO(stephenfin): Remove this in pbr 6.0+
deprecated = [f for f in existing if f in PY_REQUIREMENTS_FILES]
if deprecated:
warnings.warn('Support for \'-pyN\'-suffixed requirements files is '
'removed in pbr 5.0 and these files are now ignored. '
'Use environment markers instead. Conflicting files: '
'%r' % deprecated,
DeprecationWarning)
existing = [f for f in existing if f not in PY_REQUIREMENTS_FILES]
for requirements_file in existing:
with open(requirements_file, 'r') as fil:
return fil.read().split('\n')
return []
def parse_requirements(requirements_files=None, strip_markers=False):
if requirements_files is None:
requirements_files = get_requirements_files()
def egg_fragment(match):
# take a versioned egg fragment and return a
# versioned package requirement e.g.
# nova-1.2.3 becomes nova>=1.2.3
return re.sub(r'([\w.]+)-([\w.-]+)',
r'\1>=\2',
match.groups()[-1])
requirements = []
for line in get_reqs_from_files(requirements_files):
# Ignore comments
if (not line.strip()) or line.startswith('#'):
continue
# Ignore index URL lines
if re.match(r'^\s*(-i|--index-url|--extra-index-url).*', line):
continue
# Handle nested requirements files such as:
# -r other-requirements.txt
if line.startswith('-r'):
req_file = line.partition(' ')[2]
requirements += parse_requirements(
[req_file], strip_markers=strip_markers)
continue
try:
project_name = pkg_resources.Requirement.parse(line).project_name
except ValueError:
project_name = None
# For the requirements list, we need to inject only the portion
# after egg= so that distutils knows the package it's looking for
# such as:
# -e git://github.com/openstack/nova/master#egg=nova
# -e git://github.com/openstack/nova/master#egg=nova-1.2.3
# -e git+https://foo.com/zipball#egg=bar&subdirectory=baz
if re.match(r'\s*-e\s+', line):
line = re.sub(r'\s*-e\s+.*#egg=([^&]+).*$', egg_fragment, line)
# such as:
# http://github.com/openstack/nova/zipball/master#egg=nova
# http://github.com/openstack/nova/zipball/master#egg=nova-1.2.3
# git+https://foo.com/zipball#egg=bar&subdirectory=baz
elif re.match(r'\s*(https?|git(\+(https|ssh))?):', line):
line = re.sub(r'\s*(https?|git(\+(https|ssh))?):.*#egg=([^&]+).*$',
egg_fragment, line)
# -f lines are for index locations, and don't get used here
elif re.match(r'\s*-f\s+', line):
line = None
reason = 'Index Location'
if line is not None:
line = re.sub('#.*$', '', line)
if strip_markers:
semi_pos = line.find(';')
if semi_pos < 0:
semi_pos = None
line = line[:semi_pos]
requirements.append(line)
else:
log.info(
'[pbr] Excluding %s: %s' % (project_name, reason))
return requirements
def parse_dependency_links(requirements_files=None):
if requirements_files is None:
requirements_files = get_requirements_files()
dependency_links = []
# dependency_links inject alternate locations to find packages listed
# in requirements
for line in get_reqs_from_files(requirements_files):
# skip comments and blank lines
if re.match(r'(\s*#)|(\s*$)', line):
continue
# lines with -e or -f need the whole line, minus the flag
if re.match(r'\s*-[ef]\s+', line):
dependency_links.append(re.sub(r'\s*-[ef]\s+', '', line))
# lines that are only urls can go in unmolested
elif re.match(r'\s*(https?|git(\+(https|ssh))?):', line):
dependency_links.append(line)
return dependency_links
class InstallWithGit(install.install):
"""Extracts ChangeLog and AUTHORS from git then installs.
This is useful for e.g. readthedocs where the package is
installed and then docs built.
"""
command_name = 'install'
def run(self):
_from_git(self.distribution)
return install.install.run(self)
class LocalInstall(install.install):
"""Runs python setup.py install in a sensible manner.
Force a non-egg installed in the manner of
single-version-externally-managed, which allows us to install manpages
and config files.
"""
command_name = 'install'
def run(self):
_from_git(self.distribution)
return du_install.install.run(self)
class TestrTest(testr_command.Testr):
"""Make setup.py test do the right thing."""
command_name = 'test'
description = 'DEPRECATED: Run unit tests using testr'
def run(self):
warnings.warn('testr integration is deprecated in pbr 4.2 and will '
'be removed in a future release. Please call your test '
'runner directly',
DeprecationWarning)
# Can't use super - base class old-style class
testr_command.Testr.run(self)
class LocalRPMVersion(setuptools.Command):
__doc__ = """Output the rpm *compatible* version string of this package"""
description = __doc__
user_options = []
command_name = "rpm_version"
def run(self):
log.info("[pbr] Extracting rpm version")
name = self.distribution.get_name()
print(version.VersionInfo(name).semantic_version().rpm_string())
def initialize_options(self):
pass
def finalize_options(self):
pass
class LocalDebVersion(setuptools.Command):
__doc__ = """Output the deb *compatible* version string of this package"""
description = __doc__
user_options = []
command_name = "deb_version"
def run(self):
log.info("[pbr] Extracting deb version")
name = self.distribution.get_name()
print(version.VersionInfo(name).semantic_version().debian_string())
def initialize_options(self):
pass
def finalize_options(self):
pass
def have_testr():
return testr_command.have_testr
try:
from nose import commands
class NoseTest(commands.nosetests):
"""Fallback test runner if testr is a no-go."""
command_name = 'test'
description = 'DEPRECATED: Run unit tests using nose'
def run(self):
warnings.warn('nose integration in pbr is deprecated. Please use '
'the native nose setuptools configuration or call '
'nose directly',
DeprecationWarning)
# Can't use super - base class old-style class
commands.nosetests.run(self)
_have_nose = True
except ImportError:
_have_nose = False
def have_nose():
return _have_nose
_wsgi_text = """#PBR Generated from %(group)r
import threading
from %(module_name)s import %(import_target)s
if __name__ == "__main__":
import argparse
import socket
import sys
import wsgiref.simple_server as wss
parser = argparse.ArgumentParser(
description=%(import_target)s.__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
usage='%%(prog)s [-h] [--port PORT] [--host IP] -- [passed options]')
parser.add_argument('--port', '-p', type=int, default=8000,
help='TCP port to listen on')
parser.add_argument('--host', '-b', default='',
help='IP to bind the server to')
parser.add_argument('args',
nargs=argparse.REMAINDER,
metavar='-- [passed options]',
help="'--' is the separator of the arguments used "
"to start the WSGI server and the arguments passed "
"to the WSGI application.")
args = parser.parse_args()
if args.args:
if args.args[0] == '--':
args.args.pop(0)
else:
parser.error("unrecognized arguments: %%s" %% ' '.join(args.args))
sys.argv[1:] = args.args
server = wss.make_server(args.host, args.port, %(invoke_target)s())
print("*" * 80)
print("STARTING test server %(module_name)s.%(invoke_target)s")
url = "http://%%s:%%d/" %% (server.server_name, server.server_port)
print("Available at %%s" %% url)
print("DANGER! For testing only, do not use in production")
print("*" * 80)
sys.stdout.flush()
server.serve_forever()
else:
application = None
app_lock = threading.Lock()
with app_lock:
if application is None:
application = %(invoke_target)s()
"""
_script_text = """# PBR Generated from %(group)r
import sys
from %(module_name)s import %(import_target)s
if __name__ == "__main__":
sys.exit(%(invoke_target)s())
"""
# the following allows us to specify different templates per entry
# point group when generating pbr scripts.
ENTRY_POINTS_MAP = {
'console_scripts': _script_text,
'gui_scripts': _script_text,
'wsgi_scripts': _wsgi_text
}
def generate_script(group, entry_point, header, template):
"""Generate the script based on the template.
:param str group:
The entry-point group name, e.g., "console_scripts".
:param str header:
The first line of the script, e.g., "!#/usr/bin/env python".
:param str template:
The script template.
:returns:
The templated script content
:rtype:
str
"""
if not entry_point.attrs or len(entry_point.attrs) > 2:
raise ValueError("Script targets must be of the form "
"'func' or 'Class.class_method'.")
script_text = template % dict(
group=group,
module_name=entry_point.module_name,
import_target=entry_point.attrs[0],
invoke_target='.'.join(entry_point.attrs),
)
return header + script_text
def override_get_script_args(
dist, executable=os.path.normpath(sys.executable), is_wininst=False):
"""Override entrypoints console_script."""
header = easy_install.get_script_header("", executable, is_wininst)
for group, template in ENTRY_POINTS_MAP.items():
for name, ep in dist.get_entry_map(group).items():
yield (name, generate_script(group, ep, header, template))
class LocalDevelop(develop.develop):
command_name = 'develop'
def install_wrapper_scripts(self, dist):
if sys.platform == 'win32':
return develop.develop.install_wrapper_scripts(self, dist)
if not self.exclude_scripts:
for args in override_get_script_args(dist):
self.write_script(*args)
class LocalInstallScripts(install_scripts.install_scripts):
"""Intercepts console scripts entry_points."""
command_name = 'install_scripts'
def _make_wsgi_scripts_only(self, dist, executable, is_wininst):
header = easy_install.get_script_header("", executable, is_wininst)
wsgi_script_template = ENTRY_POINTS_MAP['wsgi_scripts']
for name, ep in dist.get_entry_map('wsgi_scripts').items():
content = generate_script(
'wsgi_scripts', ep, header, wsgi_script_template)
self.write_script(name, content)
def run(self):
import distutils.command.install_scripts
self.run_command("egg_info")
if self.distribution.scripts:
# run first to set up self.outfiles
distutils.command.install_scripts.install_scripts.run(self)
else:
self.outfiles = []
ei_cmd = self.get_finalized_command("egg_info")
dist = pkg_resources.Distribution(
ei_cmd.egg_base,
pkg_resources.PathMetadata(ei_cmd.egg_base, ei_cmd.egg_info),
ei_cmd.egg_name, ei_cmd.egg_version,
)
bs_cmd = self.get_finalized_command('build_scripts')
executable = getattr(
bs_cmd, 'executable', easy_install.sys_executable)
is_wininst = getattr(
self.get_finalized_command("bdist_wininst"), '_is_running', False
)
if 'bdist_wheel' in self.distribution.have_run:
# We're building a wheel which has no way of generating mod_wsgi
# scripts for us. Let's build them.
# NOTE(sigmavirus24): This needs to happen here because, as the
# comment below indicates, no_ep is True when building a wheel.
self._make_wsgi_scripts_only(dist, executable, is_wininst)
if self.no_ep:
# no_ep is True if we're installing into an .egg file or building
# a .whl file, in those cases, we do not want to build all of the
# entry-points listed for this package.
return
if os.name != 'nt':
get_script_args = override_get_script_args
else:
get_script_args = easy_install.get_script_args
executable = '"%s"' % executable
for args in get_script_args(dist, executable, is_wininst):
self.write_script(*args)
class LocalManifestMaker(egg_info.manifest_maker):
"""Add any files that are in git and some standard sensible files."""
def _add_pbr_defaults(self):
for template_line in [
'include AUTHORS',
'include ChangeLog',
'exclude .gitignore',
'exclude .gitreview',
'global-exclude *.pyc'
]:
self.filelist.process_template_line(template_line)
def add_defaults(self):
"""Add all the default files to self.filelist:
Extends the functionality provided by distutils to also included
additional sane defaults, such as the ``AUTHORS`` and ``ChangeLog``
files generated by *pbr*.
Warns if (``README`` or ``README.txt``) or ``setup.py`` are missing;
everything else is optional.
"""
option_dict = self.distribution.get_option_dict('pbr')
sdist.sdist.add_defaults(self)
self.filelist.append(self.template)
self.filelist.append(self.manifest)
self.filelist.extend(extra_files.get_extra_files())
should_skip = options.get_boolean_option(option_dict, 'skip_git_sdist',
'SKIP_GIT_SDIST')
if not should_skip:
rcfiles = git._find_git_files()
if rcfiles:
self.filelist.extend(rcfiles)
elif os.path.exists(self.manifest):
self.read_manifest()
ei_cmd = self.get_finalized_command('egg_info')
self._add_pbr_defaults()
self.filelist.include_pattern("*", prefix=ei_cmd.egg_info)
class LocalEggInfo(egg_info.egg_info):
"""Override the egg_info command to regenerate SOURCES.txt sensibly."""
command_name = 'egg_info'
def find_sources(self):
"""Generate SOURCES.txt only if there isn't one already.
If we are in an sdist command, then we always want to update
SOURCES.txt. If we are not in an sdist command, then it doesn't
matter one flip, and is actually destructive.
However, if we're in a git context, it's always the right thing to do
to recreate SOURCES.txt
"""
manifest_filename = os.path.join(self.egg_info, "SOURCES.txt")
if (not os.path.exists(manifest_filename) or
os.path.exists('.git') or
'sdist' in sys.argv):
log.info("[pbr] Processing SOURCES.txt")
mm = LocalManifestMaker(self.distribution)
mm.manifest = manifest_filename
mm.run()
self.filelist = mm.filelist
else:
log.info("[pbr] Reusing existing SOURCES.txt")
self.filelist = egg_info.FileList()
for entry in open(manifest_filename, 'r').read().split('\n'):
self.filelist.append(entry)
def _from_git(distribution):
option_dict = distribution.get_option_dict('pbr')
changelog = git._iter_log_oneline()
if changelog:
changelog = git._iter_changelog(changelog)
git.write_git_changelog(option_dict=option_dict, changelog=changelog)
git.generate_authors(option_dict=option_dict)
class LocalSDist(sdist.sdist):
"""Builds the ChangeLog and Authors files from VC first."""
command_name = 'sdist'
def checking_reno(self):
"""Ensure reno is installed and configured.
We can't run reno-based commands if reno isn't installed/available, and
don't want to if the user isn't using it.
"""
if hasattr(self, '_has_reno'):
return self._has_reno
option_dict = self.distribution.get_option_dict('pbr')
should_skip = options.get_boolean_option(option_dict, 'skip_reno',
'SKIP_GENERATE_RENO')
if should_skip:
self._has_reno = False
return False
try:
# versions of reno witout this module will not have the required
# feature, hence the import
from reno import setup_command # noqa
except ImportError:
log.info('[pbr] reno was not found or is too old. Skipping '
'release notes')
self._has_reno = False
return False
conf, output_file, cache_file = setup_command.load_config(
self.distribution)
if not os.path.exists(os.path.join(conf.reporoot, conf.notespath)):
log.info('[pbr] reno does not appear to be configured. Skipping '
'release notes')
self._has_reno = False
return False
self._files = [output_file, cache_file]
log.info('[pbr] Generating release notes')
self._has_reno = True
return True
sub_commands = [('build_reno', checking_reno)] + sdist.sdist.sub_commands
def run(self):
_from_git(self.distribution)
# sdist.sdist is an old style class, can't use super()
sdist.sdist.run(self)
def make_distribution(self):
# This is included in make_distribution because setuptools doesn't use
# 'get_file_list'. As such, this is the only hook point that runs after
# the commands in 'sub_commands'
if self.checking_reno():
self.filelist.extend(self._files)
self.filelist.sort()
sdist.sdist.make_distribution(self)
try:
from pbr import builddoc
_have_sphinx = True
# Import the symbols from their new home so the package API stays
# compatible.
LocalBuildDoc = builddoc.LocalBuildDoc
except ImportError:
_have_sphinx = False
LocalBuildDoc = None
def have_sphinx():
return _have_sphinx
def _get_increment_kwargs(git_dir, tag):
"""Calculate the sort of semver increment needed from git history.
Every commit from HEAD to tag is consider for Sem-Ver metadata lines.
See the pbr docs for their syntax.
:return: a dict of kwargs for passing into SemanticVersion.increment.
"""
result = {}
if tag:
version_spec = tag + "..HEAD"
else:
version_spec = "HEAD"
# Get the raw body of the commit messages so that we don't have to
# parse out any formatting whitespace and to avoid user settings on
# git log output affecting out ability to have working sem ver headers.
changelog = git._run_git_command(['log', '--pretty=%B', version_spec],
git_dir)
header_len = len('sem-ver:')
commands = [line[header_len:].strip() for line in changelog.split('\n')
if line.lower().startswith('sem-ver:')]
symbols = set()
for command in commands:
symbols.update([symbol.strip() for symbol in command.split(',')])
def _handle_symbol(symbol, symbols, impact):
if symbol in symbols:
result[impact] = True
symbols.discard(symbol)
_handle_symbol('bugfix', symbols, 'patch')
_handle_symbol('feature', symbols, 'minor')
_handle_symbol('deprecation', symbols, 'minor')
_handle_symbol('api-break', symbols, 'major')
for symbol in symbols:
log.info('[pbr] Unknown Sem-Ver symbol %r' % symbol)
# We don't want patch in the kwargs since it is not a keyword argument -
# its the default minimum increment.
result.pop('patch', None)
return result
def _get_revno_and_last_tag(git_dir):
"""Return the commit data about the most recent tag.
We use git-describe to find this out, but if there are no
tags then we fall back to counting commits since the beginning
of time.
"""
changelog = git._iter_log_oneline(git_dir=git_dir)
row_count = 0
for row_count, (ignored, tag_set, ignored) in enumerate(changelog):
version_tags = set()
semver_to_tag = dict()
for tag in list(tag_set):
try:
semver = version.SemanticVersion.from_pip_string(tag)
semver_to_tag[semver] = tag
version_tags.add(semver)
except Exception:
pass
if version_tags:
return semver_to_tag[max(version_tags)], row_count
return "", row_count
def _get_version_from_git_target(git_dir, target_version):
"""Calculate a version from a target version in git_dir.
This is used for untagged versions only. A new version is calculated as
necessary based on git metadata - distance to tags, current hash, contents
of commit messages.
:param git_dir: The git directory we're working from.
:param target_version: If None, the last tagged version (or 0 if there are
no tags yet) is incremented as needed to produce an appropriate target
version following semver rules. Otherwise target_version is used as a
constraint - if semver rules would result in a newer version then an
exception is raised.
:return: A semver version object.
"""
tag, distance = _get_revno_and_last_tag(git_dir)
last_semver = version.SemanticVersion.from_pip_string(tag or '0')
if distance == 0:
new_version = last_semver
else:
new_version = last_semver.increment(
**_get_increment_kwargs(git_dir, tag))
if target_version is not None and new_version > target_version:
raise ValueError(
"git history requires a target version of %(new)s, but target "
"version is %(target)s" %
dict(new=new_version, target=target_version))
if distance == 0:
return last_semver
new_dev = new_version.to_dev(distance)
if target_version is not None:
target_dev = target_version.to_dev(distance)
if target_dev > new_dev:
return target_dev
return new_dev
def _get_version_from_git(pre_version=None):
"""Calculate a version string from git.
If the revision is tagged, return that. Otherwise calculate a semantic
version description of the tree.
The number of revisions since the last tag is included in the dev counter
in the version for untagged versions.
:param pre_version: If supplied use this as the target version rather than
inferring one from the last tag + commit messages.
"""
git_dir = git._run_git_functions()
if git_dir:
try:
tagged = git._run_git_command(
['describe', '--exact-match'], git_dir,
throw_on_error=True).replace('-', '.')
target_version = version.SemanticVersion.from_pip_string(tagged)
except Exception:
if pre_version:
# not released yet - use pre_version as the target
target_version = version.SemanticVersion.from_pip_string(
pre_version)
else:
# not released yet - just calculate from git history
target_version = None
result = _get_version_from_git_target(git_dir, target_version)
return result.release_string()
# If we don't know the version, return an empty string so at least
# the downstream users of the value always have the same type of
# object to work with.
try:
return unicode()
except NameError:
return ''
def _get_version_from_pkg_metadata(package_name):
"""Get the version from package metadata if present.
This looks for PKG-INFO if present (for sdists), and if not looks
for METADATA (for wheels) and failing that will return None.
"""
pkg_metadata_filenames = ['PKG-INFO', 'METADATA']
pkg_metadata = {}
for filename in pkg_metadata_filenames:
try:
pkg_metadata_file = open(filename, 'r')
except (IOError, OSError):
continue
try:
pkg_metadata = email.message_from_file(pkg_metadata_file)
except email.errors.MessageError:
continue
# Check to make sure we're in our own dir
if pkg_metadata.get('Name', None) != package_name:
return None
return pkg_metadata.get('Version', None)
def get_version(package_name, pre_version=None):
"""Get the version of the project.
First, try getting it from PKG-INFO or METADATA, if it exists. If it does,
that means we're in a distribution tarball or that install has happened.
Otherwise, if there is no PKG-INFO or METADATA file, pull the version
from git.
We do not support setup.py version sanity in git archive tarballs, nor do
we support packagers directly sucking our git repo into theirs. We expect
that a source tarball be made from our git repo - or that if someone wants
to make a source tarball from a fork of our repo with additional tags in it
that they understand and desire the results of doing that.
:param pre_version: The version field from setup.cfg - if set then this
version will be the next release.
"""
version = os.environ.get(
"PBR_VERSION",
os.environ.get("OSLO_PACKAGE_VERSION", None))
if version:
return version
version = _get_version_from_pkg_metadata(package_name)
if version:
return version
version = _get_version_from_git(pre_version)
# Handle http://bugs.python.org/issue11638
# version will either be an empty unicode string or a valid
# unicode version string, but either way it's unicode and needs to
# be encoded.
if sys.version_info[0] == 2:
version = version.encode('utf-8')
if version:
return version
raise Exception("Versioning for this project requires either an sdist"
" tarball, or access to an upstream git repository."
" It's also possible that there is a mismatch between"
" the package name in setup.cfg and the argument given"
" to pbr.version.VersionInfo. Project name {name} was"
" given, but was not able to be found.".format(
name=package_name))
# This is added because pbr uses pbr to install itself. That means that
# any changes to the egg info writer entrypoints must be forward and
# backward compatible. This maintains the pbr.packaging.write_pbr_json
# path.
write_pbr_json = pbr.pbr_json.write_pbr_json

34
libs/pbr/pbr_json.py Normal file
View file

@ -0,0 +1,34 @@
# Copyright 2011 OpenStack Foundation
# Copyright 2012-2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import json
from pbr import git
def write_pbr_json(cmd, basename, filename):
if not hasattr(cmd.distribution, 'pbr') or not cmd.distribution.pbr:
return
git_dir = git._run_git_functions()
if not git_dir:
return
values = dict()
git_version = git.get_git_short_sha(git_dir)
is_release = git.get_is_release(git_dir)
if git_version is not None:
values['git_version'] = git_version
values['is_release'] = is_release
cmd.write_file('pbr', filename, json.dumps(values, sort_keys=True))

99
libs/pbr/sphinxext.py Normal file
View file

@ -0,0 +1,99 @@
# Copyright 2018 Red Hat, Inc.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os.path
from six.moves import configparser
from sphinx.util import logging
import pbr.version
_project = None
logger = logging.getLogger(__name__)
def _find_setup_cfg(srcdir):
"""Find the 'setup.cfg' file, if it exists.
This assumes we're using 'doc/source' for documentation, but also allows
for single level 'doc' paths.
"""
# TODO(stephenfin): Are we sure that this will always exist, e.g. for
# an sdist or wheel? Perhaps we should check for 'PKG-INFO' or
# 'METADATA' files, a la 'pbr.packaging._get_version_from_pkg_metadata'
for path in [
os.path.join(srcdir, os.pardir, 'setup.cfg'),
os.path.join(srcdir, os.pardir, os.pardir, 'setup.cfg')]:
if os.path.exists(path):
return path
return None
def _get_project_name(srcdir):
"""Return string name of project name, or None.
This extracts metadata from 'setup.cfg'. We don't rely on
distutils/setuptools as we don't want to actually install the package
simply to build docs.
"""
global _project
if _project is None:
parser = configparser.ConfigParser()
path = _find_setup_cfg(srcdir)
if not path or not parser.read(path):
logger.info('Could not find a setup.cfg to extract project name '
'from')
return None
try:
# for project name we use the name in setup.cfg, but if the
# length is longer then 32 we use summary. Otherwise thAe
# menu rendering looks brolen
project = parser.get('metadata', 'name')
if len(project.split()) == 1 and len(project) > 32:
project = parser.get('metadata', 'summary')
except configparser.Error:
logger.info('Could not extract project metadata from setup.cfg')
return None
_project = project
return _project
def _builder_inited(app):
# TODO(stephenfin): Once Sphinx 1.8 is released, we should move the below
# to a 'config-inited' handler
project_name = _get_project_name(app.srcdir)
try:
version_info = pbr.version.VersionInfo(project_name)
except Exception:
version_info = None
if version_info and not app.config.version and not app.config.release:
app.config.version = version_info.canonical_version_string()
app.config.release = version_info.version_string_with_vcs()
def setup(app):
app.connect('builder-inited', _builder_inited)
return {
'parallel_read_safe': True,
'parallel_write_safe': True,
}

167
libs/pbr/testr_command.py Normal file
View file

@ -0,0 +1,167 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright (c) 2013 Testrepository Contributors
#
# Licensed under either the Apache License, Version 2.0 or the BSD 3-clause
# license at the users choice. A copy of both licenses are available in the
# project source as Apache-2.0 and BSD. You may not use this file except in
# compliance with one of these two licences.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under these licenses is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# license you chose for the specific language governing permissions and
# limitations under that license.
"""setuptools/distutils command to run testr via setup.py
PBR will hook in the Testr class to provide "setup.py test" when
.testr.conf is present in the repository (see pbr/hooks/commands.py).
If we are activated but testrepository is not installed, we provide a
sensible error.
You can pass --coverage which will also export PYTHON='coverage run
--source <your package>' and automatically combine the coverage from
each testr backend test runner after the run completes.
"""
from distutils import cmd
import distutils.errors
import logging
import os
import sys
import warnings
logger = logging.getLogger(__name__)
class TestrReal(cmd.Command):
description = "DEPRECATED: Run unit tests using testr"
user_options = [
('coverage', None, "Replace PYTHON with coverage and merge coverage "
"from each testr worker."),
('testr-args=', 't', "Run 'testr' with these args"),
('omit=', 'o', "Files to omit from coverage calculations"),
('coverage-package-name=', None, "Use this name to select packages "
"for coverage (one or more, "
"comma-separated)"),
('slowest', None, "Show slowest test times after tests complete."),
('no-parallel', None, "Run testr serially"),
('log-level=', 'l', "Log level (default: info)"),
]
boolean_options = ['coverage', 'slowest', 'no_parallel']
def _run_testr(self, *args):
logger.debug("_run_testr called with args = %r", args)
return commands.run_argv([sys.argv[0]] + list(args),
sys.stdin, sys.stdout, sys.stderr)
def initialize_options(self):
self.testr_args = None
self.coverage = None
self.omit = ""
self.slowest = None
self.coverage_package_name = None
self.no_parallel = None
self.log_level = 'info'
def finalize_options(self):
self.log_level = getattr(
logging,
self.log_level.upper(),
logging.INFO)
logging.basicConfig(level=self.log_level)
logger.debug("finalize_options called")
if self.testr_args is None:
self.testr_args = []
else:
self.testr_args = self.testr_args.split()
if self.omit:
self.omit = "--omit=%s" % self.omit
logger.debug("finalize_options: self.__dict__ = %r", self.__dict__)
def run(self):
"""Set up testr repo, then run testr."""
logger.debug("run called")
warnings.warn('testr integration in pbr is deprecated. Please use '
'the \'testr\' setup command or call testr directly',
DeprecationWarning)
if not os.path.isdir(".testrepository"):
self._run_testr("init")
if self.coverage:
self._coverage_before()
if not self.no_parallel:
testr_ret = self._run_testr("run", "--parallel", *self.testr_args)
else:
testr_ret = self._run_testr("run", *self.testr_args)
if testr_ret:
raise distutils.errors.DistutilsError(
"testr failed (%d)" % testr_ret)
if self.slowest:
print("Slowest Tests")
self._run_testr("slowest")
if self.coverage:
self._coverage_after()
def _coverage_before(self):
logger.debug("_coverage_before called")
package = self.distribution.get_name()
if package.startswith('python-'):
package = package[7:]
# Use this as coverage package name
if self.coverage_package_name:
package = self.coverage_package_name
options = "--source %s --parallel-mode" % package
os.environ['PYTHON'] = ("coverage run %s" % options)
logger.debug("os.environ['PYTHON'] = %r", os.environ['PYTHON'])
def _coverage_after(self):
logger.debug("_coverage_after called")
os.system("coverage combine")
os.system("coverage html -d ./cover %s" % self.omit)
os.system("coverage xml -o ./cover/coverage.xml %s" % self.omit)
class TestrFake(cmd.Command):
description = "Run unit tests using testr"
user_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
print("Install testrepository to run 'testr' command properly.")
try:
from testrepository import commands
have_testr = True
Testr = TestrReal
except ImportError:
have_testr = False
Testr = TestrFake

View file

@ -0,0 +1,26 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import testscenarios
def load_tests(loader, standard_tests, pattern):
# top level directory cached on loader instance
this_dir = os.path.dirname(__file__)
package_tests = loader.discover(start_dir=this_dir, pattern=pattern)
result = loader.suiteClass()
result.addTests(testscenarios.generate_scenarios(standard_tests))
result.addTests(testscenarios.generate_scenarios(package_tests))
return result

221
libs/pbr/tests/base.py Normal file
View file

@ -0,0 +1,221 @@
# Copyright 2010-2011 OpenStack Foundation
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# Copyright (C) 2013 Association of Universities for Research in Astronomy
# (AURA)
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# 3. The name of AURA and its representatives may not be used to
# endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY AURA ``AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL AURA BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
"""Common utilities used in testing"""
import os
import shutil
import subprocess
import sys
import fixtures
import testresources
import testtools
from testtools import content
from pbr import options
class DiveDir(fixtures.Fixture):
"""Dive into given directory and return back on cleanup.
:ivar path: The target directory.
"""
def __init__(self, path):
self.path = path
def setUp(self):
super(DiveDir, self).setUp()
self.addCleanup(os.chdir, os.getcwd())
os.chdir(self.path)
class BaseTestCase(testtools.TestCase, testresources.ResourcedTestCase):
def setUp(self):
super(BaseTestCase, self).setUp()
test_timeout = os.environ.get('OS_TEST_TIMEOUT', 30)
try:
test_timeout = int(test_timeout)
except ValueError:
# If timeout value is invalid, fail hard.
print("OS_TEST_TIMEOUT set to invalid value"
" defaulting to no timeout")
test_timeout = 0
if test_timeout > 0:
self.useFixture(fixtures.Timeout(test_timeout, gentle=True))
if os.environ.get('OS_STDOUT_CAPTURE') in options.TRUE_VALUES:
stdout = self.useFixture(fixtures.StringStream('stdout')).stream
self.useFixture(fixtures.MonkeyPatch('sys.stdout', stdout))
if os.environ.get('OS_STDERR_CAPTURE') in options.TRUE_VALUES:
stderr = self.useFixture(fixtures.StringStream('stderr')).stream
self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr))
self.log_fixture = self.useFixture(
fixtures.FakeLogger('pbr'))
# Older git does not have config --local, so create a temporary home
# directory to permit using git config --global without stepping on
# developer configuration.
self.useFixture(fixtures.TempHomeDir())
self.useFixture(fixtures.NestedTempfile())
self.useFixture(fixtures.FakeLogger())
# TODO(lifeless) we should remove PBR_VERSION from the environment.
# rather than setting it, because thats not representative - we need to
# test non-preversioned codepaths too!
self.useFixture(fixtures.EnvironmentVariable('PBR_VERSION', '0.0'))
self.temp_dir = self.useFixture(fixtures.TempDir()).path
self.package_dir = os.path.join(self.temp_dir, 'testpackage')
shutil.copytree(os.path.join(os.path.dirname(__file__), 'testpackage'),
self.package_dir)
self.addCleanup(os.chdir, os.getcwd())
os.chdir(self.package_dir)
self.addCleanup(self._discard_testpackage)
# Tests can opt into non-PBR_VERSION by setting preversioned=False as
# an attribute.
if not getattr(self, 'preversioned', True):
self.useFixture(fixtures.EnvironmentVariable('PBR_VERSION'))
setup_cfg_path = os.path.join(self.package_dir, 'setup.cfg')
with open(setup_cfg_path, 'rt') as cfg:
content = cfg.read()
content = content.replace(u'version = 0.1.dev', u'')
with open(setup_cfg_path, 'wt') as cfg:
cfg.write(content)
def _discard_testpackage(self):
# Remove pbr.testpackage from sys.modules so that it can be freshly
# re-imported by the next test
for k in list(sys.modules):
if (k == 'pbr_testpackage' or
k.startswith('pbr_testpackage.')):
del sys.modules[k]
def run_pbr(self, *args, **kwargs):
return self._run_cmd('pbr', args, **kwargs)
def run_setup(self, *args, **kwargs):
return self._run_cmd(sys.executable, ('setup.py',) + args, **kwargs)
def _run_cmd(self, cmd, args=[], allow_fail=True, cwd=None):
"""Run a command in the root of the test working copy.
Runs a command, with the given argument list, in the root of the test
working copy--returns the stdout and stderr streams and the exit code
from the subprocess.
:param cwd: If falsy run within the test package dir, otherwise run
within the named path.
"""
cwd = cwd or self.package_dir
result = _run_cmd([cmd] + list(args), cwd=cwd)
if result[2] and not allow_fail:
raise Exception("Command failed retcode=%s" % result[2])
return result
class CapturedSubprocess(fixtures.Fixture):
"""Run a process and capture its output.
:attr stdout: The output (a string).
:attr stderr: The standard error (a string).
:attr returncode: The return code of the process.
Note that stdout and stderr are decoded from the bytestrings subprocess
returns using error=replace.
"""
def __init__(self, label, *args, **kwargs):
"""Create a CapturedSubprocess.
:param label: A label for the subprocess in the test log. E.g. 'foo'.
:param *args: The *args to pass to Popen.
:param **kwargs: The **kwargs to pass to Popen.
"""
super(CapturedSubprocess, self).__init__()
self.label = label
self.args = args
self.kwargs = kwargs
self.kwargs['stderr'] = subprocess.PIPE
self.kwargs['stdin'] = subprocess.PIPE
self.kwargs['stdout'] = subprocess.PIPE
def setUp(self):
super(CapturedSubprocess, self).setUp()
proc = subprocess.Popen(*self.args, **self.kwargs)
out, err = proc.communicate()
self.out = out.decode('utf-8', 'replace')
self.err = err.decode('utf-8', 'replace')
self.addDetail(self.label + '-stdout', content.text_content(self.out))
self.addDetail(self.label + '-stderr', content.text_content(self.err))
self.returncode = proc.returncode
if proc.returncode:
raise AssertionError('Failed process %s' % proc.returncode)
self.addCleanup(delattr, self, 'out')
self.addCleanup(delattr, self, 'err')
self.addCleanup(delattr, self, 'returncode')
def _run_cmd(args, cwd):
"""Run the command args in cwd.
:param args: The command to run e.g. ['git', 'status']
:param cwd: The directory to run the comamnd in.
:return: ((stdout, stderr), returncode)
"""
p = subprocess.Popen(
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, cwd=cwd)
streams = tuple(s.decode('latin1').strip() for s in p.communicate())
for stream_content in streams:
print(stream_content)
return (streams) + (p.returncode,)
def _config_git():
_run_cmd(
['git', 'config', '--global', 'user.email', 'example@example.com'],
None)
_run_cmd(
['git', 'config', '--global', 'user.name', 'OpenStack Developer'],
None)
_run_cmd(
['git', 'config', '--global', 'user.signingkey',
'example@example.com'], None)

View file

@ -0,0 +1,84 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright (C) 2013 Association of Universities for Research in Astronomy
# (AURA)
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# 3. The name of AURA and its representatives may not be used to
# endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY AURA ``AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL AURA BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
from testtools import content
from pbr.tests import base
class TestCommands(base.BaseTestCase):
def test_custom_build_py_command(self):
"""Test custom build_py command.
Test that a custom subclass of the build_py command runs when listed in
the commands [global] option, rather than the normal build command.
"""
stdout, stderr, return_code = self.run_setup('build_py')
self.addDetail('stdout', content.text_content(stdout))
self.addDetail('stderr', content.text_content(stderr))
self.assertIn('Running custom build_py command.', stdout)
self.assertEqual(0, return_code)
def test_custom_deb_version_py_command(self):
"""Test custom deb_version command."""
stdout, stderr, return_code = self.run_setup('deb_version')
self.addDetail('stdout', content.text_content(stdout))
self.addDetail('stderr', content.text_content(stderr))
self.assertIn('Extracting deb version', stdout)
self.assertEqual(0, return_code)
def test_custom_rpm_version_py_command(self):
"""Test custom rpm_version command."""
stdout, stderr, return_code = self.run_setup('rpm_version')
self.addDetail('stdout', content.text_content(stdout))
self.addDetail('stderr', content.text_content(stderr))
self.assertIn('Extracting rpm version', stdout)
self.assertEqual(0, return_code)
def test_freeze_command(self):
"""Test that freeze output is sorted in a case-insensitive manner."""
stdout, stderr, return_code = self.run_pbr('freeze')
self.assertEqual(0, return_code)
pkgs = []
for l in stdout.split('\n'):
pkgs.append(l.split('==')[0].lower())
pkgs_sort = sorted(pkgs[:])
self.assertEqual(pkgs_sort, pkgs)

151
libs/pbr/tests/test_core.py Normal file
View file

@ -0,0 +1,151 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright (C) 2013 Association of Universities for Research in Astronomy
# (AURA)
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# 3. The name of AURA and its representatives may not be used to
# endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY AURA ``AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL AURA BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
import glob
import os
import tarfile
import fixtures
from pbr.tests import base
class TestCore(base.BaseTestCase):
cmd_names = ('pbr_test_cmd', 'pbr_test_cmd_with_class')
def check_script_install(self, install_stdout):
for cmd_name in self.cmd_names:
install_txt = 'Installing %s script to %s' % (cmd_name,
self.temp_dir)
self.assertIn(install_txt, install_stdout)
cmd_filename = os.path.join(self.temp_dir, cmd_name)
script_txt = open(cmd_filename, 'r').read()
self.assertNotIn('pkg_resources', script_txt)
stdout, _, return_code = self._run_cmd(cmd_filename)
self.assertIn("PBR", stdout)
def test_setup_py_keywords(self):
"""setup.py --keywords.
Test that the `./setup.py --keywords` command returns the correct
value without balking.
"""
self.run_setup('egg_info')
stdout, _, _ = self.run_setup('--keywords')
assert stdout == 'packaging,distutils,setuptools'
def test_setup_py_build_sphinx(self):
stdout, _, return_code = self.run_setup('build_sphinx')
self.assertEqual(0, return_code)
def test_sdist_extra_files(self):
"""Test that the extra files are correctly added."""
stdout, _, return_code = self.run_setup('sdist', '--formats=gztar')
# There can be only one
try:
tf_path = glob.glob(os.path.join('dist', '*.tar.gz'))[0]
except IndexError:
assert False, 'source dist not found'
tf = tarfile.open(tf_path)
names = ['/'.join(p.split('/')[1:]) for p in tf.getnames()]
self.assertIn('extra-file.txt', names)
def test_console_script_install(self):
"""Test that we install a non-pkg-resources console script."""
if os.name == 'nt':
self.skipTest('Windows support is passthrough')
stdout, _, return_code = self.run_setup(
'install_scripts', '--install-dir=%s' % self.temp_dir)
self.useFixture(
fixtures.EnvironmentVariable('PYTHONPATH', '.'))
self.check_script_install(stdout)
def test_console_script_develop(self):
"""Test that we develop a non-pkg-resources console script."""
if os.name == 'nt':
self.skipTest('Windows support is passthrough')
self.useFixture(
fixtures.EnvironmentVariable(
'PYTHONPATH', ".:%s" % self.temp_dir))
stdout, _, return_code = self.run_setup(
'develop', '--install-dir=%s' % self.temp_dir)
self.check_script_install(stdout)
class TestGitSDist(base.BaseTestCase):
def setUp(self):
super(TestGitSDist, self).setUp()
stdout, _, return_code = self._run_cmd('git', ('init',))
if return_code:
self.skipTest("git not installed")
stdout, _, return_code = self._run_cmd('git', ('add', '.'))
stdout, _, return_code = self._run_cmd(
'git', ('commit', '-m', 'Turn this into a git repo'))
stdout, _, return_code = self.run_setup('sdist', '--formats=gztar')
def test_sdist_git_extra_files(self):
"""Test that extra files found in git are correctly added."""
# There can be only one
tf_path = glob.glob(os.path.join('dist', '*.tar.gz'))[0]
tf = tarfile.open(tf_path)
names = ['/'.join(p.split('/')[1:]) for p in tf.getnames()]
self.assertIn('git-extra-file.txt', names)

View file

@ -0,0 +1,78 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import os
import fixtures
from pbr.hooks import files
from pbr.tests import base
class FilesConfigTest(base.BaseTestCase):
def setUp(self):
super(FilesConfigTest, self).setUp()
pkg_fixture = fixtures.PythonPackage(
"fake_package", [
("fake_module.py", b""),
("other_fake_module.py", b""),
])
self.useFixture(pkg_fixture)
pkg_etc = os.path.join(pkg_fixture.base, 'etc')
pkg_sub = os.path.join(pkg_etc, 'sub')
subpackage = os.path.join(
pkg_fixture.base, 'fake_package', 'subpackage')
os.makedirs(pkg_sub)
os.makedirs(subpackage)
with open(os.path.join(pkg_etc, "foo"), 'w') as foo_file:
foo_file.write("Foo Data")
with open(os.path.join(pkg_sub, "bar"), 'w') as foo_file:
foo_file.write("Bar Data")
with open(os.path.join(subpackage, "__init__.py"), 'w') as foo_file:
foo_file.write("# empty")
self.useFixture(base.DiveDir(pkg_fixture.base))
def test_implicit_auto_package(self):
config = dict(
files=dict(
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn('subpackage', config['files']['packages'])
def test_auto_package(self):
config = dict(
files=dict(
packages='fake_package',
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn('subpackage', config['files']['packages'])
def test_data_files_globbing(self):
config = dict(
files=dict(
data_files="\n etc/pbr = etc/*"
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(
'\netc/pbr/ = \n etc/foo\netc/pbr/sub = \n etc/sub/bar',
config['files']['data_files'])

View file

@ -0,0 +1,75 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright (C) 2013 Association of Universities for Research in Astronomy
# (AURA)
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# 3. The name of AURA and its representatives may not be used to
# endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY AURA ``AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL AURA BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
import os
from testtools import matchers
from testtools import skipUnless
from pbr import testr_command
from pbr.tests import base
from pbr.tests import util
class TestHooks(base.BaseTestCase):
def setUp(self):
super(TestHooks, self).setUp()
with util.open_config(
os.path.join(self.package_dir, 'setup.cfg')) as cfg:
cfg.set('global', 'setup-hooks',
'pbr_testpackage._setup_hooks.test_hook_1\n'
'pbr_testpackage._setup_hooks.test_hook_2')
def test_global_setup_hooks(self):
"""Test setup_hooks.
Test that setup_hooks listed in the [global] section of setup.cfg are
executed in order.
"""
stdout, _, return_code = self.run_setup('egg_info')
assert 'test_hook_1\ntest_hook_2' in stdout
assert return_code == 0
@skipUnless(testr_command.have_testr, "testrepository not available")
def test_custom_commands_known(self):
stdout, _, return_code = self.run_setup('--help-commands')
self.assertFalse(return_code)
self.assertThat(stdout, matchers.Contains(" testr "))

View file

@ -0,0 +1,269 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import shlex
import sys
import fixtures
import testtools
import textwrap
from pbr.tests import base
from pbr.tests import test_packaging
PIPFLAGS = shlex.split(os.environ.get('PIPFLAGS', ''))
PIPVERSION = os.environ.get('PIPVERSION', 'pip')
PBRVERSION = os.environ.get('PBRVERSION', 'pbr')
REPODIR = os.environ.get('REPODIR', '')
WHEELHOUSE = os.environ.get('WHEELHOUSE', '')
PIP_CMD = ['-m', 'pip'] + PIPFLAGS + ['install', '-f', WHEELHOUSE]
PROJECTS = shlex.split(os.environ.get('PROJECTS', ''))
PBR_ROOT = os.path.abspath(os.path.join(__file__, '..', '..', '..'))
def all_projects():
if not REPODIR:
return
# Future: make this path parameterisable.
excludes = set(['tempest', 'requirements'])
for name in PROJECTS:
name = name.strip()
short_name = name.split('/')[-1]
try:
with open(os.path.join(
REPODIR, short_name, 'setup.py'), 'rt') as f:
if 'pbr' not in f.read():
continue
except IOError:
continue
if short_name in excludes:
continue
yield (short_name, dict(name=name, short_name=short_name))
class TestIntegration(base.BaseTestCase):
scenarios = list(all_projects())
def setUp(self):
# Integration tests need a higher default - big repos can be slow to
# clone, particularly under guest load.
env = fixtures.EnvironmentVariable(
'OS_TEST_TIMEOUT', os.environ.get('OS_TEST_TIMEOUT', '600'))
with env:
super(TestIntegration, self).setUp()
base._config_git()
@testtools.skipUnless(
os.environ.get('PBR_INTEGRATION', None) == '1',
'integration tests not enabled')
def test_integration(self):
# Test that we can:
# - run sdist from the repo in a venv
# - install the resulting tarball in a new venv
# - pip install the repo
# - pip install -e the repo
# We don't break these into separate tests because we'd need separate
# source dirs to isolate from side effects of running pip, and the
# overheads of setup would start to beat the benefits of parallelism.
self.useFixture(base.CapturedSubprocess(
'sync-req',
['python', 'update.py', os.path.join(REPODIR, self.short_name)],
cwd=os.path.join(REPODIR, 'requirements')))
self.useFixture(base.CapturedSubprocess(
'commit-requirements',
'git diff --quiet || git commit -amrequirements',
cwd=os.path.join(REPODIR, self.short_name), shell=True))
path = os.path.join(
self.useFixture(fixtures.TempDir()).path, 'project')
self.useFixture(base.CapturedSubprocess(
'clone',
['git', 'clone', os.path.join(REPODIR, self.short_name), path]))
venv = self.useFixture(
test_packaging.Venv('sdist',
modules=['pip', 'wheel', PBRVERSION],
pip_cmd=PIP_CMD))
python = venv.python
self.useFixture(base.CapturedSubprocess(
'sdist', [python, 'setup.py', 'sdist'], cwd=path))
venv = self.useFixture(
test_packaging.Venv('tarball',
modules=['pip', 'wheel', PBRVERSION],
pip_cmd=PIP_CMD))
python = venv.python
filename = os.path.join(
path, 'dist', os.listdir(os.path.join(path, 'dist'))[0])
self.useFixture(base.CapturedSubprocess(
'tarball', [python] + PIP_CMD + [filename]))
venv = self.useFixture(
test_packaging.Venv('install-git',
modules=['pip', 'wheel', PBRVERSION],
pip_cmd=PIP_CMD))
root = venv.path
python = venv.python
self.useFixture(base.CapturedSubprocess(
'install-git', [python] + PIP_CMD + ['git+file://' + path]))
if self.short_name == 'nova':
found = False
for _, _, filenames in os.walk(root):
if 'migrate.cfg' in filenames:
found = True
self.assertTrue(found)
venv = self.useFixture(
test_packaging.Venv('install-e',
modules=['pip', 'wheel', PBRVERSION],
pip_cmd=PIP_CMD))
root = venv.path
python = venv.python
self.useFixture(base.CapturedSubprocess(
'install-e', [python] + PIP_CMD + ['-e', path]))
class TestInstallWithoutPbr(base.BaseTestCase):
@testtools.skipUnless(
os.environ.get('PBR_INTEGRATION', None) == '1',
'integration tests not enabled')
def test_install_without_pbr(self):
# Test easy-install of a thing that depends on a thing using pbr
tempdir = self.useFixture(fixtures.TempDir()).path
# A directory containing sdists of the things we're going to depend on
# in using-package.
dist_dir = os.path.join(tempdir, 'distdir')
os.mkdir(dist_dir)
self._run_cmd(sys.executable, ('setup.py', 'sdist', '-d', dist_dir),
allow_fail=False, cwd=PBR_ROOT)
# testpkg - this requires a pbr-using package
test_pkg_dir = os.path.join(tempdir, 'testpkg')
os.mkdir(test_pkg_dir)
pkgs = {
'pkgTest': {
'setup.py': textwrap.dedent("""\
#!/usr/bin/env python
import setuptools
setuptools.setup(
name = 'pkgTest',
tests_require = ['pkgReq'],
test_suite='pkgReq'
)
"""),
'setup.cfg': textwrap.dedent("""\
[easy_install]
find_links = %s
""" % dist_dir)},
'pkgReq': {
'requirements.txt': textwrap.dedent("""\
pbr
"""),
'pkgReq/__init__.py': textwrap.dedent("""\
print("FakeTest loaded and ran")
""")},
}
pkg_dirs = self.useFixture(
test_packaging.CreatePackages(pkgs)).package_dirs
test_pkg_dir = pkg_dirs['pkgTest']
req_pkg_dir = pkg_dirs['pkgReq']
self._run_cmd(sys.executable, ('setup.py', 'sdist', '-d', dist_dir),
allow_fail=False, cwd=req_pkg_dir)
# A venv to test within
venv = self.useFixture(test_packaging.Venv('nopbr', ['pip', 'wheel']))
python = venv.python
# Run the depending script
self.useFixture(base.CapturedSubprocess(
'nopbr', [python] + ['setup.py', 'test'], cwd=test_pkg_dir))
class TestMarkersPip(base.BaseTestCase):
scenarios = [
('pip-1.5', {'modules': ['pip>=1.5,<1.6']}),
('pip-6.0', {'modules': ['pip>=6.0,<6.1']}),
('pip-latest', {'modules': ['pip']}),
('setuptools-EL7', {'modules': ['pip==1.4.1', 'setuptools==0.9.8']}),
('setuptools-Trusty', {'modules': ['pip==1.5', 'setuptools==2.2']}),
('setuptools-minimum', {'modules': ['pip==1.5', 'setuptools==0.7.2']}),
]
@testtools.skipUnless(
os.environ.get('PBR_INTEGRATION', None) == '1',
'integration tests not enabled')
def test_pip_versions(self):
pkgs = {
'test_markers':
{'requirements.txt': textwrap.dedent("""\
pkg_a; python_version=='1.2'
pkg_b; python_version!='1.2'
""")},
'pkg_a': {},
'pkg_b': {},
}
pkg_dirs = self.useFixture(
test_packaging.CreatePackages(pkgs)).package_dirs
temp_dir = self.useFixture(fixtures.TempDir()).path
repo_dir = os.path.join(temp_dir, 'repo')
venv = self.useFixture(test_packaging.Venv('markers'))
bin_python = venv.python
os.mkdir(repo_dir)
for module in self.modules:
self._run_cmd(
bin_python,
['-m', 'pip', 'install', '--upgrade', module],
cwd=venv.path, allow_fail=False)
for pkg in pkg_dirs:
self._run_cmd(
bin_python, ['setup.py', 'sdist', '-d', repo_dir],
cwd=pkg_dirs[pkg], allow_fail=False)
self._run_cmd(
bin_python,
['-m', 'pip', 'install', '--no-index', '-f', repo_dir,
'test_markers'],
cwd=venv.path, allow_fail=False)
self.assertIn('pkg-b', self._run_cmd(
bin_python, ['-m', 'pip', 'freeze'], cwd=venv.path,
allow_fail=False)[0])
class TestLTSSupport(base.BaseTestCase):
# These versions come from the versions installed from the 'virtualenv'
# command from the 'python-virtualenv' package.
scenarios = [
('EL7', {'modules': ['pip==1.4.1', 'setuptools==0.9.8'],
'py3support': True}), # And EPEL6
('Trusty', {'modules': ['pip==1.5', 'setuptools==2.2'],
'py3support': True}),
('Jessie', {'modules': ['pip==1.5.6', 'setuptools==5.5.1'],
'py3support': True}),
# Wheezy has pip1.1, which cannot be called with '-m pip'
# So we'll use a different version of pip here.
('WheezyPrecise', {'modules': ['pip==1.4.1', 'setuptools==0.6c11'],
'py3support': False})
]
@testtools.skipUnless(
os.environ.get('PBR_INTEGRATION', None) == '1',
'integration tests not enabled')
def test_lts_venv_default_versions(self):
if (sys.version_info[0] == 3 and not self.py3support):
self.skipTest('This combination will not install with py3, '
'skipping test')
venv = self.useFixture(
test_packaging.Venv('setuptools', modules=self.modules))
bin_python = venv.python
pbr = 'file://%s#egg=pbr' % PBR_ROOT
# Installing PBR is a reasonable indication that we are not broken on
# this particular combination of setuptools and pip.
self._run_cmd(bin_python, ['-m', 'pip', 'install', pbr],
cwd=venv.path, allow_fail=False)

View file

@ -0,0 +1,923 @@
# Copyright (c) 2013 New Dream Network, LLC (DreamHost)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright (C) 2013 Association of Universities for Research in Astronomy
# (AURA)
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# 3. The name of AURA and its representatives may not be used to
# endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY AURA ``AS IS'' AND ANY EXPRESS OR IMPLIED
# WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL AURA BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
import email
import email.errors
import imp
import os
import re
import sysconfig
import tempfile
import textwrap
import fixtures
import mock
import pkg_resources
import six
import testscenarios
import testtools
from testtools import matchers
import virtualenv
from wheel import wheelfile
from pbr import git
from pbr import packaging
from pbr.tests import base
PBR_ROOT = os.path.abspath(os.path.join(__file__, '..', '..', '..'))
class TestRepo(fixtures.Fixture):
"""A git repo for testing with.
Use of TempHomeDir with this fixture is strongly recommended as due to the
lack of config --local in older gits, it will write to the users global
configuration without TempHomeDir.
"""
def __init__(self, basedir):
super(TestRepo, self).__init__()
self._basedir = basedir
def setUp(self):
super(TestRepo, self).setUp()
base._run_cmd(['git', 'init', '.'], self._basedir)
base._config_git()
base._run_cmd(['git', 'add', '.'], self._basedir)
def commit(self, message_content='test commit'):
files = len(os.listdir(self._basedir))
path = self._basedir + '/%d' % files
open(path, 'wt').close()
base._run_cmd(['git', 'add', path], self._basedir)
base._run_cmd(['git', 'commit', '-m', message_content], self._basedir)
def uncommit(self):
base._run_cmd(['git', 'reset', '--hard', 'HEAD^'], self._basedir)
def tag(self, version):
base._run_cmd(
['git', 'tag', '-sm', 'test tag', version], self._basedir)
class GPGKeyFixture(fixtures.Fixture):
"""Creates a GPG key for testing.
It's recommended that this be used in concert with a unique home
directory.
"""
def setUp(self):
super(GPGKeyFixture, self).setUp()
tempdir = self.useFixture(fixtures.TempDir())
gnupg_version_re = re.compile('^gpg\s.*\s([\d+])\.([\d+])\.([\d+])')
gnupg_version = base._run_cmd(['gpg', '--version'], tempdir.path)
for line in gnupg_version[0].split('\n'):
gnupg_version = gnupg_version_re.match(line)
if gnupg_version:
gnupg_version = (int(gnupg_version.group(1)),
int(gnupg_version.group(2)),
int(gnupg_version.group(3)))
break
else:
if gnupg_version is None:
gnupg_version = (0, 0, 0)
config_file = tempdir.path + '/key-config'
f = open(config_file, 'wt')
try:
if gnupg_version[0] == 2 and gnupg_version[1] >= 1:
f.write("""
%no-protection
%transient-key
""")
f.write("""
%no-ask-passphrase
Key-Type: RSA
Name-Real: Example Key
Name-Comment: N/A
Name-Email: example@example.com
Expire-Date: 2d
Preferences: (setpref)
%commit
""")
finally:
f.close()
# Note that --quick-random (--debug-quick-random in GnuPG 2.x)
# does not have a corresponding preferences file setting and
# must be passed explicitly on the command line instead
if gnupg_version[0] == 1:
gnupg_random = '--quick-random'
elif gnupg_version[0] >= 2:
gnupg_random = '--debug-quick-random'
else:
gnupg_random = ''
base._run_cmd(
['gpg', '--gen-key', '--batch', gnupg_random, config_file],
tempdir.path)
class Venv(fixtures.Fixture):
"""Create a virtual environment for testing with.
:attr path: The path to the environment root.
:attr python: The path to the python binary in the environment.
"""
def __init__(self, reason, modules=(), pip_cmd=None):
"""Create a Venv fixture.
:param reason: A human readable string to bake into the venv
file path to aid diagnostics in the case of failures.
:param modules: A list of modules to install, defaults to latest
pip, wheel, and the working copy of PBR.
:attr pip_cmd: A list to override the default pip_cmd passed to
python for installing base packages.
"""
self._reason = reason
if modules == ():
pbr = 'file://%s#egg=pbr' % PBR_ROOT
modules = ['pip', 'wheel', pbr]
self.modules = modules
if pip_cmd is None:
self.pip_cmd = ['-m', 'pip', 'install']
else:
self.pip_cmd = pip_cmd
def _setUp(self):
path = self.useFixture(fixtures.TempDir()).path
virtualenv.create_environment(path, clear=True)
python = os.path.join(path, 'bin', 'python')
command = [python] + self.pip_cmd + ['-U']
if self.modules and len(self.modules) > 0:
command.extend(self.modules)
self.useFixture(base.CapturedSubprocess(
'mkvenv-' + self._reason, command))
self.addCleanup(delattr, self, 'path')
self.addCleanup(delattr, self, 'python')
self.path = path
self.python = python
return path, python
class CreatePackages(fixtures.Fixture):
"""Creates packages from dict with defaults
:param package_dirs: A dict of package name to directory strings
{'pkg_a': '/tmp/path/to/tmp/pkg_a', 'pkg_b': '/tmp/path/to/tmp/pkg_b'}
"""
defaults = {
'setup.py': textwrap.dedent(six.u("""\
#!/usr/bin/env python
import setuptools
setuptools.setup(
setup_requires=['pbr'],
pbr=True,
)
""")),
'setup.cfg': textwrap.dedent(six.u("""\
[metadata]
name = {pkg_name}
"""))
}
def __init__(self, packages):
"""Creates packages from dict with defaults
:param packages: a dict where the keys are the package name and a
value that is a second dict that may be empty, containing keys of
filenames and a string value of the contents.
{'package-a': {'requirements.txt': 'string', 'setup.cfg': 'string'}
"""
self.packages = packages
def _writeFile(self, directory, file_name, contents):
path = os.path.abspath(os.path.join(directory, file_name))
path_dir = os.path.dirname(path)
if not os.path.exists(path_dir):
if path_dir.startswith(directory):
os.makedirs(path_dir)
else:
raise ValueError
with open(path, 'wt') as f:
f.write(contents)
def _setUp(self):
tmpdir = self.useFixture(fixtures.TempDir()).path
package_dirs = {}
for pkg_name in self.packages:
pkg_path = os.path.join(tmpdir, pkg_name)
package_dirs[pkg_name] = pkg_path
os.mkdir(pkg_path)
for cf in ['setup.py', 'setup.cfg']:
if cf in self.packages[pkg_name]:
contents = self.packages[pkg_name].pop(cf)
else:
contents = self.defaults[cf].format(pkg_name=pkg_name)
self._writeFile(pkg_path, cf, contents)
for cf in self.packages[pkg_name]:
self._writeFile(pkg_path, cf, self.packages[pkg_name][cf])
self.useFixture(TestRepo(pkg_path)).commit()
self.addCleanup(delattr, self, 'package_dirs')
self.package_dirs = package_dirs
return package_dirs
class TestPackagingInGitRepoWithCommit(base.BaseTestCase):
scenarios = [
('preversioned', dict(preversioned=True)),
('postversioned', dict(preversioned=False)),
]
def setUp(self):
super(TestPackagingInGitRepoWithCommit, self).setUp()
self.repo = self.useFixture(TestRepo(self.package_dir))
self.repo.commit()
def test_authors(self):
self.run_setup('sdist', allow_fail=False)
# One commit, something should be in the authors list
with open(os.path.join(self.package_dir, 'AUTHORS'), 'r') as f:
body = f.read()
self.assertNotEqual(body, '')
def test_changelog(self):
self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read()
# One commit, something should be in the ChangeLog list
self.assertNotEqual(body, '')
def test_changelog_handles_astrisk(self):
self.repo.commit(message_content="Allow *.openstack.org to work")
self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read()
self.assertIn('\*', body)
def test_changelog_handles_dead_links_in_commit(self):
self.repo.commit(message_content="See os_ for to_do about qemu_.")
self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read()
self.assertIn('os\_', body)
self.assertIn('to\_do', body)
self.assertIn('qemu\_', body)
def test_changelog_handles_backticks(self):
self.repo.commit(message_content="Allow `openstack.org` to `work")
self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read()
self.assertIn('\`', body)
def test_manifest_exclude_honoured(self):
self.run_setup('sdist', allow_fail=False)
with open(os.path.join(
self.package_dir,
'pbr_testpackage.egg-info/SOURCES.txt'), 'r') as f:
body = f.read()
self.assertThat(
body, matchers.Not(matchers.Contains('pbr_testpackage/extra.py')))
self.assertThat(body, matchers.Contains('pbr_testpackage/__init__.py'))
def test_install_writes_changelog(self):
stdout, _, _ = self.run_setup(
'install', '--root', self.temp_dir + 'installed',
allow_fail=False)
self.expectThat(stdout, matchers.Contains('Generating ChangeLog'))
class TestExtrafileInstallation(base.BaseTestCase):
def test_install_glob(self):
stdout, _, _ = self.run_setup(
'install', '--root', self.temp_dir + 'installed',
allow_fail=False)
self.expectThat(
stdout, matchers.Contains('copying data_files/a.txt'))
self.expectThat(
stdout, matchers.Contains('copying data_files/b.txt'))
class TestPackagingInGitRepoWithoutCommit(base.BaseTestCase):
def setUp(self):
super(TestPackagingInGitRepoWithoutCommit, self).setUp()
self.useFixture(TestRepo(self.package_dir))
self.run_setup('sdist', allow_fail=False)
def test_authors(self):
# No commits, no authors in list
with open(os.path.join(self.package_dir, 'AUTHORS'), 'r') as f:
body = f.read()
self.assertEqual('\n', body)
def test_changelog(self):
# No commits, nothing should be in the ChangeLog list
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read()
self.assertEqual('CHANGES\n=======\n\n', body)
class TestPackagingWheels(base.BaseTestCase):
def setUp(self):
super(TestPackagingWheels, self).setUp()
self.useFixture(TestRepo(self.package_dir))
# Build the wheel
self.run_setup('bdist_wheel', allow_fail=False)
# Slowly construct the path to the generated whl
dist_dir = os.path.join(self.package_dir, 'dist')
relative_wheel_filename = os.listdir(dist_dir)[0]
absolute_wheel_filename = os.path.join(
dist_dir, relative_wheel_filename)
wheel_file = wheelfile.WheelFile(absolute_wheel_filename)
wheel_name = wheel_file.parsed_filename.group('namever')
# Create a directory path to unpack the wheel to
self.extracted_wheel_dir = os.path.join(dist_dir, wheel_name)
# Extract the wheel contents to the directory we just created
wheel_file.extractall(self.extracted_wheel_dir)
wheel_file.close()
def test_data_directory_has_wsgi_scripts(self):
# Build the path to the scripts directory
scripts_dir = os.path.join(
self.extracted_wheel_dir, 'pbr_testpackage-0.0.data/scripts')
self.assertTrue(os.path.exists(scripts_dir))
scripts = os.listdir(scripts_dir)
self.assertIn('pbr_test_wsgi', scripts)
self.assertIn('pbr_test_wsgi_with_class', scripts)
self.assertNotIn('pbr_test_cmd', scripts)
self.assertNotIn('pbr_test_cmd_with_class', scripts)
def test_generates_c_extensions(self):
built_package_dir = os.path.join(
self.extracted_wheel_dir, 'pbr_testpackage')
static_object_filename = 'testext.so'
soabi = get_soabi()
if soabi:
static_object_filename = 'testext.{0}.so'.format(soabi)
static_object_path = os.path.join(
built_package_dir, static_object_filename)
self.assertTrue(os.path.exists(built_package_dir))
self.assertTrue(os.path.exists(static_object_path))
class TestPackagingHelpers(testtools.TestCase):
def test_generate_script(self):
group = 'console_scripts'
entry_point = pkg_resources.EntryPoint(
name='test-ep',
module_name='pbr.packaging',
attrs=('LocalInstallScripts',))
header = '#!/usr/bin/env fake-header\n'
template = ('%(group)s %(module_name)s %(import_target)s '
'%(invoke_target)s')
generated_script = packaging.generate_script(
group, entry_point, header, template)
expected_script = (
'#!/usr/bin/env fake-header\nconsole_scripts pbr.packaging '
'LocalInstallScripts LocalInstallScripts'
)
self.assertEqual(expected_script, generated_script)
def test_generate_script_validates_expectations(self):
group = 'console_scripts'
entry_point = pkg_resources.EntryPoint(
name='test-ep',
module_name='pbr.packaging')
header = '#!/usr/bin/env fake-header\n'
template = ('%(group)s %(module_name)s %(import_target)s '
'%(invoke_target)s')
self.assertRaises(
ValueError, packaging.generate_script, group, entry_point, header,
template)
entry_point = pkg_resources.EntryPoint(
name='test-ep',
module_name='pbr.packaging',
attrs=('attr1', 'attr2', 'attr3'))
self.assertRaises(
ValueError, packaging.generate_script, group, entry_point, header,
template)
class TestPackagingInPlainDirectory(base.BaseTestCase):
def setUp(self):
super(TestPackagingInPlainDirectory, self).setUp()
def test_authors(self):
self.run_setup('sdist', allow_fail=False)
# Not a git repo, no AUTHORS file created
filename = os.path.join(self.package_dir, 'AUTHORS')
self.assertFalse(os.path.exists(filename))
def test_changelog(self):
self.run_setup('sdist', allow_fail=False)
# Not a git repo, no ChangeLog created
filename = os.path.join(self.package_dir, 'ChangeLog')
self.assertFalse(os.path.exists(filename))
def test_install_no_ChangeLog(self):
stdout, _, _ = self.run_setup(
'install', '--root', self.temp_dir + 'installed',
allow_fail=False)
self.expectThat(
stdout, matchers.Not(matchers.Contains('Generating ChangeLog')))
class TestPresenceOfGit(base.BaseTestCase):
def testGitIsInstalled(self):
with mock.patch.object(git,
'_run_shell_command') as _command:
_command.return_value = 'git version 1.8.4.1'
self.assertEqual(True, git._git_is_installed())
def testGitIsNotInstalled(self):
with mock.patch.object(git,
'_run_shell_command') as _command:
_command.side_effect = OSError
self.assertEqual(False, git._git_is_installed())
class ParseRequirementsTest(base.BaseTestCase):
def test_empty_requirements(self):
actual = packaging.parse_requirements([])
self.assertEqual([], actual)
def test_default_requirements(self):
"""Ensure default files used if no files provided."""
tempdir = tempfile.mkdtemp()
requirements = os.path.join(tempdir, 'requirements.txt')
with open(requirements, 'w') as f:
f.write('pbr')
# the defaults are relative to where pbr is called from so we need to
# override them. This is OK, however, as we want to validate that
# defaults are used - not what those defaults are
with mock.patch.object(packaging, 'REQUIREMENTS_FILES', (
requirements,)):
result = packaging.parse_requirements()
self.assertEqual(['pbr'], result)
def test_override_with_env(self):
"""Ensure environment variable used if no files provided."""
_, tmp_file = tempfile.mkstemp(prefix='openstack', suffix='.setup')
with open(tmp_file, 'w') as fh:
fh.write("foo\nbar")
self.useFixture(
fixtures.EnvironmentVariable('PBR_REQUIREMENTS_FILES', tmp_file))
self.assertEqual(['foo', 'bar'],
packaging.parse_requirements())
def test_override_with_env_multiple_files(self):
_, tmp_file = tempfile.mkstemp(prefix='openstack', suffix='.setup')
with open(tmp_file, 'w') as fh:
fh.write("foo\nbar")
self.useFixture(
fixtures.EnvironmentVariable('PBR_REQUIREMENTS_FILES',
"no-such-file," + tmp_file))
self.assertEqual(['foo', 'bar'],
packaging.parse_requirements())
def test_index_present(self):
tempdir = tempfile.mkdtemp()
requirements = os.path.join(tempdir, 'requirements.txt')
with open(requirements, 'w') as f:
f.write('-i https://myindex.local')
f.write(' --index-url https://myindex.local')
f.write(' --extra-index-url https://myindex.local')
result = packaging.parse_requirements([requirements])
self.assertEqual([], result)
def test_nested_requirements(self):
tempdir = tempfile.mkdtemp()
requirements = os.path.join(tempdir, 'requirements.txt')
nested = os.path.join(tempdir, 'nested.txt')
with open(requirements, 'w') as f:
f.write('-r ' + nested)
with open(nested, 'w') as f:
f.write('pbr')
result = packaging.parse_requirements([requirements])
self.assertEqual(['pbr'], result)
class ParseRequirementsTestScenarios(base.BaseTestCase):
versioned_scenarios = [
('non-versioned', {'versioned': False, 'expected': ['bar']}),
('versioned', {'versioned': True, 'expected': ['bar>=1.2.3']})
]
subdirectory_scenarios = [
('non-subdirectory', {'has_subdirectory': False}),
('has-subdirectory', {'has_subdirectory': True})
]
scenarios = [
('normal', {'url': "foo\nbar", 'expected': ['foo', 'bar']}),
('normal_with_comments', {
'url': "# this is a comment\nfoo\n# and another one\nbar",
'expected': ['foo', 'bar']}),
('removes_index_lines', {'url': '-f foobar', 'expected': []}),
]
scenarios = scenarios + testscenarios.multiply_scenarios([
('ssh_egg_url', {'url': 'git+ssh://foo.com/zipball#egg=bar'}),
('git_https_egg_url', {'url': 'git+https://foo.com/zipball#egg=bar'}),
('http_egg_url', {'url': 'https://foo.com/zipball#egg=bar'}),
], versioned_scenarios, subdirectory_scenarios)
scenarios = scenarios + testscenarios.multiply_scenarios(
[
('git_egg_url',
{'url': 'git://foo.com/zipball#egg=bar', 'name': 'bar'})
], [
('non-editable', {'editable': False}),
('editable', {'editable': True}),
],
versioned_scenarios, subdirectory_scenarios)
def test_parse_requirements(self):
tmp_file = tempfile.NamedTemporaryFile()
req_string = self.url
if hasattr(self, 'editable') and self.editable:
req_string = ("-e %s" % req_string)
if hasattr(self, 'versioned') and self.versioned:
req_string = ("%s-1.2.3" % req_string)
if hasattr(self, 'has_subdirectory') and self.has_subdirectory:
req_string = ("%s&subdirectory=baz" % req_string)
with open(tmp_file.name, 'w') as fh:
fh.write(req_string)
self.assertEqual(self.expected,
packaging.parse_requirements([tmp_file.name]))
class ParseDependencyLinksTest(base.BaseTestCase):
def setUp(self):
super(ParseDependencyLinksTest, self).setUp()
_, self.tmp_file = tempfile.mkstemp(prefix="openstack",
suffix=".setup")
def test_parse_dependency_normal(self):
with open(self.tmp_file, "w") as fh:
fh.write("http://test.com\n")
self.assertEqual(
["http://test.com"],
packaging.parse_dependency_links([self.tmp_file]))
def test_parse_dependency_with_git_egg_url(self):
with open(self.tmp_file, "w") as fh:
fh.write("-e git://foo.com/zipball#egg=bar")
self.assertEqual(
["git://foo.com/zipball#egg=bar"],
packaging.parse_dependency_links([self.tmp_file]))
class TestVersions(base.BaseTestCase):
scenarios = [
('preversioned', dict(preversioned=True)),
('postversioned', dict(preversioned=False)),
]
def setUp(self):
super(TestVersions, self).setUp()
self.repo = self.useFixture(TestRepo(self.package_dir))
self.useFixture(GPGKeyFixture())
self.useFixture(base.DiveDir(self.package_dir))
def test_email_parsing_errors_are_handled(self):
mocked_open = mock.mock_open()
with mock.patch('pbr.packaging.open', mocked_open):
with mock.patch('email.message_from_file') as message_from_file:
message_from_file.side_effect = [
email.errors.MessageError('Test'),
{'Name': 'pbr_testpackage'}]
version = packaging._get_version_from_pkg_metadata(
'pbr_testpackage')
self.assertTrue(message_from_file.called)
self.assertIsNone(version)
def test_capitalized_headers(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('Sem-Ver: api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_capitalized_headers_partial(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('Sem-ver: api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_tagged_version_has_tag_version(self):
self.repo.commit()
self.repo.tag('1.2.3')
version = packaging._get_version_from_git('1.2.3')
self.assertEqual('1.2.3', version)
def test_non_canonical_tagged_version_bump(self):
self.repo.commit()
self.repo.tag('1.4')
self.repo.commit('Sem-Ver: api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_untagged_version_has_dev_version_postversion(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit()
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.2.4.dev1'))
def test_untagged_pre_release_has_pre_dev_version_postversion(self):
self.repo.commit()
self.repo.tag('1.2.3.0a1')
self.repo.commit()
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.2.3.0a2.dev1'))
def test_untagged_version_minor_bump(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('sem-ver: deprecation')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.3.0.dev1'))
def test_untagged_version_major_bump(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('sem-ver: api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_untagged_version_has_dev_version_preversion(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit()
version = packaging._get_version_from_git('1.2.5')
self.assertThat(version, matchers.StartsWith('1.2.5.dev1'))
def test_untagged_version_after_pre_has_dev_version_preversion(self):
self.repo.commit()
self.repo.tag('1.2.3.0a1')
self.repo.commit()
version = packaging._get_version_from_git('1.2.5')
self.assertThat(version, matchers.StartsWith('1.2.5.dev1'))
def test_untagged_version_after_rc_has_dev_version_preversion(self):
self.repo.commit()
self.repo.tag('1.2.3.0a1')
self.repo.commit()
version = packaging._get_version_from_git('1.2.3')
self.assertThat(version, matchers.StartsWith('1.2.3.0a2.dev1'))
def test_preversion_too_low_simple(self):
# That is, the target version is either already released or not high
# enough for the semver requirements given api breaks etc.
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit()
# Note that we can't target 1.2.3 anymore - with 1.2.3 released we
# need to be working on 1.2.4.
err = self.assertRaises(
ValueError, packaging._get_version_from_git, '1.2.3')
self.assertThat(err.args[0], matchers.StartsWith('git history'))
def test_preversion_too_low_semver_headers(self):
# That is, the target version is either already released or not high
# enough for the semver requirements given api breaks etc.
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('sem-ver: feature')
# Note that we can't target 1.2.4, the feature header means we need
# to be working on 1.3.0 or above.
err = self.assertRaises(
ValueError, packaging._get_version_from_git, '1.2.4')
self.assertThat(err.args[0], matchers.StartsWith('git history'))
def test_get_kwargs_corner_cases(self):
# No tags:
git_dir = self.repo._basedir + '/.git'
get_kwargs = lambda tag: packaging._get_increment_kwargs(git_dir, tag)
def _check_combinations(tag):
self.repo.commit()
self.assertEqual(dict(), get_kwargs(tag))
self.repo.commit('sem-ver: bugfix')
self.assertEqual(dict(), get_kwargs(tag))
self.repo.commit('sem-ver: feature')
self.assertEqual(dict(minor=True), get_kwargs(tag))
self.repo.uncommit()
self.repo.commit('sem-ver: deprecation')
self.assertEqual(dict(minor=True), get_kwargs(tag))
self.repo.uncommit()
self.repo.commit('sem-ver: api-break')
self.assertEqual(dict(major=True), get_kwargs(tag))
self.repo.commit('sem-ver: deprecation')
self.assertEqual(dict(major=True, minor=True), get_kwargs(tag))
_check_combinations('')
self.repo.tag('1.2.3')
_check_combinations('1.2.3')
def test_invalid_tag_ignored(self):
# Fix for bug 1356784 - we treated any tag as a version, not just those
# that are valid versions.
self.repo.commit()
self.repo.tag('1')
self.repo.commit()
# when the tree is tagged and its wrong:
self.repo.tag('badver')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.0.1.dev1'))
# When the tree isn't tagged, we also fall through.
self.repo.commit()
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.0.1.dev2'))
# We don't fall through x.y versions
self.repo.commit()
self.repo.tag('1.2')
self.repo.commit()
self.repo.tag('badver2')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.2.1.dev1'))
# Or x.y.z versions
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit()
self.repo.tag('badver3')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.2.4.dev1'))
# Or alpha/beta/pre versions
self.repo.commit()
self.repo.tag('1.2.4.0a1')
self.repo.commit()
self.repo.tag('badver4')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.2.4.0a2.dev1'))
# Non-release related tags are ignored.
self.repo.commit()
self.repo.tag('2')
self.repo.commit()
self.repo.tag('non-release-tag/2014.12.16-1')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.1.dev1'))
def test_valid_tag_honoured(self):
# Fix for bug 1370608 - we converted any target into a 'dev version'
# even if there was a distance of 0 - indicating that we were on the
# tag itself.
self.repo.commit()
self.repo.tag('1.3.0.0a1')
version = packaging._get_version_from_git()
self.assertEqual('1.3.0.0a1', version)
def test_skip_write_git_changelog(self):
# Fix for bug 1467440
self.repo.commit()
self.repo.tag('1.2.3')
os.environ['SKIP_WRITE_GIT_CHANGELOG'] = '1'
version = packaging._get_version_from_git('1.2.3')
self.assertEqual('1.2.3', version)
def tearDown(self):
super(TestVersions, self).tearDown()
os.environ.pop('SKIP_WRITE_GIT_CHANGELOG', None)
class TestRequirementParsing(base.BaseTestCase):
def test_requirement_parsing(self):
pkgs = {
'test_reqparse':
{
'requirements.txt': textwrap.dedent("""\
bar
quux<1.0; python_version=='2.6'
requests-aws>=0.1.4 # BSD License (3 clause)
Routes>=1.12.3,!=2.0,!=2.1;python_version=='2.7'
requests-kerberos>=0.6;python_version=='2.7' # MIT
"""),
'setup.cfg': textwrap.dedent("""\
[metadata]
name = test_reqparse
[extras]
test =
foo
baz>3.2 :python_version=='2.7' # MIT
bar>3.3 :python_version=='2.7' # MIT # Apache
""")},
}
pkg_dirs = self.useFixture(CreatePackages(pkgs)).package_dirs
pkg_dir = pkg_dirs['test_reqparse']
# pkg_resources.split_sections uses None as the title of an
# anonymous section instead of the empty string. Weird.
expected_requirements = {
None: ['bar', 'requests-aws>=0.1.4'],
":(python_version=='2.6')": ['quux<1.0'],
":(python_version=='2.7')": ['Routes!=2.0,!=2.1,>=1.12.3',
'requests-kerberos>=0.6'],
'test': ['foo'],
"test:(python_version=='2.7')": ['baz>3.2', 'bar>3.3']
}
venv = self.useFixture(Venv('reqParse'))
bin_python = venv.python
# Two things are tested by this
# 1) pbr properly parses markers from requiremnts.txt and setup.cfg
# 2) bdist_wheel causes pbr to not evaluate markers
self._run_cmd(bin_python, ('setup.py', 'bdist_wheel'),
allow_fail=False, cwd=pkg_dir)
egg_info = os.path.join(pkg_dir, 'test_reqparse.egg-info')
requires_txt = os.path.join(egg_info, 'requires.txt')
with open(requires_txt, 'rt') as requires:
generated_requirements = dict(
pkg_resources.split_sections(requires))
# NOTE(dhellmann): We have to spell out the comparison because
# the rendering for version specifiers in a range is not
# consistent across versions of setuptools.
for section, expected in expected_requirements.items():
exp_parsed = [
pkg_resources.Requirement.parse(s)
for s in expected
]
gen_parsed = [
pkg_resources.Requirement.parse(s)
for s in generated_requirements[section]
]
self.assertEqual(exp_parsed, gen_parsed)
def get_soabi():
soabi = None
try:
soabi = sysconfig.get_config_var('SOABI')
arch = sysconfig.get_config_var('MULTIARCH')
except IOError:
pass
if soabi and arch and 'pypy' in sysconfig.get_scheme_names():
soabi = '%s-%s' % (soabi, arch)
if soabi is None and 'pypy' in sysconfig.get_scheme_names():
# NOTE(sigmavirus24): PyPy only added support for the SOABI config var
# to sysconfig in 2015. That was well after 2.2.1 was published in the
# Ubuntu 14.04 archive.
for suffix, _, _ in imp.get_suffixes():
if suffix.startswith('.pypy') and suffix.endswith('.so'):
soabi = suffix.split('.')[1]
break
return soabi

View file

@ -0,0 +1,30 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import mock
from pbr import pbr_json
from pbr.tests import base
class TestJsonContent(base.BaseTestCase):
@mock.patch('pbr.git._run_git_functions', return_value=True)
@mock.patch('pbr.git.get_git_short_sha', return_value="123456")
@mock.patch('pbr.git.get_is_release', return_value=True)
def test_content(self, mock_get_is, mock_get_git, mock_run):
cmd = mock.Mock()
pbr_json.write_pbr_json(cmd, "basename", "pbr.json")
cmd.write_file.assert_called_once_with(
'pbr',
'pbr.json',
'{"git_version": "123456", "is_release": true}'
)

View file

@ -0,0 +1,445 @@
# Copyright (c) 2011 OpenStack Foundation
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import os
try:
import cStringIO as io
BytesIO = io.StringIO
except ImportError:
import io
BytesIO = io.BytesIO
import fixtures
from pbr import git
from pbr import options
from pbr import packaging
from pbr.tests import base
class SkipFileWrites(base.BaseTestCase):
scenarios = [
('changelog_option_true',
dict(option_key='skip_changelog', option_value='True',
env_key='SKIP_WRITE_GIT_CHANGELOG', env_value=None,
pkg_func=git.write_git_changelog, filename='ChangeLog')),
('changelog_option_false',
dict(option_key='skip_changelog', option_value='False',
env_key='SKIP_WRITE_GIT_CHANGELOG', env_value=None,
pkg_func=git.write_git_changelog, filename='ChangeLog')),
('changelog_env_true',
dict(option_key='skip_changelog', option_value='False',
env_key='SKIP_WRITE_GIT_CHANGELOG', env_value='True',
pkg_func=git.write_git_changelog, filename='ChangeLog')),
('changelog_both_true',
dict(option_key='skip_changelog', option_value='True',
env_key='SKIP_WRITE_GIT_CHANGELOG', env_value='True',
pkg_func=git.write_git_changelog, filename='ChangeLog')),
('authors_option_true',
dict(option_key='skip_authors', option_value='True',
env_key='SKIP_GENERATE_AUTHORS', env_value=None,
pkg_func=git.generate_authors, filename='AUTHORS')),
('authors_option_false',
dict(option_key='skip_authors', option_value='False',
env_key='SKIP_GENERATE_AUTHORS', env_value=None,
pkg_func=git.generate_authors, filename='AUTHORS')),
('authors_env_true',
dict(option_key='skip_authors', option_value='False',
env_key='SKIP_GENERATE_AUTHORS', env_value='True',
pkg_func=git.generate_authors, filename='AUTHORS')),
('authors_both_true',
dict(option_key='skip_authors', option_value='True',
env_key='SKIP_GENERATE_AUTHORS', env_value='True',
pkg_func=git.generate_authors, filename='AUTHORS')),
]
def setUp(self):
super(SkipFileWrites, self).setUp()
self.temp_path = self.useFixture(fixtures.TempDir()).path
self.root_dir = os.path.abspath(os.path.curdir)
self.git_dir = os.path.join(self.root_dir, ".git")
if not os.path.exists(self.git_dir):
self.skipTest("%s is missing; skipping git-related checks"
% self.git_dir)
return
self.filename = os.path.join(self.temp_path, self.filename)
self.option_dict = dict()
if self.option_key is not None:
self.option_dict[self.option_key] = ('setup.cfg',
self.option_value)
self.useFixture(
fixtures.EnvironmentVariable(self.env_key, self.env_value))
def test_skip(self):
self.pkg_func(git_dir=self.git_dir,
dest_dir=self.temp_path,
option_dict=self.option_dict)
self.assertEqual(
not os.path.exists(self.filename),
(self.option_value.lower() in options.TRUE_VALUES
or self.env_value is not None))
_changelog_content = """7780758\x00Break parser\x00 (tag: refs/tags/1_foo.1)
04316fe\x00Make python\x00 (refs/heads/review/monty_taylor/27519)
378261a\x00Add an integration test script.\x00
3c373ac\x00Merge "Lib\x00 (HEAD, tag: refs/tags/2013.2.rc2, tag: refs/tags/2013.2, refs/heads/mile-proposed)
182feb3\x00Fix pip invocation for old versions of pip.\x00 (tag: refs/tags/0.5.17)
fa4f46e\x00Remove explicit depend on distribute.\x00 (tag: refs/tags/0.5.16)
d1c53dd\x00Use pip instead of easy_install for installation.\x00
a793ea1\x00Merge "Skip git-checkout related tests when .git is missing"\x00
6c27ce7\x00Skip git-checkout related tests when .git is missing\x00
451e513\x00Bug fix: create_stack() fails when waiting\x00
4c8cfe4\x00Improve test coverage: network delete API\x00 (tag: refs/tags/(evil))
d7e6167\x00Bug fix: Fix pass thru filtering in list_networks\x00 (tag: refs/tags/ev()il)
c47ec15\x00Consider 'in-use' a non-pending volume for caching\x00 (tag: refs/tags/ev)il)
8696fbd\x00Improve test coverage: private extension API\x00 (tag: refs/tags/ev(il)
f0440f8\x00Improve test coverage: hypervisor list\x00 (tag: refs/tags/e(vi)l)
04984a5\x00Refactor hooks file.\x00 (HEAD, tag: 0.6.7,b, tag: refs/tags/(12), refs/heads/master)
a65e8ee\x00Remove jinja pin.\x00 (tag: refs/tags/0.5.14, tag: refs/tags/0.5.13)
""" # noqa
def _make_old_git_changelog_format(line):
"""Convert post-1.8.1 git log format to pre-1.8.1 git log format"""
if not line.strip():
return line
sha, msg, refname = line.split('\x00')
refname = refname.replace('tag: ', '')
return '\x00'.join((sha, msg, refname))
_old_git_changelog_content = '\n'.join(
_make_old_git_changelog_format(line)
for line in _changelog_content.split('\n'))
class GitLogsTest(base.BaseTestCase):
scenarios = [
('pre1.8.3', {'changelog': _old_git_changelog_content}),
('post1.8.3', {'changelog': _changelog_content}),
]
def setUp(self):
super(GitLogsTest, self).setUp()
self.temp_path = self.useFixture(fixtures.TempDir()).path
self.root_dir = os.path.abspath(os.path.curdir)
self.git_dir = os.path.join(self.root_dir, ".git")
self.useFixture(
fixtures.EnvironmentVariable('SKIP_GENERATE_AUTHORS'))
self.useFixture(
fixtures.EnvironmentVariable('SKIP_WRITE_GIT_CHANGELOG'))
def test_write_git_changelog(self):
self.useFixture(fixtures.FakePopen(lambda _: {
"stdout": BytesIO(self.changelog.encode('utf-8'))
}))
git.write_git_changelog(git_dir=self.git_dir,
dest_dir=self.temp_path)
with open(os.path.join(self.temp_path, "ChangeLog"), "r") as ch_fh:
changelog_contents = ch_fh.read()
self.assertIn("2013.2", changelog_contents)
self.assertIn("0.5.17", changelog_contents)
self.assertIn("------", changelog_contents)
self.assertIn("Refactor hooks file", changelog_contents)
self.assertIn(
"Bug fix: create\_stack() fails when waiting",
changelog_contents)
self.assertNotIn("Refactor hooks file.", changelog_contents)
self.assertNotIn("182feb3", changelog_contents)
self.assertNotIn("review/monty_taylor/27519", changelog_contents)
self.assertNotIn("0.5.13", changelog_contents)
self.assertNotIn("0.6.7", changelog_contents)
self.assertNotIn("12", changelog_contents)
self.assertNotIn("(evil)", changelog_contents)
self.assertNotIn("ev()il", changelog_contents)
self.assertNotIn("ev(il", changelog_contents)
self.assertNotIn("ev)il", changelog_contents)
self.assertNotIn("e(vi)l", changelog_contents)
self.assertNotIn('Merge "', changelog_contents)
self.assertNotIn('1\_foo.1', changelog_contents)
def test_generate_authors(self):
author_old = u"Foo Foo <email@foo.com>"
author_new = u"Bar Bar <email@bar.com>"
co_author = u"Foo Bar <foo@bar.com>"
co_author_by = u"Co-authored-by: " + co_author
git_log_cmd = (
"git --git-dir=%s log --format=%%aN <%%aE>"
% self.git_dir)
git_co_log_cmd = ("git --git-dir=%s log" % self.git_dir)
git_top_level = "git rev-parse --show-toplevel"
cmd_map = {
git_log_cmd: author_new,
git_co_log_cmd: co_author_by,
git_top_level: self.root_dir,
}
exist_files = [self.git_dir,
os.path.join(self.temp_path, "AUTHORS.in")]
self.useFixture(fixtures.MonkeyPatch(
"os.path.exists",
lambda path: os.path.abspath(path) in exist_files))
def _fake_run_shell_command(cmd, **kwargs):
return cmd_map[" ".join(cmd)]
self.useFixture(fixtures.MonkeyPatch(
"pbr.git._run_shell_command",
_fake_run_shell_command))
with open(os.path.join(self.temp_path, "AUTHORS.in"), "w") as auth_fh:
auth_fh.write("%s\n" % author_old)
git.generate_authors(git_dir=self.git_dir,
dest_dir=self.temp_path)
with open(os.path.join(self.temp_path, "AUTHORS"), "r") as auth_fh:
authors = auth_fh.read()
self.assertTrue(author_old in authors)
self.assertTrue(author_new in authors)
self.assertTrue(co_author in authors)
class _SphinxConfig(object):
man_pages = ['foo']
class BaseSphinxTest(base.BaseTestCase):
def setUp(self):
super(BaseSphinxTest, self).setUp()
# setup_command requires the Sphinx instance to have some
# attributes that aren't set normally with the way we use the
# class (because we replace the constructor). Add default
# values directly to the class definition.
import sphinx.application
sphinx.application.Sphinx.messagelog = []
sphinx.application.Sphinx.statuscode = 0
self.useFixture(fixtures.MonkeyPatch(
"sphinx.application.Sphinx.__init__", lambda *a, **kw: None))
self.useFixture(fixtures.MonkeyPatch(
"sphinx.application.Sphinx.build", lambda *a, **kw: None))
self.useFixture(fixtures.MonkeyPatch(
"sphinx.application.Sphinx.config", _SphinxConfig))
self.useFixture(fixtures.MonkeyPatch(
"sphinx.config.Config.init_values", lambda *a: None))
self.useFixture(fixtures.MonkeyPatch(
"sphinx.config.Config.__init__", lambda *a: None))
from distutils import dist
self.distr = dist.Distribution()
self.distr.packages = ("fake_package",)
self.distr.command_options["build_sphinx"] = {
"source_dir": ["a", "."]}
pkg_fixture = fixtures.PythonPackage(
"fake_package", [("fake_module.py", b""),
("another_fake_module_for_testing.py", b""),
("fake_private_module.py", b"")])
self.useFixture(pkg_fixture)
self.useFixture(base.DiveDir(pkg_fixture.base))
self.distr.command_options["pbr"] = {}
if hasattr(self, "excludes"):
self.distr.command_options["pbr"]["autodoc_exclude_modules"] = (
'setup.cfg',
"fake_package.fake_private_module\n"
"fake_package.another_fake_*\n"
"fake_package.unknown_module")
if hasattr(self, 'has_opt') and self.has_opt:
options = self.distr.command_options["pbr"]
options["autodoc_index_modules"] = ('setup.cfg', self.autodoc)
class BuildSphinxTest(BaseSphinxTest):
scenarios = [
('true_autodoc_caps',
dict(has_opt=True, autodoc='True', has_autodoc=True)),
('true_autodoc_caps_with_excludes',
dict(has_opt=True, autodoc='True', has_autodoc=True,
excludes="fake_package.fake_private_module\n"
"fake_package.another_fake_*\n"
"fake_package.unknown_module")),
('true_autodoc_lower',
dict(has_opt=True, autodoc='true', has_autodoc=True)),
('false_autodoc',
dict(has_opt=True, autodoc='False', has_autodoc=False)),
('no_autodoc',
dict(has_opt=False, autodoc='False', has_autodoc=False)),
]
def test_build_doc(self):
build_doc = packaging.LocalBuildDoc(self.distr)
build_doc.run()
self.assertTrue(
os.path.exists("api/autoindex.rst") == self.has_autodoc)
self.assertTrue(
os.path.exists(
"api/fake_package.fake_module.rst") == self.has_autodoc)
if not self.has_autodoc or hasattr(self, "excludes"):
assertion = self.assertFalse
else:
assertion = self.assertTrue
assertion(
os.path.exists(
"api/fake_package.fake_private_module.rst"))
assertion(
os.path.exists(
"api/fake_package.another_fake_module_for_testing.rst"))
def test_builders_config(self):
build_doc = packaging.LocalBuildDoc(self.distr)
build_doc.finalize_options()
self.assertEqual(1, len(build_doc.builders))
self.assertIn('html', build_doc.builders)
build_doc = packaging.LocalBuildDoc(self.distr)
build_doc.builders = ''
build_doc.finalize_options()
self.assertEqual('', build_doc.builders)
build_doc = packaging.LocalBuildDoc(self.distr)
build_doc.builders = 'man'
build_doc.finalize_options()
self.assertEqual(1, len(build_doc.builders))
self.assertIn('man', build_doc.builders)
build_doc = packaging.LocalBuildDoc(self.distr)
build_doc.builders = 'html,man,doctest'
build_doc.finalize_options()
self.assertIn('html', build_doc.builders)
self.assertIn('man', build_doc.builders)
self.assertIn('doctest', build_doc.builders)
def test_cmd_builder_override(self):
if self.has_opt:
self.distr.command_options["pbr"] = {
"autodoc_index_modules": ('setup.cfg', self.autodoc)
}
self.distr.command_options["build_sphinx"]["builder"] = (
"command line", "non-existing-builder")
build_doc = packaging.LocalBuildDoc(self.distr)
self.assertNotIn('non-existing-builder', build_doc.builders)
self.assertIn('html', build_doc.builders)
# process command line options which should override config
build_doc.finalize_options()
self.assertIn('non-existing-builder', build_doc.builders)
self.assertNotIn('html', build_doc.builders)
def test_cmd_builder_override_multiple_builders(self):
if self.has_opt:
self.distr.command_options["pbr"] = {
"autodoc_index_modules": ('setup.cfg', self.autodoc)
}
self.distr.command_options["build_sphinx"]["builder"] = (
"command line", "builder1,builder2")
build_doc = packaging.LocalBuildDoc(self.distr)
build_doc.finalize_options()
self.assertEqual(["builder1", "builder2"], build_doc.builders)
class APIAutoDocTest(base.BaseTestCase):
def setUp(self):
super(APIAutoDocTest, self).setUp()
# setup_command requires the Sphinx instance to have some
# attributes that aren't set normally with the way we use the
# class (because we replace the constructor). Add default
# values directly to the class definition.
import sphinx.application
sphinx.application.Sphinx.messagelog = []
sphinx.application.Sphinx.statuscode = 0
self.useFixture(fixtures.MonkeyPatch(
"sphinx.application.Sphinx.__init__", lambda *a, **kw: None))
self.useFixture(fixtures.MonkeyPatch(
"sphinx.application.Sphinx.build", lambda *a, **kw: None))
self.useFixture(fixtures.MonkeyPatch(
"sphinx.application.Sphinx.config", _SphinxConfig))
self.useFixture(fixtures.MonkeyPatch(
"sphinx.config.Config.init_values", lambda *a: None))
self.useFixture(fixtures.MonkeyPatch(
"sphinx.config.Config.__init__", lambda *a: None))
from distutils import dist
self.distr = dist.Distribution()
self.distr.packages = ("fake_package",)
self.distr.command_options["build_sphinx"] = {
"source_dir": ["a", "."]}
self.sphinx_options = self.distr.command_options["build_sphinx"]
pkg_fixture = fixtures.PythonPackage(
"fake_package", [("fake_module.py", b""),
("another_fake_module_for_testing.py", b""),
("fake_private_module.py", b"")])
self.useFixture(pkg_fixture)
self.useFixture(base.DiveDir(pkg_fixture.base))
self.pbr_options = self.distr.command_options.setdefault('pbr', {})
self.pbr_options["autodoc_index_modules"] = ('setup.cfg', 'True')
def test_default_api_build_dir(self):
build_doc = packaging.LocalBuildDoc(self.distr)
build_doc.run()
print('PBR OPTIONS:', self.pbr_options)
print('DISTR OPTIONS:', self.distr.command_options)
self.assertTrue(os.path.exists("api/autoindex.rst"))
self.assertTrue(os.path.exists("api/fake_package.fake_module.rst"))
self.assertTrue(
os.path.exists(
"api/fake_package.fake_private_module.rst"))
self.assertTrue(
os.path.exists(
"api/fake_package.another_fake_module_for_testing.rst"))
def test_different_api_build_dir(self):
# Options have to come out of the settings dict as a tuple
# showing the source and the value.
self.pbr_options['api_doc_dir'] = (None, 'contributor/api')
build_doc = packaging.LocalBuildDoc(self.distr)
build_doc.run()
print('PBR OPTIONS:', self.pbr_options)
print('DISTR OPTIONS:', self.distr.command_options)
self.assertTrue(os.path.exists("contributor/api/autoindex.rst"))
self.assertTrue(
os.path.exists("contributor/api/fake_package.fake_module.rst"))
self.assertTrue(
os.path.exists(
"contributor/api/fake_package.fake_private_module.rst"))

View file

@ -0,0 +1,91 @@
# Copyright (c) 2015 Hewlett-Packard Development Company, L.P. (HP)
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import io
import textwrap
import six
from six.moves import configparser
import sys
from pbr.tests import base
from pbr import util
class TestExtrasRequireParsingScenarios(base.BaseTestCase):
scenarios = [
('simple_extras', {
'config_text': """
[extras]
first =
foo
bar==1.0
second =
baz>=3.2
foo
""",
'expected_extra_requires': {
'first': ['foo', 'bar==1.0'],
'second': ['baz>=3.2', 'foo'],
'test': ['requests-mock'],
"test:(python_version=='2.6')": ['ordereddict'],
}
}),
('with_markers', {
'config_text': """
[extras]
test =
foo:python_version=='2.6'
bar
baz<1.6 :python_version=='2.6'
zaz :python_version>'1.0'
""",
'expected_extra_requires': {
"test:(python_version=='2.6')": ['foo', 'baz<1.6'],
"test": ['bar', 'zaz']}}),
('no_extras', {
'config_text': """
[metadata]
long_description = foo
""",
'expected_extra_requires':
{}
})]
def config_from_ini(self, ini):
config = {}
if sys.version_info >= (3, 2):
parser = configparser.ConfigParser()
else:
parser = configparser.SafeConfigParser()
ini = textwrap.dedent(six.u(ini))
parser.readfp(io.StringIO(ini))
for section in parser.sections():
config[section] = dict(parser.items(section))
return config
def test_extras_parsing(self):
config = self.config_from_ini(self.config_text)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(self.expected_extra_requires,
kwargs['extras_require'])
class TestInvalidMarkers(base.BaseTestCase):
def test_invalid_marker_raises_error(self):
config = {'extras': {'test': "foo :bad_marker>'1.0'"}}
self.assertRaises(SyntaxError, util.setup_cfg_to_setup_kwargs, config)

View file

@ -0,0 +1,311 @@
# Copyright 2012 Red Hat, Inc.
# Copyright 2012-2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import itertools
from testtools import matchers
from pbr.tests import base
from pbr import version
from_pip_string = version.SemanticVersion.from_pip_string
class TestSemanticVersion(base.BaseTestCase):
def test_ordering(self):
ordered_versions = [
"1.2.3.dev6",
"1.2.3.dev7",
"1.2.3.a4.dev12",
"1.2.3.a4.dev13",
"1.2.3.a4",
"1.2.3.a5.dev1",
"1.2.3.a5",
"1.2.3.b3.dev1",
"1.2.3.b3",
"1.2.3.rc2.dev1",
"1.2.3.rc2",
"1.2.3.rc3.dev1",
"1.2.3",
"1.2.4",
"1.3.3",
"2.2.3",
]
for v in ordered_versions:
sv = version.SemanticVersion.from_pip_string(v)
self.expectThat(sv, matchers.Equals(sv))
for left, right in itertools.combinations(ordered_versions, 2):
l_pos = ordered_versions.index(left)
r_pos = ordered_versions.index(right)
if l_pos < r_pos:
m1 = matchers.LessThan
m2 = matchers.GreaterThan
else:
m1 = matchers.GreaterThan
m2 = matchers.LessThan
left_sv = version.SemanticVersion.from_pip_string(left)
right_sv = version.SemanticVersion.from_pip_string(right)
self.expectThat(left_sv, m1(right_sv))
self.expectThat(right_sv, m2(left_sv))
def test_from_pip_string_legacy_alpha(self):
expected = version.SemanticVersion(
1, 2, 0, prerelease_type='rc', prerelease=1)
parsed = from_pip_string('1.2.0rc1')
self.assertEqual(expected, parsed)
def test_from_pip_string_legacy_postN(self):
# When pbr trunk was incompatible with PEP-440, a stable release was
# made that used postN versions to represent developer builds. As
# we expect only to be parsing versions of our own, we map those
# into dev builds of the next version.
expected = version.SemanticVersion(1, 2, 4, dev_count=5)
parsed = from_pip_string('1.2.3.post5')
self.expectThat(expected, matchers.Equals(parsed))
expected = version.SemanticVersion(1, 2, 3, 'a', 5, dev_count=6)
parsed = from_pip_string('1.2.3.0a4.post6')
self.expectThat(expected, matchers.Equals(parsed))
# We can't define a mapping for .postN.devM, so it should raise.
self.expectThat(
lambda: from_pip_string('1.2.3.post5.dev6'),
matchers.raises(ValueError))
def test_from_pip_string_v_version(self):
parsed = from_pip_string('v1.2.3')
expected = version.SemanticVersion(1, 2, 3)
self.expectThat(expected, matchers.Equals(parsed))
expected = version.SemanticVersion(1, 2, 3, 'a', 5, dev_count=6)
parsed = from_pip_string('V1.2.3.0a4.post6')
self.expectThat(expected, matchers.Equals(parsed))
self.expectThat(
lambda: from_pip_string('x1.2.3'),
matchers.raises(ValueError))
def test_from_pip_string_legacy_nonzero_lead_in(self):
# reported in bug 1361251
expected = version.SemanticVersion(
0, 0, 1, prerelease_type='a', prerelease=2)
parsed = from_pip_string('0.0.1a2')
self.assertEqual(expected, parsed)
def test_from_pip_string_legacy_short_nonzero_lead_in(self):
expected = version.SemanticVersion(
0, 1, 0, prerelease_type='a', prerelease=2)
parsed = from_pip_string('0.1a2')
self.assertEqual(expected, parsed)
def test_from_pip_string_legacy_no_0_prerelease(self):
expected = version.SemanticVersion(
2, 1, 0, prerelease_type='rc', prerelease=1)
parsed = from_pip_string('2.1.0.rc1')
self.assertEqual(expected, parsed)
def test_from_pip_string_legacy_no_0_prerelease_2(self):
expected = version.SemanticVersion(
2, 0, 0, prerelease_type='rc', prerelease=1)
parsed = from_pip_string('2.0.0.rc1')
self.assertEqual(expected, parsed)
def test_from_pip_string_legacy_non_440_beta(self):
expected = version.SemanticVersion(
2014, 2, prerelease_type='b', prerelease=2)
parsed = from_pip_string('2014.2.b2')
self.assertEqual(expected, parsed)
def test_from_pip_string_pure_git_hash(self):
self.assertRaises(ValueError, from_pip_string, '6eed5ae')
def test_from_pip_string_non_digit_start(self):
self.assertRaises(ValueError, from_pip_string,
'non-release-tag/2014.12.16-1')
def test_final_version(self):
semver = version.SemanticVersion(1, 2, 3)
self.assertEqual((1, 2, 3, 'final', 0), semver.version_tuple())
self.assertEqual("1.2.3", semver.brief_string())
self.assertEqual("1.2.3", semver.debian_string())
self.assertEqual("1.2.3", semver.release_string())
self.assertEqual("1.2.3", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.3"))
def test_parsing_short_forms(self):
semver = version.SemanticVersion(1, 0, 0)
self.assertEqual(semver, from_pip_string("1"))
self.assertEqual(semver, from_pip_string("1.0"))
self.assertEqual(semver, from_pip_string("1.0.0"))
def test_dev_version(self):
semver = version.SemanticVersion(1, 2, 4, dev_count=5)
self.assertEqual((1, 2, 4, 'dev', 4), semver.version_tuple())
self.assertEqual("1.2.4", semver.brief_string())
self.assertEqual("1.2.4~dev5", semver.debian_string())
self.assertEqual("1.2.4.dev5", semver.release_string())
self.assertEqual("1.2.3.dev5", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.4.dev5"))
def test_dev_no_git_version(self):
semver = version.SemanticVersion(1, 2, 4, dev_count=5)
self.assertEqual((1, 2, 4, 'dev', 4), semver.version_tuple())
self.assertEqual("1.2.4", semver.brief_string())
self.assertEqual("1.2.4~dev5", semver.debian_string())
self.assertEqual("1.2.4.dev5", semver.release_string())
self.assertEqual("1.2.3.dev5", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.4.dev5"))
def test_dev_zero_version(self):
semver = version.SemanticVersion(1, 2, 0, dev_count=5)
self.assertEqual((1, 2, 0, 'dev', 4), semver.version_tuple())
self.assertEqual("1.2.0", semver.brief_string())
self.assertEqual("1.2.0~dev5", semver.debian_string())
self.assertEqual("1.2.0.dev5", semver.release_string())
self.assertEqual("1.1.9999.dev5", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.0.dev5"))
def test_alpha_dev_version(self):
semver = version.SemanticVersion(1, 2, 4, 'a', 1, 12)
self.assertEqual((1, 2, 4, 'alphadev', 12), semver.version_tuple())
self.assertEqual("1.2.4", semver.brief_string())
self.assertEqual("1.2.4~a1.dev12", semver.debian_string())
self.assertEqual("1.2.4.0a1.dev12", semver.release_string())
self.assertEqual("1.2.3.a1.dev12", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.4.0a1.dev12"))
def test_alpha_version(self):
semver = version.SemanticVersion(1, 2, 4, 'a', 1)
self.assertEqual((1, 2, 4, 'alpha', 1), semver.version_tuple())
self.assertEqual("1.2.4", semver.brief_string())
self.assertEqual("1.2.4~a1", semver.debian_string())
self.assertEqual("1.2.4.0a1", semver.release_string())
self.assertEqual("1.2.3.a1", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.4.0a1"))
def test_alpha_zero_version(self):
semver = version.SemanticVersion(1, 2, 0, 'a', 1)
self.assertEqual((1, 2, 0, 'alpha', 1), semver.version_tuple())
self.assertEqual("1.2.0", semver.brief_string())
self.assertEqual("1.2.0~a1", semver.debian_string())
self.assertEqual("1.2.0.0a1", semver.release_string())
self.assertEqual("1.1.9999.a1", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.0.0a1"))
def test_alpha_major_zero_version(self):
semver = version.SemanticVersion(1, 0, 0, 'a', 1)
self.assertEqual((1, 0, 0, 'alpha', 1), semver.version_tuple())
self.assertEqual("1.0.0", semver.brief_string())
self.assertEqual("1.0.0~a1", semver.debian_string())
self.assertEqual("1.0.0.0a1", semver.release_string())
self.assertEqual("0.9999.9999.a1", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.0.0.0a1"))
def test_alpha_default_version(self):
semver = version.SemanticVersion(1, 2, 4, 'a')
self.assertEqual((1, 2, 4, 'alpha', 0), semver.version_tuple())
self.assertEqual("1.2.4", semver.brief_string())
self.assertEqual("1.2.4~a0", semver.debian_string())
self.assertEqual("1.2.4.0a0", semver.release_string())
self.assertEqual("1.2.3.a0", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.4.0a0"))
def test_beta_dev_version(self):
semver = version.SemanticVersion(1, 2, 4, 'b', 1, 12)
self.assertEqual((1, 2, 4, 'betadev', 12), semver.version_tuple())
self.assertEqual("1.2.4", semver.brief_string())
self.assertEqual("1.2.4~b1.dev12", semver.debian_string())
self.assertEqual("1.2.4.0b1.dev12", semver.release_string())
self.assertEqual("1.2.3.b1.dev12", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.4.0b1.dev12"))
def test_beta_version(self):
semver = version.SemanticVersion(1, 2, 4, 'b', 1)
self.assertEqual((1, 2, 4, 'beta', 1), semver.version_tuple())
self.assertEqual("1.2.4", semver.brief_string())
self.assertEqual("1.2.4~b1", semver.debian_string())
self.assertEqual("1.2.4.0b1", semver.release_string())
self.assertEqual("1.2.3.b1", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.4.0b1"))
def test_decrement_nonrelease(self):
# The prior version of any non-release is a release
semver = version.SemanticVersion(1, 2, 4, 'b', 1)
self.assertEqual(
version.SemanticVersion(1, 2, 3), semver.decrement())
def test_decrement_nonrelease_zero(self):
# We set an arbitrary max version of 9999 when decrementing versions
# - this is part of handling rpm support.
semver = version.SemanticVersion(1, 0, 0)
self.assertEqual(
version.SemanticVersion(0, 9999, 9999), semver.decrement())
def test_decrement_release(self):
# The next patch version of a release version requires a change to the
# patch level.
semver = version.SemanticVersion(2, 2, 5)
self.assertEqual(
version.SemanticVersion(2, 2, 4), semver.decrement())
def test_increment_nonrelease(self):
# The next patch version of a non-release version is another
# non-release version as the next release doesn't need to be
# incremented.
semver = version.SemanticVersion(1, 2, 4, 'b', 1)
self.assertEqual(
version.SemanticVersion(1, 2, 4, 'b', 2), semver.increment())
# Major and minor increments however need to bump things.
self.assertEqual(
version.SemanticVersion(1, 3, 0), semver.increment(minor=True))
self.assertEqual(
version.SemanticVersion(2, 0, 0), semver.increment(major=True))
def test_increment_release(self):
# The next patch version of a release version requires a change to the
# patch level.
semver = version.SemanticVersion(1, 2, 5)
self.assertEqual(
version.SemanticVersion(1, 2, 6), semver.increment())
self.assertEqual(
version.SemanticVersion(1, 3, 0), semver.increment(minor=True))
self.assertEqual(
version.SemanticVersion(2, 0, 0), semver.increment(major=True))
def test_rc_dev_version(self):
semver = version.SemanticVersion(1, 2, 4, 'rc', 1, 12)
self.assertEqual((1, 2, 4, 'candidatedev', 12), semver.version_tuple())
self.assertEqual("1.2.4", semver.brief_string())
self.assertEqual("1.2.4~rc1.dev12", semver.debian_string())
self.assertEqual("1.2.4.0rc1.dev12", semver.release_string())
self.assertEqual("1.2.3.rc1.dev12", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.4.0rc1.dev12"))
def test_rc_version(self):
semver = version.SemanticVersion(1, 2, 4, 'rc', 1)
self.assertEqual((1, 2, 4, 'candidate', 1), semver.version_tuple())
self.assertEqual("1.2.4", semver.brief_string())
self.assertEqual("1.2.4~rc1", semver.debian_string())
self.assertEqual("1.2.4.0rc1", semver.release_string())
self.assertEqual("1.2.3.rc1", semver.rpm_string())
self.assertEqual(semver, from_pip_string("1.2.4.0rc1"))
def test_to_dev(self):
self.assertEqual(
version.SemanticVersion(1, 2, 3, dev_count=1),
version.SemanticVersion(1, 2, 3).to_dev(1))
self.assertEqual(
version.SemanticVersion(1, 2, 3, 'rc', 1, dev_count=1),
version.SemanticVersion(1, 2, 3, 'rc', 1).to_dev(1))

163
libs/pbr/tests/test_wsgi.py Normal file
View file

@ -0,0 +1,163 @@
# Copyright (c) 2015 Hewlett-Packard Development Company, L.P. (HP)
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import re
import subprocess
import sys
try:
# python 2
from urllib2 import urlopen
except ImportError:
# python 3
from urllib.request import urlopen
from pbr.tests import base
class TestWsgiScripts(base.BaseTestCase):
cmd_names = ('pbr_test_wsgi', 'pbr_test_wsgi_with_class')
def _get_path(self):
if os.path.isdir("%s/lib64" % self.temp_dir):
path = "%s/lib64" % self.temp_dir
elif os.path.isdir("%s/lib" % self.temp_dir):
path = "%s/lib" % self.temp_dir
elif os.path.isdir("%s/site-packages" % self.temp_dir):
return ".:%s/site-packages" % self.temp_dir
else:
raise Exception("Could not determine path for test")
return ".:%s/python%s.%s/site-packages" % (
path,
sys.version_info[0],
sys.version_info[1])
def test_wsgi_script_install(self):
"""Test that we install a non-pkg-resources wsgi script."""
if os.name == 'nt':
self.skipTest('Windows support is passthrough')
stdout, _, return_code = self.run_setup(
'install', '--prefix=%s' % self.temp_dir)
self._check_wsgi_install_content(stdout)
def test_wsgi_script_run(self):
"""Test that we install a runnable wsgi script.
This test actually attempts to start and interact with the
wsgi script in question to demonstrate that it's a working
wsgi script using simple server.
"""
if os.name == 'nt':
self.skipTest('Windows support is passthrough')
stdout, _, return_code = self.run_setup(
'install', '--prefix=%s' % self.temp_dir)
self._check_wsgi_install_content(stdout)
# Live test run the scripts and see that they respond to wsgi
# requests.
for cmd_name in self.cmd_names:
self._test_wsgi(cmd_name, b'Hello World')
def _test_wsgi(self, cmd_name, output, extra_args=None):
cmd = os.path.join(self.temp_dir, 'bin', cmd_name)
print("Running %s -p 0" % cmd)
popen_cmd = [cmd, '-p', '0']
if extra_args:
popen_cmd.extend(extra_args)
env = {'PYTHONPATH': self._get_path()}
p = subprocess.Popen(popen_cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, cwd=self.temp_dir,
env=env)
self.addCleanup(p.kill)
stdoutdata = p.stdout.readline() # ****...
stdoutdata = p.stdout.readline() # STARTING test server...
self.assertIn(
b"STARTING test server pbr_testpackage.wsgi",
stdoutdata)
stdoutdata = p.stdout.readline() # Available at ...
print(stdoutdata)
m = re.search(b'(http://[^:]+:\d+)/', stdoutdata)
self.assertIsNotNone(m, "Regex failed to match on %s" % stdoutdata)
stdoutdata = p.stdout.readline() # DANGER! ...
self.assertIn(
b"DANGER! For testing only, do not use in production",
stdoutdata)
stdoutdata = p.stdout.readline() # ***...
f = urlopen(m.group(1).decode('utf-8'))
self.assertEqual(output, f.read())
# Request again so that the application can force stderr.flush(),
# otherwise the log is buffered and the next readline() will hang.
urlopen(m.group(1).decode('utf-8'))
stdoutdata = p.stderr.readline()
# we should have logged an HTTP request, return code 200, that
# returned the right amount of bytes
status = '"GET / HTTP/1.1" 200 %d' % len(output)
self.assertIn(status.encode('utf-8'), stdoutdata)
def _check_wsgi_install_content(self, install_stdout):
for cmd_name in self.cmd_names:
install_txt = 'Installing %s script to %s' % (cmd_name,
self.temp_dir)
self.assertIn(install_txt, install_stdout)
cmd_filename = os.path.join(self.temp_dir, 'bin', cmd_name)
script_txt = open(cmd_filename, 'r').read()
self.assertNotIn('pkg_resources', script_txt)
main_block = """if __name__ == "__main__":
import argparse
import socket
import sys
import wsgiref.simple_server as wss"""
if cmd_name == 'pbr_test_wsgi':
app_name = "main"
else:
app_name = "WSGI.app"
starting_block = ("STARTING test server pbr_testpackage.wsgi."
"%s" % app_name)
else_block = """else:
application = None"""
self.assertIn(main_block, script_txt)
self.assertIn(starting_block, script_txt)
self.assertIn(else_block, script_txt)
def test_with_argument(self):
if os.name == 'nt':
self.skipTest('Windows support is passthrough')
stdout, _, return_code = self.run_setup(
'install', '--prefix=%s' % self.temp_dir)
self._test_wsgi('pbr_test_wsgi', b'Foo Bar', ["--", "-c", "Foo Bar"])

View file

@ -0,0 +1,86 @@
Changelog
===========
0.3 (unreleased)
------------------
- The ``glob_data_files`` hook became a pre-command hook for the install_data
command instead of being a setup-hook. This is to support the additional
functionality of requiring data_files with relative destination paths to be
install relative to the package's install path (i.e. site-packages).
- Dropped support for and deprecated the easier_install custom command.
Although it should still work, it probably won't be used anymore for
stsci_python packages.
- Added support for the ``build_optional_ext`` command, which replaces/extends
the default ``build_ext`` command. See the README for more details.
- Added the ``tag_svn_revision`` setup_hook as a replacement for the
setuptools-specific tag_svn_revision option to the egg_info command. This
new hook is easier to use than the old tag_svn_revision option: It's
automatically enabled by the presence of ``.dev`` in the version string, and
disabled otherwise.
- The ``svn_info_pre_hook`` and ``svn_info_post_hook`` have been replaced with
``version_pre_command_hook`` and ``version_post_command_hook`` respectively.
However, a new ``version_setup_hook``, which has the same purpose, has been
added. It is generally easier to use and will give more consistent results
in that it will run every time setup.py is run, regardless of which command
is used. ``stsci.distutils`` itself uses this hook--see the `setup.cfg` file
and `stsci/distutils/__init__.py` for example usage.
- Instead of creating an `svninfo.py` module, the new ``version_`` hooks create
a file called `version.py`. In addition to the SVN info that was included
in `svninfo.py`, it includes a ``__version__`` variable to be used by the
package's `__init__.py`. This allows there to be a hard-coded
``__version__`` variable included in the source code, rather than using
pkg_resources to get the version.
- In `version.py`, the variables previously named ``__svn_version__`` and
``__full_svn_info__`` are now named ``__svn_revision__`` and
``__svn_full_info__``.
- Fixed a bug when using stsci.distutils in the installation of other packages
in the ``stsci.*`` namespace package. If stsci.distutils was not already
installed, and was downloaded automatically by distribute through the
setup_requires option, then ``stsci.distutils`` would fail to import. This
is because the way the namespace package (nspkg) mechanism currently works,
all packages belonging to the nspkg *must* be on the import path at initial
import time.
So when installing stsci.tools, for example, if ``stsci.tools`` is imported
from within the source code at install time, but before ``stsci.distutils``
is downloaded and added to the path, the ``stsci`` package is already
imported and can't be extended to include the path of ``stsci.distutils``
after the fact. The easiest way of dealing with this, it seems, is to
delete ``stsci`` from ``sys.modules``, which forces it to be reimported, now
the its ``__path__`` extended to include ``stsci.distutil``'s path.
0.2.2 (2011-11-09)
------------------
- Fixed check for the issue205 bug on actual setuptools installs; before it
only worked on distribute. setuptools has the issue205 bug prior to version
0.6c10.
- Improved the fix for the issue205 bug, especially on setuptools.
setuptools, prior to 0.6c10, did not back of sys.modules either before
sandboxing, which causes serious problems. In fact, it's so bad that it's
not enough to add a sys.modules backup to the current sandbox: It's in fact
necessary to monkeypatch setuptools.sandbox.run_setup so that any subsequent
calls to it also back up sys.modules.
0.2.1 (2011-09-02)
------------------
- Fixed the dependencies so that setuptools is requirement but 'distribute'
specifically. Previously installation could fail if users had plain
setuptools installed and not distribute
0.2 (2011-08-23)
------------------
- Initial public release

View file

@ -0,0 +1,29 @@
Copyright (C) 2005 Association of Universities for Research in Astronomy (AURA)
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
3. The name of AURA and its representatives may not be used to
endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY AURA ``AS IS'' AND ANY EXPRESS OR IMPLIED
WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL AURA BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS
OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
DAMAGE.

View file

@ -0,0 +1,2 @@
include data_files/*
exclude pbr_testpackage/extra.py

View file

@ -0,0 +1,148 @@
Introduction
============
This package contains utilities used to package some of STScI's Python
projects; specifically those projects that comprise stsci_python_ and
Astrolib_.
It currently consists mostly of some setup_hook scripts meant for use with
`distutils2/packaging`_ and/or pbr_, and a customized easy_install command
meant for use with distribute_.
This package is not meant for general consumption, though it might be worth
looking at for examples of how to do certain things with your own packages, but
YMMV.
Features
========
Hook Scripts
------------
Currently the main features of this package are a couple of setup_hook scripts.
In distutils2, a setup_hook is a script that runs at the beginning of any
pysetup command, and can modify the package configuration read from setup.cfg.
There are also pre- and post-command hooks that only run before/after a
specific setup command (eg. build_ext, install) is run.
stsci.distutils.hooks.use_packages_root
'''''''''''''''''''''''''''''''''''''''
If using the ``packages_root`` option under the ``[files]`` section of
setup.cfg, this hook will add that path to ``sys.path`` so that modules in your
package can be imported and used in setup. This can be used even if
``packages_root`` is not specified--in this case it adds ``''`` to
``sys.path``.
stsci.distutils.hooks.version_setup_hook
''''''''''''''''''''''''''''''''''''''''
Creates a Python module called version.py which currently contains four
variables:
* ``__version__`` (the release version)
* ``__svn_revision__`` (the SVN revision info as returned by the ``svnversion``
command)
* ``__svn_full_info__`` (as returned by the ``svn info`` command)
* ``__setup_datetime__`` (the date and time that setup.py was last run).
These variables can be imported in the package's `__init__.py` for degugging
purposes. The version.py module will *only* be created in a package that
imports from the version module in its `__init__.py`. It should be noted that
this is generally preferable to writing these variables directly into
`__init__.py`, since this provides more control and is less likely to
unexpectedly break things in `__init__.py`.
stsci.distutils.hooks.version_pre_command_hook
''''''''''''''''''''''''''''''''''''''''''''''
Identical to version_setup_hook, but designed to be used as a pre-command
hook.
stsci.distutils.hooks.version_post_command_hook
'''''''''''''''''''''''''''''''''''''''''''''''
The complement to version_pre_command_hook. This will delete any version.py
files created during a build in order to prevent them from cluttering an SVN
working copy (note, however, that version.py is *not* deleted from the build/
directory, so a copy of it is still preserved). It will also not be deleted
if the current directory is not an SVN working copy. For example, if source
code extracted from a source tarball it will be preserved.
stsci.distutils.hooks.tag_svn_revision
''''''''''''''''''''''''''''''''''''''
A setup_hook to add the SVN revision of the current working copy path to the
package version string, but only if the version ends in .dev.
For example, ``mypackage-1.0.dev`` becomes ``mypackage-1.0.dev1234``. This is
in accordance with the version string format standardized by PEP 386.
This should be used as a replacement for the ``tag_svn_revision`` option to
the egg_info command. This hook is more compatible with packaging/distutils2,
which does not include any VCS support. This hook is also more flexible in
that it turns the revision number on/off depending on the presence of ``.dev``
in the version string, so that it's not automatically added to the version in
final releases.
This hook does require the ``svnversion`` command to be available in order to
work. It does not examine the working copy metadata directly.
stsci.distutils.hooks.numpy_extension_hook
''''''''''''''''''''''''''''''''''''''''''
This is a pre-command hook for the build_ext command. To use it, add a
``[build_ext]`` section to your setup.cfg, and add to it::
pre-hook.numpy-extension-hook = stsci.distutils.hooks.numpy_extension_hook
This hook must be used to build extension modules that use Numpy. The primary
side-effect of this hook is to add the correct numpy include directories to
`include_dirs`. To use it, add 'numpy' to the 'include-dirs' option of each
extension module that requires numpy to build. The value 'numpy' will be
replaced with the actual path to the numpy includes.
stsci.distutils.hooks.is_display_option
'''''''''''''''''''''''''''''''''''''''
This is not actually a hook, but is a useful utility function that can be used
in writing other hooks. Basically, it returns ``True`` if setup.py was run
with a "display option" such as --version or --help. This can be used to
prevent your hook from running in such cases.
stsci.distutils.hooks.glob_data_files
'''''''''''''''''''''''''''''''''''''
A pre-command hook for the install_data command. Allows filename wildcards as
understood by ``glob.glob()`` to be used in the data_files option. This hook
must be used in order to have this functionality since it does not normally
exist in distutils.
This hook also ensures that data files are installed relative to the package
path. data_files shouldn't normally be installed this way, but the
functionality is required for a few special cases.
Commands
--------
build_optional_ext
''''''''''''''''''
This serves as an optional replacement for the default built_ext command,
which compiles C extension modules. Its purpose is to allow extension modules
to be *optional*, so that if their build fails the rest of the package is
still allowed to be built and installed. This can be used when an extension
module is not definitely required to use the package.
To use this custom command, add::
commands = stsci.distutils.command.build_optional_ext.build_optional_ext
under the ``[global]`` section of your package's setup.cfg. Then, to mark
an individual extension module as optional, under the setup.cfg section for
that extension add::
optional = True
Optionally, you may also add a custom failure message by adding::
fail_message = The foobar extension module failed to compile.
This could be because you lack such and such headers.
This package will still work, but such and such features
will be disabled.
.. _stsci_python: http://www.stsci.edu/resources/software_hardware/pyraf/stsci_python
.. _Astrolib: http://www.scipy.org/AstroLib/
.. _distutils2/packaging: http://distutils2.notmyidea.org/
.. _d2to1: http://pypi.python.org/pypi/d2to1
.. _distribute: http://pypi.python.org/pypi/distribute

Some files were not shown because too many files have changed in this diff Show more