Updates vendored subliminal to 2.1.0

Updates rarfile to 3.1
Updates stevedore to 3.5.0
Updates appdirs to 1.4.4
Updates click to 8.1.3
Updates decorator to 5.1.1
Updates dogpile.cache to 1.1.8
Updates pbr to 5.11.0
Updates pysrt to 1.1.2
Updates pytz to 2022.6
Adds importlib-metadata version 3.1.1
Adds typing-extensions version 4.1.1
Adds zipp version 3.11.0
This commit is contained in:
Labrys of Knossos 2022-11-29 00:08:39 -05:00
commit f05b09f349
694 changed files with 16621 additions and 11056 deletions

Binary file not shown.

View file

@ -13,8 +13,8 @@ See <http://github.com/ActiveState/appdirs> for details and usage.
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
__version_info__ = (1, 4, 3)
__version__ = '.'.join(map(str, __version_info__))
__version__ = "1.4.4"
__version_info__ = tuple(int(segment) for segment in __version__.split("."))
import sys

BIN
libs/common/bin/beet.exe Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
libs/common/bin/guessit.exe Normal file

Binary file not shown.

BIN
libs/common/bin/mid3cp.exe Normal file

Binary file not shown.

Binary file not shown.

BIN
libs/common/bin/mid3v2.exe Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
libs/common/bin/pbr.exe Normal file

Binary file not shown.

BIN
libs/common/bin/srt.exe Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -1,97 +1,73 @@
# -*- coding: utf-8 -*-
"""
click
~~~~~
Click is a simple Python module inspired by the stdlib optparse to make
writing command line scripts fun. Unlike other modules, it's based
around a simple API that does not come with too much magic and is
composable.
:copyright: © 2014 by the Pallets team.
:license: BSD, see LICENSE.rst for more details.
"""
from .core import Argument as Argument
from .core import BaseCommand as BaseCommand
from .core import Command as Command
from .core import CommandCollection as CommandCollection
from .core import Context as Context
from .core import Group as Group
from .core import MultiCommand as MultiCommand
from .core import Option as Option
from .core import Parameter as Parameter
from .decorators import argument as argument
from .decorators import command as command
from .decorators import confirmation_option as confirmation_option
from .decorators import group as group
from .decorators import help_option as help_option
from .decorators import make_pass_decorator as make_pass_decorator
from .decorators import option as option
from .decorators import pass_context as pass_context
from .decorators import pass_obj as pass_obj
from .decorators import password_option as password_option
from .decorators import version_option as version_option
from .exceptions import Abort as Abort
from .exceptions import BadArgumentUsage as BadArgumentUsage
from .exceptions import BadOptionUsage as BadOptionUsage
from .exceptions import BadParameter as BadParameter
from .exceptions import ClickException as ClickException
from .exceptions import FileError as FileError
from .exceptions import MissingParameter as MissingParameter
from .exceptions import NoSuchOption as NoSuchOption
from .exceptions import UsageError as UsageError
from .formatting import HelpFormatter as HelpFormatter
from .formatting import wrap_text as wrap_text
from .globals import get_current_context as get_current_context
from .parser import OptionParser as OptionParser
from .termui import clear as clear
from .termui import confirm as confirm
from .termui import echo_via_pager as echo_via_pager
from .termui import edit as edit
from .termui import getchar as getchar
from .termui import launch as launch
from .termui import pause as pause
from .termui import progressbar as progressbar
from .termui import prompt as prompt
from .termui import secho as secho
from .termui import style as style
from .termui import unstyle as unstyle
from .types import BOOL as BOOL
from .types import Choice as Choice
from .types import DateTime as DateTime
from .types import File as File
from .types import FLOAT as FLOAT
from .types import FloatRange as FloatRange
from .types import INT as INT
from .types import IntRange as IntRange
from .types import ParamType as ParamType
from .types import Path as Path
from .types import STRING as STRING
from .types import Tuple as Tuple
from .types import UNPROCESSED as UNPROCESSED
from .types import UUID as UUID
from .utils import echo as echo
from .utils import format_filename as format_filename
from .utils import get_app_dir as get_app_dir
from .utils import get_binary_stream as get_binary_stream
from .utils import get_text_stream as get_text_stream
from .utils import open_file as open_file
# Core classes
from .core import Context, BaseCommand, Command, MultiCommand, Group, \
CommandCollection, Parameter, Option, Argument
# Globals
from .globals import get_current_context
# Decorators
from .decorators import pass_context, pass_obj, make_pass_decorator, \
command, group, argument, option, confirmation_option, \
password_option, version_option, help_option
# Types
from .types import ParamType, File, Path, Choice, IntRange, Tuple, \
DateTime, STRING, INT, FLOAT, BOOL, UUID, UNPROCESSED, FloatRange
# Utilities
from .utils import echo, get_binary_stream, get_text_stream, open_file, \
format_filename, get_app_dir, get_os_args
# Terminal functions
from .termui import prompt, confirm, get_terminal_size, echo_via_pager, \
progressbar, clear, style, unstyle, secho, edit, launch, getchar, \
pause
# Exceptions
from .exceptions import ClickException, UsageError, BadParameter, \
FileError, Abort, NoSuchOption, BadOptionUsage, BadArgumentUsage, \
MissingParameter
# Formatting
from .formatting import HelpFormatter, wrap_text
# Parsing
from .parser import OptionParser
__all__ = [
# Core classes
'Context', 'BaseCommand', 'Command', 'MultiCommand', 'Group',
'CommandCollection', 'Parameter', 'Option', 'Argument',
# Globals
'get_current_context',
# Decorators
'pass_context', 'pass_obj', 'make_pass_decorator', 'command', 'group',
'argument', 'option', 'confirmation_option', 'password_option',
'version_option', 'help_option',
# Types
'ParamType', 'File', 'Path', 'Choice', 'IntRange', 'Tuple',
'DateTime', 'STRING', 'INT', 'FLOAT', 'BOOL', 'UUID', 'UNPROCESSED',
'FloatRange',
# Utilities
'echo', 'get_binary_stream', 'get_text_stream', 'open_file',
'format_filename', 'get_app_dir', 'get_os_args',
# Terminal functions
'prompt', 'confirm', 'get_terminal_size', 'echo_via_pager',
'progressbar', 'clear', 'style', 'unstyle', 'secho', 'edit', 'launch',
'getchar', 'pause',
# Exceptions
'ClickException', 'UsageError', 'BadParameter', 'FileError',
'Abort', 'NoSuchOption', 'BadOptionUsage', 'BadArgumentUsage',
'MissingParameter',
# Formatting
'HelpFormatter', 'wrap_text',
# Parsing
'OptionParser',
]
# Controls if click should emit the warning about the use of unicode
# literals.
disable_unicode_literals_warning = False
__version__ = '7.0'
__version__ = "8.1.3"

View file

@ -1,293 +0,0 @@
import copy
import os
import re
from .utils import echo
from .parser import split_arg_string
from .core import MultiCommand, Option, Argument
from .types import Choice
try:
from collections import abc
except ImportError:
import collections as abc
WORDBREAK = '='
# Note, only BASH version 4.4 and later have the nosort option.
COMPLETION_SCRIPT_BASH = '''
%(complete_func)s() {
local IFS=$'\n'
COMPREPLY=( $( env COMP_WORDS="${COMP_WORDS[*]}" \\
COMP_CWORD=$COMP_CWORD \\
%(autocomplete_var)s=complete $1 ) )
return 0
}
%(complete_func)setup() {
local COMPLETION_OPTIONS=""
local BASH_VERSION_ARR=(${BASH_VERSION//./ })
# Only BASH version 4.4 and later have the nosort option.
if [ ${BASH_VERSION_ARR[0]} -gt 4 ] || ([ ${BASH_VERSION_ARR[0]} -eq 4 ] && [ ${BASH_VERSION_ARR[1]} -ge 4 ]); then
COMPLETION_OPTIONS="-o nosort"
fi
complete $COMPLETION_OPTIONS -F %(complete_func)s %(script_names)s
}
%(complete_func)setup
'''
COMPLETION_SCRIPT_ZSH = '''
%(complete_func)s() {
local -a completions
local -a completions_with_descriptions
local -a response
response=("${(@f)$( env COMP_WORDS=\"${words[*]}\" \\
COMP_CWORD=$((CURRENT-1)) \\
%(autocomplete_var)s=\"complete_zsh\" \\
%(script_names)s )}")
for key descr in ${(kv)response}; do
if [[ "$descr" == "_" ]]; then
completions+=("$key")
else
completions_with_descriptions+=("$key":"$descr")
fi
done
if [ -n "$completions_with_descriptions" ]; then
_describe -V unsorted completions_with_descriptions -U -Q
fi
if [ -n "$completions" ]; then
compadd -U -V unsorted -Q -a completions
fi
compstate[insert]="automenu"
}
compdef %(complete_func)s %(script_names)s
'''
_invalid_ident_char_re = re.compile(r'[^a-zA-Z0-9_]')
def get_completion_script(prog_name, complete_var, shell):
cf_name = _invalid_ident_char_re.sub('', prog_name.replace('-', '_'))
script = COMPLETION_SCRIPT_ZSH if shell == 'zsh' else COMPLETION_SCRIPT_BASH
return (script % {
'complete_func': '_%s_completion' % cf_name,
'script_names': prog_name,
'autocomplete_var': complete_var,
}).strip() + ';'
def resolve_ctx(cli, prog_name, args):
"""
Parse into a hierarchy of contexts. Contexts are connected through the parent variable.
:param cli: command definition
:param prog_name: the program that is running
:param args: full list of args
:return: the final context/command parsed
"""
ctx = cli.make_context(prog_name, args, resilient_parsing=True)
args = ctx.protected_args + ctx.args
while args:
if isinstance(ctx.command, MultiCommand):
if not ctx.command.chain:
cmd_name, cmd, args = ctx.command.resolve_command(ctx, args)
if cmd is None:
return ctx
ctx = cmd.make_context(cmd_name, args, parent=ctx,
resilient_parsing=True)
args = ctx.protected_args + ctx.args
else:
# Walk chained subcommand contexts saving the last one.
while args:
cmd_name, cmd, args = ctx.command.resolve_command(ctx, args)
if cmd is None:
return ctx
sub_ctx = cmd.make_context(cmd_name, args, parent=ctx,
allow_extra_args=True,
allow_interspersed_args=False,
resilient_parsing=True)
args = sub_ctx.args
ctx = sub_ctx
args = sub_ctx.protected_args + sub_ctx.args
else:
break
return ctx
def start_of_option(param_str):
"""
:param param_str: param_str to check
:return: whether or not this is the start of an option declaration (i.e. starts "-" or "--")
"""
return param_str and param_str[:1] == '-'
def is_incomplete_option(all_args, cmd_param):
"""
:param all_args: the full original list of args supplied
:param cmd_param: the current command paramter
:return: whether or not the last option declaration (i.e. starts "-" or "--") is incomplete and
corresponds to this cmd_param. In other words whether this cmd_param option can still accept
values
"""
if not isinstance(cmd_param, Option):
return False
if cmd_param.is_flag:
return False
last_option = None
for index, arg_str in enumerate(reversed([arg for arg in all_args if arg != WORDBREAK])):
if index + 1 > cmd_param.nargs:
break
if start_of_option(arg_str):
last_option = arg_str
return True if last_option and last_option in cmd_param.opts else False
def is_incomplete_argument(current_params, cmd_param):
"""
:param current_params: the current params and values for this argument as already entered
:param cmd_param: the current command parameter
:return: whether or not the last argument is incomplete and corresponds to this cmd_param. In
other words whether or not the this cmd_param argument can still accept values
"""
if not isinstance(cmd_param, Argument):
return False
current_param_values = current_params[cmd_param.name]
if current_param_values is None:
return True
if cmd_param.nargs == -1:
return True
if isinstance(current_param_values, abc.Iterable) \
and cmd_param.nargs > 1 and len(current_param_values) < cmd_param.nargs:
return True
return False
def get_user_autocompletions(ctx, args, incomplete, cmd_param):
"""
:param ctx: context associated with the parsed command
:param args: full list of args
:param incomplete: the incomplete text to autocomplete
:param cmd_param: command definition
:return: all the possible user-specified completions for the param
"""
results = []
if isinstance(cmd_param.type, Choice):
# Choices don't support descriptions.
results = [(c, None)
for c in cmd_param.type.choices if str(c).startswith(incomplete)]
elif cmd_param.autocompletion is not None:
dynamic_completions = cmd_param.autocompletion(ctx=ctx,
args=args,
incomplete=incomplete)
results = [c if isinstance(c, tuple) else (c, None)
for c in dynamic_completions]
return results
def get_visible_commands_starting_with(ctx, starts_with):
"""
:param ctx: context associated with the parsed command
:starts_with: string that visible commands must start with.
:return: all visible (not hidden) commands that start with starts_with.
"""
for c in ctx.command.list_commands(ctx):
if c.startswith(starts_with):
command = ctx.command.get_command(ctx, c)
if not command.hidden:
yield command
def add_subcommand_completions(ctx, incomplete, completions_out):
# Add subcommand completions.
if isinstance(ctx.command, MultiCommand):
completions_out.extend(
[(c.name, c.get_short_help_str()) for c in get_visible_commands_starting_with(ctx, incomplete)])
# Walk up the context list and add any other completion possibilities from chained commands
while ctx.parent is not None:
ctx = ctx.parent
if isinstance(ctx.command, MultiCommand) and ctx.command.chain:
remaining_commands = [c for c in get_visible_commands_starting_with(ctx, incomplete)
if c.name not in ctx.protected_args]
completions_out.extend([(c.name, c.get_short_help_str()) for c in remaining_commands])
def get_choices(cli, prog_name, args, incomplete):
"""
:param cli: command definition
:param prog_name: the program that is running
:param args: full list of args
:param incomplete: the incomplete text to autocomplete
:return: all the possible completions for the incomplete
"""
all_args = copy.deepcopy(args)
ctx = resolve_ctx(cli, prog_name, args)
if ctx is None:
return []
# In newer versions of bash long opts with '='s are partitioned, but it's easier to parse
# without the '='
if start_of_option(incomplete) and WORDBREAK in incomplete:
partition_incomplete = incomplete.partition(WORDBREAK)
all_args.append(partition_incomplete[0])
incomplete = partition_incomplete[2]
elif incomplete == WORDBREAK:
incomplete = ''
completions = []
if start_of_option(incomplete):
# completions for partial options
for param in ctx.command.params:
if isinstance(param, Option) and not param.hidden:
param_opts = [param_opt for param_opt in param.opts +
param.secondary_opts if param_opt not in all_args or param.multiple]
completions.extend([(o, param.help) for o in param_opts if o.startswith(incomplete)])
return completions
# completion for option values from user supplied values
for param in ctx.command.params:
if is_incomplete_option(all_args, param):
return get_user_autocompletions(ctx, all_args, incomplete, param)
# completion for argument values from user supplied values
for param in ctx.command.params:
if is_incomplete_argument(ctx.params, param):
return get_user_autocompletions(ctx, all_args, incomplete, param)
add_subcommand_completions(ctx, incomplete, completions)
# Sort before returning so that proper ordering can be enforced in custom types.
return sorted(completions)
def do_complete(cli, prog_name, include_descriptions):
cwords = split_arg_string(os.environ['COMP_WORDS'])
cword = int(os.environ['COMP_CWORD'])
args = cwords[1:cword]
try:
incomplete = cwords[cword]
except IndexError:
incomplete = ''
for item in get_choices(cli, prog_name, args, incomplete):
echo(item[0])
if include_descriptions:
# ZSH has trouble dealing with empty array parameters when returned from commands, so use a well defined character '_' to indicate no description is present.
echo(item[1] if item[1] else '_')
return True
def bashcomplete(cli, prog_name, complete_var, complete_instr):
if complete_instr.startswith('source'):
shell = 'zsh' if complete_instr == 'source_zsh' else 'bash'
echo(get_completion_script(prog_name, complete_var, shell))
return True
elif complete_instr == 'complete' or complete_instr == 'complete_zsh':
return do_complete(cli, prog_name, complete_instr == 'complete_zsh')
return False

File diff suppressed because it is too large Load diff

View file

@ -1,62 +1,56 @@
# -*- coding: utf-8 -*-
"""
click._termui_impl
~~~~~~~~~~~~~~~~~~
This module contains implementations for the termui module. To keep the
import time of Click down, some infrequently used functionality is
placed in this module and only imported as needed.
:copyright: © 2014 by the Pallets team.
:license: BSD, see LICENSE.rst for more details.
"""
import contextlib
import math
import os
import sys
import time
import math
import contextlib
from ._compat import _default_text_stdout, range_type, PY2, isatty, \
open_stream, strip_ansi, term_len, get_best_encoding, WIN, int_types, \
CYGWIN
from .utils import echo
import typing as t
from gettext import gettext as _
from ._compat import _default_text_stdout
from ._compat import CYGWIN
from ._compat import get_best_encoding
from ._compat import isatty
from ._compat import open_stream
from ._compat import strip_ansi
from ._compat import term_len
from ._compat import WIN
from .exceptions import ClickException
from .utils import echo
V = t.TypeVar("V")
if os.name == 'nt':
BEFORE_BAR = '\r'
AFTER_BAR = '\n'
if os.name == "nt":
BEFORE_BAR = "\r"
AFTER_BAR = "\n"
else:
BEFORE_BAR = '\r\033[?25l'
AFTER_BAR = '\033[?25h\n'
BEFORE_BAR = "\r\033[?25l"
AFTER_BAR = "\033[?25h\n"
def _length_hint(obj):
"""Returns the length hint of an object."""
try:
return len(obj)
except (AttributeError, TypeError):
try:
get_hint = type(obj).__length_hint__
except AttributeError:
return None
try:
hint = get_hint(obj)
except TypeError:
return None
if hint is NotImplemented or \
not isinstance(hint, int_types) or \
hint < 0:
return None
return hint
class ProgressBar(object):
def __init__(self, iterable, length=None, fill_char='#', empty_char=' ',
bar_template='%(bar)s', info_sep=' ', show_eta=True,
show_percent=None, show_pos=False, item_show_func=None,
label=None, file=None, color=None, width=30):
class ProgressBar(t.Generic[V]):
def __init__(
self,
iterable: t.Optional[t.Iterable[V]],
length: t.Optional[int] = None,
fill_char: str = "#",
empty_char: str = " ",
bar_template: str = "%(bar)s",
info_sep: str = " ",
show_eta: bool = True,
show_percent: t.Optional[bool] = None,
show_pos: bool = False,
item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
label: t.Optional[str] = None,
file: t.Optional[t.TextIO] = None,
color: t.Optional[bool] = None,
update_min_steps: int = 1,
width: int = 30,
) -> None:
self.fill_char = fill_char
self.empty_char = empty_char
self.bar_template = bar_template
@ -65,77 +59,87 @@ class ProgressBar(object):
self.show_percent = show_percent
self.show_pos = show_pos
self.item_show_func = item_show_func
self.label = label or ''
self.label = label or ""
if file is None:
file = _default_text_stdout()
self.file = file
self.color = color
self.update_min_steps = update_min_steps
self._completed_intervals = 0
self.width = width
self.autowidth = width == 0
if length is None:
length = _length_hint(iterable)
from operator import length_hint
length = length_hint(iterable, -1)
if length == -1:
length = None
if iterable is None:
if length is None:
raise TypeError('iterable or length is required')
iterable = range_type(length)
raise TypeError("iterable or length is required")
iterable = t.cast(t.Iterable[V], range(length))
self.iter = iter(iterable)
self.length = length
self.length_known = length is not None
self.pos = 0
self.avg = []
self.avg: t.List[float] = []
self.start = self.last_eta = time.time()
self.eta_known = False
self.finished = False
self.max_width = None
self.max_width: t.Optional[int] = None
self.entered = False
self.current_item = None
self.current_item: t.Optional[V] = None
self.is_hidden = not isatty(self.file)
self._last_line = None
self.short_limit = 0.5
self._last_line: t.Optional[str] = None
def __enter__(self):
def __enter__(self) -> "ProgressBar":
self.entered = True
self.render_progress()
return self
def __exit__(self, exc_type, exc_value, tb):
def __exit__(self, exc_type, exc_value, tb): # type: ignore
self.render_finish()
def __iter__(self):
def __iter__(self) -> t.Iterator[V]:
if not self.entered:
raise RuntimeError('You need to use progress bars in a with block.')
raise RuntimeError("You need to use progress bars in a with block.")
self.render_progress()
return self.generator()
def is_fast(self):
return time.time() - self.start <= self.short_limit
def __next__(self) -> V:
# Iteration is defined in terms of a generator function,
# returned by iter(self); use that to define next(). This works
# because `self.iter` is an iterable consumed by that generator,
# so it is re-entry safe. Calling `next(self.generator())`
# twice works and does "what you want".
return next(iter(self))
def render_finish(self):
if self.is_hidden or self.is_fast():
def render_finish(self) -> None:
if self.is_hidden:
return
self.file.write(AFTER_BAR)
self.file.flush()
@property
def pct(self):
def pct(self) -> float:
if self.finished:
return 1.0
return min(self.pos / (float(self.length) or 1), 1.0)
return min(self.pos / (float(self.length or 1) or 1), 1.0)
@property
def time_per_iteration(self):
def time_per_iteration(self) -> float:
if not self.avg:
return 0.0
return sum(self.avg) / float(len(self.avg))
@property
def eta(self):
if self.length_known and not self.finished:
def eta(self) -> float:
if self.length is not None and not self.finished:
return self.time_per_iteration * (self.length - self.pos)
return 0.0
def format_eta(self):
def format_eta(self) -> str:
if self.eta_known:
t = int(self.eta)
seconds = t % 60
@ -145,41 +149,44 @@ class ProgressBar(object):
hours = t % 24
t //= 24
if t > 0:
days = t
return '%dd %02d:%02d:%02d' % (days, hours, minutes, seconds)
return f"{t}d {hours:02}:{minutes:02}:{seconds:02}"
else:
return '%02d:%02d:%02d' % (hours, minutes, seconds)
return ''
return f"{hours:02}:{minutes:02}:{seconds:02}"
return ""
def format_pos(self):
def format_pos(self) -> str:
pos = str(self.pos)
if self.length_known:
pos += '/%s' % self.length
if self.length is not None:
pos += f"/{self.length}"
return pos
def format_pct(self):
return ('% 4d%%' % int(self.pct * 100))[1:]
def format_pct(self) -> str:
return f"{int(self.pct * 100): 4}%"[1:]
def format_bar(self):
if self.length_known:
def format_bar(self) -> str:
if self.length is not None:
bar_length = int(self.pct * self.width)
bar = self.fill_char * bar_length
bar += self.empty_char * (self.width - bar_length)
elif self.finished:
bar = self.fill_char * self.width
else:
bar = list(self.empty_char * (self.width or 1))
chars = list(self.empty_char * (self.width or 1))
if self.time_per_iteration != 0:
bar[int((math.cos(self.pos * self.time_per_iteration)
/ 2.0 + 0.5) * self.width)] = self.fill_char
bar = ''.join(bar)
chars[
int(
(math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5)
* self.width
)
] = self.fill_char
bar = "".join(chars)
return bar
def format_progress_line(self):
def format_progress_line(self) -> str:
show_percent = self.show_percent
info_bits = []
if self.length_known and show_percent is None:
if self.length is not None and show_percent is None:
show_percent = not self.show_pos
if self.show_pos:
@ -193,16 +200,25 @@ class ProgressBar(object):
if item_info is not None:
info_bits.append(item_info)
return (self.bar_template % {
'label': self.label,
'bar': self.format_bar(),
'info': self.info_sep.join(info_bits)
}).rstrip()
return (
self.bar_template
% {
"label": self.label,
"bar": self.format_bar(),
"info": self.info_sep.join(info_bits),
}
).rstrip()
def render_progress(self):
from .termui import get_terminal_size
def render_progress(self) -> None:
import shutil
if self.is_hidden:
# Only output the label as it changes if the output is not a
# TTY. Use file=stderr if you expect to be piping stdout.
if self._last_line != self.label:
self._last_line = self.label
echo(self.label, file=self.file, color=self.color)
return
buf = []
@ -211,10 +227,10 @@ class ProgressBar(object):
old_width = self.width
self.width = 0
clutter_length = term_len(self.format_progress_line())
new_width = max(0, get_terminal_size()[0] - clutter_length)
new_width = max(0, shutil.get_terminal_size().columns - clutter_length)
if new_width < old_width:
buf.append(BEFORE_BAR)
buf.append(' ' * self.max_width)
buf.append(" " * self.max_width) # type: ignore
self.max_width = new_width
self.width = new_width
@ -229,18 +245,18 @@ class ProgressBar(object):
self.max_width = line_len
buf.append(line)
buf.append(' ' * (clear_width - line_len))
line = ''.join(buf)
buf.append(" " * (clear_width - line_len))
line = "".join(buf)
# Render the line only if it changed.
if line != self._last_line and not self.is_fast():
if line != self._last_line:
self._last_line = line
echo(line, file=self.file, color=self.color, nl=False)
self.file.flush()
def make_step(self, n_steps):
def make_step(self, n_steps: int) -> None:
self.pos += n_steps
if self.length_known and self.pos >= self.length:
if self.length is not None and self.pos >= self.length:
self.finished = True
if (time.time() - self.last_eta) < 1.0:
@ -258,97 +274,134 @@ class ProgressBar(object):
self.avg = self.avg[-6:] + [step]
self.eta_known = self.length_known
self.eta_known = self.length is not None
def update(self, n_steps):
self.make_step(n_steps)
self.render_progress()
def update(self, n_steps: int, current_item: t.Optional[V] = None) -> None:
"""Update the progress bar by advancing a specified number of
steps, and optionally set the ``current_item`` for this new
position.
def finish(self):
self.eta_known = 0
:param n_steps: Number of steps to advance.
:param current_item: Optional item to set as ``current_item``
for the updated position.
.. versionchanged:: 8.0
Added the ``current_item`` optional parameter.
.. versionchanged:: 8.0
Only render when the number of steps meets the
``update_min_steps`` threshold.
"""
if current_item is not None:
self.current_item = current_item
self._completed_intervals += n_steps
if self._completed_intervals >= self.update_min_steps:
self.make_step(self._completed_intervals)
self.render_progress()
self._completed_intervals = 0
def finish(self) -> None:
self.eta_known = False
self.current_item = None
self.finished = True
def generator(self):
"""
Returns a generator which yields the items added to the bar during
construction, and updates the progress bar *after* the yielded block
returns.
def generator(self) -> t.Iterator[V]:
"""Return a generator which yields the items added to the bar
during construction, and updates the progress bar *after* the
yielded block returns.
"""
# WARNING: the iterator interface for `ProgressBar` relies on
# this and only works because this is a simple generator which
# doesn't create or manage additional state. If this function
# changes, the impact should be evaluated both against
# `iter(bar)` and `next(bar)`. `next()` in particular may call
# `self.generator()` repeatedly, and this must remain safe in
# order for that interface to work.
if not self.entered:
raise RuntimeError('You need to use progress bars in a with block.')
raise RuntimeError("You need to use progress bars in a with block.")
if self.is_hidden:
for rv in self.iter:
yield rv
yield from self.iter
else:
for rv in self.iter:
self.current_item = rv
# This allows show_item_func to be updated before the
# item is processed. Only trigger at the beginning of
# the update interval.
if self._completed_intervals == 0:
self.render_progress()
yield rv
self.update(1)
self.finish()
self.render_progress()
def pager(generator, color=None):
def pager(generator: t.Iterable[str], color: t.Optional[bool] = None) -> None:
"""Decide what method to use for paging through text."""
stdout = _default_text_stdout()
if not isatty(sys.stdin) or not isatty(stdout):
return _nullpager(stdout, generator, color)
pager_cmd = (os.environ.get('PAGER', None) or '').strip()
pager_cmd = (os.environ.get("PAGER", None) or "").strip()
if pager_cmd:
if WIN:
return _tempfilepager(generator, pager_cmd, color)
return _pipepager(generator, pager_cmd, color)
if os.environ.get('TERM') in ('dumb', 'emacs'):
if os.environ.get("TERM") in ("dumb", "emacs"):
return _nullpager(stdout, generator, color)
if WIN or sys.platform.startswith('os2'):
return _tempfilepager(generator, 'more <', color)
if hasattr(os, 'system') and os.system('(less) 2>/dev/null') == 0:
return _pipepager(generator, 'less', color)
if WIN or sys.platform.startswith("os2"):
return _tempfilepager(generator, "more <", color)
if hasattr(os, "system") and os.system("(less) 2>/dev/null") == 0:
return _pipepager(generator, "less", color)
import tempfile
fd, filename = tempfile.mkstemp()
os.close(fd)
try:
if hasattr(os, 'system') and os.system('more "%s"' % filename) == 0:
return _pipepager(generator, 'more', color)
if hasattr(os, "system") and os.system(f'more "{filename}"') == 0:
return _pipepager(generator, "more", color)
return _nullpager(stdout, generator, color)
finally:
os.unlink(filename)
def _pipepager(generator, cmd, color):
def _pipepager(generator: t.Iterable[str], cmd: str, color: t.Optional[bool]) -> None:
"""Page through text by feeding it to another program. Invoking a
pager through this might support colors.
"""
import subprocess
env = dict(os.environ)
# If we're piping to less we might support colors under the
# condition that
cmd_detail = cmd.rsplit('/', 1)[-1].split()
if color is None and cmd_detail[0] == 'less':
less_flags = os.environ.get('LESS', '') + ' '.join(cmd_detail[1:])
cmd_detail = cmd.rsplit("/", 1)[-1].split()
if color is None and cmd_detail[0] == "less":
less_flags = f"{os.environ.get('LESS', '')}{' '.join(cmd_detail[1:])}"
if not less_flags:
env['LESS'] = '-R'
env["LESS"] = "-R"
color = True
elif 'r' in less_flags or 'R' in less_flags:
elif "r" in less_flags or "R" in less_flags:
color = True
c = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE,
env=env)
encoding = get_best_encoding(c.stdin)
c = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, env=env)
stdin = t.cast(t.BinaryIO, c.stdin)
encoding = get_best_encoding(stdin)
try:
for text in generator:
if not color:
text = strip_ansi(text)
c.stdin.write(text.encode(encoding, 'replace'))
except (IOError, KeyboardInterrupt):
stdin.write(text.encode(encoding, "replace"))
except (OSError, KeyboardInterrupt):
pass
else:
c.stdin.close()
stdin.close()
# Less doesn't respect ^C, but catches it for its own UI purposes (aborting
# search or other commands inside less).
@ -367,24 +420,30 @@ def _pipepager(generator, cmd, color):
break
def _tempfilepager(generator, cmd, color):
def _tempfilepager(
generator: t.Iterable[str], cmd: str, color: t.Optional[bool]
) -> None:
"""Page through text by invoking a program on a temporary file."""
import tempfile
filename = tempfile.mktemp()
fd, filename = tempfile.mkstemp()
# TODO: This never terminates if the passed generator never terminates.
text = "".join(generator)
if not color:
text = strip_ansi(text)
encoding = get_best_encoding(sys.stdout)
with open_stream(filename, 'wb')[0] as f:
with open_stream(filename, "wb")[0] as f:
f.write(text.encode(encoding))
try:
os.system(cmd + ' "' + filename + '"')
os.system(f'{cmd} "{filename}"')
finally:
os.close(fd)
os.unlink(filename)
def _nullpager(stream, generator, color):
def _nullpager(
stream: t.TextIO, generator: t.Iterable[str], color: t.Optional[bool]
) -> None:
"""Simply print unformatted text. This is the ultimate fallback."""
for text in generator:
if not color:
@ -392,159 +451,184 @@ def _nullpager(stream, generator, color):
stream.write(text)
class Editor(object):
def __init__(self, editor=None, env=None, require_save=True,
extension='.txt'):
class Editor:
def __init__(
self,
editor: t.Optional[str] = None,
env: t.Optional[t.Mapping[str, str]] = None,
require_save: bool = True,
extension: str = ".txt",
) -> None:
self.editor = editor
self.env = env
self.require_save = require_save
self.extension = extension
def get_editor(self):
def get_editor(self) -> str:
if self.editor is not None:
return self.editor
for key in 'VISUAL', 'EDITOR':
for key in "VISUAL", "EDITOR":
rv = os.environ.get(key)
if rv:
return rv
if WIN:
return 'notepad'
for editor in 'vim', 'nano':
if os.system('which %s >/dev/null 2>&1' % editor) == 0:
return "notepad"
for editor in "sensible-editor", "vim", "nano":
if os.system(f"which {editor} >/dev/null 2>&1") == 0:
return editor
return 'vi'
return "vi"
def edit_file(self, filename):
def edit_file(self, filename: str) -> None:
import subprocess
editor = self.get_editor()
environ: t.Optional[t.Dict[str, str]] = None
if self.env:
environ = os.environ.copy()
environ.update(self.env)
else:
environ = None
try:
c = subprocess.Popen('%s "%s"' % (editor, filename),
env=environ, shell=True)
c = subprocess.Popen(f'{editor} "{filename}"', env=environ, shell=True)
exit_code = c.wait()
if exit_code != 0:
raise ClickException('%s: Editing failed!' % editor)
raise ClickException(
_("{editor}: Editing failed").format(editor=editor)
)
except OSError as e:
raise ClickException('%s: Editing failed: %s' % (editor, e))
raise ClickException(
_("{editor}: Editing failed: {e}").format(editor=editor, e=e)
) from e
def edit(self, text):
def edit(self, text: t.Optional[t.AnyStr]) -> t.Optional[t.AnyStr]:
import tempfile
text = text or ''
if text and not text.endswith('\n'):
text += '\n'
if not text:
data = b""
elif isinstance(text, (bytes, bytearray)):
data = text
else:
if text and not text.endswith("\n"):
text += "\n"
fd, name = tempfile.mkstemp(prefix='editor-', suffix=self.extension)
try:
if WIN:
encoding = 'utf-8-sig'
text = text.replace('\n', '\r\n')
data = text.replace("\n", "\r\n").encode("utf-8-sig")
else:
encoding = 'utf-8'
text = text.encode(encoding)
data = text.encode("utf-8")
f = os.fdopen(fd, 'wb')
f.write(text)
f.close()
fd, name = tempfile.mkstemp(prefix="editor-", suffix=self.extension)
f: t.BinaryIO
try:
with os.fdopen(fd, "wb") as f:
f.write(data)
# If the filesystem resolution is 1 second, like Mac OS
# 10.12 Extended, or 2 seconds, like FAT32, and the editor
# closes very fast, require_save can fail. Set the modified
# time to be 2 seconds in the past to work around this.
os.utime(name, (os.path.getatime(name), os.path.getmtime(name) - 2))
# Depending on the resolution, the exact value might not be
# recorded, so get the new recorded value.
timestamp = os.path.getmtime(name)
self.edit_file(name)
if self.require_save \
and os.path.getmtime(name) == timestamp:
if self.require_save and os.path.getmtime(name) == timestamp:
return None
f = open(name, 'rb')
try:
with open(name, "rb") as f:
rv = f.read()
finally:
f.close()
return rv.decode('utf-8-sig').replace('\r\n', '\n')
if isinstance(text, (bytes, bytearray)):
return rv
return rv.decode("utf-8-sig").replace("\r\n", "\n") # type: ignore
finally:
os.unlink(name)
def open_url(url, wait=False, locate=False):
def open_url(url: str, wait: bool = False, locate: bool = False) -> int:
import subprocess
def _unquote_file(url):
try:
import urllib
except ImportError:
import urllib
if url.startswith('file://'):
url = urllib.unquote(url[7:])
def _unquote_file(url: str) -> str:
from urllib.parse import unquote
if url.startswith("file://"):
url = unquote(url[7:])
return url
if sys.platform == 'darwin':
args = ['open']
if sys.platform == "darwin":
args = ["open"]
if wait:
args.append('-W')
args.append("-W")
if locate:
args.append('-R')
args.append("-R")
args.append(_unquote_file(url))
null = open('/dev/null', 'w')
null = open("/dev/null", "w")
try:
return subprocess.Popen(args, stderr=null).wait()
finally:
null.close()
elif WIN:
if locate:
url = _unquote_file(url)
args = 'explorer /select,"%s"' % _unquote_file(
url.replace('"', ''))
url = _unquote_file(url.replace('"', ""))
args = f'explorer /select,"{url}"'
else:
args = 'start %s "" "%s"' % (
wait and '/WAIT' or '', url.replace('"', ''))
url = url.replace('"', "")
wait_str = "/WAIT" if wait else ""
args = f'start {wait_str} "" "{url}"'
return os.system(args)
elif CYGWIN:
if locate:
url = _unquote_file(url)
args = 'cygstart "%s"' % (os.path.dirname(url).replace('"', ''))
url = os.path.dirname(_unquote_file(url).replace('"', ""))
args = f'cygstart "{url}"'
else:
args = 'cygstart %s "%s"' % (
wait and '-w' or '', url.replace('"', ''))
url = url.replace('"', "")
wait_str = "-w" if wait else ""
args = f'cygstart {wait_str} "{url}"'
return os.system(args)
try:
if locate:
url = os.path.dirname(_unquote_file(url)) or '.'
url = os.path.dirname(_unquote_file(url)) or "."
else:
url = _unquote_file(url)
c = subprocess.Popen(['xdg-open', url])
c = subprocess.Popen(["xdg-open", url])
if wait:
return c.wait()
return 0
except OSError:
if url.startswith(('http://', 'https://')) and not locate and not wait:
if url.startswith(("http://", "https://")) and not locate and not wait:
import webbrowser
webbrowser.open(url)
return 0
return 1
def _translate_ch_to_exc(ch):
if ch == u'\x03':
def _translate_ch_to_exc(ch: str) -> t.Optional[BaseException]:
if ch == "\x03":
raise KeyboardInterrupt()
if ch == u'\x04' and not WIN: # Unix-like, Ctrl+D
if ch == "\x04" and not WIN: # Unix-like, Ctrl+D
raise EOFError()
if ch == u'\x1a' and WIN: # Windows, Ctrl+Z
if ch == "\x1a" and WIN: # Windows, Ctrl+Z
raise EOFError()
return None
if WIN:
import msvcrt
@contextlib.contextmanager
def raw_terminal():
yield
def raw_terminal() -> t.Iterator[int]:
yield -1
def getchar(echo):
def getchar(echo: bool) -> str:
# The function `getch` will return a bytes object corresponding to
# the pressed character. Since Windows 10 build 1803, it will also
# return \x00 when called a second time after pressing a regular key.
@ -574,48 +658,60 @@ if WIN:
#
# Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
# is doing the right thing in more situations than with `getch`.
func: t.Callable[[], str]
if echo:
func = msvcrt.getwche
func = msvcrt.getwche # type: ignore
else:
func = msvcrt.getwch
func = msvcrt.getwch # type: ignore
rv = func()
if rv in (u'\x00', u'\xe0'):
if rv in ("\x00", "\xe0"):
# \x00 and \xe0 are control characters that indicate special key,
# see above.
rv += func()
_translate_ch_to_exc(rv)
return rv
else:
import tty
import termios
@contextlib.contextmanager
def raw_terminal():
def raw_terminal() -> t.Iterator[int]:
f: t.Optional[t.TextIO]
fd: int
if not isatty(sys.stdin):
f = open('/dev/tty')
f = open("/dev/tty")
fd = f.fileno()
else:
fd = sys.stdin.fileno()
f = None
try:
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(fd)
yield fd
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
sys.stdout.flush()
if f is not None:
f.close()
except termios.error:
pass
def getchar(echo):
def getchar(echo: bool) -> str:
with raw_terminal() as fd:
ch = os.read(fd, 32)
ch = ch.decode(get_best_encoding(sys.stdin), 'replace')
ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace")
if echo and isatty(sys.stdout):
sys.stdout.write(ch)
_translate_ch_to_exc(ch)
return ch

View file

@ -1,10 +1,16 @@
import textwrap
import typing as t
from contextlib import contextmanager
class TextWrapper(textwrap.TextWrapper):
def _handle_long_word(self, reversed_chunks, cur_line, cur_len, width):
def _handle_long_word(
self,
reversed_chunks: t.List[str],
cur_line: t.List[str],
cur_len: int,
width: int,
) -> None:
space_left = max(width - cur_len, 1)
if self.break_long_words:
@ -17,22 +23,27 @@ class TextWrapper(textwrap.TextWrapper):
cur_line.append(reversed_chunks.pop())
@contextmanager
def extra_indent(self, indent):
def extra_indent(self, indent: str) -> t.Iterator[None]:
old_initial_indent = self.initial_indent
old_subsequent_indent = self.subsequent_indent
self.initial_indent += indent
self.subsequent_indent += indent
try:
yield
finally:
self.initial_indent = old_initial_indent
self.subsequent_indent = old_subsequent_indent
def indent_only(self, text):
def indent_only(self, text: str) -> str:
rv = []
for idx, line in enumerate(text.splitlines()):
indent = self.initial_indent
if idx > 0:
indent = self.subsequent_indent
rv.append(indent + line)
return '\n'.join(rv)
rv.append(f"{indent}{line}")
return "\n".join(rv)

View file

@ -1,125 +0,0 @@
import os
import sys
import codecs
from ._compat import PY2
# If someone wants to vendor click, we want to ensure the
# correct package is discovered. Ideally we could use a
# relative import here but unfortunately Python does not
# support that.
click = sys.modules[__name__.rsplit('.', 1)[0]]
def _find_unicode_literals_frame():
import __future__
if not hasattr(sys, '_getframe'): # not all Python implementations have it
return 0
frm = sys._getframe(1)
idx = 1
while frm is not None:
if frm.f_globals.get('__name__', '').startswith('click.'):
frm = frm.f_back
idx += 1
elif frm.f_code.co_flags & __future__.unicode_literals.compiler_flag:
return idx
else:
break
return 0
def _check_for_unicode_literals():
if not __debug__:
return
if not PY2 or click.disable_unicode_literals_warning:
return
bad_frame = _find_unicode_literals_frame()
if bad_frame <= 0:
return
from warnings import warn
warn(Warning('Click detected the use of the unicode_literals '
'__future__ import. This is heavily discouraged '
'because it can introduce subtle bugs in your '
'code. You should instead use explicit u"" literals '
'for your unicode strings. For more information see '
'https://click.palletsprojects.com/python3/'),
stacklevel=bad_frame)
def _verify_python3_env():
"""Ensures that the environment is good for unicode on Python 3."""
if PY2:
return
try:
import locale
fs_enc = codecs.lookup(locale.getpreferredencoding()).name
except Exception:
fs_enc = 'ascii'
if fs_enc != 'ascii':
return
extra = ''
if os.name == 'posix':
import subprocess
try:
rv = subprocess.Popen(['locale', '-a'], stdout=subprocess.PIPE,
stderr=subprocess.PIPE).communicate()[0]
except OSError:
rv = b''
good_locales = set()
has_c_utf8 = False
# Make sure we're operating on text here.
if isinstance(rv, bytes):
rv = rv.decode('ascii', 'replace')
for line in rv.splitlines():
locale = line.strip()
if locale.lower().endswith(('.utf-8', '.utf8')):
good_locales.add(locale)
if locale.lower() in ('c.utf8', 'c.utf-8'):
has_c_utf8 = True
extra += '\n\n'
if not good_locales:
extra += (
'Additional information: on this system no suitable UTF-8\n'
'locales were discovered. This most likely requires resolving\n'
'by reconfiguring the locale system.'
)
elif has_c_utf8:
extra += (
'This system supports the C.UTF-8 locale which is recommended.\n'
'You might be able to resolve your issue by exporting the\n'
'following environment variables:\n\n'
' export LC_ALL=C.UTF-8\n'
' export LANG=C.UTF-8'
)
else:
extra += (
'This system lists a couple of UTF-8 supporting locales that\n'
'you can pick from. The following suitable locales were\n'
'discovered: %s'
) % ', '.join(sorted(good_locales))
bad_locale = None
for locale in os.environ.get('LC_ALL'), os.environ.get('LANG'):
if locale and locale.lower().endswith(('.utf-8', '.utf8')):
bad_locale = locale
if locale is not None:
break
if bad_locale is not None:
extra += (
'\n\nClick discovered that you exported a UTF-8 locale\n'
'but the locale system could not pick up from it because\n'
'it does not exist. The exported locale is "%s" but it\n'
'is not supported'
) % bad_locale
raise RuntimeError(
'Click will abort further execution because Python 3 was'
' configured to use ASCII as encoding for the environment.'
' Consult https://click.palletsprojects.com/en/7.x/python3/ for'
' mitigation steps.' + extra
)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This module is based on the excellent work by Adam Bartoš who
# provided a lot of what went into the implementation here in
# the discussion to issue1602 in the Python bug tracker.
@ -6,26 +5,32 @@
# There are some general differences in regards to how this works
# compared to the original patches as we do not need to patch
# the entire interpreter but just work in our little world of
# echo and prmopt.
# echo and prompt.
import io
import os
import sys
import zlib
import time
import ctypes
import msvcrt
from ._compat import _NonClosingTextIOWrapper, text_type, PY2
from ctypes import byref, POINTER, c_int, c_char, c_char_p, \
c_void_p, py_object, c_ssize_t, c_ulong, windll, WINFUNCTYPE
try:
from ctypes import pythonapi
PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
PyBuffer_Release = pythonapi.PyBuffer_Release
except ImportError:
pythonapi = None
from ctypes.wintypes import LPWSTR, LPCWSTR
import typing as t
from ctypes import byref
from ctypes import c_char
from ctypes import c_char_p
from ctypes import c_int
from ctypes import c_ssize_t
from ctypes import c_ulong
from ctypes import c_void_p
from ctypes import POINTER
from ctypes import py_object
from ctypes import Structure
from ctypes.wintypes import DWORD
from ctypes.wintypes import HANDLE
from ctypes.wintypes import LPCWSTR
from ctypes.wintypes import LPWSTR
from ._compat import _NonClosingTextIOWrapper
assert sys.platform == "win32"
import msvcrt # noqa: E402
from ctypes import windll # noqa: E402
from ctypes import WINFUNCTYPE # noqa: E402
c_ssize_p = POINTER(c_ssize_t)
@ -33,19 +38,18 @@ kernel32 = windll.kernel32
GetStdHandle = kernel32.GetStdHandle
ReadConsoleW = kernel32.ReadConsoleW
WriteConsoleW = kernel32.WriteConsoleW
GetConsoleMode = kernel32.GetConsoleMode
GetLastError = kernel32.GetLastError
GetCommandLineW = WINFUNCTYPE(LPWSTR)(
('GetCommandLineW', windll.kernel32))
CommandLineToArgvW = WINFUNCTYPE(
POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
('CommandLineToArgvW', windll.shell32))
GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32))
CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
("CommandLineToArgvW", windll.shell32)
)
LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32))
STDIN_HANDLE = GetStdHandle(-10)
STDOUT_HANDLE = GetStdHandle(-11)
STDERR_HANDLE = GetStdHandle(-12)
PyBUF_SIMPLE = 0
PyBUF_WRITABLE = 1
@ -57,38 +61,40 @@ STDIN_FILENO = 0
STDOUT_FILENO = 1
STDERR_FILENO = 2
EOF = b'\x1a'
EOF = b"\x1a"
MAX_BYTES_WRITTEN = 32767
class Py_buffer(ctypes.Structure):
_fields_ = [
('buf', c_void_p),
('obj', py_object),
('len', c_ssize_t),
('itemsize', c_ssize_t),
('readonly', c_int),
('ndim', c_int),
('format', c_char_p),
('shape', c_ssize_p),
('strides', c_ssize_p),
('suboffsets', c_ssize_p),
('internal', c_void_p)
]
if PY2:
_fields_.insert(-1, ('smalltable', c_ssize_t * 2))
# On PyPy we cannot get buffers so our ability to operate here is
# serverly limited.
if pythonapi is None:
try:
from ctypes import pythonapi
except ImportError:
# On PyPy we cannot get buffers so our ability to operate here is
# severely limited.
get_buffer = None
else:
class Py_buffer(Structure):
_fields_ = [
("buf", c_void_p),
("obj", py_object),
("len", c_ssize_t),
("itemsize", c_ssize_t),
("readonly", c_int),
("ndim", c_int),
("format", c_char_p),
("shape", c_ssize_p),
("strides", c_ssize_p),
("suboffsets", c_ssize_p),
("internal", c_void_p),
]
PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
PyBuffer_Release = pythonapi.PyBuffer_Release
def get_buffer(obj, writable=False):
buf = Py_buffer()
flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
PyObject_GetBuffer(py_object(obj), byref(buf), flags)
try:
buffer_type = c_char * buf.len
return buffer_type.from_address(buf.buf)
@ -97,17 +103,15 @@ else:
class _WindowsConsoleRawIOBase(io.RawIOBase):
def __init__(self, handle):
self.handle = handle
def isatty(self):
io.RawIOBase.isatty(self)
super().isatty()
return True
class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
def readable(self):
return True
@ -116,20 +120,26 @@ class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
if not bytes_to_be_read:
return 0
elif bytes_to_be_read % 2:
raise ValueError('cannot read odd number of bytes from '
'UTF-16-LE encoded console')
raise ValueError(
"cannot read odd number of bytes from UTF-16-LE encoded console"
)
buffer = get_buffer(b, writable=True)
code_units_to_be_read = bytes_to_be_read // 2
code_units_read = c_ulong()
rv = ReadConsoleW(self.handle, buffer, code_units_to_be_read,
byref(code_units_read), None)
rv = ReadConsoleW(
HANDLE(self.handle),
buffer,
code_units_to_be_read,
byref(code_units_read),
None,
)
if GetLastError() == ERROR_OPERATION_ABORTED:
# wait for KeyboardInterrupt
time.sleep(0.1)
if not rv:
raise OSError('Windows error: %s' % GetLastError())
raise OSError(f"Windows error: {GetLastError()}")
if buffer[0] == EOF:
return 0
@ -137,27 +147,30 @@ class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
def writable(self):
return True
@staticmethod
def _get_error_message(errno):
if errno == ERROR_SUCCESS:
return 'ERROR_SUCCESS'
return "ERROR_SUCCESS"
elif errno == ERROR_NOT_ENOUGH_MEMORY:
return 'ERROR_NOT_ENOUGH_MEMORY'
return 'Windows error %s' % errno
return "ERROR_NOT_ENOUGH_MEMORY"
return f"Windows error {errno}"
def write(self, b):
bytes_to_be_written = len(b)
buf = get_buffer(b)
code_units_to_be_written = min(bytes_to_be_written,
MAX_BYTES_WRITTEN) // 2
code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
code_units_written = c_ulong()
WriteConsoleW(self.handle, buf, code_units_to_be_written,
byref(code_units_written), None)
WriteConsoleW(
HANDLE(self.handle),
buf,
code_units_to_be_written,
byref(code_units_written),
None,
)
bytes_written = 2 * code_units_written.value
if bytes_written == 0 and bytes_to_be_written > 0:
@ -165,18 +178,17 @@ class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
return bytes_written
class ConsoleStream(object):
def __init__(self, text_stream, byte_stream):
class ConsoleStream:
def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None:
self._text_stream = text_stream
self.buffer = byte_stream
@property
def name(self):
def name(self) -> str:
return self.buffer.name
def write(self, x):
if isinstance(x, text_type):
def write(self, x: t.AnyStr) -> int:
if isinstance(x, str):
return self._text_stream.write(x)
try:
self.flush()
@ -184,124 +196,84 @@ class ConsoleStream(object):
pass
return self.buffer.write(x)
def writelines(self, lines):
def writelines(self, lines: t.Iterable[t.AnyStr]) -> None:
for line in lines:
self.write(line)
def __getattr__(self, name):
def __getattr__(self, name: str) -> t.Any:
return getattr(self._text_stream, name)
def isatty(self):
def isatty(self) -> bool:
return self.buffer.isatty()
def __repr__(self):
return '<ConsoleStream name=%r encoding=%r>' % (
self.name,
self.encoding,
)
return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>"
class WindowsChunkedWriter(object):
"""
Wraps a stream (such as stdout), acting as a transparent proxy for all
attribute access apart from method 'write()' which we wrap to write in
limited chunks due to a Windows limitation on binary console streams.
"""
def __init__(self, wrapped):
# double-underscore everything to prevent clashes with names of
# attributes on the wrapped stream object.
self.__wrapped = wrapped
def __getattr__(self, name):
return getattr(self.__wrapped, name)
def write(self, text):
total_to_write = len(text)
written = 0
while written < total_to_write:
to_write = min(total_to_write - written, MAX_BYTES_WRITTEN)
self.__wrapped.write(text[written:written+to_write])
written += to_write
_wrapped_std_streams = set()
def _wrap_std_stream(name):
# Python 2 & Windows 7 and below
if PY2 and sys.getwindowsversion()[:2] <= (6, 1) and name not in _wrapped_std_streams:
setattr(sys, name, WindowsChunkedWriter(getattr(sys, name)))
_wrapped_std_streams.add(name)
def _get_text_stdin(buffer_stream):
def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO:
text_stream = _NonClosingTextIOWrapper(
io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
'utf-16-le', 'strict', line_buffering=True)
return ConsoleStream(text_stream, buffer_stream)
"utf-16-le",
"strict",
line_buffering=True,
)
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
def _get_text_stdout(buffer_stream):
def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO:
text_stream = _NonClosingTextIOWrapper(
io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
'utf-16-le', 'strict', line_buffering=True)
return ConsoleStream(text_stream, buffer_stream)
"utf-16-le",
"strict",
line_buffering=True,
)
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
def _get_text_stderr(buffer_stream):
def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO:
text_stream = _NonClosingTextIOWrapper(
io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
'utf-16-le', 'strict', line_buffering=True)
return ConsoleStream(text_stream, buffer_stream)
"utf-16-le",
"strict",
line_buffering=True,
)
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
if PY2:
def _hash_py_argv():
return zlib.crc32('\x00'.join(sys.argv[1:]))
_initial_argv_hash = _hash_py_argv()
def _get_windows_argv():
argc = c_int(0)
argv_unicode = CommandLineToArgvW(GetCommandLineW(), byref(argc))
argv = [argv_unicode[i] for i in range(0, argc.value)]
if not hasattr(sys, 'frozen'):
argv = argv[1:]
while len(argv) > 0:
arg = argv[0]
if not arg.startswith('-') or arg == '-':
break
argv = argv[1:]
if arg.startswith(('-c', '-m')):
break
return argv[1:]
_stream_factories = {
_stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = {
0: _get_text_stdin,
1: _get_text_stdout,
2: _get_text_stderr,
}
def _get_windows_console_stream(f, encoding, errors):
if get_buffer is not None and \
encoding in ('utf-16-le', None) \
and errors in ('strict', None) and \
hasattr(f, 'isatty') and f.isatty():
def _is_console(f: t.TextIO) -> bool:
if not hasattr(f, "fileno"):
return False
try:
fileno = f.fileno()
except (OSError, io.UnsupportedOperation):
return False
handle = msvcrt.get_osfhandle(fileno)
return bool(GetConsoleMode(handle, byref(DWORD())))
def _get_windows_console_stream(
f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str]
) -> t.Optional[t.TextIO]:
if (
get_buffer is not None
and encoding in {"utf-16-le", None}
and errors in {"strict", None}
and _is_console(f)
):
func = _stream_factories.get(f.fileno())
if func is not None:
if not PY2:
f = getattr(f, 'buffer', None)
if f is None:
return None
else:
# If we are on Python 2 we need to set the stream that we
# deal with to binary mode as otherwise the exercise if a
# bit moot. The same problems apply as for
# get_binary_stdin and friends from _compat.
msvcrt.setmode(f.fileno(), os.O_BINARY)
return func(f)
b = getattr(f, "buffer", None)
if b is None:
return None
return func(b)

File diff suppressed because it is too large Load diff

View file

@ -1,34 +1,48 @@
import sys
import inspect
import types
import typing as t
from functools import update_wrapper
from gettext import gettext as _
from ._compat import iteritems
from ._unicodefun import _check_for_unicode_literals
from .utils import echo
from .core import Argument
from .core import Command
from .core import Context
from .core import Group
from .core import Option
from .core import Parameter
from .globals import get_current_context
from .utils import echo
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
FC = t.TypeVar("FC", bound=t.Union[t.Callable[..., t.Any], Command])
def pass_context(f):
def pass_context(f: F) -> F:
"""Marks a callback as wanting to receive the current context
object as first argument.
"""
def new_func(*args, **kwargs):
def new_func(*args, **kwargs): # type: ignore
return f(get_current_context(), *args, **kwargs)
return update_wrapper(new_func, f)
return update_wrapper(t.cast(F, new_func), f)
def pass_obj(f):
def pass_obj(f: F) -> F:
"""Similar to :func:`pass_context`, but only pass the object on the
context onwards (:attr:`Context.obj`). This is useful if that object
represents the state of a nested system.
"""
def new_func(*args, **kwargs):
def new_func(*args, **kwargs): # type: ignore
return f(get_current_context().obj, *args, **kwargs)
return update_wrapper(new_func, f)
return update_wrapper(t.cast(F, new_func), f)
def make_pass_decorator(object_type, ensure=False):
def make_pass_decorator(
object_type: t.Type, ensure: bool = False
) -> "t.Callable[[F], F]":
"""Given an object type this creates a decorator that will work
similar to :func:`pass_obj` but instead of passing the object of the
current context, it will find the innermost context of type
@ -50,55 +64,106 @@ def make_pass_decorator(object_type, ensure=False):
:param ensure: if set to `True`, a new object will be created and
remembered on the context if it's not there yet.
"""
def decorator(f):
def new_func(*args, **kwargs):
def decorator(f: F) -> F:
def new_func(*args, **kwargs): # type: ignore
ctx = get_current_context()
if ensure:
obj = ctx.ensure_object(object_type)
else:
obj = ctx.find_object(object_type)
if obj is None:
raise RuntimeError('Managed to invoke callback without a '
'context object of type %r existing'
% object_type.__name__)
raise RuntimeError(
"Managed to invoke callback without a context"
f" object of type {object_type.__name__!r}"
" existing."
)
return ctx.invoke(f, obj, *args, **kwargs)
return update_wrapper(new_func, f)
return update_wrapper(t.cast(F, new_func), f)
return decorator
def _make_command(f, name, attrs, cls):
if isinstance(f, Command):
raise TypeError('Attempted to convert a callback into a '
'command twice.')
try:
params = f.__click_params__
params.reverse()
del f.__click_params__
except AttributeError:
params = []
help = attrs.get('help')
if help is None:
help = inspect.getdoc(f)
if isinstance(help, bytes):
help = help.decode('utf-8')
else:
help = inspect.cleandoc(help)
attrs['help'] = help
_check_for_unicode_literals()
return cls(name=name or f.__name__.lower().replace('_', '-'),
callback=f, params=params, **attrs)
def pass_meta_key(
key: str, *, doc_description: t.Optional[str] = None
) -> "t.Callable[[F], F]":
"""Create a decorator that passes a key from
:attr:`click.Context.meta` as the first argument to the decorated
function.
:param key: Key in ``Context.meta`` to pass.
:param doc_description: Description of the object being passed,
inserted into the decorator's docstring. Defaults to "the 'key'
key from Context.meta".
.. versionadded:: 8.0
"""
def decorator(f: F) -> F:
def new_func(*args, **kwargs): # type: ignore
ctx = get_current_context()
obj = ctx.meta[key]
return ctx.invoke(f, obj, *args, **kwargs)
return update_wrapper(t.cast(F, new_func), f)
if doc_description is None:
doc_description = f"the {key!r} key from :attr:`click.Context.meta`"
decorator.__doc__ = (
f"Decorator that passes {doc_description} as the first argument"
" to the decorated function."
)
return decorator
def command(name=None, cls=None, **attrs):
CmdType = t.TypeVar("CmdType", bound=Command)
@t.overload
def command(
__func: t.Callable[..., t.Any],
) -> Command:
...
@t.overload
def command(
name: t.Optional[str] = None,
**attrs: t.Any,
) -> t.Callable[..., Command]:
...
@t.overload
def command(
name: t.Optional[str] = None,
cls: t.Type[CmdType] = ...,
**attrs: t.Any,
) -> t.Callable[..., CmdType]:
...
def command(
name: t.Union[str, t.Callable[..., t.Any], None] = None,
cls: t.Optional[t.Type[Command]] = None,
**attrs: t.Any,
) -> t.Union[Command, t.Callable[..., Command]]:
r"""Creates a new :class:`Command` and uses the decorated function as
callback. This will also automatically attach all decorated
:func:`option`\s and :func:`argument`\s as parameters to the command.
The name of the command defaults to the name of the function. If you
want to change that, you can pass the intended name as the first
argument.
The name of the command defaults to the name of the function with
underscores replaced by dashes. If you want to change that, you can
pass the intended name as the first argument.
All keyword arguments are forwarded to the underlying command class.
For the ``params`` argument, any decorated params are appended to
the end of the list.
Once decorated the function turns into a :class:`Command` instance
that can be invoked as a command line utility or be attached to a
@ -108,35 +173,105 @@ def command(name=None, cls=None, **attrs):
name with underscores replaced by dashes.
:param cls: the command class to instantiate. This defaults to
:class:`Command`.
.. versionchanged:: 8.1
This decorator can be applied without parentheses.
.. versionchanged:: 8.1
The ``params`` argument can be used. Decorated params are
appended to the end of the list.
"""
func: t.Optional[t.Callable[..., t.Any]] = None
if callable(name):
func = name
name = None
assert cls is None, "Use 'command(cls=cls)(callable)' to specify a class."
assert not attrs, "Use 'command(**kwargs)(callable)' to provide arguments."
if cls is None:
cls = Command
def decorator(f):
cmd = _make_command(f, name, attrs, cls)
def decorator(f: t.Callable[..., t.Any]) -> Command:
if isinstance(f, Command):
raise TypeError("Attempted to convert a callback into a command twice.")
attr_params = attrs.pop("params", None)
params = attr_params if attr_params is not None else []
try:
decorator_params = f.__click_params__ # type: ignore
except AttributeError:
pass
else:
del f.__click_params__ # type: ignore
params.extend(reversed(decorator_params))
if attrs.get("help") is None:
attrs["help"] = f.__doc__
cmd = cls( # type: ignore[misc]
name=name or f.__name__.lower().replace("_", "-"), # type: ignore[arg-type]
callback=f,
params=params,
**attrs,
)
cmd.__doc__ = f.__doc__
return cmd
if func is not None:
return decorator(func)
return decorator
def group(name=None, **attrs):
@t.overload
def group(
__func: t.Callable[..., t.Any],
) -> Group:
...
@t.overload
def group(
name: t.Optional[str] = None,
**attrs: t.Any,
) -> t.Callable[[F], Group]:
...
def group(
name: t.Union[str, t.Callable[..., t.Any], None] = None, **attrs: t.Any
) -> t.Union[Group, t.Callable[[F], Group]]:
"""Creates a new :class:`Group` with a function as callback. This
works otherwise the same as :func:`command` just that the `cls`
parameter is set to :class:`Group`.
.. versionchanged:: 8.1
This decorator can be applied without parentheses.
"""
attrs.setdefault('cls', Group)
return command(name, **attrs)
if attrs.get("cls") is None:
attrs["cls"] = Group
if callable(name):
grp: t.Callable[[F], Group] = t.cast(Group, command(**attrs))
return grp(name)
return t.cast(Group, command(name, **attrs))
def _param_memo(f, param):
def _param_memo(f: FC, param: Parameter) -> None:
if isinstance(f, Command):
f.params.append(param)
else:
if not hasattr(f, '__click_params__'):
f.__click_params__ = []
f.__click_params__.append(param)
if not hasattr(f, "__click_params__"):
f.__click_params__ = [] # type: ignore
f.__click_params__.append(param) # type: ignore
def argument(*param_decls, **attrs):
def argument(*param_decls: str, **attrs: t.Any) -> t.Callable[[FC], FC]:
"""Attaches an argument to the command. All positional arguments are
passed as parameter declarations to :class:`Argument`; all keyword
arguments are forwarded unchanged (except ``cls``).
@ -146,14 +281,16 @@ def argument(*param_decls, **attrs):
:param cls: the argument class to instantiate. This defaults to
:class:`Argument`.
"""
def decorator(f):
ArgumentClass = attrs.pop('cls', Argument)
def decorator(f: FC) -> FC:
ArgumentClass = attrs.pop("cls", None) or Argument
_param_memo(f, ArgumentClass(param_decls, **attrs))
return f
return decorator
def option(*param_decls, **attrs):
def option(*param_decls: str, **attrs: t.Any) -> t.Callable[[FC], FC]:
"""Attaches an option to the command. All positional arguments are
passed as parameter declarations to :class:`Option`; all keyword
arguments are forwarded unchanged (except ``cls``).
@ -163,149 +300,198 @@ def option(*param_decls, **attrs):
:param cls: the option class to instantiate. This defaults to
:class:`Option`.
"""
def decorator(f):
def decorator(f: FC) -> FC:
# Issue 926, copy attrs, so pre-defined options can re-use the same cls=
option_attrs = attrs.copy()
if 'help' in option_attrs:
option_attrs['help'] = inspect.cleandoc(option_attrs['help'])
OptionClass = option_attrs.pop('cls', Option)
OptionClass = option_attrs.pop("cls", None) or Option
_param_memo(f, OptionClass(param_decls, **option_attrs))
return f
return decorator
def confirmation_option(*param_decls, **attrs):
"""Shortcut for confirmation prompts that can be ignored by passing
``--yes`` as parameter.
def confirmation_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
"""Add a ``--yes`` option which shows a prompt before continuing if
not passed. If the prompt is declined, the program will exit.
This is equivalent to decorating a function with :func:`option` with
the following parameters::
def callback(ctx, param, value):
if not value:
ctx.abort()
@click.command()
@click.option('--yes', is_flag=True, callback=callback,
expose_value=False, prompt='Do you want to continue?')
def dropdb():
pass
:param param_decls: One or more option names. Defaults to the single
value ``"--yes"``.
:param kwargs: Extra arguments are passed to :func:`option`.
"""
def decorator(f):
def callback(ctx, param, value):
if not value:
ctx.abort()
attrs.setdefault('is_flag', True)
attrs.setdefault('callback', callback)
attrs.setdefault('expose_value', False)
attrs.setdefault('prompt', 'Do you want to continue?')
attrs.setdefault('help', 'Confirm the action without prompting.')
return option(*(param_decls or ('--yes',)), **attrs)(f)
return decorator
def callback(ctx: Context, param: Parameter, value: bool) -> None:
if not value:
ctx.abort()
if not param_decls:
param_decls = ("--yes",)
kwargs.setdefault("is_flag", True)
kwargs.setdefault("callback", callback)
kwargs.setdefault("expose_value", False)
kwargs.setdefault("prompt", "Do you want to continue?")
kwargs.setdefault("help", "Confirm the action without prompting.")
return option(*param_decls, **kwargs)
def password_option(*param_decls, **attrs):
"""Shortcut for password prompts.
def password_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
"""Add a ``--password`` option which prompts for a password, hiding
input and asking to enter the value again for confirmation.
This is equivalent to decorating a function with :func:`option` with
the following parameters::
@click.command()
@click.option('--password', prompt=True, confirmation_prompt=True,
hide_input=True)
def changeadmin(password):
pass
:param param_decls: One or more option names. Defaults to the single
value ``"--password"``.
:param kwargs: Extra arguments are passed to :func:`option`.
"""
def decorator(f):
attrs.setdefault('prompt', True)
attrs.setdefault('confirmation_prompt', True)
attrs.setdefault('hide_input', True)
return option(*(param_decls or ('--password',)), **attrs)(f)
return decorator
if not param_decls:
param_decls = ("--password",)
kwargs.setdefault("prompt", True)
kwargs.setdefault("confirmation_prompt", True)
kwargs.setdefault("hide_input", True)
return option(*param_decls, **kwargs)
def version_option(version=None, *param_decls, **attrs):
"""Adds a ``--version`` option which immediately ends the program
printing out the version number. This is implemented as an eager
option that prints the version and exits the program in the callback.
def version_option(
version: t.Optional[str] = None,
*param_decls: str,
package_name: t.Optional[str] = None,
prog_name: t.Optional[str] = None,
message: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Callable[[FC], FC]:
"""Add a ``--version`` option which immediately prints the version
number and exits the program.
:param version: the version number to show. If not provided Click
attempts an auto discovery via setuptools.
:param prog_name: the name of the program (defaults to autodetection)
:param message: custom message to show instead of the default
(``'%(prog)s, version %(version)s'``)
:param others: everything else is forwarded to :func:`option`.
If ``version`` is not provided, Click will try to detect it using
:func:`importlib.metadata.version` to get the version for the
``package_name``. On Python < 3.8, the ``importlib_metadata``
backport must be installed.
If ``package_name`` is not provided, Click will try to detect it by
inspecting the stack frames. This will be used to detect the
version, so it must match the name of the installed package.
:param version: The version number to show. If not provided, Click
will try to detect it.
:param param_decls: One or more option names. Defaults to the single
value ``"--version"``.
:param package_name: The package name to detect the version from. If
not provided, Click will try to detect it.
:param prog_name: The name of the CLI to show in the message. If not
provided, it will be detected from the command.
:param message: The message to show. The values ``%(prog)s``,
``%(package)s``, and ``%(version)s`` are available. Defaults to
``"%(prog)s, version %(version)s"``.
:param kwargs: Extra arguments are passed to :func:`option`.
:raise RuntimeError: ``version`` could not be detected.
.. versionchanged:: 8.0
Add the ``package_name`` parameter, and the ``%(package)s``
value for messages.
.. versionchanged:: 8.0
Use :mod:`importlib.metadata` instead of ``pkg_resources``. The
version is detected based on the package name, not the entry
point name. The Python package name must match the installed
package name, or be passed with ``package_name=``.
"""
if version is None:
if hasattr(sys, '_getframe'):
module = sys._getframe(1).f_globals.get('__name__')
else:
module = ''
if message is None:
message = _("%(prog)s, version %(version)s")
def decorator(f):
prog_name = attrs.pop('prog_name', None)
message = attrs.pop('message', '%(prog)s, version %(version)s')
if version is None and package_name is None:
frame = inspect.currentframe()
f_back = frame.f_back if frame is not None else None
f_globals = f_back.f_globals if f_back is not None else None
# break reference cycle
# https://docs.python.org/3/library/inspect.html#the-interpreter-stack
del frame
def callback(ctx, param, value):
if not value or ctx.resilient_parsing:
return
prog = prog_name
if prog is None:
prog = ctx.find_root().info_name
ver = version
if ver is None:
try:
import pkg_resources
except ImportError:
pass
else:
for dist in pkg_resources.working_set:
scripts = dist.get_entry_map().get('console_scripts') or {}
for script_name, entry_point in iteritems(scripts):
if entry_point.module_name == module:
ver = dist.version
break
if ver is None:
raise RuntimeError('Could not determine version')
echo(message % {
'prog': prog,
'version': ver,
}, color=ctx.color)
ctx.exit()
if f_globals is not None:
package_name = f_globals.get("__name__")
attrs.setdefault('is_flag', True)
attrs.setdefault('expose_value', False)
attrs.setdefault('is_eager', True)
attrs.setdefault('help', 'Show the version and exit.')
attrs['callback'] = callback
return option(*(param_decls or ('--version',)), **attrs)(f)
return decorator
if package_name == "__main__":
package_name = f_globals.get("__package__")
if package_name:
package_name = package_name.partition(".")[0]
def callback(ctx: Context, param: Parameter, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
nonlocal prog_name
nonlocal version
if prog_name is None:
prog_name = ctx.find_root().info_name
if version is None and package_name is not None:
metadata: t.Optional[types.ModuleType]
try:
from importlib import metadata # type: ignore
except ImportError:
# Python < 3.8
import importlib_metadata as metadata # type: ignore
try:
version = metadata.version(package_name) # type: ignore
except metadata.PackageNotFoundError: # type: ignore
raise RuntimeError(
f"{package_name!r} is not installed. Try passing"
" 'package_name' instead."
) from None
if version is None:
raise RuntimeError(
f"Could not determine the version for {package_name!r} automatically."
)
echo(
t.cast(str, message)
% {"prog": prog_name, "package": package_name, "version": version},
color=ctx.color,
)
ctx.exit()
if not param_decls:
param_decls = ("--version",)
kwargs.setdefault("is_flag", True)
kwargs.setdefault("expose_value", False)
kwargs.setdefault("is_eager", True)
kwargs.setdefault("help", _("Show the version and exit."))
kwargs["callback"] = callback
return option(*param_decls, **kwargs)
def help_option(*param_decls, **attrs):
"""Adds a ``--help`` option which immediately ends the program
printing out the help page. This is usually unnecessary to add as
this is added by default to all commands unless suppressed.
def help_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
"""Add a ``--help`` option which immediately prints the help page
and exits the program.
Like :func:`version_option`, this is implemented as eager option that
prints in the callback and exits.
This is usually unnecessary, as the ``--help`` option is added to
each command automatically unless ``add_help_option=False`` is
passed.
All arguments are forwarded to :func:`option`.
:param param_decls: One or more option names. Defaults to the single
value ``"--help"``.
:param kwargs: Extra arguments are passed to :func:`option`.
"""
def decorator(f):
def callback(ctx, param, value):
if value and not ctx.resilient_parsing:
echo(ctx.get_help(), color=ctx.color)
ctx.exit()
attrs.setdefault('is_flag', True)
attrs.setdefault('expose_value', False)
attrs.setdefault('help', 'Show this message and exit.')
attrs.setdefault('is_eager', True)
attrs['callback'] = callback
return option(*(param_decls or ('--help',)), **attrs)(f)
return decorator
def callback(ctx: Context, param: Parameter, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
# Circular dependencies between core and decorators
from .core import Command, Group, Argument, Option
echo(ctx.get_help(), color=ctx.color)
ctx.exit()
if not param_decls:
param_decls = ("--help",)
kwargs.setdefault("is_flag", True)
kwargs.setdefault("expose_value", False)
kwargs.setdefault("is_eager", True)
kwargs.setdefault("help", _("Show this message and exit."))
kwargs["callback"] = callback
return option(*param_decls, **kwargs)

View file

@ -1,43 +1,46 @@
from ._compat import PY2, filename_to_ui, get_text_stderr
import os
import typing as t
from gettext import gettext as _
from gettext import ngettext
from ._compat import get_text_stderr
from .utils import echo
if t.TYPE_CHECKING:
from .core import Context
from .core import Parameter
def _join_param_hints(
param_hint: t.Optional[t.Union[t.Sequence[str], str]]
) -> t.Optional[str]:
if param_hint is not None and not isinstance(param_hint, str):
return " / ".join(repr(x) for x in param_hint)
def _join_param_hints(param_hint):
if isinstance(param_hint, (tuple, list)):
return ' / '.join('"%s"' % x for x in param_hint)
return param_hint
class ClickException(Exception):
"""An exception that Click can handle and show to the user."""
#: The exit code for this exception
#: The exit code for this exception.
exit_code = 1
def __init__(self, message):
ctor_msg = message
if PY2:
if ctor_msg is not None:
ctor_msg = ctor_msg.encode('utf-8')
Exception.__init__(self, ctor_msg)
def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message
def format_message(self):
def format_message(self) -> str:
return self.message
def __str__(self):
def __str__(self) -> str:
return self.message
if PY2:
__unicode__ = __str__
def __str__(self):
return self.message.encode('utf-8')
def show(self, file=None):
def show(self, file: t.Optional[t.IO] = None) -> None:
if file is None:
file = get_text_stderr()
echo('Error: %s' % self.format_message(), file=file)
echo(_("Error: {message}").format(message=self.format_message()), file=file)
class UsageError(ClickException):
@ -48,26 +51,35 @@ class UsageError(ClickException):
:param ctx: optionally the context that caused this error. Click will
fill in the context automatically in some situations.
"""
exit_code = 2
def __init__(self, message, ctx=None):
ClickException.__init__(self, message)
def __init__(self, message: str, ctx: t.Optional["Context"] = None) -> None:
super().__init__(message)
self.ctx = ctx
self.cmd = self.ctx and self.ctx.command or None
self.cmd = self.ctx.command if self.ctx else None
def show(self, file=None):
def show(self, file: t.Optional[t.IO] = None) -> None:
if file is None:
file = get_text_stderr()
color = None
hint = ''
if (self.cmd is not None and
self.cmd.get_help_option(self.ctx) is not None):
hint = ('Try "%s %s" for help.\n'
% (self.ctx.command_path, self.ctx.help_option_names[0]))
hint = ""
if (
self.ctx is not None
and self.ctx.command.get_help_option(self.ctx) is not None
):
hint = _("Try '{command} {option}' for help.").format(
command=self.ctx.command_path, option=self.ctx.help_option_names[0]
)
hint = f"{hint}\n"
if self.ctx is not None:
color = self.ctx.color
echo(self.ctx.get_usage() + '\n%s' % hint, file=file, color=color)
echo('Error: %s' % self.format_message(), file=file, color=color)
echo(f"{self.ctx.get_usage()}\n{hint}", file=file, color=color)
echo(
_("Error: {message}").format(message=self.format_message()),
file=file,
color=color,
)
class BadParameter(UsageError):
@ -88,22 +100,28 @@ class BadParameter(UsageError):
each item is quoted and separated.
"""
def __init__(self, message, ctx=None, param=None,
param_hint=None):
UsageError.__init__(self, message, ctx)
def __init__(
self,
message: str,
ctx: t.Optional["Context"] = None,
param: t.Optional["Parameter"] = None,
param_hint: t.Optional[str] = None,
) -> None:
super().__init__(message, ctx)
self.param = param
self.param_hint = param_hint
def format_message(self):
def format_message(self) -> str:
if self.param_hint is not None:
param_hint = self.param_hint
elif self.param is not None:
param_hint = self.param.get_error_hint(self.ctx)
param_hint = self.param.get_error_hint(self.ctx) # type: ignore
else:
return 'Invalid value: %s' % self.message
param_hint = _join_param_hints(param_hint)
return _("Invalid value: {message}").format(message=self.message)
return 'Invalid value for %s: %s' % (param_hint, self.message)
return _("Invalid value for {param_hint}: {message}").format(
param_hint=_join_param_hints(param_hint), message=self.message
)
class MissingParameter(BadParameter):
@ -118,19 +136,27 @@ class MissingParameter(BadParameter):
``'option'`` or ``'argument'``.
"""
def __init__(self, message=None, ctx=None, param=None,
param_hint=None, param_type=None):
BadParameter.__init__(self, message, ctx, param, param_hint)
def __init__(
self,
message: t.Optional[str] = None,
ctx: t.Optional["Context"] = None,
param: t.Optional["Parameter"] = None,
param_hint: t.Optional[str] = None,
param_type: t.Optional[str] = None,
) -> None:
super().__init__(message or "", ctx, param, param_hint)
self.param_type = param_type
def format_message(self):
def format_message(self) -> str:
if self.param_hint is not None:
param_hint = self.param_hint
param_hint: t.Optional[str] = self.param_hint
elif self.param is not None:
param_hint = self.param.get_error_hint(self.ctx)
param_hint = self.param.get_error_hint(self.ctx) # type: ignore
else:
param_hint = None
param_hint = _join_param_hints(param_hint)
param_hint = f" {param_hint}" if param_hint else ""
param_type = self.param_type
if param_type is None and self.param is not None:
@ -141,16 +167,30 @@ class MissingParameter(BadParameter):
msg_extra = self.param.type.get_missing_message(self.param)
if msg_extra:
if msg:
msg += '. ' + msg_extra
msg += f". {msg_extra}"
else:
msg = msg_extra
return 'Missing %s%s%s%s' % (
param_type,
param_hint and ' %s' % param_hint or '',
msg and '. ' or '.',
msg or '',
)
msg = f" {msg}" if msg else ""
# Translate param_type for known types.
if param_type == "argument":
missing = _("Missing argument")
elif param_type == "option":
missing = _("Missing option")
elif param_type == "parameter":
missing = _("Missing parameter")
else:
missing = _("Missing {param_type}").format(param_type=param_type)
return f"{missing}{param_hint}.{msg}"
def __str__(self) -> str:
if not self.message:
param_name = self.param.name if self.param else None
return _("Missing parameter: {param_name}").format(param_name=param_name)
else:
return self.message
class NoSuchOption(UsageError):
@ -160,23 +200,31 @@ class NoSuchOption(UsageError):
.. versionadded:: 4.0
"""
def __init__(self, option_name, message=None, possibilities=None,
ctx=None):
def __init__(
self,
option_name: str,
message: t.Optional[str] = None,
possibilities: t.Optional[t.Sequence[str]] = None,
ctx: t.Optional["Context"] = None,
) -> None:
if message is None:
message = 'no such option: %s' % option_name
UsageError.__init__(self, message, ctx)
message = _("No such option: {name}").format(name=option_name)
super().__init__(message, ctx)
self.option_name = option_name
self.possibilities = possibilities
def format_message(self):
bits = [self.message]
if self.possibilities:
if len(self.possibilities) == 1:
bits.append('Did you mean %s?' % self.possibilities[0])
else:
possibilities = sorted(self.possibilities)
bits.append('(Possible options: %s)' % ', '.join(possibilities))
return ' '.join(bits)
def format_message(self) -> str:
if not self.possibilities:
return self.message
possibility_str = ", ".join(sorted(self.possibilities))
suggest = ngettext(
"Did you mean {possibility}?",
"(Possible options: {possibilities})",
len(self.possibilities),
).format(possibility=possibility_str, possibilities=possibility_str)
return f"{self.message} {suggest}"
class BadOptionUsage(UsageError):
@ -189,8 +237,10 @@ class BadOptionUsage(UsageError):
:param option_name: the name of the option being used incorrectly.
"""
def __init__(self, option_name, message, ctx=None):
UsageError.__init__(self, message, ctx)
def __init__(
self, option_name: str, message: str, ctx: t.Optional["Context"] = None
) -> None:
super().__init__(message, ctx)
self.option_name = option_name
@ -202,23 +252,22 @@ class BadArgumentUsage(UsageError):
.. versionadded:: 6.0
"""
def __init__(self, message, ctx=None):
UsageError.__init__(self, message, ctx)
class FileError(ClickException):
"""Raised if a file cannot be opened."""
def __init__(self, filename, hint=None):
ui_filename = filename_to_ui(filename)
def __init__(self, filename: str, hint: t.Optional[str] = None) -> None:
if hint is None:
hint = 'unknown error'
ClickException.__init__(self, hint)
self.ui_filename = ui_filename
hint = _("unknown error")
super().__init__(hint)
self.ui_filename = os.fsdecode(filename)
self.filename = filename
def format_message(self):
return 'Could not open file %s: %s' % (self.ui_filename, self.message)
def format_message(self) -> str:
return _("Could not open file {filename!r}: {message}").format(
filename=self.ui_filename, message=self.message
)
class Abort(RuntimeError):
@ -231,5 +280,8 @@ class Exit(RuntimeError):
:param code: the status code to exit with.
"""
def __init__(self, code=0):
__slots__ = ("exit_code",)
def __init__(self, code: int = 0) -> None:
self.exit_code = code

View file

@ -1,29 +1,38 @@
import typing as t
from contextlib import contextmanager
from .termui import get_terminal_size
from .parser import split_opt
from ._compat import term_len
from gettext import gettext as _
from ._compat import term_len
from .parser import split_opt
# Can force a width. This is used by the test system
FORCED_WIDTH = None
FORCED_WIDTH: t.Optional[int] = None
def measure_table(rows):
widths = {}
def measure_table(rows: t.Iterable[t.Tuple[str, str]]) -> t.Tuple[int, ...]:
widths: t.Dict[int, int] = {}
for row in rows:
for idx, col in enumerate(row):
widths[idx] = max(widths.get(idx, 0), term_len(col))
return tuple(y for x, y in sorted(widths.items()))
def iter_rows(rows, col_count):
def iter_rows(
rows: t.Iterable[t.Tuple[str, str]], col_count: int
) -> t.Iterator[t.Tuple[str, ...]]:
for row in rows:
row = tuple(row)
yield row + ('',) * (col_count - len(row))
yield row + ("",) * (col_count - len(row))
def wrap_text(text, width=78, initial_indent='', subsequent_indent='',
preserve_paragraphs=False):
def wrap_text(
text: str,
width: int = 78,
initial_indent: str = "",
subsequent_indent: str = "",
preserve_paragraphs: bool = False,
) -> str:
"""A helper function that intelligently wraps text. By default, it
assumes that it operates on a single paragraph of text but if the
`preserve_paragraphs` parameter is provided it will intelligently
@ -43,24 +52,28 @@ def wrap_text(text, width=78, initial_indent='', subsequent_indent='',
intelligently handle paragraphs.
"""
from ._textwrap import TextWrapper
text = text.expandtabs()
wrapper = TextWrapper(width, initial_indent=initial_indent,
subsequent_indent=subsequent_indent,
replace_whitespace=False)
wrapper = TextWrapper(
width,
initial_indent=initial_indent,
subsequent_indent=subsequent_indent,
replace_whitespace=False,
)
if not preserve_paragraphs:
return wrapper.fill(text)
p = []
buf = []
p: t.List[t.Tuple[int, bool, str]] = []
buf: t.List[str] = []
indent = None
def _flush_par():
def _flush_par() -> None:
if not buf:
return
if buf[0].strip() == '\b':
p.append((indent or 0, True, '\n'.join(buf[1:])))
if buf[0].strip() == "\b":
p.append((indent or 0, True, "\n".join(buf[1:])))
else:
p.append((indent or 0, False, ' '.join(buf)))
p.append((indent or 0, False, " ".join(buf)))
del buf[:]
for line in text.splitlines():
@ -77,16 +90,16 @@ def wrap_text(text, width=78, initial_indent='', subsequent_indent='',
rv = []
for indent, raw, text in p:
with wrapper.extra_indent(' ' * indent):
with wrapper.extra_indent(" " * indent):
if raw:
rv.append(wrapper.indent_only(text))
else:
rv.append(wrapper.fill(text))
return '\n\n'.join(rv)
return "\n\n".join(rv)
class HelpFormatter(object):
class HelpFormatter:
"""This class helps with formatting text-based help pages. It's
usually just needed for very special internal cases, but it's also
exposed so that developers can write their own fancy outputs.
@ -98,79 +111,108 @@ class HelpFormatter(object):
width clamped to a maximum of 78.
"""
def __init__(self, indent_increment=2, width=None, max_width=None):
def __init__(
self,
indent_increment: int = 2,
width: t.Optional[int] = None,
max_width: t.Optional[int] = None,
) -> None:
import shutil
self.indent_increment = indent_increment
if max_width is None:
max_width = 80
if width is None:
width = FORCED_WIDTH
if width is None:
width = max(min(get_terminal_size()[0], max_width) - 2, 50)
width = max(min(shutil.get_terminal_size().columns, max_width) - 2, 50)
self.width = width
self.current_indent = 0
self.buffer = []
self.buffer: t.List[str] = []
def write(self, string):
def write(self, string: str) -> None:
"""Writes a unicode string into the internal buffer."""
self.buffer.append(string)
def indent(self):
def indent(self) -> None:
"""Increases the indentation."""
self.current_indent += self.indent_increment
def dedent(self):
def dedent(self) -> None:
"""Decreases the indentation."""
self.current_indent -= self.indent_increment
def write_usage(self, prog, args='', prefix='Usage: '):
def write_usage(
self, prog: str, args: str = "", prefix: t.Optional[str] = None
) -> None:
"""Writes a usage line into the buffer.
:param prog: the program name.
:param args: whitespace separated list of arguments.
:param prefix: the prefix for the first line.
:param prefix: The prefix for the first line. Defaults to
``"Usage: "``.
"""
usage_prefix = '%*s%s ' % (self.current_indent, prefix, prog)
if prefix is None:
prefix = f"{_('Usage:')} "
usage_prefix = f"{prefix:>{self.current_indent}}{prog} "
text_width = self.width - self.current_indent
if text_width >= (term_len(usage_prefix) + 20):
# The arguments will fit to the right of the prefix.
indent = ' ' * term_len(usage_prefix)
self.write(wrap_text(args, text_width,
initial_indent=usage_prefix,
subsequent_indent=indent))
indent = " " * term_len(usage_prefix)
self.write(
wrap_text(
args,
text_width,
initial_indent=usage_prefix,
subsequent_indent=indent,
)
)
else:
# The prefix is too long, put the arguments on the next line.
self.write(usage_prefix)
self.write('\n')
indent = ' ' * (max(self.current_indent, term_len(prefix)) + 4)
self.write(wrap_text(args, text_width,
initial_indent=indent,
subsequent_indent=indent))
self.write("\n")
indent = " " * (max(self.current_indent, term_len(prefix)) + 4)
self.write(
wrap_text(
args, text_width, initial_indent=indent, subsequent_indent=indent
)
)
self.write('\n')
self.write("\n")
def write_heading(self, heading):
def write_heading(self, heading: str) -> None:
"""Writes a heading into the buffer."""
self.write('%*s%s:\n' % (self.current_indent, '', heading))
self.write(f"{'':>{self.current_indent}}{heading}:\n")
def write_paragraph(self):
def write_paragraph(self) -> None:
"""Writes a paragraph into the buffer."""
if self.buffer:
self.write('\n')
self.write("\n")
def write_text(self, text):
def write_text(self, text: str) -> None:
"""Writes re-indented text into the buffer. This rewraps and
preserves paragraphs.
"""
text_width = max(self.width - self.current_indent, 11)
indent = ' ' * self.current_indent
self.write(wrap_text(text, text_width,
initial_indent=indent,
subsequent_indent=indent,
preserve_paragraphs=True))
self.write('\n')
indent = " " * self.current_indent
self.write(
wrap_text(
text,
self.width,
initial_indent=indent,
subsequent_indent=indent,
preserve_paragraphs=True,
)
)
self.write("\n")
def write_dl(self, rows, col_max=30, col_spacing=2):
def write_dl(
self,
rows: t.Sequence[t.Tuple[str, str]],
col_max: int = 30,
col_spacing: int = 2,
) -> None:
"""Writes a definition list into the buffer. This is how options
and commands are usually formatted.
@ -182,33 +224,35 @@ class HelpFormatter(object):
rows = list(rows)
widths = measure_table(rows)
if len(widths) != 2:
raise TypeError('Expected two columns for definition list')
raise TypeError("Expected two columns for definition list")
first_col = min(widths[0], col_max) + col_spacing
for first, second in iter_rows(rows, len(widths)):
self.write('%*s%s' % (self.current_indent, '', first))
self.write(f"{'':>{self.current_indent}}{first}")
if not second:
self.write('\n')
self.write("\n")
continue
if term_len(first) <= first_col - col_spacing:
self.write(' ' * (first_col - term_len(first)))
self.write(" " * (first_col - term_len(first)))
else:
self.write('\n')
self.write(' ' * (first_col + self.current_indent))
self.write("\n")
self.write(" " * (first_col + self.current_indent))
text_width = max(self.width - first_col - 2, 10)
lines = iter(wrap_text(second, text_width).splitlines())
wrapped_text = wrap_text(second, text_width, preserve_paragraphs=True)
lines = wrapped_text.splitlines()
if lines:
self.write(next(lines) + '\n')
for line in lines:
self.write('%*s%s\n' % (
first_col + self.current_indent, '', line))
self.write(f"{lines[0]}\n")
for line in lines[1:]:
self.write(f"{'':>{first_col + self.current_indent}}{line}\n")
else:
self.write('\n')
self.write("\n")
@contextmanager
def section(self, name):
def section(self, name: str) -> t.Iterator[None]:
"""Helpful context manager that writes a paragraph, a heading,
and the indents.
@ -223,7 +267,7 @@ class HelpFormatter(object):
self.dedent()
@contextmanager
def indentation(self):
def indentation(self) -> t.Iterator[None]:
"""A context manager that increases the indentation."""
self.indent()
try:
@ -231,12 +275,12 @@ class HelpFormatter(object):
finally:
self.dedent()
def getvalue(self):
def getvalue(self) -> str:
"""Returns the buffer contents."""
return ''.join(self.buffer)
return "".join(self.buffer)
def join_options(options):
def join_options(options: t.Sequence[str]) -> t.Tuple[str, bool]:
"""Given a list of option strings this joins them in the most appropriate
way and returns them in the form ``(formatted_string,
any_prefix_is_slash)`` where the second item in the tuple is a flag that
@ -244,13 +288,14 @@ def join_options(options):
"""
rv = []
any_prefix_is_slash = False
for opt in options:
prefix = split_opt(opt)[0]
if prefix == '/':
if prefix == "/":
any_prefix_is_slash = True
rv.append((len(prefix), opt))
rv.sort(key=lambda x: x[0])
rv = ', '.join(x[1] for x in rv)
return rv, any_prefix_is_slash
return ", ".join(x[1] for x in rv), any_prefix_is_slash

View file

@ -1,10 +1,24 @@
import typing as t
from threading import local
if t.TYPE_CHECKING:
import typing_extensions as te
from .core import Context
_local = local()
def get_current_context(silent=False):
@t.overload
def get_current_context(silent: "te.Literal[False]" = False) -> "Context":
...
@t.overload
def get_current_context(silent: bool = ...) -> t.Optional["Context"]:
...
def get_current_context(silent: bool = False) -> t.Optional["Context"]:
"""Returns the current click context. This can be used as a way to
access the current context object from anywhere. This is a more implicit
alternative to the :func:`pass_context` decorator. This function is
@ -15,34 +29,40 @@ def get_current_context(silent=False):
.. versionadded:: 5.0
:param silent: is set to `True` the return value is `None` if no context
:param silent: if set to `True` the return value is `None` if no context
is available. The default behavior is to raise a
:exc:`RuntimeError`.
"""
try:
return getattr(_local, 'stack')[-1]
except (AttributeError, IndexError):
return t.cast("Context", _local.stack[-1])
except (AttributeError, IndexError) as e:
if not silent:
raise RuntimeError('There is no active click context.')
raise RuntimeError("There is no active click context.") from e
return None
def push_context(ctx):
def push_context(ctx: "Context") -> None:
"""Pushes a new context to the current stack."""
_local.__dict__.setdefault('stack', []).append(ctx)
_local.__dict__.setdefault("stack", []).append(ctx)
def pop_context():
def pop_context() -> None:
"""Removes the top level from the stack."""
_local.stack.pop()
def resolve_color_default(color=None):
""""Internal helper to get the default value of the color flag. If a
def resolve_color_default(color: t.Optional[bool] = None) -> t.Optional[bool]:
"""Internal helper to get the default value of the color flag. If a
value is passed it's returned unchanged, otherwise it's looked up from
the current context.
"""
if color is not None:
return color
ctx = get_current_context(silent=True)
if ctx is not None:
return ctx.color
return None

View file

@ -1,8 +1,4 @@
# -*- coding: utf-8 -*-
"""
click.parser
~~~~~~~~~~~~
This module started out as largely a copy paste from the stdlib's
optparse module with the features removed that we do not need from
optparse because we implement them in Click on a higher level (for
@ -14,15 +10,45 @@ The reason this is a different module and not optparse from the stdlib
is that there are differences in 2.x and 3.x about the error messages
generated and optparse in the stdlib uses gettext for no good reason
and might cause us issues.
Click uses parts of optparse written by Gregory P. Ward and maintained
by the Python Software Foundation. This is limited to code in parser.py.
Copyright 2001-2006 Gregory P. Ward. All rights reserved.
Copyright 2002-2006 Python Software Foundation. All rights reserved.
"""
import re
# This code uses parts of optparse written by Gregory P. Ward and
# maintained by the Python Software Foundation.
# Copyright 2001-2006 Gregory P. Ward
# Copyright 2002-2006 Python Software Foundation
import typing as t
from collections import deque
from .exceptions import UsageError, NoSuchOption, BadOptionUsage, \
BadArgumentUsage
from gettext import gettext as _
from gettext import ngettext
from .exceptions import BadArgumentUsage
from .exceptions import BadOptionUsage
from .exceptions import NoSuchOption
from .exceptions import UsageError
if t.TYPE_CHECKING:
import typing_extensions as te
from .core import Argument as CoreArgument
from .core import Context
from .core import Option as CoreOption
from .core import Parameter as CoreParameter
V = t.TypeVar("V")
# Sentinel value that indicates an option was passed as a flag without a
# value but is not a flag option. Option.consume_value uses this to
# prompt or use the flag_value.
_flag_needs_value = object()
def _unpack_args(args, nargs_spec):
def _unpack_args(
args: t.Sequence[str], nargs_spec: t.Sequence[int]
) -> t.Tuple[t.Sequence[t.Union[str, t.Sequence[t.Optional[str]], None]], t.List[str]]:
"""Given an iterable of arguments and an iterable of nargs specifications,
it returns a tuple with all the unpacked arguments at the first index
and all remaining arguments as the second.
@ -34,10 +60,10 @@ def _unpack_args(args, nargs_spec):
"""
args = deque(args)
nargs_spec = deque(nargs_spec)
rv = []
spos = None
rv: t.List[t.Union[str, t.Tuple[t.Optional[str], ...], None]] = []
spos: t.Optional[int] = None
def _fetch(c):
def _fetch(c: "te.Deque[V]") -> t.Optional[V]:
try:
if spos is None:
return c.popleft()
@ -48,18 +74,25 @@ def _unpack_args(args, nargs_spec):
while nargs_spec:
nargs = _fetch(nargs_spec)
if nargs is None:
continue
if nargs == 1:
rv.append(_fetch(args))
elif nargs > 1:
x = [_fetch(args) for _ in range(nargs)]
# If we're reversed, we're pulling in the arguments in reverse,
# so we need to turn them around.
if spos is not None:
x.reverse()
rv.append(tuple(x))
elif nargs < 0:
if spos is not None:
raise TypeError('Cannot have two nargs < 0')
raise TypeError("Cannot have two nargs < 0")
spos = len(rv)
rv.append(None)
@ -68,54 +101,71 @@ def _unpack_args(args, nargs_spec):
if spos is not None:
rv[spos] = tuple(args)
args = []
rv[spos + 1:] = reversed(rv[spos + 1:])
rv[spos + 1 :] = reversed(rv[spos + 1 :])
return tuple(rv), list(args)
def _error_opt_args(nargs, opt):
if nargs == 1:
raise BadOptionUsage(opt, '%s option requires an argument' % opt)
raise BadOptionUsage(opt, '%s option requires %d arguments' % (opt, nargs))
def split_opt(opt):
def split_opt(opt: str) -> t.Tuple[str, str]:
first = opt[:1]
if first.isalnum():
return '', opt
return "", opt
if opt[1:2] == first:
return opt[:2], opt[2:]
return first, opt[1:]
def normalize_opt(opt, ctx):
def normalize_opt(opt: str, ctx: t.Optional["Context"]) -> str:
if ctx is None or ctx.token_normalize_func is None:
return opt
prefix, opt = split_opt(opt)
return prefix + ctx.token_normalize_func(opt)
return f"{prefix}{ctx.token_normalize_func(opt)}"
def split_arg_string(string):
"""Given an argument string this attempts to split it into small parts."""
rv = []
for match in re.finditer(r"('([^'\\]*(?:\\.[^'\\]*)*)'"
r'|"([^"\\]*(?:\\.[^"\\]*)*)"'
r'|\S+)\s*', string, re.S):
arg = match.group().strip()
if arg[:1] == arg[-1:] and arg[:1] in '"\'':
arg = arg[1:-1].encode('ascii', 'backslashreplace') \
.decode('unicode-escape')
try:
arg = type(string)(arg)
except UnicodeError:
pass
rv.append(arg)
return rv
def split_arg_string(string: str) -> t.List[str]:
"""Split an argument string as with :func:`shlex.split`, but don't
fail if the string is incomplete. Ignores a missing closing quote or
incomplete escape sequence and uses the partial token as-is.
.. code-block:: python
split_arg_string("example 'my file")
["example", "my file"]
split_arg_string("example my\\")
["example", "my"]
:param string: String to split.
"""
import shlex
lex = shlex.shlex(string, posix=True)
lex.whitespace_split = True
lex.commenters = ""
out = []
try:
for token in lex:
out.append(token)
except ValueError:
# Raised when end-of-string is reached in an invalid state. Use
# the partial token as-is. The quote or escape character is in
# lex.state, not lex.token.
out.append(lex.token)
return out
class Option(object):
def __init__(self, opts, dest, action=None, nargs=1, const=None, obj=None):
class Option:
def __init__(
self,
obj: "CoreOption",
opts: t.Sequence[str],
dest: t.Optional[str],
action: t.Optional[str] = None,
nargs: int = 1,
const: t.Optional[t.Any] = None,
):
self._short_opts = []
self._long_opts = []
self.prefixes = set()
@ -123,8 +173,7 @@ class Option(object):
for opt in opts:
prefix, value = split_opt(opt)
if not prefix:
raise ValueError('Invalid start character for option (%s)'
% opt)
raise ValueError(f"Invalid start character for option ({opt})")
self.prefixes.add(prefix[0])
if len(prefix) == 1 and len(value) == 1:
self._short_opts.append(opt)
@ -133,7 +182,7 @@ class Option(object):
self.prefixes.add(prefix)
if action is None:
action = 'store'
action = "store"
self.dest = dest
self.action = action
@ -142,54 +191,66 @@ class Option(object):
self.obj = obj
@property
def takes_value(self):
return self.action in ('store', 'append')
def takes_value(self) -> bool:
return self.action in ("store", "append")
def process(self, value, state):
if self.action == 'store':
state.opts[self.dest] = value
elif self.action == 'store_const':
state.opts[self.dest] = self.const
elif self.action == 'append':
state.opts.setdefault(self.dest, []).append(value)
elif self.action == 'append_const':
state.opts.setdefault(self.dest, []).append(self.const)
elif self.action == 'count':
state.opts[self.dest] = state.opts.get(self.dest, 0) + 1
def process(self, value: str, state: "ParsingState") -> None:
if self.action == "store":
state.opts[self.dest] = value # type: ignore
elif self.action == "store_const":
state.opts[self.dest] = self.const # type: ignore
elif self.action == "append":
state.opts.setdefault(self.dest, []).append(value) # type: ignore
elif self.action == "append_const":
state.opts.setdefault(self.dest, []).append(self.const) # type: ignore
elif self.action == "count":
state.opts[self.dest] = state.opts.get(self.dest, 0) + 1 # type: ignore
else:
raise ValueError('unknown action %r' % self.action)
raise ValueError(f"unknown action '{self.action}'")
state.order.append(self.obj)
class Argument(object):
def __init__(self, dest, nargs=1, obj=None):
class Argument:
def __init__(self, obj: "CoreArgument", dest: t.Optional[str], nargs: int = 1):
self.dest = dest
self.nargs = nargs
self.obj = obj
def process(self, value, state):
def process(
self,
value: t.Union[t.Optional[str], t.Sequence[t.Optional[str]]],
state: "ParsingState",
) -> None:
if self.nargs > 1:
assert value is not None
holes = sum(1 for x in value if x is None)
if holes == len(value):
value = None
elif holes != 0:
raise BadArgumentUsage('argument %s takes %d values'
% (self.dest, self.nargs))
state.opts[self.dest] = value
raise BadArgumentUsage(
_("Argument {name!r} takes {nargs} values.").format(
name=self.dest, nargs=self.nargs
)
)
if self.nargs == -1 and self.obj.envvar is not None and value == ():
# Replace empty tuple with None so that a value from the
# environment may be tried.
value = None
state.opts[self.dest] = value # type: ignore
state.order.append(self.obj)
class ParsingState(object):
def __init__(self, rargs):
self.opts = {}
self.largs = []
class ParsingState:
def __init__(self, rargs: t.List[str]) -> None:
self.opts: t.Dict[str, t.Any] = {}
self.largs: t.List[str] = []
self.rargs = rargs
self.order = []
self.order: t.List["CoreParameter"] = []
class OptionParser(object):
class OptionParser:
"""The option parser is an internal class that is ultimately used to
parse options and arguments. It's modelled after optparse and brings
a similar but vastly simplified API. It should generally not be used
@ -203,7 +264,7 @@ class OptionParser(object):
should go with.
"""
def __init__(self, ctx=None):
def __init__(self, ctx: t.Optional["Context"] = None) -> None:
#: The :class:`~click.Context` for this parser. This might be
#: `None` for some advanced use cases.
self.ctx = ctx
@ -217,46 +278,54 @@ class OptionParser(object):
#: second mode where it will ignore it and continue processing
#: after shifting all the unknown options into the resulting args.
self.ignore_unknown_options = False
if ctx is not None:
self.allow_interspersed_args = ctx.allow_interspersed_args
self.ignore_unknown_options = ctx.ignore_unknown_options
self._short_opt = {}
self._long_opt = {}
self._opt_prefixes = set(['-', '--'])
self._args = []
def add_option(self, opts, dest, action=None, nargs=1, const=None,
obj=None):
self._short_opt: t.Dict[str, Option] = {}
self._long_opt: t.Dict[str, Option] = {}
self._opt_prefixes = {"-", "--"}
self._args: t.List[Argument] = []
def add_option(
self,
obj: "CoreOption",
opts: t.Sequence[str],
dest: t.Optional[str],
action: t.Optional[str] = None,
nargs: int = 1,
const: t.Optional[t.Any] = None,
) -> None:
"""Adds a new option named `dest` to the parser. The destination
is not inferred (unlike with optparse) and needs to be explicitly
provided. Action can be any of ``store``, ``store_const``,
``append``, ``appnd_const`` or ``count``.
``append``, ``append_const`` or ``count``.
The `obj` can be used to identify the option in the order list
that is returned from the parser.
"""
if obj is None:
obj = dest
opts = [normalize_opt(opt, self.ctx) for opt in opts]
option = Option(opts, dest, action=action, nargs=nargs,
const=const, obj=obj)
option = Option(obj, opts, dest, action=action, nargs=nargs, const=const)
self._opt_prefixes.update(option.prefixes)
for opt in option._short_opts:
self._short_opt[opt] = option
for opt in option._long_opts:
self._long_opt[opt] = option
def add_argument(self, dest, nargs=1, obj=None):
def add_argument(
self, obj: "CoreArgument", dest: t.Optional[str], nargs: int = 1
) -> None:
"""Adds a positional argument named `dest` to the parser.
The `obj` can be used to identify the option in the order list
that is returned from the parser.
"""
if obj is None:
obj = dest
self._args.append(Argument(dest=dest, nargs=nargs, obj=obj))
self._args.append(Argument(obj, dest=dest, nargs=nargs))
def parse_args(self, args):
def parse_args(
self, args: t.List[str]
) -> t.Tuple[t.Dict[str, t.Any], t.List[str], t.List["CoreParameter"]]:
"""Parses positional arguments and returns ``(values, args, order)``
for the parsed options and arguments as well as the leftover
arguments if there are any. The order is a list of objects as they
@ -272,9 +341,10 @@ class OptionParser(object):
raise
return state.opts, state.largs, state.order
def _process_args_for_args(self, state):
pargs, args = _unpack_args(state.largs + state.rargs,
[x.nargs for x in self._args])
def _process_args_for_args(self, state: ParsingState) -> None:
pargs, args = _unpack_args(
state.largs + state.rargs, [x.nargs for x in self._args]
)
for idx, arg in enumerate(self._args):
arg.process(pargs[idx], state)
@ -282,13 +352,13 @@ class OptionParser(object):
state.largs = args
state.rargs = []
def _process_args_for_options(self, state):
def _process_args_for_options(self, state: ParsingState) -> None:
while state.rargs:
arg = state.rargs.pop(0)
arglen = len(arg)
# Double dashes always handled explicitly regardless of what
# prefixes are valid.
if arg == '--':
if arg == "--":
return
elif arg[:1] in self._opt_prefixes and arglen > 1:
self._process_opts(arg, state)
@ -318,10 +388,13 @@ class OptionParser(object):
# *empty* -- still a subset of [arg0, ..., arg(i-1)], but
# not a very interesting subset!
def _match_long_opt(self, opt, explicit_value, state):
def _match_long_opt(
self, opt: str, explicit_value: t.Optional[str], state: ParsingState
) -> None:
if opt not in self._long_opt:
possibilities = [word for word in self._long_opt
if word.startswith(opt)]
from difflib import get_close_matches
possibilities = get_close_matches(opt, self._long_opt)
raise NoSuchOption(opt, possibilities=possibilities, ctx=self.ctx)
option = self._long_opt[opt]
@ -333,31 +406,26 @@ class OptionParser(object):
if explicit_value is not None:
state.rargs.insert(0, explicit_value)
nargs = option.nargs
if len(state.rargs) < nargs:
_error_opt_args(nargs, opt)
elif nargs == 1:
value = state.rargs.pop(0)
else:
value = tuple(state.rargs[:nargs])
del state.rargs[:nargs]
value = self._get_value_from_state(opt, option, state)
elif explicit_value is not None:
raise BadOptionUsage(opt, '%s option does not take a value' % opt)
raise BadOptionUsage(
opt, _("Option {name!r} does not take a value.").format(name=opt)
)
else:
value = None
option.process(value, state)
def _match_short_opt(self, arg, state):
def _match_short_opt(self, arg: str, state: ParsingState) -> None:
stop = False
i = 1
prefix = arg[0]
unknown_options = []
for ch in arg[1:]:
opt = normalize_opt(prefix + ch, self.ctx)
opt = normalize_opt(f"{prefix}{ch}", self.ctx)
option = self._short_opt.get(opt)
i += 1
@ -373,14 +441,7 @@ class OptionParser(object):
state.rargs.insert(0, arg[i:])
stop = True
nargs = option.nargs
if len(state.rargs) < nargs:
_error_opt_args(nargs, opt)
elif nargs == 1:
value = state.rargs.pop(0)
else:
value = tuple(state.rargs[:nargs])
del state.rargs[:nargs]
value = self._get_value_from_state(opt, option, state)
else:
value = None
@ -395,15 +456,53 @@ class OptionParser(object):
# to the state as new larg. This way there is basic combinatorics
# that can be achieved while still ignoring unknown arguments.
if self.ignore_unknown_options and unknown_options:
state.largs.append(prefix + ''.join(unknown_options))
state.largs.append(f"{prefix}{''.join(unknown_options)}")
def _process_opts(self, arg, state):
def _get_value_from_state(
self, option_name: str, option: Option, state: ParsingState
) -> t.Any:
nargs = option.nargs
if len(state.rargs) < nargs:
if option.obj._flag_needs_value:
# Option allows omitting the value.
value = _flag_needs_value
else:
raise BadOptionUsage(
option_name,
ngettext(
"Option {name!r} requires an argument.",
"Option {name!r} requires {nargs} arguments.",
nargs,
).format(name=option_name, nargs=nargs),
)
elif nargs == 1:
next_rarg = state.rargs[0]
if (
option.obj._flag_needs_value
and isinstance(next_rarg, str)
and next_rarg[:1] in self._opt_prefixes
and len(next_rarg) > 1
):
# The next arg looks like the start of an option, don't
# use it as the value if omitting the value is allowed.
value = _flag_needs_value
else:
value = state.rargs.pop(0)
else:
value = tuple(state.rargs[:nargs])
del state.rargs[:nargs]
return value
def _process_opts(self, arg: str, state: ParsingState) -> None:
explicit_value = None
# Long option handling happens in two parts. The first part is
# supporting explicitly attached values. In any case, we will try
# to long match the option first.
if '=' in arg:
long_opt, explicit_value = arg.split('=', 1)
if "=" in arg:
long_opt, explicit_value = arg.split("=", 1)
else:
long_opt = arg
norm_long_opt = normalize_opt(long_opt, self.ctx)
@ -421,7 +520,10 @@ class OptionParser(object):
# short option code and will instead raise the no option
# error.
if arg[:2] not in self._opt_prefixes:
return self._match_short_opt(arg, state)
self._match_short_opt(arg, state)
return
if not self.ignore_unknown_options:
raise
state.largs.append(arg)

View file

View file

@ -0,0 +1,580 @@
import os
import re
import typing as t
from gettext import gettext as _
from .core import Argument
from .core import BaseCommand
from .core import Context
from .core import MultiCommand
from .core import Option
from .core import Parameter
from .core import ParameterSource
from .parser import split_arg_string
from .utils import echo
def shell_complete(
cli: BaseCommand,
ctx_args: t.Dict[str, t.Any],
prog_name: str,
complete_var: str,
instruction: str,
) -> int:
"""Perform shell completion for the given CLI program.
:param cli: Command being called.
:param ctx_args: Extra arguments to pass to
``cli.make_context``.
:param prog_name: Name of the executable in the shell.
:param complete_var: Name of the environment variable that holds
the completion instruction.
:param instruction: Value of ``complete_var`` with the completion
instruction and shell, in the form ``instruction_shell``.
:return: Status code to exit with.
"""
shell, _, instruction = instruction.partition("_")
comp_cls = get_completion_class(shell)
if comp_cls is None:
return 1
comp = comp_cls(cli, ctx_args, prog_name, complete_var)
if instruction == "source":
echo(comp.source())
return 0
if instruction == "complete":
echo(comp.complete())
return 0
return 1
class CompletionItem:
"""Represents a completion value and metadata about the value. The
default metadata is ``type`` to indicate special shell handling,
and ``help`` if a shell supports showing a help string next to the
value.
Arbitrary parameters can be passed when creating the object, and
accessed using ``item.attr``. If an attribute wasn't passed,
accessing it returns ``None``.
:param value: The completion suggestion.
:param type: Tells the shell script to provide special completion
support for the type. Click uses ``"dir"`` and ``"file"``.
:param help: String shown next to the value if supported.
:param kwargs: Arbitrary metadata. The built-in implementations
don't use this, but custom type completions paired with custom
shell support could use it.
"""
__slots__ = ("value", "type", "help", "_info")
def __init__(
self,
value: t.Any,
type: str = "plain",
help: t.Optional[str] = None,
**kwargs: t.Any,
) -> None:
self.value = value
self.type = type
self.help = help
self._info = kwargs
def __getattr__(self, name: str) -> t.Any:
return self._info.get(name)
# Only Bash >= 4.4 has the nosort option.
_SOURCE_BASH = """\
%(complete_func)s() {
local IFS=$'\\n'
local response
response=$(env COMP_WORDS="${COMP_WORDS[*]}" COMP_CWORD=$COMP_CWORD \
%(complete_var)s=bash_complete $1)
for completion in $response; do
IFS=',' read type value <<< "$completion"
if [[ $type == 'dir' ]]; then
COMPREPLY=()
compopt -o dirnames
elif [[ $type == 'file' ]]; then
COMPREPLY=()
compopt -o default
elif [[ $type == 'plain' ]]; then
COMPREPLY+=($value)
fi
done
return 0
}
%(complete_func)s_setup() {
complete -o nosort -F %(complete_func)s %(prog_name)s
}
%(complete_func)s_setup;
"""
_SOURCE_ZSH = """\
#compdef %(prog_name)s
%(complete_func)s() {
local -a completions
local -a completions_with_descriptions
local -a response
(( ! $+commands[%(prog_name)s] )) && return 1
response=("${(@f)$(env COMP_WORDS="${words[*]}" COMP_CWORD=$((CURRENT-1)) \
%(complete_var)s=zsh_complete %(prog_name)s)}")
for type key descr in ${response}; do
if [[ "$type" == "plain" ]]; then
if [[ "$descr" == "_" ]]; then
completions+=("$key")
else
completions_with_descriptions+=("$key":"$descr")
fi
elif [[ "$type" == "dir" ]]; then
_path_files -/
elif [[ "$type" == "file" ]]; then
_path_files -f
fi
done
if [ -n "$completions_with_descriptions" ]; then
_describe -V unsorted completions_with_descriptions -U
fi
if [ -n "$completions" ]; then
compadd -U -V unsorted -a completions
fi
}
compdef %(complete_func)s %(prog_name)s;
"""
_SOURCE_FISH = """\
function %(complete_func)s;
set -l response;
for value in (env %(complete_var)s=fish_complete COMP_WORDS=(commandline -cp) \
COMP_CWORD=(commandline -t) %(prog_name)s);
set response $response $value;
end;
for completion in $response;
set -l metadata (string split "," $completion);
if test $metadata[1] = "dir";
__fish_complete_directories $metadata[2];
else if test $metadata[1] = "file";
__fish_complete_path $metadata[2];
else if test $metadata[1] = "plain";
echo $metadata[2];
end;
end;
end;
complete --no-files --command %(prog_name)s --arguments \
"(%(complete_func)s)";
"""
class ShellComplete:
"""Base class for providing shell completion support. A subclass for
a given shell will override attributes and methods to implement the
completion instructions (``source`` and ``complete``).
:param cli: Command being called.
:param prog_name: Name of the executable in the shell.
:param complete_var: Name of the environment variable that holds
the completion instruction.
.. versionadded:: 8.0
"""
name: t.ClassVar[str]
"""Name to register the shell as with :func:`add_completion_class`.
This is used in completion instructions (``{name}_source`` and
``{name}_complete``).
"""
source_template: t.ClassVar[str]
"""Completion script template formatted by :meth:`source`. This must
be provided by subclasses.
"""
def __init__(
self,
cli: BaseCommand,
ctx_args: t.Dict[str, t.Any],
prog_name: str,
complete_var: str,
) -> None:
self.cli = cli
self.ctx_args = ctx_args
self.prog_name = prog_name
self.complete_var = complete_var
@property
def func_name(self) -> str:
"""The name of the shell function defined by the completion
script.
"""
safe_name = re.sub(r"\W*", "", self.prog_name.replace("-", "_"), re.ASCII)
return f"_{safe_name}_completion"
def source_vars(self) -> t.Dict[str, t.Any]:
"""Vars for formatting :attr:`source_template`.
By default this provides ``complete_func``, ``complete_var``,
and ``prog_name``.
"""
return {
"complete_func": self.func_name,
"complete_var": self.complete_var,
"prog_name": self.prog_name,
}
def source(self) -> str:
"""Produce the shell script that defines the completion
function. By default this ``%``-style formats
:attr:`source_template` with the dict returned by
:meth:`source_vars`.
"""
return self.source_template % self.source_vars()
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
"""Use the env vars defined by the shell script to return a
tuple of ``args, incomplete``. This must be implemented by
subclasses.
"""
raise NotImplementedError
def get_completions(
self, args: t.List[str], incomplete: str
) -> t.List[CompletionItem]:
"""Determine the context and last complete command or parameter
from the complete args. Call that object's ``shell_complete``
method to get the completions for the incomplete value.
:param args: List of complete args before the incomplete value.
:param incomplete: Value being completed. May be empty.
"""
ctx = _resolve_context(self.cli, self.ctx_args, self.prog_name, args)
obj, incomplete = _resolve_incomplete(ctx, args, incomplete)
return obj.shell_complete(ctx, incomplete)
def format_completion(self, item: CompletionItem) -> str:
"""Format a completion item into the form recognized by the
shell script. This must be implemented by subclasses.
:param item: Completion item to format.
"""
raise NotImplementedError
def complete(self) -> str:
"""Produce the completion data to send back to the shell.
By default this calls :meth:`get_completion_args`, gets the
completions, then calls :meth:`format_completion` for each
completion.
"""
args, incomplete = self.get_completion_args()
completions = self.get_completions(args, incomplete)
out = [self.format_completion(item) for item in completions]
return "\n".join(out)
class BashComplete(ShellComplete):
"""Shell completion for Bash."""
name = "bash"
source_template = _SOURCE_BASH
def _check_version(self) -> None:
import subprocess
output = subprocess.run(
["bash", "-c", "echo ${BASH_VERSION}"], stdout=subprocess.PIPE
)
match = re.search(r"^(\d+)\.(\d+)\.\d+", output.stdout.decode())
if match is not None:
major, minor = match.groups()
if major < "4" or major == "4" and minor < "4":
raise RuntimeError(
_(
"Shell completion is not supported for Bash"
" versions older than 4.4."
)
)
else:
raise RuntimeError(
_("Couldn't detect Bash version, shell completion is not supported.")
)
def source(self) -> str:
self._check_version()
return super().source()
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
cwords = split_arg_string(os.environ["COMP_WORDS"])
cword = int(os.environ["COMP_CWORD"])
args = cwords[1:cword]
try:
incomplete = cwords[cword]
except IndexError:
incomplete = ""
return args, incomplete
def format_completion(self, item: CompletionItem) -> str:
return f"{item.type},{item.value}"
class ZshComplete(ShellComplete):
"""Shell completion for Zsh."""
name = "zsh"
source_template = _SOURCE_ZSH
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
cwords = split_arg_string(os.environ["COMP_WORDS"])
cword = int(os.environ["COMP_CWORD"])
args = cwords[1:cword]
try:
incomplete = cwords[cword]
except IndexError:
incomplete = ""
return args, incomplete
def format_completion(self, item: CompletionItem) -> str:
return f"{item.type}\n{item.value}\n{item.help if item.help else '_'}"
class FishComplete(ShellComplete):
"""Shell completion for Fish."""
name = "fish"
source_template = _SOURCE_FISH
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
cwords = split_arg_string(os.environ["COMP_WORDS"])
incomplete = os.environ["COMP_CWORD"]
args = cwords[1:]
# Fish stores the partial word in both COMP_WORDS and
# COMP_CWORD, remove it from complete args.
if incomplete and args and args[-1] == incomplete:
args.pop()
return args, incomplete
def format_completion(self, item: CompletionItem) -> str:
if item.help:
return f"{item.type},{item.value}\t{item.help}"
return f"{item.type},{item.value}"
_available_shells: t.Dict[str, t.Type[ShellComplete]] = {
"bash": BashComplete,
"fish": FishComplete,
"zsh": ZshComplete,
}
def add_completion_class(
cls: t.Type[ShellComplete], name: t.Optional[str] = None
) -> None:
"""Register a :class:`ShellComplete` subclass under the given name.
The name will be provided by the completion instruction environment
variable during completion.
:param cls: The completion class that will handle completion for the
shell.
:param name: Name to register the class under. Defaults to the
class's ``name`` attribute.
"""
if name is None:
name = cls.name
_available_shells[name] = cls
def get_completion_class(shell: str) -> t.Optional[t.Type[ShellComplete]]:
"""Look up a registered :class:`ShellComplete` subclass by the name
provided by the completion instruction environment variable. If the
name isn't registered, returns ``None``.
:param shell: Name the class is registered under.
"""
return _available_shells.get(shell)
def _is_incomplete_argument(ctx: Context, param: Parameter) -> bool:
"""Determine if the given parameter is an argument that can still
accept values.
:param ctx: Invocation context for the command represented by the
parsed complete args.
:param param: Argument object being checked.
"""
if not isinstance(param, Argument):
return False
assert param.name is not None
value = ctx.params[param.name]
return (
param.nargs == -1
or ctx.get_parameter_source(param.name) is not ParameterSource.COMMANDLINE
or (
param.nargs > 1
and isinstance(value, (tuple, list))
and len(value) < param.nargs
)
)
def _start_of_option(ctx: Context, value: str) -> bool:
"""Check if the value looks like the start of an option."""
if not value:
return False
c = value[0]
return c in ctx._opt_prefixes
def _is_incomplete_option(ctx: Context, args: t.List[str], param: Parameter) -> bool:
"""Determine if the given parameter is an option that needs a value.
:param args: List of complete args before the incomplete value.
:param param: Option object being checked.
"""
if not isinstance(param, Option):
return False
if param.is_flag or param.count:
return False
last_option = None
for index, arg in enumerate(reversed(args)):
if index + 1 > param.nargs:
break
if _start_of_option(ctx, arg):
last_option = arg
return last_option is not None and last_option in param.opts
def _resolve_context(
cli: BaseCommand, ctx_args: t.Dict[str, t.Any], prog_name: str, args: t.List[str]
) -> Context:
"""Produce the context hierarchy starting with the command and
traversing the complete arguments. This only follows the commands,
it doesn't trigger input prompts or callbacks.
:param cli: Command being called.
:param prog_name: Name of the executable in the shell.
:param args: List of complete args before the incomplete value.
"""
ctx_args["resilient_parsing"] = True
ctx = cli.make_context(prog_name, args.copy(), **ctx_args)
args = ctx.protected_args + ctx.args
while args:
command = ctx.command
if isinstance(command, MultiCommand):
if not command.chain:
name, cmd, args = command.resolve_command(ctx, args)
if cmd is None:
return ctx
ctx = cmd.make_context(name, args, parent=ctx, resilient_parsing=True)
args = ctx.protected_args + ctx.args
else:
while args:
name, cmd, args = command.resolve_command(ctx, args)
if cmd is None:
return ctx
sub_ctx = cmd.make_context(
name,
args,
parent=ctx,
allow_extra_args=True,
allow_interspersed_args=False,
resilient_parsing=True,
)
args = sub_ctx.args
ctx = sub_ctx
args = [*sub_ctx.protected_args, *sub_ctx.args]
else:
break
return ctx
def _resolve_incomplete(
ctx: Context, args: t.List[str], incomplete: str
) -> t.Tuple[t.Union[BaseCommand, Parameter], str]:
"""Find the Click object that will handle the completion of the
incomplete value. Return the object and the incomplete value.
:param ctx: Invocation context for the command represented by
the parsed complete args.
:param args: List of complete args before the incomplete value.
:param incomplete: Value being completed. May be empty.
"""
# Different shells treat an "=" between a long option name and
# value differently. Might keep the value joined, return the "="
# as a separate item, or return the split name and value. Always
# split and discard the "=" to make completion easier.
if incomplete == "=":
incomplete = ""
elif "=" in incomplete and _start_of_option(ctx, incomplete):
name, _, incomplete = incomplete.partition("=")
args.append(name)
# The "--" marker tells Click to stop treating values as options
# even if they start with the option character. If it hasn't been
# given and the incomplete arg looks like an option, the current
# command will provide option name completions.
if "--" not in args and _start_of_option(ctx, incomplete):
return ctx.command, incomplete
params = ctx.command.get_params(ctx)
# If the last complete arg is an option name with an incomplete
# value, the option will provide value completions.
for param in params:
if _is_incomplete_option(ctx, args, param):
return param, incomplete
# It's not an option name or value. The first argument without a
# parsed value will provide value completions.
for param in params:
if _is_incomplete_argument(ctx, param):
return param, incomplete
# There were no unparsed arguments, the command may be a group that
# will provide command name completions.
return ctx.command, incomplete

View file

@ -1,81 +1,109 @@
import inspect
import io
import itertools
import os
import sys
import struct
import inspect
import itertools
import typing as t
from gettext import gettext as _
from ._compat import raw_input, text_type, string_types, \
isatty, strip_ansi, get_winterm_size, DEFAULT_COLUMNS, WIN
from .utils import echo
from .exceptions import Abort, UsageError
from .types import convert_type, Choice, Path
from ._compat import isatty
from ._compat import strip_ansi
from ._compat import WIN
from .exceptions import Abort
from .exceptions import UsageError
from .globals import resolve_color_default
from .types import Choice
from .types import convert_type
from .types import ParamType
from .utils import echo
from .utils import LazyFile
if t.TYPE_CHECKING:
from ._termui_impl import ProgressBar
V = t.TypeVar("V")
# The prompt functions to use. The doc tools currently override these
# functions to customize how they work.
visible_prompt_func = raw_input
visible_prompt_func: t.Callable[[str], str] = input
_ansi_colors = {
'black': 30,
'red': 31,
'green': 32,
'yellow': 33,
'blue': 34,
'magenta': 35,
'cyan': 36,
'white': 37,
'reset': 39,
'bright_black': 90,
'bright_red': 91,
'bright_green': 92,
'bright_yellow': 93,
'bright_blue': 94,
'bright_magenta': 95,
'bright_cyan': 96,
'bright_white': 97,
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 37,
"reset": 39,
"bright_black": 90,
"bright_red": 91,
"bright_green": 92,
"bright_yellow": 93,
"bright_blue": 94,
"bright_magenta": 95,
"bright_cyan": 96,
"bright_white": 97,
}
_ansi_reset_all = '\033[0m'
_ansi_reset_all = "\033[0m"
def hidden_prompt_func(prompt):
def hidden_prompt_func(prompt: str) -> str:
import getpass
return getpass.getpass(prompt)
def _build_prompt(text, suffix, show_default=False, default=None, show_choices=True, type=None):
def _build_prompt(
text: str,
suffix: str,
show_default: bool = False,
default: t.Optional[t.Any] = None,
show_choices: bool = True,
type: t.Optional[ParamType] = None,
) -> str:
prompt = text
if type is not None and show_choices and isinstance(type, Choice):
prompt += ' (' + ", ".join(map(str, type.choices)) + ')'
prompt += f" ({', '.join(map(str, type.choices))})"
if default is not None and show_default:
prompt = '%s [%s]' % (prompt, default)
return prompt + suffix
prompt = f"{prompt} [{_format_default(default)}]"
return f"{prompt}{suffix}"
def prompt(text, default=None, hide_input=False, confirmation_prompt=False,
type=None, value_proc=None, prompt_suffix=': ', show_default=True,
err=False, show_choices=True):
def _format_default(default: t.Any) -> t.Any:
if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"):
return default.name # type: ignore
return default
def prompt(
text: str,
default: t.Optional[t.Any] = None,
hide_input: bool = False,
confirmation_prompt: t.Union[bool, str] = False,
type: t.Optional[t.Union[ParamType, t.Any]] = None,
value_proc: t.Optional[t.Callable[[str], t.Any]] = None,
prompt_suffix: str = ": ",
show_default: bool = True,
err: bool = False,
show_choices: bool = True,
) -> t.Any:
"""Prompts a user for input. This is a convenience function that can
be used to prompt a user for input later.
If the user aborts the input by sending a interrupt signal, this
If the user aborts the input by sending an interrupt signal, this
function will catch it and raise a :exc:`Abort` exception.
.. versionadded:: 7.0
Added the show_choices parameter.
.. versionadded:: 6.0
Added unicode support for cmd.exe on Windows.
.. versionadded:: 4.0
Added the `err` parameter.
:param text: the text to show for the prompt.
:param default: the default value to use if no input happens. If this
is not given it will prompt until it's aborted.
:param hide_input: if this is set to true then the input value will
be hidden.
:param confirmation_prompt: asks for confirmation for the value.
:param confirmation_prompt: Prompt a second time to confirm the
value. Can be set to a string instead of ``True`` to customize
the message.
:param type: the type to use to check the value against.
:param value_proc: if this parameter is provided it's a function that
is invoked instead of the type conversion to
@ -88,93 +116,133 @@ def prompt(text, default=None, hide_input=False, confirmation_prompt=False,
For example if type is a Choice of either day or week,
show_choices is true and text is "Group by" then the
prompt will be "Group by (day, week): ".
"""
result = None
def prompt_func(text):
f = hide_input and hidden_prompt_func or visible_prompt_func
.. versionadded:: 8.0
``confirmation_prompt`` can be a custom string.
.. versionadded:: 7.0
Added the ``show_choices`` parameter.
.. versionadded:: 6.0
Added unicode support for cmd.exe on Windows.
.. versionadded:: 4.0
Added the `err` parameter.
"""
def prompt_func(text: str) -> str:
f = hidden_prompt_func if hide_input else visible_prompt_func
try:
# Write the prompt separately so that we get nice
# coloring through colorama on Windows
echo(text, nl=False, err=err)
return f('')
echo(text.rstrip(" "), nl=False, err=err)
# Echo a space to stdout to work around an issue where
# readline causes backspace to clear the whole line.
return f(" ")
except (KeyboardInterrupt, EOFError):
# getpass doesn't print a newline if the user aborts input with ^C.
# Allegedly this behavior is inherited from getpass(3).
# A doc bug has been filed at https://bugs.python.org/issue24711
if hide_input:
echo(None, err=err)
raise Abort()
raise Abort() from None
if value_proc is None:
value_proc = convert_type(type, default)
prompt = _build_prompt(text, prompt_suffix, show_default, default, show_choices, type)
prompt = _build_prompt(
text, prompt_suffix, show_default, default, show_choices, type
)
while 1:
while 1:
if confirmation_prompt:
if confirmation_prompt is True:
confirmation_prompt = _("Repeat for confirmation")
confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix)
while True:
while True:
value = prompt_func(prompt)
if value:
break
elif default is not None:
if isinstance(value_proc, Path):
# validate Path default value(exists, dir_okay etc.)
value = default
break
return default
value = default
break
try:
result = value_proc(value)
except UsageError as e:
echo('Error: %s' % e.message, err=err)
if hide_input:
echo(_("Error: The value you entered was invalid."), err=err)
else:
echo(_("Error: {e.message}").format(e=e), err=err) # noqa: B306
continue
if not confirmation_prompt:
return result
while 1:
value2 = prompt_func('Repeat for confirmation: ')
if value2:
while True:
value2 = prompt_func(confirmation_prompt)
is_empty = not value and not value2
if value2 or is_empty:
break
if value == value2:
return result
echo('Error: the two entered values do not match', err=err)
echo(_("Error: The two entered values do not match."), err=err)
def confirm(text, default=False, abort=False, prompt_suffix=': ',
show_default=True, err=False):
def confirm(
text: str,
default: t.Optional[bool] = False,
abort: bool = False,
prompt_suffix: str = ": ",
show_default: bool = True,
err: bool = False,
) -> bool:
"""Prompts for confirmation (yes/no question).
If the user aborts the input by sending a interrupt signal this
function will catch it and raise a :exc:`Abort` exception.
.. versionadded:: 4.0
Added the `err` parameter.
:param text: the question to ask.
:param default: the default for the prompt.
:param default: The default value to use when no input is given. If
``None``, repeat until input is given.
:param abort: if this is set to `True` a negative answer aborts the
exception by raising :exc:`Abort`.
:param prompt_suffix: a suffix that should be added to the prompt.
:param show_default: shows or hides the default value in the prompt.
:param err: if set to true the file defaults to ``stderr`` instead of
``stdout``, the same as with echo.
.. versionchanged:: 8.0
Repeat until input is given if ``default`` is ``None``.
.. versionadded:: 4.0
Added the ``err`` parameter.
"""
prompt = _build_prompt(text, prompt_suffix, show_default,
default and 'Y/n' or 'y/N')
while 1:
prompt = _build_prompt(
text,
prompt_suffix,
show_default,
"y/n" if default is None else ("Y/n" if default else "y/N"),
)
while True:
try:
# Write the prompt separately so that we get nice
# coloring through colorama on Windows
echo(prompt, nl=False, err=err)
value = visible_prompt_func('').lower().strip()
echo(prompt.rstrip(" "), nl=False, err=err)
# Echo a space to stdout to work around an issue where
# readline causes backspace to clear the whole line.
value = visible_prompt_func(" ").lower().strip()
except (KeyboardInterrupt, EOFError):
raise Abort()
if value in ('y', 'yes'):
raise Abort() from None
if value in ("y", "yes"):
rv = True
elif value in ('n', 'no'):
elif value in ("n", "no"):
rv = False
elif value == '':
elif default is not None and value == "":
rv = default
else:
echo('Error: invalid input', err=err)
echo(_("Error: invalid input"), err=err)
continue
break
if abort and not rv:
@ -182,54 +250,10 @@ def confirm(text, default=False, abort=False, prompt_suffix=': ',
return rv
def get_terminal_size():
"""Returns the current size of the terminal as tuple in the form
``(width, height)`` in columns and rows.
"""
# If shutil has get_terminal_size() (Python 3.3 and later) use that
if sys.version_info >= (3, 3):
import shutil
shutil_get_terminal_size = getattr(shutil, 'get_terminal_size', None)
if shutil_get_terminal_size:
sz = shutil_get_terminal_size()
return sz.columns, sz.lines
# We provide a sensible default for get_winterm_size() when being invoked
# inside a subprocess. Without this, it would not provide a useful input.
if get_winterm_size is not None:
size = get_winterm_size()
if size == (0, 0):
return (79, 24)
else:
return size
def ioctl_gwinsz(fd):
try:
import fcntl
import termios
cr = struct.unpack(
'hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234'))
except Exception:
return
return cr
cr = ioctl_gwinsz(0) or ioctl_gwinsz(1) or ioctl_gwinsz(2)
if not cr:
try:
fd = os.open(os.ctermid(), os.O_RDONLY)
try:
cr = ioctl_gwinsz(fd)
finally:
os.close(fd)
except Exception:
pass
if not cr or not cr[0] or not cr[1]:
cr = (os.environ.get('LINES', 25),
os.environ.get('COLUMNS', DEFAULT_COLUMNS))
return int(cr[1]), int(cr[0])
def echo_via_pager(text_or_generator, color=None):
def echo_via_pager(
text_or_generator: t.Union[t.Iterable[str], t.Callable[[], t.Iterable[str]], str],
color: t.Optional[bool] = None,
) -> None:
"""This function takes a text and shows it via an environment specific
pager on stdout.
@ -244,25 +268,37 @@ def echo_via_pager(text_or_generator, color=None):
color = resolve_color_default(color)
if inspect.isgeneratorfunction(text_or_generator):
i = text_or_generator()
elif isinstance(text_or_generator, string_types):
i = t.cast(t.Callable[[], t.Iterable[str]], text_or_generator)()
elif isinstance(text_or_generator, str):
i = [text_or_generator]
else:
i = iter(text_or_generator)
i = iter(t.cast(t.Iterable[str], text_or_generator))
# convert every element of i to a text type if necessary
text_generator = (el if isinstance(el, string_types) else text_type(el)
for el in i)
text_generator = (el if isinstance(el, str) else str(el) for el in i)
from ._termui_impl import pager
return pager(itertools.chain(text_generator, "\n"), color)
def progressbar(iterable=None, length=None, label=None, show_eta=True,
show_percent=None, show_pos=False,
item_show_func=None, fill_char='#', empty_char='-',
bar_template='%(label)s [%(bar)s] %(info)s',
info_sep=' ', width=36, file=None, color=None):
def progressbar(
iterable: t.Optional[t.Iterable[V]] = None,
length: t.Optional[int] = None,
label: t.Optional[str] = None,
show_eta: bool = True,
show_percent: t.Optional[bool] = None,
show_pos: bool = False,
item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
fill_char: str = "#",
empty_char: str = "-",
bar_template: str = "%(label)s [%(bar)s] %(info)s",
info_sep: str = " ",
width: int = 36,
file: t.Optional[t.TextIO] = None,
color: t.Optional[bool] = None,
update_min_steps: int = 1,
) -> "ProgressBar[V]":
"""This function creates an iterable context manager that can be used
to iterate over something while showing a progress bar. It will
either iterate over the `iterable` or `length` items (that are counted
@ -272,11 +308,17 @@ def progressbar(iterable=None, length=None, label=None, show_eta=True,
will not be rendered if the file is not a terminal.
The context manager creates the progress bar. When the context
manager is entered the progress bar is already displayed. With every
manager is entered the progress bar is already created. With every
iteration over the progress bar, the iterable passed to the bar is
advanced and the bar is updated. When the context manager exits,
a newline is printed and the progress bar is finalized on screen.
Note: The progress bar is currently designed for use cases where the
total progress can be expected to take at least several seconds.
Because of this, the ProgressBar class object won't display
progress that is considered too fast, and progress where the time
between steps is less than a second.
No printing must happen or the progress bar will be unintentionally
destroyed.
@ -296,11 +338,19 @@ def progressbar(iterable=None, length=None, label=None, show_eta=True,
process_chunk(chunk)
bar.update(chunks.bytes)
.. versionadded:: 2.0
The ``update()`` method also takes an optional value specifying the
``current_item`` at the new position. This is useful when used
together with ``item_show_func`` to customize the output for each
manual step::
.. versionadded:: 4.0
Added the `color` parameter. Added a `update` method to the
progressbar object.
with click.progressbar(
length=total_size,
label='Unzipping archive',
item_show_func=lambda a: a.filename
) as bar:
for archive in zip_file:
archive.extract()
bar.update(archive.size, archive)
:param iterable: an iterable to iterate over. If not provided the length
is required.
@ -319,10 +369,10 @@ def progressbar(iterable=None, length=None, label=None, show_eta=True,
`False` if not.
:param show_pos: enables or disables the absolute position display. The
default is `False`.
:param item_show_func: a function called with the current item which
can return a string to show the current item
next to the progress bar. Note that the current
item can be `None`!
:param item_show_func: A function called with the current item which
can return a string to show next to the progress bar. If the
function returns ``None`` nothing is shown. The current item can
be ``None``, such as when entering and exiting the bar.
:param fill_char: the character to use to show the filled part of the
progress bar.
:param empty_char: the character to use to show the non-filled part of
@ -334,24 +384,57 @@ def progressbar(iterable=None, length=None, label=None, show_eta=True,
:param info_sep: the separator between multiple info items (eta etc.)
:param width: the width of the progress bar in characters, 0 means full
terminal width
:param file: the file to write to. If this is not a terminal then
only the label is printed.
:param file: The file to write to. If this is not a terminal then
only the label is printed.
:param color: controls if the terminal supports ANSI colors or not. The
default is autodetection. This is only needed if ANSI
codes are included anywhere in the progress bar output
which is not the case by default.
:param update_min_steps: Render only when this many updates have
completed. This allows tuning for very fast iterators.
.. versionchanged:: 8.0
Output is shown even if execution time is less than 0.5 seconds.
.. versionchanged:: 8.0
``item_show_func`` shows the current item, not the previous one.
.. versionchanged:: 8.0
Labels are echoed if the output is not a TTY. Reverts a change
in 7.0 that removed all output.
.. versionadded:: 8.0
Added the ``update_min_steps`` parameter.
.. versionchanged:: 4.0
Added the ``color`` parameter. Added the ``update`` method to
the object.
.. versionadded:: 2.0
"""
from ._termui_impl import ProgressBar
color = resolve_color_default(color)
return ProgressBar(iterable=iterable, length=length, show_eta=show_eta,
show_percent=show_percent, show_pos=show_pos,
item_show_func=item_show_func, fill_char=fill_char,
empty_char=empty_char, bar_template=bar_template,
info_sep=info_sep, file=file, label=label,
width=width, color=color)
return ProgressBar(
iterable=iterable,
length=length,
show_eta=show_eta,
show_percent=show_percent,
show_pos=show_pos,
item_show_func=item_show_func,
fill_char=fill_char,
empty_char=empty_char,
bar_template=bar_template,
info_sep=info_sep,
file=file,
label=label,
width=width,
color=color,
update_min_steps=update_min_steps,
)
def clear():
def clear() -> None:
"""Clears the terminal screen. This will have the effect of clearing
the whole visible space of the terminal and moving the cursor to the
top left. This does not do anything if not connected to a terminal.
@ -360,17 +443,39 @@ def clear():
"""
if not isatty(sys.stdout):
return
# If we're on Windows and we don't have colorama available, then we
# clear the screen by shelling out. Otherwise we can use an escape
# sequence.
if WIN:
os.system('cls')
os.system("cls")
else:
sys.stdout.write('\033[2J\033[1;1H')
sys.stdout.write("\033[2J\033[1;1H")
def style(text, fg=None, bg=None, bold=None, dim=None, underline=None,
blink=None, reverse=None, reset=True):
def _interpret_color(
color: t.Union[int, t.Tuple[int, int, int], str], offset: int = 0
) -> str:
if isinstance(color, int):
return f"{38 + offset};5;{color:d}"
if isinstance(color, (tuple, list)):
r, g, b = color
return f"{38 + offset};2;{r:d};{g:d};{b:d}"
return str(_ansi_colors[color] + offset)
def style(
text: t.Any,
fg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
bg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
bold: t.Optional[bool] = None,
dim: t.Optional[bool] = None,
underline: t.Optional[bool] = None,
overline: t.Optional[bool] = None,
italic: t.Optional[bool] = None,
blink: t.Optional[bool] = None,
reverse: t.Optional[bool] = None,
strikethrough: t.Optional[bool] = None,
reset: bool = True,
) -> str:
"""Styles a text with ANSI styles and returns the new string. By
default the styling is self contained which means that at the end
of the string a reset code is issued. This can be prevented by
@ -381,6 +486,7 @@ def style(text, fg=None, bg=None, bold=None, dim=None, underline=None,
click.echo(click.style('Hello World!', fg='green'))
click.echo(click.style('ATTENTION!', blink=True))
click.echo(click.style('Some things', reverse=True, fg='cyan'))
click.echo(click.style('More colors', fg=(255, 12, 128), bg=117))
Supported color names:
@ -402,10 +508,15 @@ def style(text, fg=None, bg=None, bold=None, dim=None, underline=None,
* ``bright_white``
* ``reset`` (reset the color code only)
.. versionadded:: 2.0
If the terminal supports it, color may also be specified as:
.. versionadded:: 7.0
Added support for bright colors.
- An integer in the interval [0, 255]. The terminal must support
8-bit/256-color mode.
- An RGB tuple of three integers in [0, 255]. The terminal must
support 24-bit/true-color mode.
See https://en.wikipedia.org/wiki/ANSI_color and
https://gist.github.com/XVilka/8346728 for more information.
:param text: the string to style with ansi codes.
:param fg: if provided this will become the foreground color.
@ -414,42 +525,73 @@ def style(text, fg=None, bg=None, bold=None, dim=None, underline=None,
:param dim: if provided this will enable or disable dim mode. This is
badly supported.
:param underline: if provided this will enable or disable underline.
:param overline: if provided this will enable or disable overline.
:param italic: if provided this will enable or disable italic.
:param blink: if provided this will enable or disable blinking.
:param reverse: if provided this will enable or disable inverse
rendering (foreground becomes background and the
other way round).
:param strikethrough: if provided this will enable or disable
striking through text.
:param reset: by default a reset-all code is added at the end of the
string which means that styles do not carry over. This
can be disabled to compose styles.
.. versionchanged:: 8.0
A non-string ``message`` is converted to a string.
.. versionchanged:: 8.0
Added support for 256 and RGB color codes.
.. versionchanged:: 8.0
Added the ``strikethrough``, ``italic``, and ``overline``
parameters.
.. versionchanged:: 7.0
Added support for bright colors.
.. versionadded:: 2.0
"""
if not isinstance(text, str):
text = str(text)
bits = []
if fg:
try:
bits.append('\033[%dm' % (_ansi_colors[fg]))
bits.append(f"\033[{_interpret_color(fg)}m")
except KeyError:
raise TypeError('Unknown color %r' % fg)
raise TypeError(f"Unknown color {fg!r}") from None
if bg:
try:
bits.append('\033[%dm' % (_ansi_colors[bg] + 10))
bits.append(f"\033[{_interpret_color(bg, 10)}m")
except KeyError:
raise TypeError('Unknown color %r' % bg)
raise TypeError(f"Unknown color {bg!r}") from None
if bold is not None:
bits.append('\033[%dm' % (1 if bold else 22))
bits.append(f"\033[{1 if bold else 22}m")
if dim is not None:
bits.append('\033[%dm' % (2 if dim else 22))
bits.append(f"\033[{2 if dim else 22}m")
if underline is not None:
bits.append('\033[%dm' % (4 if underline else 24))
bits.append(f"\033[{4 if underline else 24}m")
if overline is not None:
bits.append(f"\033[{53 if overline else 55}m")
if italic is not None:
bits.append(f"\033[{3 if italic else 23}m")
if blink is not None:
bits.append('\033[%dm' % (5 if blink else 25))
bits.append(f"\033[{5 if blink else 25}m")
if reverse is not None:
bits.append('\033[%dm' % (7 if reverse else 27))
bits.append(f"\033[{7 if reverse else 27}m")
if strikethrough is not None:
bits.append(f"\033[{9 if strikethrough else 29}m")
bits.append(text)
if reset:
bits.append(_ansi_reset_all)
return ''.join(bits)
return "".join(bits)
def unstyle(text):
def unstyle(text: str) -> str:
"""Removes ANSI styling information from a string. Usually it's not
necessary to use this function as Click's echo function will
automatically remove styling if necessary.
@ -461,7 +603,14 @@ def unstyle(text):
return strip_ansi(text)
def secho(message=None, file=None, nl=True, err=False, color=None, **styles):
def secho(
message: t.Optional[t.Any] = None,
file: t.Optional[t.IO[t.AnyStr]] = None,
nl: bool = True,
err: bool = False,
color: t.Optional[bool] = None,
**styles: t.Any,
) -> None:
"""This function combines :func:`echo` and :func:`style` into one
call. As such the following two calls are the same::
@ -471,15 +620,31 @@ def secho(message=None, file=None, nl=True, err=False, color=None, **styles):
All keyword arguments are forwarded to the underlying functions
depending on which one they go with.
Non-string types will be converted to :class:`str`. However,
:class:`bytes` are passed directly to :meth:`echo` without applying
style. If you want to style bytes that represent text, call
:meth:`bytes.decode` first.
.. versionchanged:: 8.0
A non-string ``message`` is converted to a string. Bytes are
passed through without style applied.
.. versionadded:: 2.0
"""
if message is not None:
if message is not None and not isinstance(message, (bytes, bytearray)):
message = style(message, **styles)
return echo(message, file=file, nl=nl, err=err, color=color)
def edit(text=None, editor=None, env=None, require_save=True,
extension='.txt', filename=None):
def edit(
text: t.Optional[t.AnyStr] = None,
editor: t.Optional[str] = None,
env: t.Optional[t.Mapping[str, str]] = None,
require_save: bool = True,
extension: str = ".txt",
filename: t.Optional[str] = None,
) -> t.Optional[t.AnyStr]:
r"""Edits the given text in the defined editor. If an editor is given
(should be the full path to the executable but the regular operating
system search path is used for finding the executable) it overrides
@ -508,14 +673,17 @@ def edit(text=None, editor=None, env=None, require_save=True,
file as an indirection in that case.
"""
from ._termui_impl import Editor
editor = Editor(editor=editor, env=env, require_save=require_save,
extension=extension)
ed = Editor(editor=editor, env=env, require_save=require_save, extension=extension)
if filename is None:
return editor.edit(text)
editor.edit_file(filename)
return ed.edit(text)
ed.edit_file(filename)
return None
def launch(url, wait=False, locate=False):
def launch(url: str, wait: bool = False, locate: bool = False) -> int:
"""This function launches the given URL (or filename) in the default
viewer application for this file type. If this is an executable, it
might launch the executable in a new session. The return value is
@ -530,7 +698,9 @@ def launch(url, wait=False, locate=False):
.. versionadded:: 2.0
:param url: URL or filename of the thing to launch.
:param wait: waits for the program to stop.
:param wait: Wait for the program to exit before returning. This
only works if the launched program blocks. In particular,
``xdg-open`` on Linux does not block.
:param locate: if this is set to `True` then instead of launching the
application associated with the URL it will attempt to
launch a file manager with the file located. This
@ -538,15 +708,16 @@ def launch(url, wait=False, locate=False):
the filesystem.
"""
from ._termui_impl import open_url
return open_url(url, wait=wait, locate=locate)
# If this is provided, getchar() calls into this instead. This is used
# for unittesting purposes.
_getchar = None
_getchar: t.Optional[t.Callable[[bool], str]] = None
def getchar(echo=False):
def getchar(echo: bool = False) -> str:
"""Fetches a single character from the terminal and returns it. This
will always return a unicode character and under certain rare
circumstances this might return more than one character. The
@ -566,18 +737,23 @@ def getchar(echo=False):
:param echo: if set to `True`, the character read will also show up on
the terminal. The default is to not show it.
"""
f = _getchar
if f is None:
global _getchar
if _getchar is None:
from ._termui_impl import getchar as f
return f(echo)
_getchar = f
return _getchar(echo)
def raw_terminal():
def raw_terminal() -> t.ContextManager[int]:
from ._termui_impl import raw_terminal as f
return f()
def pause(info='Press any key to continue ...', err=False):
def pause(info: t.Optional[str] = None, err: bool = False) -> None:
"""This command stops execution and waits for the user to press any
key to continue. This is similar to the Windows batch "pause"
command. If the program is not run through a terminal, this command
@ -588,12 +764,17 @@ def pause(info='Press any key to continue ...', err=False):
.. versionadded:: 4.0
Added the `err` parameter.
:param info: the info string to print before pausing.
:param info: The message to print before pausing. Defaults to
``"Press any key to continue..."``.
:param err: if set to message goes to ``stderr`` instead of
``stdout``, the same as with echo.
"""
if not isatty(sys.stdin) or not isatty(sys.stdout):
return
if info is None:
info = _("Press any key to continue...")
try:
if info:
echo(info, nl=False, err=err)

View file

@ -1,86 +1,128 @@
import os
import sys
import shutil
import tempfile
import contextlib
import io
import os
import shlex
import shutil
import sys
import tempfile
import typing as t
from types import TracebackType
from ._compat import iteritems, PY2, string_types
from . import formatting
from . import termui
from . import utils
from ._compat import _find_binary_reader
if t.TYPE_CHECKING:
from .core import BaseCommand
# If someone wants to vendor click, we want to ensure the
# correct package is discovered. Ideally we could use a
# relative import here but unfortunately Python does not
# support that.
clickpkg = sys.modules[__name__.rsplit('.', 1)[0]]
if PY2:
from cStringIO import StringIO
else:
import io
from ._compat import _find_binary_reader
class EchoingStdin(object):
def __init__(self, input, output):
class EchoingStdin:
def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None:
self._input = input
self._output = output
self._paused = False
def __getattr__(self, x):
def __getattr__(self, x: str) -> t.Any:
return getattr(self._input, x)
def _echo(self, rv):
self._output.write(rv)
def _echo(self, rv: bytes) -> bytes:
if not self._paused:
self._output.write(rv)
return rv
def read(self, n=-1):
def read(self, n: int = -1) -> bytes:
return self._echo(self._input.read(n))
def readline(self, n=-1):
def read1(self, n: int = -1) -> bytes:
return self._echo(self._input.read1(n)) # type: ignore
def readline(self, n: int = -1) -> bytes:
return self._echo(self._input.readline(n))
def readlines(self):
def readlines(self) -> t.List[bytes]:
return [self._echo(x) for x in self._input.readlines()]
def __iter__(self):
def __iter__(self) -> t.Iterator[bytes]:
return iter(self._echo(x) for x in self._input)
def __repr__(self):
def __repr__(self) -> str:
return repr(self._input)
def make_input_stream(input, charset):
@contextlib.contextmanager
def _pause_echo(stream: t.Optional[EchoingStdin]) -> t.Iterator[None]:
if stream is None:
yield
else:
stream._paused = True
yield
stream._paused = False
class _NamedTextIOWrapper(io.TextIOWrapper):
def __init__(
self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any
) -> None:
super().__init__(buffer, **kwargs)
self._name = name
self._mode = mode
@property
def name(self) -> str:
return self._name
@property
def mode(self) -> str:
return self._mode
def make_input_stream(
input: t.Optional[t.Union[str, bytes, t.IO]], charset: str
) -> t.BinaryIO:
# Is already an input stream.
if hasattr(input, 'read'):
if PY2:
return input
rv = _find_binary_reader(input)
if hasattr(input, "read"):
rv = _find_binary_reader(t.cast(t.IO, input))
if rv is not None:
return rv
raise TypeError('Could not find binary reader for input stream.')
raise TypeError("Could not find binary reader for input stream.")
if input is None:
input = b''
elif not isinstance(input, bytes):
input = b""
elif isinstance(input, str):
input = input.encode(charset)
if PY2:
return StringIO(input)
return io.BytesIO(input)
return io.BytesIO(t.cast(bytes, input))
class Result(object):
class Result:
"""Holds the captured result of an invoked CLI script."""
def __init__(self, runner, stdout_bytes, stderr_bytes, exit_code,
exception, exc_info=None):
def __init__(
self,
runner: "CliRunner",
stdout_bytes: bytes,
stderr_bytes: t.Optional[bytes],
return_value: t.Any,
exit_code: int,
exception: t.Optional[BaseException],
exc_info: t.Optional[
t.Tuple[t.Type[BaseException], BaseException, TracebackType]
] = None,
):
#: The runner that created the result
self.runner = runner
#: The standard output as bytes.
self.stdout_bytes = stdout_bytes
#: The standard error as bytes, or False(y) if not available
#: The standard error as bytes, or None if not available
self.stderr_bytes = stderr_bytes
#: The value returned from the invoked command.
#:
#: .. versionadded:: 8.0
self.return_value = return_value
#: The exit code as integer.
self.exit_code = exit_code
#: The exception that happened if one did.
@ -89,41 +131,38 @@ class Result(object):
self.exc_info = exc_info
@property
def output(self):
def output(self) -> str:
"""The (standard) output as unicode string."""
return self.stdout
@property
def stdout(self):
def stdout(self) -> str:
"""The standard output as unicode string."""
return self.stdout_bytes.decode(self.runner.charset, 'replace') \
.replace('\r\n', '\n')
@property
def stderr(self):
"""The standard error as unicode string."""
if not self.stderr_bytes:
raise ValueError("stderr not separately captured")
return self.stderr_bytes.decode(self.runner.charset, 'replace') \
.replace('\r\n', '\n')
def __repr__(self):
return '<%s %s>' % (
type(self).__name__,
self.exception and repr(self.exception) or 'okay',
return self.stdout_bytes.decode(self.runner.charset, "replace").replace(
"\r\n", "\n"
)
@property
def stderr(self) -> str:
"""The standard error as unicode string."""
if self.stderr_bytes is None:
raise ValueError("stderr not separately captured")
return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
"\r\n", "\n"
)
class CliRunner(object):
def __repr__(self) -> str:
exc_str = repr(self.exception) if self.exception else "okay"
return f"<{type(self).__name__} {exc_str}>"
class CliRunner:
"""The CLI runner provides functionality to invoke a Click command line
script for unittesting purposes in a isolated environment. This only
works in single-threaded systems without any concurrency as it changes the
global interpreter state.
:param charset: the character set for the input and output data. This is
UTF-8 by default and should not be changed currently as
the reporting to Click only works in Python 2 properly.
:param charset: the character set for the input and output data.
:param env: a dictionary with environment variables for overriding.
:param echo_stdin: if this is set to `True`, then reading from stdin writes
to stdout. This is useful for showing examples in
@ -136,23 +175,28 @@ class CliRunner(object):
independently
"""
def __init__(self, charset=None, env=None, echo_stdin=False,
mix_stderr=True):
if charset is None:
charset = 'utf-8'
def __init__(
self,
charset: str = "utf-8",
env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
echo_stdin: bool = False,
mix_stderr: bool = True,
) -> None:
self.charset = charset
self.env = env or {}
self.echo_stdin = echo_stdin
self.mix_stderr = mix_stderr
def get_default_prog_name(self, cli):
def get_default_prog_name(self, cli: "BaseCommand") -> str:
"""Given a command object it will return the default program name
for it. The default is the `name` attribute or ``"root"`` if not
set.
"""
return cli.name or 'root'
return cli.name or "root"
def make_env(self, overrides=None):
def make_env(
self, overrides: t.Optional[t.Mapping[str, t.Optional[str]]] = None
) -> t.Mapping[str, t.Optional[str]]:
"""Returns the environment overrides for invoking a script."""
rv = dict(self.env)
if overrides:
@ -160,7 +204,12 @@ class CliRunner(object):
return rv
@contextlib.contextmanager
def isolation(self, input=None, env=None, color=False):
def isolation(
self,
input: t.Optional[t.Union[str, bytes, t.IO]] = None,
env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
color: bool = False,
) -> t.Iterator[t.Tuple[io.BytesIO, t.Optional[io.BytesIO]]]:
"""A context manager that sets up the isolation for invoking of a
command line tool. This sets up stdin with the given input data
and `os.environ` with the overrides from the given dictionary.
@ -169,87 +218,107 @@ class CliRunner(object):
This is automatically done in the :meth:`invoke` method.
.. versionadded:: 4.0
The ``color`` parameter was added.
:param input: the input stream to put into sys.stdin.
:param env: the environment overrides as dictionary.
:param color: whether the output should contain color codes. The
application can still override this explicitly.
.. versionchanged:: 8.0
``stderr`` is opened with ``errors="backslashreplace"``
instead of the default ``"strict"``.
.. versionchanged:: 4.0
Added the ``color`` parameter.
"""
input = make_input_stream(input, self.charset)
bytes_input = make_input_stream(input, self.charset)
echo_input = None
old_stdin = sys.stdin
old_stdout = sys.stdout
old_stderr = sys.stderr
old_forced_width = clickpkg.formatting.FORCED_WIDTH
clickpkg.formatting.FORCED_WIDTH = 80
old_forced_width = formatting.FORCED_WIDTH
formatting.FORCED_WIDTH = 80
env = self.make_env(env)
if PY2:
bytes_output = StringIO()
if self.echo_stdin:
input = EchoingStdin(input, bytes_output)
sys.stdout = bytes_output
if not self.mix_stderr:
bytes_error = StringIO()
sys.stderr = bytes_error
else:
bytes_output = io.BytesIO()
if self.echo_stdin:
input = EchoingStdin(input, bytes_output)
input = io.TextIOWrapper(input, encoding=self.charset)
sys.stdout = io.TextIOWrapper(
bytes_output, encoding=self.charset)
if not self.mix_stderr:
bytes_error = io.BytesIO()
sys.stderr = io.TextIOWrapper(
bytes_error, encoding=self.charset)
bytes_output = io.BytesIO()
if self.echo_stdin:
bytes_input = echo_input = t.cast(
t.BinaryIO, EchoingStdin(bytes_input, bytes_output)
)
sys.stdin = text_input = _NamedTextIOWrapper(
bytes_input, encoding=self.charset, name="<stdin>", mode="r"
)
if self.echo_stdin:
# Force unbuffered reads, otherwise TextIOWrapper reads a
# large chunk which is echoed early.
text_input._CHUNK_SIZE = 1 # type: ignore
sys.stdout = _NamedTextIOWrapper(
bytes_output, encoding=self.charset, name="<stdout>", mode="w"
)
bytes_error = None
if self.mix_stderr:
sys.stderr = sys.stdout
else:
bytes_error = io.BytesIO()
sys.stderr = _NamedTextIOWrapper(
bytes_error,
encoding=self.charset,
name="<stderr>",
mode="w",
errors="backslashreplace",
)
sys.stdin = input
def visible_input(prompt=None):
sys.stdout.write(prompt or '')
val = input.readline().rstrip('\r\n')
sys.stdout.write(val + '\n')
@_pause_echo(echo_input) # type: ignore
def visible_input(prompt: t.Optional[str] = None) -> str:
sys.stdout.write(prompt or "")
val = text_input.readline().rstrip("\r\n")
sys.stdout.write(f"{val}\n")
sys.stdout.flush()
return val
def hidden_input(prompt=None):
sys.stdout.write((prompt or '') + '\n')
@_pause_echo(echo_input) # type: ignore
def hidden_input(prompt: t.Optional[str] = None) -> str:
sys.stdout.write(f"{prompt or ''}\n")
sys.stdout.flush()
return input.readline().rstrip('\r\n')
return text_input.readline().rstrip("\r\n")
def _getchar(echo):
@_pause_echo(echo_input) # type: ignore
def _getchar(echo: bool) -> str:
char = sys.stdin.read(1)
if echo:
sys.stdout.write(char)
sys.stdout.flush()
sys.stdout.flush()
return char
default_color = color
def should_strip_ansi(stream=None, color=None):
def should_strip_ansi(
stream: t.Optional[t.IO] = None, color: t.Optional[bool] = None
) -> bool:
if color is None:
return not default_color
return not color
old_visible_prompt_func = clickpkg.termui.visible_prompt_func
old_hidden_prompt_func = clickpkg.termui.hidden_prompt_func
old__getchar_func = clickpkg.termui._getchar
old_should_strip_ansi = clickpkg.utils.should_strip_ansi
clickpkg.termui.visible_prompt_func = visible_input
clickpkg.termui.hidden_prompt_func = hidden_input
clickpkg.termui._getchar = _getchar
clickpkg.utils.should_strip_ansi = should_strip_ansi
old_visible_prompt_func = termui.visible_prompt_func
old_hidden_prompt_func = termui.hidden_prompt_func
old__getchar_func = termui._getchar
old_should_strip_ansi = utils.should_strip_ansi # type: ignore
termui.visible_prompt_func = visible_input
termui.hidden_prompt_func = hidden_input
termui._getchar = _getchar
utils.should_strip_ansi = should_strip_ansi # type: ignore
old_env = {}
try:
for key, value in iteritems(env):
for key, value in env.items():
old_env[key] = os.environ.get(key)
if value is None:
try:
@ -258,9 +327,9 @@ class CliRunner(object):
pass
else:
os.environ[key] = value
yield (bytes_output, not self.mix_stderr and bytes_error)
yield (bytes_output, bytes_error)
finally:
for key, value in iteritems(old_env):
for key, value in old_env.items():
if value is None:
try:
del os.environ[key]
@ -271,14 +340,22 @@ class CliRunner(object):
sys.stdout = old_stdout
sys.stderr = old_stderr
sys.stdin = old_stdin
clickpkg.termui.visible_prompt_func = old_visible_prompt_func
clickpkg.termui.hidden_prompt_func = old_hidden_prompt_func
clickpkg.termui._getchar = old__getchar_func
clickpkg.utils.should_strip_ansi = old_should_strip_ansi
clickpkg.formatting.FORCED_WIDTH = old_forced_width
termui.visible_prompt_func = old_visible_prompt_func
termui.hidden_prompt_func = old_hidden_prompt_func
termui._getchar = old__getchar_func
utils.should_strip_ansi = old_should_strip_ansi # type: ignore
formatting.FORCED_WIDTH = old_forced_width
def invoke(self, cli, args=None, input=None, env=None,
catch_exceptions=True, color=False, mix_stderr=False, **extra):
def invoke(
self,
cli: "BaseCommand",
args: t.Optional[t.Union[str, t.Sequence[str]]] = None,
input: t.Optional[t.Union[str, bytes, t.IO]] = None,
env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
catch_exceptions: bool = True,
color: bool = False,
**extra: t.Any,
) -> Result:
"""Invokes a command in an isolated environment. The arguments are
forwarded directly to the command line script, the `extra` keyword
arguments are passed to the :meth:`~clickpkg.Command.main` function of
@ -286,16 +363,6 @@ class CliRunner(object):
This returns a :class:`Result` object.
.. versionadded:: 3.0
The ``catch_exceptions`` parameter was added.
.. versionchanged:: 3.0
The result object now has an `exc_info` attribute with the
traceback if available.
.. versionadded:: 4.0
The ``color`` parameter was added.
:param cli: the command to invoke
:param args: the arguments to invoke. It may be given as an iterable
or a string. When given as string it will be interpreted
@ -308,13 +375,28 @@ class CliRunner(object):
:param extra: the keyword arguments to pass to :meth:`main`.
:param color: whether the output should contain color codes. The
application can still override this explicitly.
.. versionchanged:: 8.0
The result object has the ``return_value`` attribute with
the value returned from the invoked command.
.. versionchanged:: 4.0
Added the ``color`` parameter.
.. versionchanged:: 3.0
Added the ``catch_exceptions`` parameter.
.. versionchanged:: 3.0
The result object has the ``exc_info`` attribute with the
traceback if available.
"""
exc_info = None
with self.isolation(input=input, env=env, color=color) as outstreams:
exception = None
return_value = None
exception: t.Optional[BaseException] = None
exit_code = 0
if isinstance(args, string_types):
if isinstance(args, str):
args = shlex.split(args)
try:
@ -323,20 +405,23 @@ class CliRunner(object):
prog_name = self.get_default_prog_name(cli)
try:
cli.main(args=args or (), prog_name=prog_name, **extra)
return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
except SystemExit as e:
exc_info = sys.exc_info()
exit_code = e.code
if exit_code is None:
exit_code = 0
e_code = t.cast(t.Optional[t.Union[int, t.Any]], e.code)
if exit_code != 0:
if e_code is None:
e_code = 0
if e_code != 0:
exception = e
if not isinstance(exit_code, int):
sys.stdout.write(str(exit_code))
sys.stdout.write('\n')
exit_code = 1
if not isinstance(e_code, int):
sys.stdout.write(str(e_code))
sys.stdout.write("\n")
e_code = 1
exit_code = e_code
except Exception as e:
if not catch_exceptions:
@ -347,28 +432,48 @@ class CliRunner(object):
finally:
sys.stdout.flush()
stdout = outstreams[0].getvalue()
stderr = outstreams[1] and outstreams[1].getvalue()
if self.mix_stderr:
stderr = None
else:
stderr = outstreams[1].getvalue() # type: ignore
return Result(runner=self,
stdout_bytes=stdout,
stderr_bytes=stderr,
exit_code=exit_code,
exception=exception,
exc_info=exc_info)
return Result(
runner=self,
stdout_bytes=stdout,
stderr_bytes=stderr,
return_value=return_value,
exit_code=exit_code,
exception=exception,
exc_info=exc_info, # type: ignore
)
@contextlib.contextmanager
def isolated_filesystem(self):
"""A context manager that creates a temporary folder and changes
the current working directory to it for isolated filesystem tests.
def isolated_filesystem(
self, temp_dir: t.Optional[t.Union[str, os.PathLike]] = None
) -> t.Iterator[str]:
"""A context manager that creates a temporary directory and
changes the current working directory to it. This isolates tests
that affect the contents of the CWD to prevent them from
interfering with each other.
:param temp_dir: Create the temporary directory under this
directory. If given, the created directory is not removed
when exiting.
.. versionchanged:: 8.0
Added the ``temp_dir`` parameter.
"""
cwd = os.getcwd()
t = tempfile.mkdtemp()
os.chdir(t)
dt = tempfile.mkdtemp(dir=temp_dir) # type: ignore[type-var]
os.chdir(dt)
try:
yield t
yield t.cast(str, dt)
finally:
os.chdir(cwd)
try:
shutil.rmtree(t)
except (OSError, IOError):
pass
if temp_dir is None:
try:
shutil.rmtree(dt)
except OSError: # noqa: B014
pass

File diff suppressed because it is too large Load diff

View file

@ -1,92 +1,131 @@
import os
import re
import sys
import typing as t
from functools import update_wrapper
from types import ModuleType
from ._compat import _default_text_stderr
from ._compat import _default_text_stdout
from ._compat import _find_binary_writer
from ._compat import auto_wrap_for_ansi
from ._compat import binary_streams
from ._compat import get_filesystem_encoding
from ._compat import open_stream
from ._compat import should_strip_ansi
from ._compat import strip_ansi
from ._compat import text_streams
from ._compat import WIN
from .globals import resolve_color_default
from ._compat import text_type, open_stream, get_filesystem_encoding, \
get_streerror, string_types, PY2, binary_streams, text_streams, \
filename_to_ui, auto_wrap_for_ansi, strip_ansi, should_strip_ansi, \
_default_text_stdout, _default_text_stderr, is_bytes, WIN
if t.TYPE_CHECKING:
import typing_extensions as te
if not PY2:
from ._compat import _find_binary_writer
elif WIN:
from ._winconsole import _get_windows_argv, \
_hash_py_argv, _initial_argv_hash
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
echo_native_types = string_types + (bytes, bytearray)
def _posixify(name: str) -> str:
return "-".join(name.split()).lower()
def _posixify(name):
return '-'.join(name.split()).lower()
def safecall(func):
def safecall(func: F) -> F:
"""Wraps a function so that it swallows exceptions."""
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs): # type: ignore
try:
return func(*args, **kwargs)
except Exception:
pass
return wrapper
return update_wrapper(t.cast(F, wrapper), func)
def make_str(value):
def make_str(value: t.Any) -> str:
"""Converts a value into a valid string."""
if isinstance(value, bytes):
try:
return value.decode(get_filesystem_encoding())
except UnicodeError:
return value.decode('utf-8', 'replace')
return text_type(value)
return value.decode("utf-8", "replace")
return str(value)
def make_default_short_help(help, max_length=45):
"""Return a condensed version of help string."""
def make_default_short_help(help: str, max_length: int = 45) -> str:
"""Returns a condensed version of help string."""
# Consider only the first paragraph.
paragraph_end = help.find("\n\n")
if paragraph_end != -1:
help = help[:paragraph_end]
# Collapse newlines, tabs, and spaces.
words = help.split()
if not words:
return ""
# The first paragraph started with a "no rewrap" marker, ignore it.
if words[0] == "\b":
words = words[1:]
total_length = 0
result = []
done = False
last_index = len(words) - 1
for word in words:
if word[-1:] == '.':
done = True
new_length = result and 1 + len(word) or len(word)
if total_length + new_length > max_length:
result.append('...')
done = True
else:
if result:
result.append(' ')
result.append(word)
if done:
for i, word in enumerate(words):
total_length += len(word) + (i > 0)
if total_length > max_length: # too long, truncate
break
total_length += new_length
return ''.join(result)
if word[-1] == ".": # sentence end, truncate without "..."
return " ".join(words[: i + 1])
if total_length == max_length and i != last_index:
break # not at sentence end, truncate with "..."
else:
return " ".join(words) # no truncation needed
# Account for the length of the suffix.
total_length += len("...")
# remove words until the length is short enough
while i > 0:
total_length -= len(words[i]) + (i > 0)
if total_length <= max_length:
break
i -= 1
return " ".join(words[:i]) + "..."
class LazyFile(object):
class LazyFile:
"""A lazy file works like a regular file but it does not fully open
the file but it does perform some basic checks early to see if the
filename parameter does make sense. This is useful for safely opening
files for writing.
"""
def __init__(self, filename, mode='r', encoding=None, errors='strict',
atomic=False):
def __init__(
self,
filename: str,
mode: str = "r",
encoding: t.Optional[str] = None,
errors: t.Optional[str] = "strict",
atomic: bool = False,
):
self.name = filename
self.mode = mode
self.encoding = encoding
self.errors = errors
self.atomic = atomic
self._f: t.Optional[t.IO]
if filename == '-':
self._f, self.should_close = open_stream(filename, mode,
encoding, errors)
if filename == "-":
self._f, self.should_close = open_stream(filename, mode, encoding, errors)
else:
if 'r' in mode:
if "r" in mode:
# Open and close the file in case we're opening it for
# reading so that we can catch at least some errors in
# some cases early.
@ -94,15 +133,15 @@ class LazyFile(object):
self._f = None
self.should_close = True
def __getattr__(self, name):
def __getattr__(self, name: str) -> t.Any:
return getattr(self.open(), name)
def __repr__(self):
def __repr__(self) -> str:
if self._f is not None:
return repr(self._f)
return '<unopened file %r %s>' % (self.name, self.mode)
return f"<unopened file '{self.name}' {self.mode}>"
def open(self):
def open(self) -> t.IO:
"""Opens the file if it's not yet open. This call might fail with
a :exc:`FileError`. Not handling this error will produce an error
that Click shows.
@ -110,106 +149,103 @@ class LazyFile(object):
if self._f is not None:
return self._f
try:
rv, self.should_close = open_stream(self.name, self.mode,
self.encoding,
self.errors,
atomic=self.atomic)
except (IOError, OSError) as e:
rv, self.should_close = open_stream(
self.name, self.mode, self.encoding, self.errors, atomic=self.atomic
)
except OSError as e: # noqa: E402
from .exceptions import FileError
raise FileError(self.name, hint=get_streerror(e))
raise FileError(self.name, hint=e.strerror) from e
self._f = rv
return rv
def close(self):
def close(self) -> None:
"""Closes the underlying file, no matter what."""
if self._f is not None:
self._f.close()
def close_intelligently(self):
def close_intelligently(self) -> None:
"""This function only closes the file if it was opened by the lazy
file wrapper. For instance this will never close stdin.
"""
if self.should_close:
self.close()
def __enter__(self):
def __enter__(self) -> "LazyFile":
return self
def __exit__(self, exc_type, exc_value, tb):
def __exit__(self, exc_type, exc_value, tb): # type: ignore
self.close_intelligently()
def __iter__(self):
def __iter__(self) -> t.Iterator[t.AnyStr]:
self.open()
return iter(self._f)
return iter(self._f) # type: ignore
class KeepOpenFile(object):
def __init__(self, file):
class KeepOpenFile:
def __init__(self, file: t.IO) -> None:
self._file = file
def __getattr__(self, name):
def __getattr__(self, name: str) -> t.Any:
return getattr(self._file, name)
def __enter__(self):
def __enter__(self) -> "KeepOpenFile":
return self
def __exit__(self, exc_type, exc_value, tb):
def __exit__(self, exc_type, exc_value, tb): # type: ignore
pass
def __repr__(self):
def __repr__(self) -> str:
return repr(self._file)
def __iter__(self):
def __iter__(self) -> t.Iterator[t.AnyStr]:
return iter(self._file)
def echo(message=None, file=None, nl=True, err=False, color=None):
"""Prints a message plus a newline to the given file or stdout. On
first sight, this looks like the print function, but it has improved
support for handling Unicode and binary data that does not fail no
matter how badly configured the system is.
def echo(
message: t.Optional[t.Any] = None,
file: t.Optional[t.IO[t.Any]] = None,
nl: bool = True,
err: bool = False,
color: t.Optional[bool] = None,
) -> None:
"""Print a message and newline to stdout or a file. This should be
used instead of :func:`print` because it provides better support
for different data, files, and environments.
Primarily it means that you can print binary data as well as Unicode
data on both 2.x and 3.x to the given file in the most appropriate way
possible. This is a very carefree function in that it will try its
best to not fail. As of Click 6.0 this includes support for unicode
output on the Windows console.
Compared to :func:`print`, this does the following:
In addition to that, if `colorama`_ is installed, the echo function will
also support clever handling of ANSI codes. Essentially it will then
do the following:
- Ensures that the output encoding is not misconfigured on Linux.
- Supports Unicode in the Windows console.
- Supports writing to binary outputs, and supports writing bytes
to text outputs.
- Supports colors and styles on Windows.
- Removes ANSI color and style codes if the output does not look
like an interactive terminal.
- Always flushes the output.
- add transparent handling of ANSI color codes on Windows.
- hide ANSI codes automatically if the destination file is not a
terminal.
.. _colorama: https://pypi.org/project/colorama/
:param message: The string or bytes to output. Other objects are
converted to strings.
:param file: The file to write to. Defaults to ``stdout``.
:param err: Write to ``stderr`` instead of ``stdout``.
:param nl: Print a newline after the message. Enabled by default.
:param color: Force showing or hiding colors and other styles. By
default Click will remove color if the output does not look like
an interactive terminal.
.. versionchanged:: 6.0
As of Click 6.0 the echo function will properly support unicode
output on the windows console. Not that click does not modify
the interpreter in any way which means that `sys.stdout` or the
print statement or function will still not provide unicode support.
.. versionchanged:: 2.0
Starting with version 2.0 of Click, the echo function will work
with colorama if it's installed.
.. versionadded:: 3.0
The `err` parameter was added.
Support Unicode output on the Windows console. Click does not
modify ``sys.stdout``, so ``sys.stdout.write()`` and ``print()``
will still not support Unicode.
.. versionchanged:: 4.0
Added the `color` flag.
Added the ``color`` parameter.
:param message: the message to print
:param file: the file to write to (defaults to ``stdout``)
:param err: if set to true the file defaults to ``stderr`` instead of
``stdout``. This is faster and easier than calling
:func:`get_text_stderr` yourself.
:param nl: if set to `True` (the default) a newline is printed afterwards.
:param color: controls if the terminal supports ANSI colors or not. The
default is autodetection.
.. versionadded:: 3.0
Added the ``err`` parameter.
.. versionchanged:: 2.0
Support colors on Windows if colorama is installed.
"""
if file is None:
if err:
@ -218,70 +254,73 @@ def echo(message=None, file=None, nl=True, err=False, color=None):
file = _default_text_stdout()
# Convert non bytes/text into the native string type.
if message is not None and not isinstance(message, echo_native_types):
message = text_type(message)
if message is not None and not isinstance(message, (str, bytes, bytearray)):
out: t.Optional[t.Union[str, bytes]] = str(message)
else:
out = message
if nl:
message = message or u''
if isinstance(message, text_type):
message += u'\n'
out = out or ""
if isinstance(out, str):
out += "\n"
else:
message += b'\n'
out += b"\n"
# If there is a message, and we're in Python 3, and the value looks
# like bytes, we manually need to find the binary stream and write the
# message in there. This is done separately so that most stream
# types will work as you would expect. Eg: you can write to StringIO
# for other cases.
if message and not PY2 and is_bytes(message):
if not out:
file.flush()
return
# If there is a message and the value looks like bytes, we manually
# need to find the binary stream and write the message in there.
# This is done separately so that most stream types will work as you
# would expect. Eg: you can write to StringIO for other cases.
if isinstance(out, (bytes, bytearray)):
binary_file = _find_binary_writer(file)
if binary_file is not None:
file.flush()
binary_file.write(message)
binary_file.write(out)
binary_file.flush()
return
# ANSI-style support. If there is no message or we are dealing with
# bytes nothing is happening. If we are connected to a file we want
# to strip colors. If we are on windows we either wrap the stream
# to strip the color or we use the colorama support to translate the
# ansi codes to API calls.
if message and not is_bytes(message):
# ANSI style code support. For no message or bytes, nothing happens.
# When outputting to a file instead of a terminal, strip codes.
else:
color = resolve_color_default(color)
if should_strip_ansi(file, color):
message = strip_ansi(message)
out = strip_ansi(out)
elif WIN:
if auto_wrap_for_ansi is not None:
file = auto_wrap_for_ansi(file)
file = auto_wrap_for_ansi(file) # type: ignore
elif not color:
message = strip_ansi(message)
out = strip_ansi(out)
if message:
file.write(message)
file.write(out) # type: ignore
file.flush()
def get_binary_stream(name):
"""Returns a system stream for byte processing. This essentially
returns the stream from the sys module with the given name but it
solves some compatibility issues between different Python versions.
Primarily this function is necessary for getting binary streams on
Python 3.
def get_binary_stream(name: "te.Literal['stdin', 'stdout', 'stderr']") -> t.BinaryIO:
"""Returns a system stream for byte processing.
:param name: the name of the stream to open. Valid names are ``'stdin'``,
``'stdout'`` and ``'stderr'``
"""
opener = binary_streams.get(name)
if opener is None:
raise TypeError('Unknown standard stream %r' % name)
raise TypeError(f"Unknown standard stream '{name}'")
return opener()
def get_text_stream(name, encoding=None, errors='strict'):
def get_text_stream(
name: "te.Literal['stdin', 'stdout', 'stderr']",
encoding: t.Optional[str] = None,
errors: t.Optional[str] = "strict",
) -> t.TextIO:
"""Returns a system stream for text processing. This usually returns
a wrapped stream around a binary stream returned from
:func:`get_binary_stream` but it also can take shortcuts on Python 3
for already correctly configured streams.
:func:`get_binary_stream` but it also can take shortcuts for already
correctly configured streams.
:param name: the name of the stream to open. Valid names are ``'stdin'``,
``'stdout'`` and ``'stderr'``
@ -290,65 +329,60 @@ def get_text_stream(name, encoding=None, errors='strict'):
"""
opener = text_streams.get(name)
if opener is None:
raise TypeError('Unknown standard stream %r' % name)
raise TypeError(f"Unknown standard stream '{name}'")
return opener(encoding, errors)
def open_file(filename, mode='r', encoding=None, errors='strict',
lazy=False, atomic=False):
"""This is similar to how the :class:`File` works but for manual
usage. Files are opened non lazy by default. This can open regular
files as well as stdin/stdout if ``'-'`` is passed.
def open_file(
filename: str,
mode: str = "r",
encoding: t.Optional[str] = None,
errors: t.Optional[str] = "strict",
lazy: bool = False,
atomic: bool = False,
) -> t.IO:
"""Open a file, with extra behavior to handle ``'-'`` to indicate
a standard stream, lazy open on write, and atomic write. Similar to
the behavior of the :class:`~click.File` param type.
If stdin/stdout is returned the stream is wrapped so that the context
manager will not close the stream accidentally. This makes it possible
to always use the function like this without having to worry to
accidentally close a standard stream::
If ``'-'`` is given to open ``stdout`` or ``stdin``, the stream is
wrapped so that using it in a context manager will not close it.
This makes it possible to use the function without accidentally
closing a standard stream:
.. code-block:: python
with open_file(filename) as f:
...
.. versionadded:: 3.0
:param filename: The name of the file to open, or ``'-'`` for
``stdin``/``stdout``.
:param mode: The mode in which to open the file.
:param encoding: The encoding to decode or encode a file opened in
text mode.
:param errors: The error handling mode.
:param lazy: Wait to open the file until it is accessed. For read
mode, the file is temporarily opened to raise access errors
early, then closed until it is read again.
:param atomic: Write to a temporary file and replace the given file
on close.
:param filename: the name of the file to open (or ``'-'`` for stdin/stdout).
:param mode: the mode in which to open the file.
:param encoding: the encoding to use.
:param errors: the error handling for this file.
:param lazy: can be flipped to true to open the file lazily.
:param atomic: in atomic mode writes go into a temporary file and it's
moved on close.
.. versionadded:: 3.0
"""
if lazy:
return LazyFile(filename, mode, encoding, errors, atomic=atomic)
f, should_close = open_stream(filename, mode, encoding, errors,
atomic=atomic)
return t.cast(t.IO, LazyFile(filename, mode, encoding, errors, atomic=atomic))
f, should_close = open_stream(filename, mode, encoding, errors, atomic=atomic)
if not should_close:
f = KeepOpenFile(f)
f = t.cast(t.IO, KeepOpenFile(f))
return f
def get_os_args():
"""This returns the argument part of sys.argv in the most appropriate
form for processing. What this means is that this return value is in
a format that works for Click to process but does not necessarily
correspond well to what's actually standard for the interpreter.
On most environments the return value is ``sys.argv[:1]`` unchanged.
However if you are on Windows and running Python 2 the return value
will actually be a list of unicode strings instead because the
default behavior on that platform otherwise will not be able to
carry all possible values that sys.argv can have.
.. versionadded:: 6.0
"""
# We can only extract the unicode argv if sys.argv has not been
# changed since the startup of the application.
if PY2 and WIN and _initial_argv_hash == _hash_py_argv():
return _get_windows_argv()
return sys.argv[1:]
def format_filename(filename, shorten=False):
def format_filename(
filename: t.Union[str, bytes, os.PathLike], shorten: bool = False
) -> str:
"""Formats a filename for user display. The main purpose of this
function is to ensure that the filename can be displayed at all. This
will decode the filename to unicode if necessary in a way that it will
@ -362,10 +396,11 @@ def format_filename(filename, shorten=False):
"""
if shorten:
filename = os.path.basename(filename)
return filename_to_ui(filename)
return os.fsdecode(filename)
def get_app_dir(app_name, roaming=True, force_posix=False):
def get_app_dir(app_name: str, roaming: bool = True, force_posix: bool = False) -> str:
r"""Returns the config folder for the application. The default behavior
is to return whatever is most appropriate for the operating system.
@ -380,13 +415,9 @@ def get_app_dir(app_name, roaming=True, force_posix=False):
``~/.config/foo-bar``
Unix (POSIX):
``~/.foo-bar``
Win XP (roaming):
``C:\Documents and Settings\<user>\Local Settings\Application Data\Foo Bar``
Win XP (not roaming):
``C:\Documents and Settings\<user>\Application Data\Foo Bar``
Win 7 (roaming):
Windows (roaming):
``C:\Users\<user>\AppData\Roaming\Foo Bar``
Win 7 (not roaming):
Windows (not roaming):
``C:\Users\<user>\AppData\Local\Foo Bar``
.. versionadded:: 2.0
@ -401,22 +432,24 @@ def get_app_dir(app_name, roaming=True, force_posix=False):
application support folder.
"""
if WIN:
key = roaming and 'APPDATA' or 'LOCALAPPDATA'
key = "APPDATA" if roaming else "LOCALAPPDATA"
folder = os.environ.get(key)
if folder is None:
folder = os.path.expanduser('~')
folder = os.path.expanduser("~")
return os.path.join(folder, app_name)
if force_posix:
return os.path.join(os.path.expanduser('~/.' + _posixify(app_name)))
if sys.platform == 'darwin':
return os.path.join(os.path.expanduser(
'~/Library/Application Support'), app_name)
return os.path.join(os.path.expanduser(f"~/.{_posixify(app_name)}"))
if sys.platform == "darwin":
return os.path.join(
os.path.expanduser("~/Library/Application Support"), app_name
)
return os.path.join(
os.environ.get('XDG_CONFIG_HOME', os.path.expanduser('~/.config')),
_posixify(app_name))
os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
_posixify(app_name),
)
class PacifyFlushWrapper(object):
class PacifyFlushWrapper:
"""This wrapper is used to catch and suppress BrokenPipeErrors resulting
from ``.flush()`` being called on broken pipe during the shutdown/final-GC
of the Python interpreter. Notably ``.flush()`` is always called on
@ -425,16 +458,123 @@ class PacifyFlushWrapper(object):
pipe, all calls and attributes are proxied.
"""
def __init__(self, wrapped):
def __init__(self, wrapped: t.IO) -> None:
self.wrapped = wrapped
def flush(self):
def flush(self) -> None:
try:
self.wrapped.flush()
except IOError as e:
except OSError as e:
import errno
if e.errno != errno.EPIPE:
raise
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> t.Any:
return getattr(self.wrapped, attr)
def _detect_program_name(
path: t.Optional[str] = None, _main: t.Optional[ModuleType] = None
) -> str:
"""Determine the command used to run the program, for use in help
text. If a file or entry point was executed, the file name is
returned. If ``python -m`` was used to execute a module or package,
``python -m name`` is returned.
This doesn't try to be too precise, the goal is to give a concise
name for help text. Files are only shown as their name without the
path. ``python`` is only shown for modules, and the full path to
``sys.executable`` is not shown.
:param path: The Python file being executed. Python puts this in
``sys.argv[0]``, which is used by default.
:param _main: The ``__main__`` module. This should only be passed
during internal testing.
.. versionadded:: 8.0
Based on command args detection in the Werkzeug reloader.
:meta private:
"""
if _main is None:
_main = sys.modules["__main__"]
if not path:
path = sys.argv[0]
# The value of __package__ indicates how Python was called. It may
# not exist if a setuptools script is installed as an egg. It may be
# set incorrectly for entry points created with pip on Windows.
if getattr(_main, "__package__", None) is None or (
os.name == "nt"
and _main.__package__ == ""
and not os.path.exists(path)
and os.path.exists(f"{path}.exe")
):
# Executed a file, like "python app.py".
return os.path.basename(path)
# Executed a module, like "python -m example".
# Rewritten by Python from "-m script" to "/path/to/script.py".
# Need to look at main module to determine how it was executed.
py_module = t.cast(str, _main.__package__)
name = os.path.splitext(os.path.basename(path))[0]
# A submodule like "example.cli".
if name != "__main__":
py_module = f"{py_module}.{name}"
return f"python -m {py_module.lstrip('.')}"
def _expand_args(
args: t.Iterable[str],
*,
user: bool = True,
env: bool = True,
glob_recursive: bool = True,
) -> t.List[str]:
"""Simulate Unix shell expansion with Python functions.
See :func:`glob.glob`, :func:`os.path.expanduser`, and
:func:`os.path.expandvars`.
This is intended for use on Windows, where the shell does not do any
expansion. It may not exactly match what a Unix shell would do.
:param args: List of command line arguments to expand.
:param user: Expand user home directory.
:param env: Expand environment variables.
:param glob_recursive: ``**`` matches directories recursively.
.. versionchanged:: 8.1
Invalid glob patterns are treated as empty expansions rather
than raising an error.
.. versionadded:: 8.0
:meta private:
"""
from glob import glob
out = []
for arg in args:
if user:
arg = os.path.expanduser(arg)
if env:
arg = os.path.expandvars(arg)
try:
matches = glob(arg, recursive=glob_recursive)
except re.error:
matches = []
if not matches:
out.append(arg)
else:
out.extend(matches)
return out

File diff suppressed because it is too large Load diff

View file

@ -1,2 +0,0 @@
"""Project version"""
__version__ = '5.1.0'

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
# ######################### LICENSE ############################ #
# Copyright (c) 2005-2018, Michele Simionato
# Copyright (c) 2005-2021, Michele Simionato
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
@ -28,49 +28,26 @@
# DAMAGE.
"""
Decorator module, see http://pypi.python.org/pypi/decorator
Decorator module, see
https://github.com/micheles/decorator/blob/master/docs/documentation.md
for the documentation.
"""
from __future__ import print_function
import re
import sys
import inspect
import operator
import itertools
import collections
__version__ = '4.3.0'
if sys.version >= '3':
from inspect import getfullargspec
def get_init(cls):
return cls.__init__
else:
FullArgSpec = collections.namedtuple(
'FullArgSpec', 'args varargs varkw defaults '
'kwonlyargs kwonlydefaults annotations')
def getfullargspec(f):
"A quick and dirty replacement for getfullargspec for Python 2.X"
return FullArgSpec._make(inspect.getargspec(f) + ([], None, {}))
def get_init(cls):
return cls.__init__.__func__
try:
iscoroutinefunction = inspect.iscoroutinefunction
except AttributeError:
# let's assume there are no coroutine functions in old Python
def iscoroutinefunction(f):
return False
from contextlib import _GeneratorContextManager
from inspect import getfullargspec, iscoroutinefunction, isgeneratorfunction
__version__ = '5.1.1'
DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(')
POS = inspect.Parameter.POSITIONAL_OR_KEYWORD
EMPTY = inspect.Parameter.empty
# basic functionality
# this is not used anymore in the core, but kept for backward compatibility
class FunctionMaker(object):
"""
An object with the ability to create functions with a given signature.
@ -94,7 +71,7 @@ class FunctionMaker(object):
self.name = '_lambda_'
self.doc = func.__doc__
self.module = func.__module__
if inspect.isfunction(func):
if inspect.isroutine(func):
argspec = getfullargspec(func)
self.annotations = getattr(func, '__annotations__', {})
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
@ -137,7 +114,9 @@ class FunctionMaker(object):
raise TypeError('You are decorating a non function: %s' % func)
def update(self, func, **kw):
"Update the signature of func with the data in self"
"""
Update the signature of func with the data in self
"""
func.__name__ = self.name
func.__doc__ = getattr(self, 'doc', None)
func.__dict__ = getattr(self, 'dict', {})
@ -154,7 +133,9 @@ class FunctionMaker(object):
func.__dict__.update(kw)
def make(self, src_templ, evaldict=None, addsource=False, **attrs):
"Make a new function from a given template and update the signature"
"""
Make a new function from a given template and update the signature
"""
src = src_templ % vars(self) # expand name and signature
evaldict = evaldict or {}
mo = DEF.search(src)
@ -173,7 +154,7 @@ class FunctionMaker(object):
# Ensure each generated function has a unique filename for profilers
# (such as cProfile) that depend on the tuple of (<filename>,
# <definition line>, <function name>) being unique.
filename = '<decorator-gen-%d>' % (next(self._compile_count),)
filename = '<decorator-gen-%d>' % next(self._compile_count)
try:
code = compile(src, filename, 'single')
exec(code, evaldict)
@ -215,90 +196,128 @@ class FunctionMaker(object):
return self.make(body, evaldict, addsource, **attrs)
def decorate(func, caller, extras=()):
def fix(args, kwargs, sig):
"""
decorate(func, caller) decorates a function using a caller.
Fix args and kwargs to be consistent with the signature
"""
evaldict = dict(_call_=caller, _func_=func)
es = ''
for i, extra in enumerate(extras):
ex = '_e%d_' % i
evaldict[ex] = extra
es += ex + ', '
fun = FunctionMaker.create(
func, "return _call_(_func_, %s%%(shortsignature)s)" % es,
evaldict, __wrapped__=func)
if hasattr(func, '__qualname__'):
fun.__qualname__ = func.__qualname__
ba = sig.bind(*args, **kwargs)
ba.apply_defaults() # needed for test_dan_schult
return ba.args, ba.kwargs
def decorate(func, caller, extras=(), kwsyntax=False):
"""
Decorates a function/generator/coroutine using a caller.
If kwsyntax is True calling the decorated functions with keyword
syntax will pass the named arguments inside the ``kw`` dictionary,
even if such argument are positional, similarly to what functools.wraps
does. By default kwsyntax is False and the the arguments are untouched.
"""
sig = inspect.signature(func)
if iscoroutinefunction(caller):
async def fun(*args, **kw):
if not kwsyntax:
args, kw = fix(args, kw, sig)
return await caller(func, *(extras + args), **kw)
elif isgeneratorfunction(caller):
def fun(*args, **kw):
if not kwsyntax:
args, kw = fix(args, kw, sig)
for res in caller(func, *(extras + args), **kw):
yield res
else:
def fun(*args, **kw):
if not kwsyntax:
args, kw = fix(args, kw, sig)
return caller(func, *(extras + args), **kw)
fun.__name__ = func.__name__
fun.__doc__ = func.__doc__
fun.__wrapped__ = func
fun.__signature__ = sig
fun.__qualname__ = func.__qualname__
# builtin functions like defaultdict.__setitem__ lack many attributes
try:
fun.__defaults__ = func.__defaults__
except AttributeError:
pass
try:
fun.__kwdefaults__ = func.__kwdefaults__
except AttributeError:
pass
try:
fun.__annotations__ = func.__annotations__
except AttributeError:
pass
try:
fun.__module__ = func.__module__
except AttributeError:
pass
try:
fun.__dict__.update(func.__dict__)
except AttributeError:
pass
return fun
def decorator(caller, _func=None):
"""decorator(caller) converts a caller function into a decorator"""
def decoratorx(caller):
"""
A version of "decorator" implemented via "exec" and not via the
Signature object. Use this if you are want to preserve the `.__code__`
object properties (https://github.com/micheles/decorator/issues/129).
"""
def dec(func):
return FunctionMaker.create(
func,
"return _call_(_func_, %(shortsignature)s)",
dict(_call_=caller, _func_=func),
__wrapped__=func, __qualname__=func.__qualname__)
return dec
def decorator(caller, _func=None, kwsyntax=False):
"""
decorator(caller) converts a caller function into a decorator
"""
if _func is not None: # return a decorated function
# this is obsolete behavior; you should use decorate instead
return decorate(_func, caller)
return decorate(_func, caller, (), kwsyntax)
# else return a decorator function
defaultargs, defaults = '', ()
if inspect.isclass(caller):
name = caller.__name__.lower()
doc = 'decorator(%s) converts functions/generators into ' \
'factories of %s objects' % (caller.__name__, caller.__name__)
elif inspect.isfunction(caller):
if caller.__name__ == '<lambda>':
name = '_lambda_'
sig = inspect.signature(caller)
dec_params = [p for p in sig.parameters.values() if p.kind is POS]
def dec(func=None, *args, **kw):
na = len(args) + 1
extras = args + tuple(kw.get(p.name, p.default)
for p in dec_params[na:]
if p.default is not EMPTY)
if func is None:
return lambda func: decorate(func, caller, extras, kwsyntax)
else:
name = caller.__name__
doc = caller.__doc__
nargs = caller.__code__.co_argcount
ndefs = len(caller.__defaults__ or ())
defaultargs = ', '.join(caller.__code__.co_varnames[nargs-ndefs:nargs])
if defaultargs:
defaultargs += ','
defaults = caller.__defaults__
else: # assume caller is an object with a __call__ method
name = caller.__class__.__name__.lower()
doc = caller.__call__.__doc__
evaldict = dict(_call=caller, _decorate_=decorate)
dec = FunctionMaker.create(
'%s(%s func)' % (name, defaultargs),
'if func is None: return lambda func: _decorate_(func, _call, (%s))\n'
'return _decorate_(func, _call, (%s))' % (defaultargs, defaultargs),
evaldict, doc=doc, module=caller.__module__, __wrapped__=caller)
if defaults:
dec.__defaults__ = defaults + (None,)
return decorate(func, caller, extras, kwsyntax)
dec.__signature__ = sig.replace(parameters=dec_params)
dec.__name__ = caller.__name__
dec.__doc__ = caller.__doc__
dec.__wrapped__ = caller
dec.__qualname__ = caller.__qualname__
dec.__kwdefaults__ = getattr(caller, '__kwdefaults__', None)
dec.__dict__.update(caller.__dict__)
return dec
# ####################### contextmanager ####################### #
try: # Python >= 3.2
from contextlib import _GeneratorContextManager
except ImportError: # Python >= 2.5
from contextlib import GeneratorContextManager as _GeneratorContextManager
class ContextManager(_GeneratorContextManager):
def __init__(self, g, *a, **k):
_GeneratorContextManager.__init__(self, g, a, k)
def __call__(self, func):
"""Context manager decorator"""
return FunctionMaker.create(
func, "with _self_: return _func_(%(shortsignature)s)",
dict(_self_=self, _func_=func), __wrapped__=func)
def caller(f, *a, **k):
with self.__class__(self.func, *self.args, **self.kwds):
return f(*a, **k)
return decorate(func, caller)
init = getfullargspec(_GeneratorContextManager.__init__)
n_args = len(init.args)
if n_args == 2 and not init.varargs: # (self, genobj) Python 2.7
def __init__(self, g, *a, **k):
return _GeneratorContextManager.__init__(self, g(*a, **k))
ContextManager.__init__ = __init__
elif n_args == 2 and init.varargs: # (self, gen, *a, **k) Python 3.4
pass
elif n_args == 4: # (self, gen, args, kwds) Python 3.5
def __init__(self, g, *a, **k):
return _GeneratorContextManager.__init__(self, g, a, k)
ContextManager.__init__ = __init__
_contextmanager = decorator(ContextManager)

View file

@ -1,4 +1,4 @@
__version__ = '0.7.1'
__version__ = "1.1.8"
from .lock import Lock # noqa
from .lock import NeedRegenerationException # noqa

View file

@ -1,4 +1,6 @@
from .region import CacheRegion, register_backend, make_region # noqa
from .region import CacheRegion # noqa
from .region import make_region # noqa
from .region import register_backend # noqa
from .. import __version__ # noqa
# backwards compat
from .. import __version__ # noqa

View file

@ -1,14 +1,22 @@
import operator
from ..util.compat import py3k
import abc
import pickle
from typing import Any
from typing import Callable
from typing import cast
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Union
class NoValue(object):
class NoValue:
"""Describe a missing cache value.
The :attr:`.NO_VALUE` module global
should be used.
The :data:`.NO_VALUE` constant should be used.
"""
@property
def payload(self):
return self
@ -18,49 +26,125 @@ class NoValue(object):
fill another cache key.
"""
return '<dogpile.cache.api.NoValue object>'
return "<dogpile.cache.api.NoValue object>"
if py3k:
def __bool__(self): # pragma NO COVERAGE
return False
else:
def __nonzero__(self): # pragma NO COVERAGE
return False
def __bool__(self): # pragma NO COVERAGE
return False
NO_VALUE = NoValue()
"""Value returned from ``get()`` that describes
a key not present."""
MetaDataType = Mapping[str, Any]
class CachedValue(tuple):
KeyType = str
"""A cache key."""
ValuePayload = Any
"""An object to be placed in the cache against a key."""
KeyManglerType = Callable[[KeyType], KeyType]
Serializer = Callable[[ValuePayload], bytes]
Deserializer = Callable[[bytes], ValuePayload]
class CacheMutex(abc.ABC):
"""Describes a mutexing object with acquire and release methods.
This is an abstract base class; any object that has acquire/release
methods may be used.
.. versionadded:: 1.1
.. seealso::
:meth:`.CacheBackend.get_mutex` - the backend method that optionally
returns this locking object.
"""
@abc.abstractmethod
def acquire(self, wait: bool = True) -> bool:
"""Acquire the mutex.
:param wait: if True, block until available, else return True/False
immediately.
:return: True if the lock succeeded.
"""
raise NotImplementedError()
@abc.abstractmethod
def release(self) -> None:
"""Release the mutex."""
raise NotImplementedError()
@abc.abstractmethod
def locked(self) -> bool:
"""Check if the mutex was acquired.
:return: true if the lock is acquired.
.. versionadded:: 1.1.2
"""
raise NotImplementedError()
@classmethod
def __subclasshook__(cls, C):
return hasattr(C, "acquire") and hasattr(C, "release")
class CachedValue(NamedTuple):
"""Represent a value stored in the cache.
:class:`.CachedValue` is a two-tuple of
``(payload, metadata)``, where ``metadata``
is dogpile.cache's tracking information (
currently the creation time). The metadata
and tuple structure is pickleable, if
the backend requires serialization.
currently the creation time).
"""
payload = property(operator.itemgetter(0))
"""Named accessor for the payload."""
metadata = property(operator.itemgetter(1))
"""Named accessor for the dogpile.cache metadata dictionary."""
payload: ValuePayload
def __new__(cls, payload, metadata):
return tuple.__new__(cls, (payload, metadata))
def __reduce__(self):
return CachedValue, (self.payload, self.metadata)
metadata: MetaDataType
class CacheBackend(object):
"""Base class for backend implementations."""
CacheReturnType = Union[CachedValue, NoValue]
"""The non-serialized form of what may be returned from a backend
get method.
key_mangler = None
"""
SerializedReturnType = Union[bytes, NoValue]
"""the serialized form of what may be returned from a backend get method."""
BackendFormatted = Union[CacheReturnType, SerializedReturnType]
"""Describes the type returned from the :meth:`.CacheBackend.get` method."""
BackendSetType = Union[CachedValue, bytes]
"""Describes the value argument passed to the :meth:`.CacheBackend.set`
method."""
BackendArguments = Mapping[str, Any]
class CacheBackend:
"""Base class for backend implementations.
Backends which set and get Python object values should subclass this
backend. For backends in which the value that's stored is ultimately
a stream of bytes, the :class:`.BytesBackend` should be used.
"""
key_mangler: Optional[Callable[[KeyType], KeyType]] = None
"""Key mangling function.
May be None, or otherwise declared
@ -68,7 +152,23 @@ class CacheBackend(object):
"""
def __init__(self, arguments): # pragma NO COVERAGE
serializer: Union[None, Serializer] = None
"""Serializer function that will be used by default if not overridden
by the region.
.. versionadded:: 1.1
"""
deserializer: Union[None, Deserializer] = None
"""deserializer function that will be used by default if not overridden
by the region.
.. versionadded:: 1.1
"""
def __init__(self, arguments: BackendArguments): # pragma NO COVERAGE
"""Construct a new :class:`.CacheBackend`.
Subclasses should override this to
@ -91,10 +191,10 @@ class CacheBackend(object):
)
)
def has_lock_timeout(self):
def has_lock_timeout(self) -> bool:
return False
def get_mutex(self, key):
def get_mutex(self, key: KeyType) -> Optional[CacheMutex]:
"""Return an optional mutexing object for the given key.
This object need only provide an ``acquire()``
@ -127,48 +227,141 @@ class CacheBackend(object):
"""
return None
def get(self, key): # pragma NO COVERAGE
"""Retrieve a value from the cache.
def get(self, key: KeyType) -> BackendFormatted: # pragma NO COVERAGE
"""Retrieve an optionally serialized value from the cache.
The returned value should be an instance of
:class:`.CachedValue`, or ``NO_VALUE`` if
not present.
:param key: String key that was passed to the :meth:`.CacheRegion.get`
method, which will also be processed by the "key mangling" function
if one was present.
:return: the Python object that corresponds to
what was established via the :meth:`.CacheBackend.set` method,
or the :data:`.NO_VALUE` constant if not present.
If a serializer is in use, this method will only be called if the
:meth:`.CacheBackend.get_serialized` method is not overridden.
"""
raise NotImplementedError()
def get_multi(self, keys): # pragma NO COVERAGE
"""Retrieve multiple values from the cache.
def get_multi(
self, keys: Sequence[KeyType]
) -> Sequence[BackendFormatted]: # pragma NO COVERAGE
"""Retrieve multiple optionally serialized values from the cache.
The returned value should be a list, corresponding
to the list of keys given.
:param keys: sequence of string keys that was passed to the
:meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.
:return a list of values as would be returned
individually via the :meth:`.CacheBackend.get` method, corresponding
to the list of keys given.
If a serializer is in use, this method will only be called if the
:meth:`.CacheBackend.get_serialized_multi` method is not overridden.
.. versionadded:: 0.5.0
"""
raise NotImplementedError()
def set(self, key, value): # pragma NO COVERAGE
"""Set a value in the cache.
def get_serialized(self, key: KeyType) -> SerializedReturnType:
"""Retrieve a serialized value from the cache.
The key will be whatever was passed
to the registry, processed by the
"key mangling" function, if any.
The value will always be an instance
of :class:`.CachedValue`.
:param key: String key that was passed to the :meth:`.CacheRegion.get`
method, which will also be processed by the "key mangling" function
if one was present.
:return: a bytes object, or :data:`.NO_VALUE`
constant if not present.
The default implementation of this method for :class:`.CacheBackend`
returns the value of the :meth:`.CacheBackend.get` method.
.. versionadded:: 1.1
.. seealso::
:class:`.BytesBackend`
"""
return cast(SerializedReturnType, self.get(key))
def get_serialized_multi(
self, keys: Sequence[KeyType]
) -> Sequence[SerializedReturnType]: # pragma NO COVERAGE
"""Retrieve multiple serialized values from the cache.
:param keys: sequence of string keys that was passed to the
:meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.
:return: list of bytes objects
The default implementation of this method for :class:`.CacheBackend`
returns the value of the :meth:`.CacheBackend.get_multi` method.
.. versionadded:: 1.1
.. seealso::
:class:`.BytesBackend`
"""
return cast(Sequence[SerializedReturnType], self.get_multi(keys))
def set(
self, key: KeyType, value: BackendSetType
) -> None: # pragma NO COVERAGE
"""Set an optionally serialized value in the cache.
:param key: String key that was passed to the :meth:`.CacheRegion.set`
method, which will also be processed by the "key mangling" function
if one was present.
:param value: The optionally serialized :class:`.CachedValue` object.
May be an instance of :class:`.CachedValue` or a bytes object
depending on if a serializer is in use with the region and if the
:meth:`.CacheBackend.set_serialized` method is not overridden.
.. seealso::
:meth:`.CacheBackend.set_serialized`
"""
raise NotImplementedError()
def set_multi(self, mapping): # pragma NO COVERAGE
def set_serialized(
self, key: KeyType, value: bytes
) -> None: # pragma NO COVERAGE
"""Set a serialized value in the cache.
:param key: String key that was passed to the :meth:`.CacheRegion.set`
method, which will also be processed by the "key mangling" function
if one was present.
:param value: a bytes object to be stored.
The default implementation of this method for :class:`.CacheBackend`
calls upon the :meth:`.CacheBackend.set` method.
.. versionadded:: 1.1
.. seealso::
:class:`.BytesBackend`
"""
self.set(key, value)
def set_multi(
self, mapping: Mapping[KeyType, BackendSetType]
) -> None: # pragma NO COVERAGE
"""Set multiple values in the cache.
``mapping`` is a dict in which
the key will be whatever was passed
to the registry, processed by the
"key mangling" function, if any.
The value will always be an instance
of :class:`.CachedValue`.
:param mapping: a dict in which the key will be whatever was passed to
the :meth:`.CacheRegion.set_multi` method, processed by the "key
mangling" function, if any.
When implementing a new :class:`.CacheBackend` or cutomizing via
:class:`.ProxyBackend`, be aware that when this method is invoked by
@ -178,17 +371,52 @@ class CacheBackend(object):
-- that will have the undesirable effect of modifying the returned
values as well.
If a serializer is in use, this method will only be called if the
:meth:`.CacheBackend.set_serialized_multi` method is not overridden.
.. versionadded:: 0.5.0
"""
raise NotImplementedError()
def delete(self, key): # pragma NO COVERAGE
def set_serialized_multi(
self, mapping: Mapping[KeyType, bytes]
) -> None: # pragma NO COVERAGE
"""Set multiple serialized values in the cache.
:param mapping: a dict in which the key will be whatever was passed to
the :meth:`.CacheRegion.set_multi` method, processed by the "key
mangling" function, if any.
When implementing a new :class:`.CacheBackend` or cutomizing via
:class:`.ProxyBackend`, be aware that when this method is invoked by
:meth:`.Region.get_or_create_multi`, the ``mapping`` values are the
same ones returned to the upstream caller. If the subclass alters the
values in any way, it must not do so 'in-place' on the ``mapping`` dict
-- that will have the undesirable effect of modifying the returned
values as well.
.. versionadded:: 1.1
The default implementation of this method for :class:`.CacheBackend`
calls upon the :meth:`.CacheBackend.set_multi` method.
.. seealso::
:class:`.BytesBackend`
"""
self.set_multi(mapping)
def delete(self, key: KeyType) -> None: # pragma NO COVERAGE
"""Delete a value from the cache.
The key will be whatever was passed
to the registry, processed by the
"key mangling" function, if any.
:param key: String key that was passed to the
:meth:`.CacheRegion.delete`
method, which will also be processed by the "key mangling" function
if one was present.
The behavior here should be idempotent,
that is, can be called any number of times
@ -197,12 +425,14 @@ class CacheBackend(object):
"""
raise NotImplementedError()
def delete_multi(self, keys): # pragma NO COVERAGE
def delete_multi(
self, keys: Sequence[KeyType]
) -> None: # pragma NO COVERAGE
"""Delete multiple values from the cache.
The key will be whatever was passed
to the registry, processed by the
"key mangling" function, if any.
:param keys: sequence of string keys that was passed to the
:meth:`.CacheRegion.delete_multi` method, which will also be processed
by the "key mangling" function if one was present.
The behavior here should be idempotent,
that is, can be called any number of times
@ -213,3 +443,95 @@ class CacheBackend(object):
"""
raise NotImplementedError()
class DefaultSerialization:
serializer: Union[None, Serializer] = staticmethod( # type: ignore
pickle.dumps
)
deserializer: Union[None, Deserializer] = staticmethod( # type: ignore
pickle.loads
)
class BytesBackend(DefaultSerialization, CacheBackend):
"""A cache backend that receives and returns series of bytes.
This backend only supports the "serialized" form of values; subclasses
should implement :meth:`.BytesBackend.get_serialized`,
:meth:`.BytesBackend.get_serialized_multi`,
:meth:`.BytesBackend.set_serialized`,
:meth:`.BytesBackend.set_serialized_multi`.
.. versionadded:: 1.1
"""
def get_serialized(self, key: KeyType) -> SerializedReturnType:
"""Retrieve a serialized value from the cache.
:param key: String key that was passed to the :meth:`.CacheRegion.get`
method, which will also be processed by the "key mangling" function
if one was present.
:return: a bytes object, or :data:`.NO_VALUE`
constant if not present.
.. versionadded:: 1.1
"""
raise NotImplementedError()
def get_serialized_multi(
self, keys: Sequence[KeyType]
) -> Sequence[SerializedReturnType]: # pragma NO COVERAGE
"""Retrieve multiple serialized values from the cache.
:param keys: sequence of string keys that was passed to the
:meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.
:return: list of bytes objects
.. versionadded:: 1.1
"""
raise NotImplementedError()
def set_serialized(
self, key: KeyType, value: bytes
) -> None: # pragma NO COVERAGE
"""Set a serialized value in the cache.
:param key: String key that was passed to the :meth:`.CacheRegion.set`
method, which will also be processed by the "key mangling" function
if one was present.
:param value: a bytes object to be stored.
.. versionadded:: 1.1
"""
raise NotImplementedError()
def set_serialized_multi(
self, mapping: Mapping[KeyType, bytes]
) -> None: # pragma NO COVERAGE
"""Set multiple serialized values in the cache.
:param mapping: a dict in which the key will be whatever was passed to
the :meth:`.CacheRegion.set_multi` method, processed by the "key
mangling" function, if any.
When implementing a new :class:`.CacheBackend` or cutomizing via
:class:`.ProxyBackend`, be aware that when this method is invoked by
:meth:`.Region.get_or_create_multi`, the ``mapping`` values are the
same ones returned to the upstream caller. If the subclass alters the
values in any way, it must not do so 'in-place' on the ``mapping`` dict
-- that will have the undesirable effect of modifying the returned
values as well.
.. versionadded:: 1.1
"""
raise NotImplementedError()

View file

@ -1,22 +1,47 @@
from dogpile.cache.region import register_backend
from ...util import PluginLoader
_backend_loader = PluginLoader("dogpile.cache")
register_backend = _backend_loader.register
register_backend(
"dogpile.cache.null", "dogpile.cache.backends.null", "NullBackend")
"dogpile.cache.null", "dogpile.cache.backends.null", "NullBackend"
)
register_backend(
"dogpile.cache.dbm", "dogpile.cache.backends.file", "DBMBackend")
"dogpile.cache.dbm", "dogpile.cache.backends.file", "DBMBackend"
)
register_backend(
"dogpile.cache.pylibmc", "dogpile.cache.backends.memcached",
"PylibmcBackend")
"dogpile.cache.pylibmc",
"dogpile.cache.backends.memcached",
"PylibmcBackend",
)
register_backend(
"dogpile.cache.bmemcached", "dogpile.cache.backends.memcached",
"BMemcachedBackend")
"dogpile.cache.bmemcached",
"dogpile.cache.backends.memcached",
"BMemcachedBackend",
)
register_backend(
"dogpile.cache.memcached", "dogpile.cache.backends.memcached",
"MemcachedBackend")
"dogpile.cache.memcached",
"dogpile.cache.backends.memcached",
"MemcachedBackend",
)
register_backend(
"dogpile.cache.memory", "dogpile.cache.backends.memory", "MemoryBackend")
"dogpile.cache.pymemcache",
"dogpile.cache.backends.memcached",
"PyMemcacheBackend",
)
register_backend(
"dogpile.cache.memory_pickle", "dogpile.cache.backends.memory",
"MemoryPickleBackend")
"dogpile.cache.memory", "dogpile.cache.backends.memory", "MemoryBackend"
)
register_backend(
"dogpile.cache.redis", "dogpile.cache.backends.redis", "RedisBackend")
"dogpile.cache.memory_pickle",
"dogpile.cache.backends.memory",
"MemoryPickleBackend",
)
register_backend(
"dogpile.cache.redis", "dogpile.cache.backends.redis", "RedisBackend"
)
register_backend(
"dogpile.cache.redis_sentinel",
"dogpile.cache.backends.redis",
"RedisSentinelBackend",
)

View file

@ -7,16 +7,20 @@ Provides backends that deal with local filesystem access.
"""
from __future__ import with_statement
from ..api import CacheBackend, NO_VALUE
from contextlib import contextmanager
from ...util import compat
from ... import util
import dbm
import os
import threading
__all__ = 'DBMBackend', 'FileLock', 'AbstractFileLock'
from ..api import BytesBackend
from ..api import NO_VALUE
from ... import util
__all__ = ["DBMBackend", "FileLock", "AbstractFileLock"]
class DBMBackend(CacheBackend):
class DBMBackend(BytesBackend):
"""A file-backend using a dbm file to store keys.
Basic usage::
@ -134,28 +138,25 @@ class DBMBackend(CacheBackend):
"""
def __init__(self, arguments):
self.filename = os.path.abspath(
os.path.normpath(arguments['filename'])
os.path.normpath(arguments["filename"])
)
dir_, filename = os.path.split(self.filename)
self.lock_factory = arguments.get("lock_factory", FileLock)
self._rw_lock = self._init_lock(
arguments.get('rw_lockfile'),
".rw.lock", dir_, filename)
arguments.get("rw_lockfile"), ".rw.lock", dir_, filename
)
self._dogpile_lock = self._init_lock(
arguments.get('dogpile_lockfile'),
arguments.get("dogpile_lockfile"),
".dogpile.lock",
dir_, filename,
util.KeyReentrantMutex.factory)
dir_,
filename,
util.KeyReentrantMutex.factory,
)
# TODO: make this configurable
if compat.py3k:
import dbm
else:
import anydbm as dbm
self.dbmmodule = dbm
self._init_dbm_file()
def _init_lock(self, argument, suffix, basedir, basefile, wrapper=None):
@ -163,9 +164,8 @@ class DBMBackend(CacheBackend):
lock = self.lock_factory(os.path.join(basedir, basefile + suffix))
elif argument is not False:
lock = self.lock_factory(
os.path.abspath(
os.path.normpath(argument)
))
os.path.abspath(os.path.normpath(argument))
)
else:
return None
if wrapper:
@ -175,12 +175,12 @@ class DBMBackend(CacheBackend):
def _init_dbm_file(self):
exists = os.access(self.filename, os.F_OK)
if not exists:
for ext in ('db', 'dat', 'pag', 'dir'):
for ext in ("db", "dat", "pag", "dir"):
if os.access(self.filename + os.extsep + ext, os.F_OK):
exists = True
break
if not exists:
fh = self.dbmmodule.open(self.filename, 'c')
fh = dbm.open(self.filename, "c")
fh.close()
def get_mutex(self, key):
@ -210,57 +210,50 @@ class DBMBackend(CacheBackend):
@contextmanager
def _dbm_file(self, write):
with self._use_rw_lock(write):
dbm = self.dbmmodule.open(
self.filename,
"w" if write else "r")
yield dbm
dbm.close()
with dbm.open(self.filename, "w" if write else "r") as dbm_obj:
yield dbm_obj
def get(self, key):
with self._dbm_file(False) as dbm:
if hasattr(dbm, 'get'):
value = dbm.get(key, NO_VALUE)
def get_serialized(self, key):
with self._dbm_file(False) as dbm_obj:
if hasattr(dbm_obj, "get"):
value = dbm_obj.get(key, NO_VALUE)
else:
# gdbm objects lack a .get method
try:
value = dbm[key]
value = dbm_obj[key]
except KeyError:
value = NO_VALUE
if value is not NO_VALUE:
value = compat.pickle.loads(value)
return value
def get_multi(self, keys):
return [self.get(key) for key in keys]
def get_serialized_multi(self, keys):
return [self.get_serialized(key) for key in keys]
def set(self, key, value):
with self._dbm_file(True) as dbm:
dbm[key] = compat.pickle.dumps(value,
compat.pickle.HIGHEST_PROTOCOL)
def set_serialized(self, key, value):
with self._dbm_file(True) as dbm_obj:
dbm_obj[key] = value
def set_multi(self, mapping):
with self._dbm_file(True) as dbm:
def set_serialized_multi(self, mapping):
with self._dbm_file(True) as dbm_obj:
for key, value in mapping.items():
dbm[key] = compat.pickle.dumps(value,
compat.pickle.HIGHEST_PROTOCOL)
dbm_obj[key] = value
def delete(self, key):
with self._dbm_file(True) as dbm:
with self._dbm_file(True) as dbm_obj:
try:
del dbm[key]
del dbm_obj[key]
except KeyError:
pass
def delete_multi(self, keys):
with self._dbm_file(True) as dbm:
with self._dbm_file(True) as dbm_obj:
for key in keys:
try:
del dbm[key]
del dbm_obj[key]
except KeyError:
pass
class AbstractFileLock(object):
class AbstractFileLock:
"""Coordinate read/write access to a file.
typically is a file-based lock but doesn't necessarily have to be.
@ -392,17 +385,18 @@ class FileLock(AbstractFileLock):
"""
def __init__(self, filename):
self._filedescriptor = compat.threading.local()
self._filedescriptor = threading.local()
self.filename = filename
@util.memoized_property
def _module(self):
import fcntl
return fcntl
@property
def is_open(self):
return hasattr(self._filedescriptor, 'fileno')
return hasattr(self._filedescriptor, "fileno")
def acquire_read_lock(self, wait):
return self._acquire(wait, os.O_RDONLY, self._module.LOCK_SH)

View file

@ -6,23 +6,43 @@ Provides backends for talking to `memcached <http://memcached.org>`_.
"""
from ..api import CacheBackend, NO_VALUE
from ...util import compat
from ... import util
import random
import threading
import time
import typing
from typing import Any
from typing import Mapping
import warnings
__all__ = 'GenericMemcachedBackend', 'MemcachedBackend',\
'PylibmcBackend', 'BMemcachedBackend', 'MemcachedLock'
from ..api import CacheBackend
from ..api import NO_VALUE
from ... import util
if typing.TYPE_CHECKING:
import bmemcached
import memcache
import pylibmc
import pymemcache
else:
# delayed import
bmemcached = None # noqa F811
memcache = None # noqa F811
pylibmc = None # noqa F811
pymemcache = None # noqa F811
__all__ = (
"GenericMemcachedBackend",
"MemcachedBackend",
"PylibmcBackend",
"PyMemcacheBackend",
"BMemcachedBackend",
"MemcachedLock",
)
class MemcachedLock(object):
"""Simple distributed lock using memcached.
This is an adaptation of the lock featured at
http://amix.dk/blog/post/19386
"""
"""Simple distributed lock using memcached."""
def __init__(self, client_fn, key, timeout=0):
self.client_fn = client_fn
@ -38,11 +58,15 @@ class MemcachedLock(object):
elif not wait:
return False
else:
sleep_time = (((i + 1) * random.random()) + 2 ** i) / 2.5
sleep_time = (((i + 1) * random.random()) + 2**i) / 2.5
time.sleep(sleep_time)
if i < 15:
i += 1
def locked(self):
client = self.client_fn()
return client.get(self.key) is not None
def release(self):
client = self.client_fn()
client.delete(self.key)
@ -100,10 +124,17 @@ class GenericMemcachedBackend(CacheBackend):
"""
set_arguments = {}
set_arguments: Mapping[str, Any] = {}
"""Additional arguments which will be passed
to the :meth:`set` method."""
# No need to override serializer, as all the memcached libraries
# handles that themselves. Still, we support customizing the
# serializer/deserializer to use better default pickle protocol
# or completely different serialization mechanism
serializer = None
deserializer = None
def __init__(self, arguments):
self._imports()
# using a plain threading.local here. threading.local
@ -111,11 +142,10 @@ class GenericMemcachedBackend(CacheBackend):
# so the idea is that this is superior to pylibmc's
# own ThreadMappedPool which doesn't handle this
# automatically.
self.url = util.to_list(arguments['url'])
self.distributed_lock = arguments.get('distributed_lock', False)
self.lock_timeout = arguments.get('lock_timeout', 0)
self.memcached_expire_time = arguments.get(
'memcached_expire_time', 0)
self.url = util.to_list(arguments["url"])
self.distributed_lock = arguments.get("distributed_lock", False)
self.lock_timeout = arguments.get("lock_timeout", 0)
self.memcached_expire_time = arguments.get("memcached_expire_time", 0)
def has_lock_timeout(self):
return self.lock_timeout != 0
@ -132,7 +162,7 @@ class GenericMemcachedBackend(CacheBackend):
def _clients(self):
backend = self
class ClientPool(compat.threading.local):
class ClientPool(threading.local):
def __init__(self):
self.memcached = backend._create_client()
@ -152,8 +182,9 @@ class GenericMemcachedBackend(CacheBackend):
def get_mutex(self, key):
if self.distributed_lock:
return MemcachedLock(lambda: self.client, key,
timeout=self.lock_timeout)
return MemcachedLock(
lambda: self.client, key, timeout=self.lock_timeout
)
else:
return None
@ -166,23 +197,18 @@ class GenericMemcachedBackend(CacheBackend):
def get_multi(self, keys):
values = self.client.get_multi(keys)
return [
NO_VALUE if key not in values
else values[key] for key in keys
NO_VALUE if val is None else val
for val in [values.get(key, NO_VALUE) for key in keys]
]
def set(self, key, value):
self.client.set(
key,
value,
**self.set_arguments
)
self.client.set(key, value, **self.set_arguments)
def set_multi(self, mapping):
self.client.set_multi(
mapping,
**self.set_arguments
)
mapping = {key: value for key, value in mapping.items()}
self.client.set_multi(mapping, **self.set_arguments)
def delete(self, key):
self.client.delete(key)
@ -191,24 +217,23 @@ class GenericMemcachedBackend(CacheBackend):
self.client.delete_multi(keys)
class MemcacheArgs(object):
class MemcacheArgs(GenericMemcachedBackend):
"""Mixin which provides support for the 'time' argument to set(),
'min_compress_len' to other methods.
"""
def __init__(self, arguments):
self.min_compress_len = arguments.get('min_compress_len', 0)
self.min_compress_len = arguments.get("min_compress_len", 0)
self.set_arguments = {}
if "memcached_expire_time" in arguments:
self.set_arguments["time"] = arguments["memcached_expire_time"]
if "min_compress_len" in arguments:
self.set_arguments["min_compress_len"] = \
arguments["min_compress_len"]
self.set_arguments["min_compress_len"] = arguments[
"min_compress_len"
]
super(MemcacheArgs, self).__init__(arguments)
pylibmc = None
class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend):
"""A backend for the
@ -245,8 +270,8 @@ class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend):
"""
def __init__(self, arguments):
self.binary = arguments.get('binary', False)
self.behaviors = arguments.get('behaviors', {})
self.binary = arguments.get("binary", False)
self.behaviors = arguments.get("behaviors", {})
super(PylibmcBackend, self).__init__(arguments)
def _imports(self):
@ -255,13 +280,9 @@ class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend):
def _create_client(self):
return pylibmc.Client(
self.url,
binary=self.binary,
behaviors=self.behaviors
self.url, binary=self.binary, behaviors=self.behaviors
)
memcache = None
class MemcachedBackend(MemcacheArgs, GenericMemcachedBackend):
"""A backend using the standard
@ -281,16 +302,39 @@ class MemcachedBackend(MemcacheArgs, GenericMemcachedBackend):
}
)
:param dead_retry: Number of seconds memcached server is considered dead
before it is tried again. Will be passed to ``memcache.Client``
as the ``dead_retry`` parameter.
.. versionchanged:: 1.1.8 Moved the ``dead_retry`` argument which was
erroneously added to "set_parameters" to
be part of the Memcached connection arguments.
:param socket_timeout: Timeout in seconds for every call to a server.
Will be passed to ``memcache.Client`` as the ``socket_timeout``
parameter.
.. versionchanged:: 1.1.8 Moved the ``socket_timeout`` argument which
was erroneously added to "set_parameters"
to be part of the Memcached connection arguments.
"""
def __init__(self, arguments):
self.dead_retry = arguments.get("dead_retry", 30)
self.socket_timeout = arguments.get("socket_timeout", 3)
super(MemcachedBackend, self).__init__(arguments)
def _imports(self):
global memcache
import memcache # noqa
def _create_client(self):
return memcache.Client(self.url)
bmemcached = None
return memcache.Client(
self.url,
dead_retry=self.dead_retry,
socket_timeout=self.socket_timeout,
)
class BMemcachedBackend(GenericMemcachedBackend):
@ -299,9 +343,11 @@ class BMemcachedBackend(GenericMemcachedBackend):
python-binary-memcached>`_
memcached client.
This is a pure Python memcached client which
includes the ability to authenticate with a memcached
server using SASL.
This is a pure Python memcached client which includes
security features like SASL and SSL/TLS.
SASL is a standard for adding authentication mechanisms
to protocols in a way that is protocol independent.
A typical configuration using username/password::
@ -317,6 +363,25 @@ class BMemcachedBackend(GenericMemcachedBackend):
}
)
A typical configuration using tls_context::
import ssl
from dogpile.cache import make_region
ctx = ssl.create_default_context(cafile="/path/to/my-ca.pem")
region = make_region().configure(
'dogpile.cache.bmemcached',
expiration_time = 3600,
arguments = {
'url':["127.0.0.1"],
'tls_context':ctx,
}
)
For advanced ways to configure TLS creating a more complex
tls_context visit https://docs.python.org/3/library/ssl.html
Arguments which can be passed to the ``arguments``
dictionary include:
@ -324,11 +389,17 @@ class BMemcachedBackend(GenericMemcachedBackend):
SASL authentication.
:param password: optional password, will be used for
SASL authentication.
:param tls_context: optional TLS context, will be used for
TLS connections.
.. versionadded:: 1.0.2
"""
def __init__(self, arguments):
self.username = arguments.get('username', None)
self.password = arguments.get('password', None)
self.username = arguments.get("username", None)
self.password = arguments.get("password", None)
self.tls_context = arguments.get("tls_context", None)
super(BMemcachedBackend, self).__init__(arguments)
def _imports(self):
@ -345,7 +416,8 @@ class BMemcachedBackend(GenericMemcachedBackend):
def add(self, key, value, timeout=0):
try:
return super(RepairBMemcachedAPI, self).add(
key, value, timeout)
key, value, timeout
)
except ValueError:
return False
@ -355,10 +427,213 @@ class BMemcachedBackend(GenericMemcachedBackend):
return self.Client(
self.url,
username=self.username,
password=self.password
password=self.password,
tls_context=self.tls_context,
)
def delete_multi(self, keys):
"""python-binary-memcached api does not implements delete_multi"""
for key in keys:
self.delete(key)
class PyMemcacheBackend(GenericMemcachedBackend):
"""A backend for the
`pymemcache <https://github.com/pinterest/pymemcache>`_
memcached client.
A comprehensive, fast, pure Python memcached client
.. versionadded:: 1.1.2
pymemcache supports the following features:
* Complete implementation of the memcached text protocol.
* Configurable timeouts for socket connect and send/recv calls.
* Access to the "noreply" flag, which can significantly increase
the speed of writes.
* Flexible, simple approach to serialization and deserialization.
* The (optional) ability to treat network and memcached errors as
cache misses.
dogpile.cache uses the ``HashClient`` from pymemcache in order to reduce
API differences when compared to other memcached client drivers.
This allows the user to provide a single server or a list of memcached
servers.
Arguments which can be passed to the ``arguments``
dictionary include:
:param tls_context: optional TLS context, will be used for
TLS connections.
A typical configuration using tls_context::
import ssl
from dogpile.cache import make_region
ctx = ssl.create_default_context(cafile="/path/to/my-ca.pem")
region = make_region().configure(
'dogpile.cache.pymemcache',
expiration_time = 3600,
arguments = {
'url':["127.0.0.1"],
'tls_context':ctx,
}
)
.. seealso::
`<https://docs.python.org/3/library/ssl.html>`_ - additional TLS
documentation.
:param serde: optional "serde". Defaults to
``pymemcache.serde.pickle_serde``.
:param default_noreply: defaults to False. When set to True this flag
enables the pymemcache "noreply" feature. See the pymemcache
documentation for further details.
:param socket_keepalive: optional socket keepalive, will be used for
TCP keepalive configuration. Use of this parameter requires pymemcache
3.5.0 or greater. This parameter
accepts a
`pymemcache.client.base.KeepAliveOpts
<https://pymemcache.readthedocs.io/en/latest/apidoc/pymemcache.client.base.html#pymemcache.client.base.KeepaliveOpts>`_
object.
A typical configuration using ``socket_keepalive``::
from pymemcache import KeepaliveOpts
from dogpile.cache import make_region
# Using the default keepalive configuration
socket_keepalive = KeepaliveOpts()
region = make_region().configure(
'dogpile.cache.pymemcache',
expiration_time = 3600,
arguments = {
'url':["127.0.0.1"],
'socket_keepalive': socket_keepalive
}
)
.. versionadded:: 1.1.4 - added support for ``socket_keepalive``.
:param enable_retry_client: optional flag to enable retry client
mechanisms to handle failure. Defaults to False. When set to ``True``,
the :paramref:`.PyMemcacheBackend.retry_attempts` parameter must also
be set, along with optional parameters
:paramref:`.PyMemcacheBackend.retry_delay`.
:paramref:`.PyMemcacheBackend.retry_for`,
:paramref:`.PyMemcacheBackend.do_not_retry_for`.
.. seealso::
`<https://pymemcache.readthedocs.io/en/latest/getting_started.html#using-the-built-in-retrying-mechanism>`_ -
in the pymemcache documentation
.. versionadded:: 1.1.4
:param retry_attempts: how many times to attempt an action with
pymemcache's retrying wrapper before failing. Must be 1 or above.
Defaults to None.
.. versionadded:: 1.1.4
:param retry_delay: optional int|float, how many seconds to sleep between
each attempt. Used by the retry wrapper. Defaults to None.
.. versionadded:: 1.1.4
:param retry_for: optional None|tuple|set|list, what exceptions to
allow retries for. Will allow retries for all exceptions if None.
Example: ``(MemcacheClientError, MemcacheUnexpectedCloseError)``
Accepts any class that is a subclass of Exception. Defaults to None.
.. versionadded:: 1.1.4
:param do_not_retry_for: optional None|tuple|set|list, what
exceptions should be retried. Will not block retries for any Exception if
None. Example: ``(IOError, MemcacheIllegalInputError)``
Accepts any class that is a subclass of Exception. Defaults to None.
.. versionadded:: 1.1.4
:param hashclient_retry_attempts: Amount of times a client should be tried
before it is marked dead and removed from the pool in the HashClient's
internal mechanisms.
.. versionadded:: 1.1.5
:param hashclient_retry_timeout: Time in seconds that should pass between
retry attempts in the HashClient's internal mechanisms.
.. versionadded:: 1.1.5
:param dead_timeout: Time in seconds before attempting to add a node
back in the pool in the HashClient's internal mechanisms.
.. versionadded:: 1.1.5
""" # noqa E501
def __init__(self, arguments):
super().__init__(arguments)
self.serde = arguments.get("serde", pymemcache.serde.pickle_serde)
self.default_noreply = arguments.get("default_noreply", False)
self.tls_context = arguments.get("tls_context", None)
self.socket_keepalive = arguments.get("socket_keepalive", None)
self.enable_retry_client = arguments.get("enable_retry_client", False)
self.retry_attempts = arguments.get("retry_attempts", None)
self.retry_delay = arguments.get("retry_delay", None)
self.retry_for = arguments.get("retry_for", None)
self.do_not_retry_for = arguments.get("do_not_retry_for", None)
self.hashclient_retry_attempts = arguments.get(
"hashclient_retry_attempts", 2
)
self.hashclient_retry_timeout = arguments.get(
"hashclient_retry_timeout", 1
)
self.dead_timeout = arguments.get("hashclient_dead_timeout", 60)
if (
self.retry_delay is not None
or self.retry_attempts is not None
or self.retry_for is not None
or self.do_not_retry_for is not None
) and not self.enable_retry_client:
warnings.warn(
"enable_retry_client is not set; retry options "
"will be ignored"
)
def _imports(self):
global pymemcache
import pymemcache
def _create_client(self):
_kwargs = {
"serde": self.serde,
"default_noreply": self.default_noreply,
"tls_context": self.tls_context,
"retry_attempts": self.hashclient_retry_attempts,
"retry_timeout": self.hashclient_retry_timeout,
"dead_timeout": self.dead_timeout,
}
if self.socket_keepalive is not None:
_kwargs.update({"socket_keepalive": self.socket_keepalive})
client = pymemcache.client.hash.HashClient(self.url, **_kwargs)
if self.enable_retry_client:
return pymemcache.client.retrying.RetryingClient(
client,
attempts=self.retry_attempts,
retry_delay=self.retry_delay,
retry_for=self.retry_for,
do_not_retry_for=self.do_not_retry_for,
)
return client

View file

@ -10,8 +10,10 @@ places the value as given into the dictionary.
"""
from ..api import CacheBackend, NO_VALUE
from ...util.compat import pickle
from ..api import CacheBackend
from ..api import DefaultSerialization
from ..api import NO_VALUE
class MemoryBackend(CacheBackend):
@ -47,39 +49,21 @@ class MemoryBackend(CacheBackend):
"""
pickle_values = False
def __init__(self, arguments):
self._cache = arguments.pop("cache_dict", {})
def get(self, key):
value = self._cache.get(key, NO_VALUE)
if value is not NO_VALUE and self.pickle_values:
value = pickle.loads(value)
return value
return self._cache.get(key, NO_VALUE)
def get_multi(self, keys):
ret = [
self._cache.get(key, NO_VALUE)
for key in keys]
if self.pickle_values:
ret = [
pickle.loads(value)
if value is not NO_VALUE else value
for value in ret
]
return ret
return [self._cache.get(key, NO_VALUE) for key in keys]
def set(self, key, value):
if self.pickle_values:
value = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
self._cache[key] = value
def set_multi(self, mapping):
pickle_values = self.pickle_values
for key, value in mapping.items():
if pickle_values:
value = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
self._cache[key] = value
def delete(self, key):
@ -90,7 +74,7 @@ class MemoryBackend(CacheBackend):
self._cache.pop(key, None)
class MemoryPickleBackend(MemoryBackend):
class MemoryPickleBackend(DefaultSerialization, MemoryBackend):
"""A backend that uses a plain dictionary, but serializes objects on
:meth:`.MemoryBackend.set` and deserializes :meth:`.MemoryBackend.get`.
@ -121,4 +105,3 @@ class MemoryPickleBackend(MemoryBackend):
.. versionadded:: 0.5.3
"""
pickle_values = True

View file

@ -10,10 +10,11 @@ caching for a region that is otherwise used normally.
"""
from ..api import CacheBackend, NO_VALUE
from ..api import CacheBackend
from ..api import NO_VALUE
__all__ = ['NullBackend']
__all__ = ["NullBackend"]
class NullLock(object):
@ -23,6 +24,9 @@ class NullLock(object):
def release(self):
pass
def locked(self):
return False
class NullBackend(CacheBackend):
"""A "null" backend that effectively disables all cache operations.

View file

@ -7,16 +7,24 @@ Provides backends for talking to `Redis <http://redis.io>`_.
"""
from __future__ import absolute_import
from ..api import CacheBackend, NO_VALUE
from ...util.compat import pickle, u
redis = None
import typing
import warnings
__all__ = 'RedisBackend',
from ..api import BytesBackend
from ..api import NO_VALUE
if typing.TYPE_CHECKING:
import redis
else:
# delayed import
redis = None # noqa F811
__all__ = ("RedisBackend", "RedisSentinelBackend")
class RedisBackend(CacheBackend):
"""A `Redis <http://redis.io/>`_ backend, using the
class RedisBackend(BytesBackend):
r"""A `Redis <http://redis.io/>`_ backend, using the
`redis-py <http://pypi.python.org/pypi/redis/>`_ backend.
Example configuration::
@ -30,23 +38,21 @@ class RedisBackend(CacheBackend):
'port': 6379,
'db': 0,
'redis_expiration_time': 60*60*2, # 2 hours
'distributed_lock': True
'distributed_lock': True,
'thread_local_lock': False
}
)
Arguments accepted in the arguments dictionary:
:param url: string. If provided, will override separate host/port/db
params. The format is that accepted by ``StrictRedis.from_url()``.
.. versionadded:: 0.4.1
:param host: string, default is ``localhost``.
:param password: string, default is no password.
.. versionadded:: 0.4.1
:param port: integer, default is ``6379``.
:param db: integer, default is ``0``.
@ -56,57 +62,66 @@ class RedisBackend(CacheBackend):
cache expiration. By default no expiration is set.
:param distributed_lock: boolean, when True, will use a
redis-lock as the dogpile lock.
Use this when multiple
processes will be talking to the same redis instance.
When left at False, dogpile will coordinate on a regular
threading mutex.
redis-lock as the dogpile lock. Use this when multiple processes will be
talking to the same redis instance. When left at False, dogpile will
coordinate on a regular threading mutex.
:param lock_timeout: integer, number of seconds after acquiring a lock that
Redis should expire it. This argument is only valid when
``distributed_lock`` is ``True``.
.. versionadded:: 0.5.0
:param socket_timeout: float, seconds for socket timeout.
Default is None (no timeout).
.. versionadded:: 0.5.4
:param lock_sleep: integer, number of seconds to sleep when failed to
acquire a lock. This argument is only valid when
``distributed_lock`` is ``True``.
.. versionadded:: 0.5.0
:param connection_pool: ``redis.ConnectionPool`` object. If provided,
this object supersedes other connection arguments passed to the
``redis.StrictRedis`` instance, including url and/or host as well as
socket_timeout, and will be passed to ``redis.StrictRedis`` as the
source of connectivity.
.. versionadded:: 0.5.4
:param thread_local_lock: bool, whether a thread-local Redis lock object
should be used. This is the default, but is not compatible with
asynchronous runners, as they run in a different thread than the one
used to create the lock.
:param connection_kwargs: dict, additional keyword arguments are passed
along to the
``StrictRedis.from_url()`` method or ``StrictRedis()`` constructor
directly, including parameters like ``ssl``, ``ssl_certfile``,
``charset``, etc.
.. versionadded:: 1.1.6 Added ``connection_kwargs`` parameter.
"""
def __init__(self, arguments):
arguments = arguments.copy()
self._imports()
self.url = arguments.pop('url', None)
self.host = arguments.pop('host', 'localhost')
self.password = arguments.pop('password', None)
self.port = arguments.pop('port', 6379)
self.db = arguments.pop('db', 0)
self.distributed_lock = arguments.get('distributed_lock', False)
self.socket_timeout = arguments.pop('socket_timeout', None)
self.url = arguments.pop("url", None)
self.host = arguments.pop("host", "localhost")
self.password = arguments.pop("password", None)
self.port = arguments.pop("port", 6379)
self.db = arguments.pop("db", 0)
self.distributed_lock = arguments.pop("distributed_lock", False)
self.socket_timeout = arguments.pop("socket_timeout", None)
self.lock_timeout = arguments.pop("lock_timeout", None)
self.lock_sleep = arguments.pop("lock_sleep", 0.1)
self.thread_local_lock = arguments.pop("thread_local_lock", True)
self.connection_kwargs = arguments.pop("connection_kwargs", {})
self.lock_timeout = arguments.get('lock_timeout', None)
self.lock_sleep = arguments.get('lock_sleep', 0.1)
if self.distributed_lock and self.thread_local_lock:
warnings.warn(
"The Redis backend thread_local_lock parameter should be "
"set to False when distributed_lock is True"
)
self.redis_expiration_time = arguments.pop('redis_expiration_time', 0)
self.connection_pool = arguments.get('connection_pool', None)
self.client = self._create_client()
self.redis_expiration_time = arguments.pop("redis_expiration_time", 0)
self.connection_pool = arguments.pop("connection_pool", None)
self._create_client()
def _imports(self):
# defer imports until backend is used
@ -118,66 +133,207 @@ class RedisBackend(CacheBackend):
# the connection pool already has all other connection
# options present within, so here we disregard socket_timeout
# and others.
return redis.StrictRedis(connection_pool=self.connection_pool)
args = {}
if self.socket_timeout:
args['socket_timeout'] = self.socket_timeout
if self.url is not None:
args.update(url=self.url)
return redis.StrictRedis.from_url(**args)
else:
args.update(
host=self.host, password=self.password,
port=self.port, db=self.db
self.writer_client = redis.StrictRedis(
connection_pool=self.connection_pool
)
return redis.StrictRedis(**args)
self.reader_client = self.writer_client
else:
args = {}
args.update(self.connection_kwargs)
if self.socket_timeout:
args["socket_timeout"] = self.socket_timeout
if self.url is not None:
args.update(url=self.url)
self.writer_client = redis.StrictRedis.from_url(**args)
self.reader_client = self.writer_client
else:
args.update(
host=self.host,
password=self.password,
port=self.port,
db=self.db,
)
self.writer_client = redis.StrictRedis(**args)
self.reader_client = self.writer_client
def get_mutex(self, key):
if self.distributed_lock:
return self.client.lock(u('_lock{0}').format(key),
self.lock_timeout, self.lock_sleep)
return _RedisLockWrapper(
self.writer_client.lock(
"_lock{0}".format(key),
timeout=self.lock_timeout,
sleep=self.lock_sleep,
thread_local=self.thread_local_lock,
)
)
else:
return None
def get(self, key):
value = self.client.get(key)
def get_serialized(self, key):
value = self.reader_client.get(key)
if value is None:
return NO_VALUE
return pickle.loads(value)
return value
def get_multi(self, keys):
def get_serialized_multi(self, keys):
if not keys:
return []
values = self.client.mget(keys)
return [
pickle.loads(v) if v is not None else NO_VALUE
for v in values]
values = self.reader_client.mget(keys)
return [v if v is not None else NO_VALUE for v in values]
def set(self, key, value):
def set_serialized(self, key, value):
if self.redis_expiration_time:
self.client.setex(key, self.redis_expiration_time,
pickle.dumps(value, pickle.HIGHEST_PROTOCOL))
self.writer_client.setex(key, self.redis_expiration_time, value)
else:
self.client.set(key, pickle.dumps(value, pickle.HIGHEST_PROTOCOL))
def set_multi(self, mapping):
mapping = dict(
(k, pickle.dumps(v, pickle.HIGHEST_PROTOCOL))
for k, v in mapping.items()
)
self.writer_client.set(key, value)
def set_serialized_multi(self, mapping):
if not self.redis_expiration_time:
self.client.mset(mapping)
self.writer_client.mset(mapping)
else:
pipe = self.client.pipeline()
pipe = self.writer_client.pipeline()
for key, value in mapping.items():
pipe.setex(key, self.redis_expiration_time, value)
pipe.execute()
def delete(self, key):
self.client.delete(key)
self.writer_client.delete(key)
def delete_multi(self, keys):
self.client.delete(*keys)
self.writer_client.delete(*keys)
class _RedisLockWrapper:
__slots__ = ("mutex", "__weakref__")
def __init__(self, mutex: typing.Any):
self.mutex = mutex
def acquire(self, wait: bool = True) -> typing.Any:
return self.mutex.acquire(blocking=wait)
def release(self) -> typing.Any:
return self.mutex.release()
def locked(self) -> bool:
return self.mutex.locked() # type: ignore
class RedisSentinelBackend(RedisBackend):
"""A `Redis <http://redis.io/>`_ backend, using the
`redis-py <http://pypi.python.org/pypi/redis/>`_ backend.
It will use the Sentinel of a Redis cluster.
.. versionadded:: 1.0.0
Example configuration::
from dogpile.cache import make_region
region = make_region().configure(
'dogpile.cache.redis_sentinel',
arguments = {
'sentinels': [
['redis_sentinel_1', 26379],
['redis_sentinel_2', 26379]
],
'db': 0,
'redis_expiration_time': 60*60*2, # 2 hours
'distributed_lock': True,
'thread_local_lock': False
}
)
Arguments accepted in the arguments dictionary:
:param db: integer, default is ``0``.
:param redis_expiration_time: integer, number of seconds after setting
a value that Redis should expire it. This should be larger than dogpile's
cache expiration. By default no expiration is set.
:param distributed_lock: boolean, when True, will use a
redis-lock as the dogpile lock. Use this when multiple processes will be
talking to the same redis instance. When False, dogpile will
coordinate on a regular threading mutex, Default is True.
:param lock_timeout: integer, number of seconds after acquiring a lock that
Redis should expire it. This argument is only valid when
``distributed_lock`` is ``True``.
:param socket_timeout: float, seconds for socket timeout.
Default is None (no timeout).
:param sentinels: is a list of sentinel nodes. Each node is represented by
a pair (hostname, port).
Default is None (not in sentinel mode).
:param service_name: str, the service name.
Default is 'mymaster'.
:param sentinel_kwargs: is a dictionary of connection arguments used when
connecting to sentinel instances. Any argument that can be passed to
a normal Redis connection can be specified here.
Default is {}.
:param connection_kwargs: dict, additional keyword arguments are passed
along to the
``StrictRedis.from_url()`` method or ``StrictRedis()`` constructor
directly, including parameters like ``ssl``, ``ssl_certfile``,
``charset``, etc.
:param lock_sleep: integer, number of seconds to sleep when failed to
acquire a lock. This argument is only valid when
``distributed_lock`` is ``True``.
:param thread_local_lock: bool, whether a thread-local Redis lock object
should be used. This is the default, but is not compatible with
asynchronous runners, as they run in a different thread than the one
used to create the lock.
"""
def __init__(self, arguments):
arguments = arguments.copy()
self.sentinels = arguments.pop("sentinels", None)
self.service_name = arguments.pop("service_name", "mymaster")
self.sentinel_kwargs = arguments.pop("sentinel_kwargs", {})
super().__init__(
arguments={
"distributed_lock": True,
"thread_local_lock": False,
**arguments,
}
)
def _imports(self):
# defer imports until backend is used
global redis
import redis.sentinel # noqa
def _create_client(self):
sentinel_kwargs = {}
sentinel_kwargs.update(self.sentinel_kwargs)
sentinel_kwargs.setdefault("password", self.password)
connection_kwargs = {}
connection_kwargs.update(self.connection_kwargs)
connection_kwargs.setdefault("password", self.password)
if self.db is not None:
connection_kwargs.setdefault("db", self.db)
sentinel_kwargs.setdefault("db", self.db)
if self.socket_timeout is not None:
connection_kwargs.setdefault("socket_timeout", self.socket_timeout)
sentinel = redis.sentinel.Sentinel(
self.sentinels,
sentinel_kwargs=sentinel_kwargs,
**connection_kwargs,
)
self.writer_client = sentinel.master_for(self.service_name)
self.reader_client = sentinel.slave_for(self.service_name)

View file

@ -51,20 +51,22 @@ class MakoPlugin(CacheImpl):
def __init__(self, cache):
super(MakoPlugin, self).__init__(cache)
try:
self.regions = self.cache.template.cache_args['regions']
self.regions = self.cache.template.cache_args["regions"]
except KeyError:
raise KeyError(
"'cache_regions' argument is required on the "
"Mako Lookup or Template object for usage "
"with the dogpile.cache plugin.")
"with the dogpile.cache plugin."
)
def _get_region(self, **kw):
try:
region = kw['region']
region = kw["region"]
except KeyError:
raise KeyError(
"'cache_region' argument must be specified with 'cache=True'"
"within templates for usage with the dogpile.cache plugin.")
"within templates for usage with the dogpile.cache plugin."
)
try:
return self.regions[region]
except KeyError:
@ -73,8 +75,8 @@ class MakoPlugin(CacheImpl):
def get_and_replace(self, key, creation_function, **kw):
expiration_time = kw.pop("timeout", None)
return self._get_region(**kw).get_or_create(
key, creation_function,
expiration_time=expiration_time)
key, creation_function, expiration_time=expiration_time
)
def get_or_create(self, key, creation_function, **kw):
return self.get_and_replace(key, creation_function, **kw)

View file

@ -10,7 +10,16 @@ base backend.
"""
from typing import Mapping
from typing import Optional
from typing import Sequence
from .api import BackendFormatted
from .api import BackendSetType
from .api import CacheBackend
from .api import CacheMutex
from .api import KeyType
from .api import SerializedReturnType
class ProxyBackend(CacheBackend):
@ -55,17 +64,17 @@ class ProxyBackend(CacheBackend):
"""
def __init__(self, *args, **kwargs):
self.proxied = None
def __init__(self, *arg, **kw):
pass
def wrap(self, backend):
''' Take a backend as an argument and setup the self.proxied property.
def wrap(self, backend: CacheBackend) -> "ProxyBackend":
"""Take a backend as an argument and setup the self.proxied property.
Return an object that be used as a backend by a :class:`.CacheRegion`
object.
'''
assert(
isinstance(backend, CacheBackend) or
isinstance(backend, ProxyBackend))
"""
assert isinstance(backend, CacheBackend) or isinstance(
backend, ProxyBackend
)
self.proxied = backend
return self
@ -73,23 +82,37 @@ class ProxyBackend(CacheBackend):
# Delegate any functions that are not already overridden to
# the proxies backend
#
def get(self, key):
def get(self, key: KeyType) -> BackendFormatted:
return self.proxied.get(key)
def set(self, key, value):
def set(self, key: KeyType, value: BackendSetType) -> None:
self.proxied.set(key, value)
def delete(self, key):
def delete(self, key: KeyType) -> None:
self.proxied.delete(key)
def get_multi(self, keys):
def get_multi(self, keys: Sequence[KeyType]) -> Sequence[BackendFormatted]:
return self.proxied.get_multi(keys)
def set_multi(self, mapping):
def set_multi(self, mapping: Mapping[KeyType, BackendSetType]) -> None:
self.proxied.set_multi(mapping)
def delete_multi(self, keys):
def delete_multi(self, keys: Sequence[KeyType]) -> None:
self.proxied.delete_multi(keys)
def get_mutex(self, key):
def get_mutex(self, key: KeyType) -> Optional[CacheMutex]:
return self.proxied.get_mutex(key)
def get_serialized(self, key: KeyType) -> SerializedReturnType:
return self.proxied.get_serialized(key)
def get_serialized_multi(
self, keys: Sequence[KeyType]
) -> Sequence[SerializedReturnType]:
return self.proxied.get_serialized_multi(keys)
def set_serialized(self, key: KeyType, value: bytes) -> None:
self.proxied.set_serialized(key, value)
def set_serialized_multi(self, mapping: Mapping[KeyType, bytes]) -> None:
self.proxied.set_serialized_multi(mapping)

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,10 @@
from hashlib import sha1
from ..util import compat
from ..util import langhelpers
def function_key_generator(namespace, fn, to_str=compat.string_type):
def function_key_generator(namespace, fn, to_str=str):
"""Return a function that generates a string
key, based on a given function as well as
arguments to the returned function itself.
@ -23,47 +24,51 @@ def function_key_generator(namespace, fn, to_str=compat.string_type):
"""
if namespace is None:
namespace = '%s:%s' % (fn.__module__, fn.__name__)
namespace = "%s:%s" % (fn.__module__, fn.__name__)
else:
namespace = '%s:%s|%s' % (fn.__module__, fn.__name__, namespace)
namespace = "%s:%s|%s" % (fn.__module__, fn.__name__, namespace)
args = compat.inspect_getargspec(fn)
has_self = args[0] and args[0][0] in ('self', 'cls')
has_self = args[0] and args[0][0] in ("self", "cls")
def generate_key(*args, **kw):
if kw:
raise ValueError(
"dogpile.cache's default key creation "
"function does not accept keyword arguments.")
"function does not accept keyword arguments."
)
if has_self:
args = args[1:]
return namespace + "|" + " ".join(map(to_str, args))
return generate_key
def function_multi_key_generator(namespace, fn, to_str=compat.string_type):
def function_multi_key_generator(namespace, fn, to_str=str):
if namespace is None:
namespace = '%s:%s' % (fn.__module__, fn.__name__)
namespace = "%s:%s" % (fn.__module__, fn.__name__)
else:
namespace = '%s:%s|%s' % (fn.__module__, fn.__name__, namespace)
namespace = "%s:%s|%s" % (fn.__module__, fn.__name__, namespace)
args = compat.inspect_getargspec(fn)
has_self = args[0] and args[0][0] in ('self', 'cls')
has_self = args[0] and args[0][0] in ("self", "cls")
def generate_keys(*args, **kw):
if kw:
raise ValueError(
"dogpile.cache's default key creation "
"function does not accept keyword arguments.")
"function does not accept keyword arguments."
)
if has_self:
args = args[1:]
return [namespace + "|" + key for key in map(to_str, args)]
return generate_keys
def kwarg_function_key_generator(namespace, fn, to_str=compat.string_type):
def kwarg_function_key_generator(namespace, fn, to_str=str):
"""Return a function that generates a string
key, based on a given function as well as
arguments to the returned function itself.
@ -83,9 +88,9 @@ def kwarg_function_key_generator(namespace, fn, to_str=compat.string_type):
"""
if namespace is None:
namespace = '%s:%s' % (fn.__module__, fn.__name__)
namespace = "%s:%s" % (fn.__module__, fn.__name__)
else:
namespace = '%s:%s|%s' % (fn.__module__, fn.__name__, namespace)
namespace = "%s:%s|%s" % (fn.__module__, fn.__name__, namespace)
argspec = compat.inspect_getargspec(fn)
default_list = list(argspec.defaults or [])
@ -94,32 +99,41 @@ def kwarg_function_key_generator(namespace, fn, to_str=compat.string_type):
# enumerate()
default_list.reverse()
# use idx*-1 to create the correct right-lookup index.
args_with_defaults = dict((argspec.args[(idx*-1)], default)
for idx, default in enumerate(default_list, 1))
if argspec.args and argspec.args[0] in ('self', 'cls'):
args_with_defaults = dict(
(argspec.args[(idx * -1)], default)
for idx, default in enumerate(default_list, 1)
)
if argspec.args and argspec.args[0] in ("self", "cls"):
arg_index_start = 1
else:
arg_index_start = 0
def generate_key(*args, **kwargs):
as_kwargs = dict(
[(argspec.args[idx], arg)
for idx, arg in enumerate(args[arg_index_start:],
arg_index_start)])
[
(argspec.args[idx], arg)
for idx, arg in enumerate(
args[arg_index_start:], arg_index_start
)
]
)
as_kwargs.update(kwargs)
for arg, val in args_with_defaults.items():
if arg not in as_kwargs:
as_kwargs[arg] = val
argument_values = [as_kwargs[key]
for key in sorted(as_kwargs.keys())]
return namespace + '|' + " ".join(map(to_str, argument_values))
argument_values = [as_kwargs[key] for key in sorted(as_kwargs.keys())]
return namespace + "|" + " ".join(map(to_str, argument_values))
return generate_key
def sha1_mangle_key(key):
"""a SHA1 key mangler."""
if isinstance(key, str):
key = key.encode("utf-8")
return sha1(key).hexdigest()
@ -128,13 +142,16 @@ def length_conditional_mangler(length, mangler):
past a certain threshold.
"""
def mangle(key):
if len(key) >= length:
return mangler(key)
else:
return key
return mangle
# in the 0.6 release these functions were moved to the dogpile.util namespace.
# They are linked here to maintain compatibility with older versions.
@ -143,3 +160,30 @@ KeyReentrantMutex = langhelpers.KeyReentrantMutex
memoized_property = langhelpers.memoized_property
PluginLoader = langhelpers.PluginLoader
to_list = langhelpers.to_list
class repr_obj:
__slots__ = ("value", "max_chars")
def __init__(self, value, max_chars=300):
self.value = value
self.max_chars = max_chars
def __eq__(self, other):
return other.value == self.value
def __repr__(self):
rep = repr(self.value)
lenrep = len(rep)
if lenrep > self.max_chars:
segment_length = self.max_chars // 2
rep = (
rep[0:segment_length]
+ (
" ... (%d characters truncated) ... "
% (lenrep - self.max_chars)
)
+ rep[-segment_length:]
)
return rep

View file

@ -8,10 +8,10 @@ dogpile.core installation is present.
"""
from .util import nameregistry # noqa
from .util import readwrite_lock # noqa
from .util.readwrite_lock import ReadWriteMutex # noqa
from .util.nameregistry import NameRegistry # noqa
from . import __version__ # noqa
from .lock import Lock # noqa
from .lock import NeedRegenerationException # noqa
from . import __version__ # noqa
from .util import nameregistry # noqa
from .util import readwrite_lock # noqa
from .util.nameregistry import NameRegistry # noqa
from .util.readwrite_lock import ReadWriteMutex # noqa

View file

@ -1,5 +1,5 @@
import time
import logging
import time
log = logging.getLogger(__name__)
@ -11,10 +11,11 @@ class NeedRegenerationException(Exception):
"""
NOT_REGENERATED = object()
class Lock(object):
class Lock:
"""Dogpile lock class.
Provides an interface around an arbitrary mutex
@ -70,8 +71,8 @@ class Lock(object):
value is available."""
return not self._has_value(createdtime) or (
self.expiretime is not None and
time.time() - createdtime > self.expiretime
self.expiretime is not None
and time.time() - createdtime > self.expiretime
)
def _has_value(self, createdtime):
@ -109,7 +110,8 @@ class Lock(object):
raise Exception(
"Generation function should "
"have just been called by a concurrent "
"thread.")
"thread."
)
else:
return value
@ -122,9 +124,7 @@ class Lock(object):
if self._has_value(createdtime):
has_value = True
if not self.mutex.acquire(False):
log.debug(
"creation function in progress "
"elsewhere, returning")
log.debug("creation function in progress elsewhere, returning")
return NOT_REGENERATED
else:
has_value = False
@ -173,8 +173,7 @@ class Lock(object):
# there's no value at all, and we have to create it synchronously
log.debug(
"Calling creation function for %s value",
"not-yet-present" if not has_value else
"previously expired"
"not-yet-present" if not has_value else "previously expired",
)
return self.creator()
finally:
@ -185,5 +184,5 @@ class Lock(object):
def __enter__(self):
return self._enter()
def __exit__(self, type, value, traceback):
def __exit__(self, type_, value, traceback):
pass

View file

@ -1,4 +1,7 @@
from .langhelpers import coerce_string_conf # noqa
from .langhelpers import KeyReentrantMutex # noqa
from .langhelpers import memoized_property # noqa
from .langhelpers import PluginLoader # noqa
from .langhelpers import to_list # noqa
from .nameregistry import NameRegistry # noqa
from .readwrite_lock import ReadWriteMutex # noqa
from .langhelpers import PluginLoader, memoized_property, \
coerce_string_conf, to_list, KeyReentrantMutex # noqa

View file

@ -1,87 +1,72 @@
import sys
py2k = sys.version_info < (3, 0)
py3k = sys.version_info >= (3, 0)
py32 = sys.version_info >= (3, 2)
py27 = sys.version_info >= (2, 7)
jython = sys.platform.startswith('java')
win32 = sys.platform.startswith('win')
try:
import threading
except ImportError:
import dummy_threading as threading # noqa
import collections
import inspect
if py3k: # pragma: no cover
string_types = str,
text_type = str
string_type = str
FullArgSpec = collections.namedtuple(
"FullArgSpec",
[
"args",
"varargs",
"varkw",
"defaults",
"kwonlyargs",
"kwonlydefaults",
"annotations",
],
)
if py32:
callable = callable
else:
def callable(fn):
return hasattr(fn, '__call__')
def u(s):
return s
def ue(s):
return s
import configparser
import io
import _thread as thread
else:
string_types = basestring,
text_type = unicode
string_type = str
def u(s):
return unicode(s, "utf-8")
def ue(s):
return unicode(s, "unicode_escape")
import ConfigParser as configparser # noqa
import StringIO as io # noqa
callable = callable # noqa
import thread # noqa
ArgSpec = collections.namedtuple(
"ArgSpec", ["args", "varargs", "keywords", "defaults"]
)
if py3k:
import collections
ArgSpec = collections.namedtuple(
"ArgSpec",
["args", "varargs", "keywords", "defaults"])
def inspect_getfullargspec(func):
"""Fully vendored version of getfullargspec from Python 3.3.
from inspect import getfullargspec as inspect_getfullargspec
This version is more performant than the one which appeared in
later Python 3 versions.
def inspect_getargspec(func):
return ArgSpec(
*inspect_getfullargspec(func)[0:4]
)
else:
from inspect import getargspec as inspect_getargspec # noqa
"""
if py3k or jython:
import pickle
else:
import cPickle as pickle # noqa
# if a Signature is already present, as is the case with newer
# "decorator" package, defer back to built in
if hasattr(func, "__signature__"):
return inspect.getfullargspec(func)
if py3k:
def read_config_file(config, fileobj):
return config.read_file(fileobj)
else:
def read_config_file(config, fileobj):
return config.readfp(fileobj)
if inspect.ismethod(func):
func = func.__func__
if not inspect.isfunction(func):
raise TypeError("{!r} is not a Python function".format(func))
co = func.__code__
if not inspect.iscode(co):
raise TypeError("{!r} is not a code object".format(co))
nargs = co.co_argcount
names = co.co_varnames
nkwargs = co.co_kwonlyargcount
args = list(names[:nargs])
kwonlyargs = list(names[nargs : nargs + nkwargs])
nargs += nkwargs
varargs = None
if co.co_flags & inspect.CO_VARARGS:
varargs = co.co_varnames[nargs]
nargs = nargs + 1
varkw = None
if co.co_flags & inspect.CO_VARKEYWORDS:
varkw = co.co_varnames[nargs]
return FullArgSpec(
args,
varargs,
varkw,
func.__defaults__,
kwonlyargs,
func.__kwdefaults__,
func.__annotations__,
)
def timedelta_total_seconds(td):
if py27:
return td.total_seconds()
else:
return (td.microseconds + (
td.seconds + td.days * 24 * 3600) * 1e6) / 1e6
def inspect_getargspec(func):
return ArgSpec(*inspect_getfullargspec(func)[0:4])

View file

@ -1,44 +1,54 @@
import re
import abc
import collections
from . import compat
import re
import threading
from typing import MutableMapping
from typing import MutableSet
import stevedore
def coerce_string_conf(d):
result = {}
for k, v in d.items():
if not isinstance(v, compat.string_types):
if not isinstance(v, str):
result[k] = v
continue
v = v.strip()
if re.match(r'^[-+]?\d+$', v):
if re.match(r"^[-+]?\d+$", v):
result[k] = int(v)
elif re.match(r'^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?$', v):
elif re.match(r"^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?$", v):
result[k] = float(v)
elif v.lower() in ('false', 'true'):
result[k] = v.lower() == 'true'
elif v == 'None':
elif v.lower() in ("false", "true"):
result[k] = v.lower() == "true"
elif v == "None":
result[k] = None
else:
result[k] = v
return result
class PluginLoader(object):
class PluginLoader:
def __init__(self, group):
self.group = group
self.impls = {}
self.impls = {} # loaded plugins
self._mgr = None # lazily defined stevedore manager
self._unloaded = {} # plugins registered but not loaded
def load(self, name):
if name in self._unloaded:
self.impls[name] = self._unloaded[name]()
return self.impls[name]
if name in self.impls:
return self.impls[name]()
return self.impls[name]
else: # pragma NO COVERAGE
import pkg_resources
for impl in pkg_resources.iter_entry_points(
self.group, name):
self.impls[name] = impl.load
return impl.load()
else:
if self._mgr is None:
self._mgr = stevedore.ExtensionManager(self.group)
try:
self.impls[name] = self._mgr[name].plugin
return self.impls[name]
except KeyError:
raise self.NotFound(
"Can't load plugin %s %s" % (self.group, name)
)
@ -47,14 +57,16 @@ class PluginLoader(object):
def load():
mod = __import__(modulepath, fromlist=[objname])
return getattr(mod, objname)
self.impls[name] = load
self._unloaded[name] = load
class NotFound(Exception):
"""The specified plugin could not be found."""
class memoized_property(object):
class memoized_property:
"""A read-only @property that is only evaluated once."""
def __init__(self, fget, doc=None):
self.fget = fget
self.__doc__ = doc or fget.__doc__
@ -77,9 +89,23 @@ def to_list(x, default=None):
return x
class KeyReentrantMutex(object):
class Mutex(abc.ABC):
@abc.abstractmethod
def acquire(self, wait: bool = True) -> bool:
raise NotImplementedError()
def __init__(self, key, mutex, keys):
@abc.abstractmethod
def release(self) -> None:
raise NotImplementedError()
class KeyReentrantMutex:
def __init__(
self,
key: str,
mutex: Mutex,
keys: MutableMapping[int, MutableSet[str]],
):
self.key = key
self.mutex = mutex
self.keys = keys
@ -89,17 +115,19 @@ class KeyReentrantMutex(object):
# this collection holds zero or one
# thread idents as the key; a set of
# keynames held as the value.
keystore = collections.defaultdict(set)
keystore: MutableMapping[
int, MutableSet[str]
] = collections.defaultdict(set)
def fac(key):
return KeyReentrantMutex(key, mutex, keystore)
return fac
def acquire(self, wait=True):
current_thread = compat.threading.current_thread().ident
current_thread = threading.get_ident()
keys = self.keys.get(current_thread)
if keys is not None and \
self.key not in keys:
if keys is not None and self.key not in keys:
# current lockholder, new key. add it in
keys.add(self.key)
return True
@ -111,7 +139,7 @@ class KeyReentrantMutex(object):
return False
def release(self):
current_thread = compat.threading.current_thread().ident
current_thread = threading.get_ident()
keys = self.keys.get(current_thread)
assert keys is not None, "this thread didn't do the acquire"
assert self.key in keys, "No acquire held for key '%s'" % self.key
@ -121,3 +149,10 @@ class KeyReentrantMutex(object):
# the thread ident and unlock.
del self.keys[current_thread]
self.mutex.release()
def locked(self):
current_thread = threading.get_ident()
keys = self.keys.get(current_thread)
if keys is None:
return False
return self.key in keys

View file

@ -1,4 +1,7 @@
from .compat import threading
import threading
from typing import Any
from typing import Callable
from typing import MutableMapping
import weakref
@ -37,19 +40,16 @@ class NameRegistry(object):
method.
"""
_locks = weakref.WeakValueDictionary()
_mutex = threading.RLock()
def __init__(self, creator):
"""Create a new :class:`.NameRegistry`.
"""
self._values = weakref.WeakValueDictionary()
def __init__(self, creator: Callable[..., Any]):
"""Create a new :class:`.NameRegistry`."""
self._values: MutableMapping[str, Any] = weakref.WeakValueDictionary()
self._mutex = threading.RLock()
self.creator = creator
def get(self, identifier, *args, **kw):
def get(self, identifier: str, *args: Any, **kw: Any) -> Any:
r"""Get and possibly create the value.
:param identifier: Hash key for the value.
@ -68,7 +68,7 @@ class NameRegistry(object):
except KeyError:
return self._sync_get(identifier, *args, **kw)
def _sync_get(self, identifier, *args, **kw):
def _sync_get(self, identifier: str, *args: Any, **kw: Any) -> Any:
self._mutex.acquire()
try:
try:
@ -76,11 +76,13 @@ class NameRegistry(object):
return self._values[identifier]
else:
self._values[identifier] = value = self.creator(
identifier, *args, **kw)
identifier, *args, **kw
)
return value
except KeyError:
self._values[identifier] = value = self.creator(
identifier, *args, **kw)
identifier, *args, **kw
)
return value
finally:
self._mutex.release()

View file

@ -1,6 +1,6 @@
from .compat import threading
import logging
import threading
log = logging.getLogger(__name__)
@ -62,13 +62,15 @@ class ReadWriteMutex(object):
# check if we are the last asynchronous reader thread
# out the door.
if self.async_ == 0:
# yes. so if a sync operation is waiting, notifyAll to wake
# yes. so if a sync operation is waiting, notify_all to wake
# it up
if self.current_sync_operation is not None:
self.condition.notifyAll()
self.condition.notify_all()
elif self.async_ < 0:
raise LockError("Synchronizer error - too many "
"release_read_locks called")
raise LockError(
"Synchronizer error - too many "
"release_read_locks called"
)
log.debug("%s released read lock", self)
finally:
self.condition.release()
@ -93,7 +95,7 @@ class ReadWriteMutex(object):
# establish ourselves as the current sync
# this indicates to other read/write operations
# that they should wait until this is None again
self.current_sync_operation = threading.currentThread()
self.current_sync_operation = threading.current_thread()
# now wait again for asyncs to finish
if self.async_ > 0:
@ -115,16 +117,18 @@ class ReadWriteMutex(object):
"""Release the 'write' lock."""
self.condition.acquire()
try:
if self.current_sync_operation is not threading.currentThread():
raise LockError("Synchronizer error - current thread doesn't "
"have the write lock")
if self.current_sync_operation is not threading.current_thread():
raise LockError(
"Synchronizer error - current thread doesn't "
"have the write lock"
)
# reset the current sync operation so
# another can get it
self.current_sync_operation = None
# tell everyone to get ready
self.condition.notifyAll()
self.condition.notify_all()
log.debug("%s released write lock", self)
finally:

View file

@ -0,0 +1,631 @@
import io
import os
import re
import abc
import csv
import sys
import zipp
import email
import pathlib
import operator
import functools
import itertools
import posixpath
import collections
from ._compat import (
NullFinder,
PyPy_repr,
install,
)
from configparser import ConfigParser
from contextlib import suppress
from importlib import import_module
from importlib.abc import MetaPathFinder
from itertools import starmap
__all__ = [
'Distribution',
'DistributionFinder',
'PackageNotFoundError',
'distribution',
'distributions',
'entry_points',
'files',
'metadata',
'requires',
'version',
]
class PackageNotFoundError(ModuleNotFoundError):
"""The package was not found."""
def __str__(self):
tmpl = "No package metadata was found for {self.name}"
return tmpl.format(**locals())
@property
def name(self):
(name,) = self.args
return name
class EntryPoint(
PyPy_repr, collections.namedtuple('EntryPointBase', 'name value group')
):
"""An entry point as defined by Python packaging conventions.
See `the packaging docs on entry points
<https://packaging.python.org/specifications/entry-points/>`_
for more information.
"""
pattern = re.compile(
r'(?P<module>[\w.]+)\s*'
r'(:\s*(?P<attr>[\w.]+))?\s*'
r'(?P<extras>\[.*\])?\s*$'
)
"""
A regular expression describing the syntax for an entry point,
which might look like:
- module
- package.module
- package.module:attribute
- package.module:object.attribute
- package.module:attr [extra1, extra2]
Other combinations are possible as well.
The expression is lenient about whitespace around the ':',
following the attr, and following any extras.
"""
def load(self):
"""Load the entry point from its definition. If only a module
is indicated by the value, return that module. Otherwise,
return the named object.
"""
match = self.pattern.match(self.value)
module = import_module(match.group('module'))
attrs = filter(None, (match.group('attr') or '').split('.'))
return functools.reduce(getattr, attrs, module)
@property
def module(self):
match = self.pattern.match(self.value)
return match.group('module')
@property
def attr(self):
match = self.pattern.match(self.value)
return match.group('attr')
@property
def extras(self):
match = self.pattern.match(self.value)
return list(re.finditer(r'\w+', match.group('extras') or ''))
@classmethod
def _from_config(cls, config):
return [
cls(name, value, group)
for group in config.sections()
for name, value in config.items(group)
]
@classmethod
def _from_text(cls, text):
config = ConfigParser(delimiters='=')
# case sensitive: https://stackoverflow.com/q/1611799/812183
config.optionxform = str
try:
config.read_string(text)
except AttributeError: # pragma: nocover
# Python 2 has no read_string
config.readfp(io.StringIO(text))
return EntryPoint._from_config(config)
def __iter__(self):
"""
Supply iter so one may construct dicts of EntryPoints easily.
"""
return iter((self.name, self))
def __reduce__(self):
return (
self.__class__,
(self.name, self.value, self.group),
)
class PackagePath(pathlib.PurePosixPath):
"""A reference to a path in a package"""
def read_text(self, encoding='utf-8'):
with self.locate().open(encoding=encoding) as stream:
return stream.read()
def read_binary(self):
with self.locate().open('rb') as stream:
return stream.read()
def locate(self):
"""Return a path-like object for this path"""
return self.dist.locate_file(self)
class FileHash:
def __init__(self, spec):
self.mode, _, self.value = spec.partition('=')
def __repr__(self):
return '<FileHash mode: {} value: {}>'.format(self.mode, self.value)
class Distribution:
"""A Python distribution package."""
@abc.abstractmethod
def read_text(self, filename):
"""Attempt to load metadata file given by the name.
:param filename: The name of the file in the distribution info.
:return: The text if found, otherwise None.
"""
@abc.abstractmethod
def locate_file(self, path):
"""
Given a path to a file in this distribution, return a path
to it.
"""
@classmethod
def from_name(cls, name):
"""Return the Distribution for the given package name.
:param name: The name of the distribution package to search for.
:return: The Distribution instance (or subclass thereof) for the named
package, if found.
:raises PackageNotFoundError: When the named package's distribution
metadata cannot be found.
"""
for resolver in cls._discover_resolvers():
dists = resolver(DistributionFinder.Context(name=name))
dist = next(iter(dists), None)
if dist is not None:
return dist
else:
raise PackageNotFoundError(name)
@classmethod
def discover(cls, **kwargs):
"""Return an iterable of Distribution objects for all packages.
Pass a ``context`` or pass keyword arguments for constructing
a context.
:context: A ``DistributionFinder.Context`` object.
:return: Iterable of Distribution objects for all packages.
"""
context = kwargs.pop('context', None)
if context and kwargs:
raise ValueError("cannot accept context and kwargs")
context = context or DistributionFinder.Context(**kwargs)
return itertools.chain.from_iterable(
resolver(context) for resolver in cls._discover_resolvers()
)
@staticmethod
def at(path):
"""Return a Distribution for the indicated metadata path
:param path: a string or path-like object
:return: a concrete Distribution instance for the path
"""
return PathDistribution(pathlib.Path(path))
@staticmethod
def _discover_resolvers():
"""Search the meta_path for resolvers."""
declared = (
getattr(finder, 'find_distributions', None) for finder in sys.meta_path
)
return filter(None, declared)
@classmethod
def _local(cls, root='.'):
from pep517 import build, meta
system = build.compat_system(root)
builder = functools.partial(
meta.build,
source_dir=root,
system=system,
)
return PathDistribution(zipp.Path(meta.build_as_zip(builder)))
@property
def metadata(self):
"""Return the parsed metadata for this Distribution.
The returned object will have keys that name the various bits of
metadata. See PEP 566 for details.
"""
text = (
self.read_text('METADATA')
or self.read_text('PKG-INFO')
# This last clause is here to support old egg-info files. Its
# effect is to just end up using the PathDistribution's self._path
# (which points to the egg-info file) attribute unchanged.
or self.read_text('')
)
return email.message_from_string(text)
@property
def version(self):
"""Return the 'Version' metadata for the distribution package."""
return self.metadata['Version']
@property
def entry_points(self):
return EntryPoint._from_text(self.read_text('entry_points.txt'))
@property
def files(self):
"""Files in this distribution.
:return: List of PackagePath for this distribution or None
Result is `None` if the metadata file that enumerates files
(i.e. RECORD for dist-info or SOURCES.txt for egg-info) is
missing.
Result may be empty if the metadata exists but is empty.
"""
file_lines = self._read_files_distinfo() or self._read_files_egginfo()
def make_file(name, hash=None, size_str=None):
result = PackagePath(name)
result.hash = FileHash(hash) if hash else None
result.size = int(size_str) if size_str else None
result.dist = self
return result
return file_lines and list(starmap(make_file, csv.reader(file_lines)))
def _read_files_distinfo(self):
"""
Read the lines of RECORD
"""
text = self.read_text('RECORD')
return text and text.splitlines()
def _read_files_egginfo(self):
"""
SOURCES.txt might contain literal commas, so wrap each line
in quotes.
"""
text = self.read_text('SOURCES.txt')
return text and map('"{}"'.format, text.splitlines())
@property
def requires(self):
"""Generated requirements specified for this Distribution"""
reqs = self._read_dist_info_reqs() or self._read_egg_info_reqs()
return reqs and list(reqs)
def _read_dist_info_reqs(self):
return self.metadata.get_all('Requires-Dist')
def _read_egg_info_reqs(self):
source = self.read_text('requires.txt')
return source and self._deps_from_requires_text(source)
@classmethod
def _deps_from_requires_text(cls, source):
section_pairs = cls._read_sections(source.splitlines())
sections = {
section: list(map(operator.itemgetter('line'), results))
for section, results in itertools.groupby(
section_pairs, operator.itemgetter('section')
)
}
return cls._convert_egg_info_reqs_to_simple_reqs(sections)
@staticmethod
def _read_sections(lines):
section = None
for line in filter(None, lines):
section_match = re.match(r'\[(.*)\]$', line)
if section_match:
section = section_match.group(1)
continue
yield locals()
@staticmethod
def _convert_egg_info_reqs_to_simple_reqs(sections):
"""
Historically, setuptools would solicit and store 'extra'
requirements, including those with environment markers,
in separate sections. More modern tools expect each
dependency to be defined separately, with any relevant
extras and environment markers attached directly to that
requirement. This method converts the former to the
latter. See _test_deps_from_requires_text for an example.
"""
def make_condition(name):
return name and 'extra == "{name}"'.format(name=name)
def parse_condition(section):
section = section or ''
extra, sep, markers = section.partition(':')
if extra and markers:
markers = '({markers})'.format(markers=markers)
conditions = list(filter(None, [markers, make_condition(extra)]))
return '; ' + ' and '.join(conditions) if conditions else ''
for section, deps in sections.items():
for dep in deps:
yield dep + parse_condition(section)
class DistributionFinder(MetaPathFinder):
"""
A MetaPathFinder capable of discovering installed distributions.
"""
class Context:
"""
Keyword arguments presented by the caller to
``distributions()`` or ``Distribution.discover()``
to narrow the scope of a search for distributions
in all DistributionFinders.
Each DistributionFinder may expect any parameters
and should attempt to honor the canonical
parameters defined below when appropriate.
"""
name = None
"""
Specific name for which a distribution finder should match.
A name of ``None`` matches all distributions.
"""
def __init__(self, **kwargs):
vars(self).update(kwargs)
@property
def path(self):
"""
The path that a distribution finder should search.
Typically refers to Python package paths and defaults
to ``sys.path``.
"""
return vars(self).get('path', sys.path)
@abc.abstractmethod
def find_distributions(self, context=Context()):
"""
Find distributions.
Return an iterable of all Distribution instances capable of
loading the metadata for packages matching the ``context``,
a DistributionFinder.Context instance.
"""
class FastPath:
"""
Micro-optimized class for searching a path for
children.
"""
def __init__(self, root):
self.root = str(root)
self.base = os.path.basename(self.root).lower()
def joinpath(self, child):
return pathlib.Path(self.root, child)
def children(self):
with suppress(Exception):
return os.listdir(self.root or '')
with suppress(Exception):
return self.zip_children()
return []
def zip_children(self):
zip_path = zipp.Path(self.root)
names = zip_path.root.namelist()
self.joinpath = zip_path.joinpath
return dict.fromkeys(child.split(posixpath.sep, 1)[0] for child in names)
def search(self, name):
return (
self.joinpath(child)
for child in self.children()
if name.matches(child, self.base)
)
class Prepared:
"""
A prepared search for metadata on a possibly-named package.
"""
normalized = None
suffixes = '.dist-info', '.egg-info'
exact_matches = [''][:0]
def __init__(self, name):
self.name = name
if name is None:
return
self.normalized = self.normalize(name)
self.exact_matches = [self.normalized + suffix for suffix in self.suffixes]
@staticmethod
def normalize(name):
"""
PEP 503 normalization plus dashes as underscores.
"""
return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_')
@staticmethod
def legacy_normalize(name):
"""
Normalize the package name as found in the convention in
older packaging tools versions and specs.
"""
return name.lower().replace('-', '_')
def matches(self, cand, base):
low = cand.lower()
pre, ext = os.path.splitext(low)
name, sep, rest = pre.partition('-')
return (
low in self.exact_matches
or ext in self.suffixes
and (not self.normalized or name.replace('.', '_') == self.normalized)
# legacy case:
or self.is_egg(base)
and low == 'egg-info'
)
def is_egg(self, base):
normalized = self.legacy_normalize(self.name or '')
prefix = normalized + '-' if normalized else ''
versionless_egg_name = normalized + '.egg' if self.name else ''
return (
base == versionless_egg_name
or base.startswith(prefix)
and base.endswith('.egg')
)
@install
class MetadataPathFinder(NullFinder, DistributionFinder):
"""A degenerate finder for distribution packages on the file system.
This finder supplies only a find_distributions() method for versions
of Python that do not have a PathFinder find_distributions().
"""
def find_distributions(self, context=DistributionFinder.Context()):
"""
Find distributions.
Return an iterable of all Distribution instances capable of
loading the metadata for packages matching ``context.name``
(or all names if ``None`` indicated) along the paths in the list
of directories ``context.path``.
"""
found = self._search_paths(context.name, context.path)
return map(PathDistribution, found)
@classmethod
def _search_paths(cls, name, paths):
"""Find metadata directories in paths heuristically."""
return itertools.chain.from_iterable(
path.search(Prepared(name)) for path in map(FastPath, paths)
)
class PathDistribution(Distribution):
def __init__(self, path):
"""Construct a distribution from a path to the metadata directory.
:param path: A pathlib.Path or similar object supporting
.joinpath(), __div__, .parent, and .read_text().
"""
self._path = path
def read_text(self, filename):
with suppress(
FileNotFoundError,
IsADirectoryError,
KeyError,
NotADirectoryError,
PermissionError,
):
return self._path.joinpath(filename).read_text(encoding='utf-8')
read_text.__doc__ = Distribution.read_text.__doc__
def locate_file(self, path):
return self._path.parent / path
def distribution(distribution_name):
"""Get the ``Distribution`` instance for the named package.
:param distribution_name: The name of the distribution package as a string.
:return: A ``Distribution`` instance (or subclass thereof).
"""
return Distribution.from_name(distribution_name)
def distributions(**kwargs):
"""Get all ``Distribution`` instances in the current environment.
:return: An iterable of ``Distribution`` instances.
"""
return Distribution.discover(**kwargs)
def metadata(distribution_name):
"""Get the metadata for the named package.
:param distribution_name: The name of the distribution package to query.
:return: An email.Message containing the parsed metadata.
"""
return Distribution.from_name(distribution_name).metadata
def version(distribution_name):
"""Get the version string for the named package.
:param distribution_name: The name of the distribution package to query.
:return: The version string for the package as defined in the package's
"Version" metadata key.
"""
return distribution(distribution_name).version
def entry_points():
"""Return EntryPoint objects for all installed packages.
:return: EntryPoint objects for all installed packages.
"""
eps = itertools.chain.from_iterable(dist.entry_points for dist in distributions())
by_group = operator.attrgetter('group')
ordered = sorted(eps, key=by_group)
grouped = itertools.groupby(ordered, by_group)
return {group: tuple(eps) for group, eps in grouped}
def files(distribution_name):
"""Return a list of files for the named package.
:param distribution_name: The name of the distribution package to query.
:return: List of files composing the distribution.
"""
return distribution(distribution_name).files
def requires(distribution_name):
"""
Return a list of requirements for the named package.
:return: An iterator of requirements, suitable for
packaging.requirement.Requirement.
"""
return distribution(distribution_name).requires

View file

@ -0,0 +1,75 @@
import sys
__all__ = ['install', 'NullFinder', 'PyPy_repr']
def install(cls):
"""
Class decorator for installation on sys.meta_path.
Adds the backport DistributionFinder to sys.meta_path and
attempts to disable the finder functionality of the stdlib
DistributionFinder.
"""
sys.meta_path.append(cls())
disable_stdlib_finder()
return cls
def disable_stdlib_finder():
"""
Give the backport primacy for discovering path-based distributions
by monkey-patching the stdlib O_O.
See #91 for more background for rationale on this sketchy
behavior.
"""
def matches(finder):
return getattr(
finder, '__module__', None
) == '_frozen_importlib_external' and hasattr(finder, 'find_distributions')
for finder in filter(matches, sys.meta_path): # pragma: nocover
del finder.find_distributions
class NullFinder:
"""
A "Finder" (aka "MetaClassFinder") that never finds any modules,
but may find distributions.
"""
@staticmethod
def find_spec(*args, **kwargs):
return None
# In Python 2, the import system requires finders
# to have a find_module() method, but this usage
# is deprecated in Python 3 in favor of find_spec().
# For the purposes of this finder (i.e. being present
# on sys.meta_path but having no other import
# system functionality), the two methods are identical.
find_module = find_spec
class PyPy_repr:
"""
Override repr for EntryPoint objects on PyPy to avoid __iter__ access.
Ref #97, #102.
"""
affected = hasattr(sys, 'pypy_version_info')
def __compat_repr__(self): # pragma: nocover
def make_param(name):
value = getattr(self, name)
return '{name}={value!r}'.format(**locals())
params = ', '.join(map(make_param, self._fields))
return 'EntryPoint({params})'.format(**locals())
if affected: # pragma: nocover
__repr__ = __compat_repr__
del affected

61
libs/common/pbr/build.py Normal file
View file

@ -0,0 +1,61 @@
# Copyright 2021 Monty Taylor
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""pep-517 support
Add::
[build-system]
requires = ["pbr>=5.7.0", "setuptools>=36.6.0", "wheel"]
build-backend = "pbr.build"
to pyproject.toml to use this
"""
from setuptools import build_meta
__all__ = [
'get_requires_for_build_sdist',
'get_requires_for_build_wheel',
'prepare_metadata_for_build_wheel',
'build_wheel',
'build_sdist',
]
def get_requires_for_build_wheel(config_settings=None):
return build_meta.get_requires_for_build_wheel(config_settings)
def get_requires_for_build_sdist(config_settings=None):
return build_meta.get_requires_for_build_sdist(config_settings)
def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None):
return build_meta.prepare_metadata_for_build_wheel(
metadata_directory, config_settings)
def build_wheel(
wheel_directory,
config_settings=None,
metadata_directory=None,
):
return build_meta.build_wheel(
wheel_directory, config_settings, metadata_directory,
)
def build_sdist(sdist_directory, config_settings=None):
return build_meta.build_sdist(sdist_directory, config_settings)

View file

@ -132,11 +132,11 @@ class LocalBuildDoc(setup_command.BuildDoc):
autoindex.write(" %s.rst\n" % module)
def _sphinx_tree(self):
source_dir = self._get_source_dir()
cmd = ['-H', 'Modules', '-o', source_dir, '.']
if apidoc_use_padding:
cmd.insert(0, 'apidoc')
apidoc.main(cmd + self.autodoc_tree_excludes)
source_dir = self._get_source_dir()
cmd = ['-H', 'Modules', '-o', source_dir, '.']
if apidoc_use_padding:
cmd.insert(0, 'apidoc')
apidoc.main(cmd + self.autodoc_tree_excludes)
def _sphinx_run(self):
if not self.verbose:

View file

@ -40,8 +40,11 @@ def get_sha(args):
def get_info(args):
print("{name}\t{version}\t{released}\t{sha}".format(
**_get_info(args.name)))
if args.short:
print("{version}".format(**_get_info(args.name)))
else:
print("{name}\t{version}\t{released}\t{sha}".format(
**_get_info(args.name)))
def _get_info(name):
@ -86,7 +89,9 @@ def main():
version=str(pbr.version.VersionInfo('pbr')))
subparsers = parser.add_subparsers(
title='commands', description='valid commands', help='additional help')
title='commands', description='valid commands', help='additional help',
dest='cmd')
subparsers.required = True
cmd_sha = subparsers.add_parser('sha', help='print sha of package')
cmd_sha.set_defaults(func=get_sha)
@ -96,6 +101,8 @@ def main():
'info', help='print version info for package')
cmd_info.set_defaults(func=get_info)
cmd_info.add_argument('name', help='package to print info of')
cmd_info.add_argument('-s', '--short', action="store_true",
help='only display package version')
cmd_freeze = subparsers.add_parser(
'freeze', help='print version info for all installed packages')

View file

@ -61,6 +61,11 @@ else:
integer_types = (int, long) # noqa
# We use this canary to detect whether the module has already been called,
# in order to avoid recursion
in_use = False
def pbr(dist, attr, value):
"""Implements the actual pbr setup() keyword.
@ -81,6 +86,16 @@ def pbr(dist, attr, value):
not work well with distributions that do use a `Distribution` subclass.
"""
# Distribution.finalize_options() is what calls this method. That means
# there is potential for recursion here. Recursion seems to be an issue
# particularly when using PEP517 build-system configs without
# setup_requires in setup.py. We can avoid the recursion by setting
# this canary so we don't repeat ourselves.
global in_use
if in_use:
return
in_use = True
if not value:
return
if isinstance(value, string_type):

View file

@ -156,9 +156,9 @@ def _clean_changelog_message(msg):
* Escapes '`' which is interpreted as a literal
"""
msg = msg.replace('*', '\*')
msg = msg.replace('_', '\_')
msg = msg.replace('`', '\`')
msg = msg.replace('*', r'\*')
msg = msg.replace('_', r'\_')
msg = msg.replace('`', r'\`')
return msg
@ -223,6 +223,11 @@ def _iter_log_inner(git_dir):
presentation logic to the output - making it suitable for different
uses.
.. caution:: this function risk to return a tag that doesn't exist really
inside the git objects list due to replacement made
to tag name to also list pre-release suffix.
Compliant with the SemVer specification (e.g 1.2.3-rc1)
:return: An iterator of (hash, tags_set, 1st_line) tuples.
"""
log.info('[pbr] Generating ChangeLog')
@ -248,7 +253,7 @@ def _iter_log_inner(git_dir):
for tag_string in refname.split("refs/tags/")[1:]:
# git tag does not allow : or " " in tag names, so we split
# on ", " which is the separator between elements
candidate = tag_string.split(", ")[0]
candidate = tag_string.split(", ")[0].replace("-", ".")
if _is_valid_version(candidate):
tags.add(candidate)
@ -271,13 +276,14 @@ def write_git_changelog(git_dir=None, dest_dir=os.path.curdir,
changelog = _iter_changelog(changelog)
if not changelog:
return
new_changelog = os.path.join(dest_dir, 'ChangeLog')
# If there's already a ChangeLog and it's not writable, just use it
if (os.path.exists(new_changelog)
and not os.access(new_changelog, os.W_OK)):
if os.path.exists(new_changelog) and not os.access(new_changelog, os.W_OK):
# If there's already a ChangeLog and it's not writable, just use it
log.info('[pbr] ChangeLog not written (file already'
' exists and it is not writeable)')
return
log.info('[pbr] Writing ChangeLog')
with io.open(new_changelog, "w", encoding="utf-8") as changelog_file:
for release, content in changelog:
@ -292,13 +298,14 @@ def generate_authors(git_dir=None, dest_dir='.', option_dict=dict()):
'SKIP_GENERATE_AUTHORS')
if should_skip:
return
start = time.time()
old_authors = os.path.join(dest_dir, 'AUTHORS.in')
new_authors = os.path.join(dest_dir, 'AUTHORS')
# If there's already an AUTHORS file and it's not writable, just use it
if (os.path.exists(new_authors)
and not os.access(new_authors, os.W_OK)):
if os.path.exists(new_authors) and not os.access(new_authors, os.W_OK):
# If there's already an AUTHORS file and it's not writable, just use it
return
log.info('[pbr] Generating AUTHORS')
ignore_emails = '((jenkins|zuul)@review|infra@lists|jenkins@openstack)'
if git_dir is None:

View file

@ -14,6 +14,7 @@
# under the License.
import os
import shlex
import sys
from pbr import find_package
@ -35,6 +36,21 @@ def get_man_section(section):
return os.path.join(get_manpath(), 'man%s' % section)
def unquote_path(path):
# unquote the full path, e.g: "'a/full/path'" becomes "a/full/path", also
# strip the quotes off individual path components because os.walk cannot
# handle paths like: "'i like spaces'/'another dir'", so we will pass it
# "i like spaces/another dir" instead.
if os.name == 'nt':
# shlex cannot handle paths that contain backslashes, treating those
# as escape characters.
path = path.replace("\\", "/")
return "".join(shlex.split(path)).replace("/", "\\")
return "".join(shlex.split(path))
class FilesConfig(base.BaseConfig):
section = 'files'
@ -57,21 +73,28 @@ class FilesConfig(base.BaseConfig):
target = target.strip()
if not target.endswith(os.path.sep):
target += os.path.sep
for (dirpath, dirnames, fnames) in os.walk(source_prefix):
finished.append(
"%s = " % dirpath.replace(source_prefix, target))
unquoted_prefix = unquote_path(source_prefix)
unquoted_target = unquote_path(target)
for (dirpath, dirnames, fnames) in os.walk(unquoted_prefix):
# As source_prefix is always matched, using replace with a
# a limit of one is always going to replace the path prefix
# and not accidentally replace some text in the middle of
# the path
new_prefix = dirpath.replace(unquoted_prefix,
unquoted_target, 1)
finished.append("'%s' = " % new_prefix)
finished.extend(
[" %s" % os.path.join(dirpath, f) for f in fnames])
[" '%s'" % os.path.join(dirpath, f) for f in fnames])
else:
finished.append(line)
self.data_files = "\n".join(finished)
def add_man_path(self, man_path):
self.data_files = "%s\n%s =" % (self.data_files, man_path)
self.data_files = "%s\n'%s' =" % (self.data_files, man_path)
def add_man_page(self, man_page):
self.data_files = "%s\n %s" % (self.data_files, man_page)
self.data_files = "%s\n '%s'" % (self.data_files, man_page)
def get_man_sections(self):
man_sections = dict()

View file

@ -48,6 +48,6 @@ TRUE_VALUES = ('true', '1', 'yes')
def get_boolean_option(option_dict, option_name, env_name):
return ((option_name in option_dict
and option_dict[option_name][1].lower() in TRUE_VALUES) or
return ((option_name in option_dict and
option_dict[option_name][1].lower() in TRUE_VALUES) or
str(os.getenv(env_name)).lower() in TRUE_VALUES)

View file

@ -22,6 +22,16 @@ from __future__ import unicode_literals
from distutils.command import install as du_install
from distutils import log
# (hberaud) do not use six here to import urlparse
# to keep this module free from external dependencies
# to avoid cross dependencies errors on minimal system
# free from dependencies.
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
import email
import email.errors
import os
@ -98,19 +108,31 @@ def get_reqs_from_files(requirements_files):
return []
def egg_fragment(match):
return re.sub(r'(?P<PackageName>[\w.-]+)-'
r'(?P<GlobalVersion>'
r'(?P<VersionTripple>'
r'(?P<Major>0|[1-9][0-9]*)\.'
r'(?P<Minor>0|[1-9][0-9]*)\.'
r'(?P<Patch>0|[1-9][0-9]*)){1}'
r'(?P<Tags>(?:\-'
r'(?P<Prerelease>(?:(?=[0]{1}[0-9A-Za-z-]{0})(?:[0]{1})|'
r'(?=[1-9]{1}[0-9]*[A-Za-z]{0})(?:[0-9]+)|'
r'(?=[0-9]*[A-Za-z-]+[0-9A-Za-z-]*)(?:[0-9A-Za-z-]+)){1}'
r'(?:\.(?=[0]{1}[0-9A-Za-z-]{0})(?:[0]{1})|'
r'\.(?=[1-9]{1}[0-9]*[A-Za-z]{0})(?:[0-9]+)|'
r'\.(?=[0-9]*[A-Za-z-]+[0-9A-Za-z-]*)'
r'(?:[0-9A-Za-z-]+))*){1}){0,1}(?:\+'
r'(?P<Meta>(?:[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))){0,1}))',
r'\g<PackageName>>=\g<GlobalVersion>',
match.groups()[-1])
def parse_requirements(requirements_files=None, strip_markers=False):
if requirements_files is None:
requirements_files = get_requirements_files()
def egg_fragment(match):
# take a versioned egg fragment and return a
# versioned package requirement e.g.
# nova-1.2.3 becomes nova>=1.2.3
return re.sub(r'([\w.]+)-([\w.-]+)',
r'\1>=\2',
match.groups()[-1])
requirements = []
for line in get_reqs_from_files(requirements_files):
# Ignore comments
@ -118,7 +140,8 @@ def parse_requirements(requirements_files=None, strip_markers=False):
continue
# Ignore index URL lines
if re.match(r'^\s*(-i|--index-url|--extra-index-url).*', line):
if re.match(r'^\s*(-i|--index-url|--extra-index-url|--find-links).*',
line):
continue
# Handle nested requirements files such as:
@ -140,16 +163,19 @@ def parse_requirements(requirements_files=None, strip_markers=False):
# -e git://github.com/openstack/nova/master#egg=nova
# -e git://github.com/openstack/nova/master#egg=nova-1.2.3
# -e git+https://foo.com/zipball#egg=bar&subdirectory=baz
if re.match(r'\s*-e\s+', line):
line = re.sub(r'\s*-e\s+.*#egg=([^&]+).*$', egg_fragment, line)
# such as:
# http://github.com/openstack/nova/zipball/master#egg=nova
# http://github.com/openstack/nova/zipball/master#egg=nova-1.2.3
# git+https://foo.com/zipball#egg=bar&subdirectory=baz
elif re.match(r'\s*(https?|git(\+(https|ssh))?):', line):
line = re.sub(r'\s*(https?|git(\+(https|ssh))?):.*#egg=([^&]+).*$',
egg_fragment, line)
# git+[ssh]://github.com/openstack/nova/zipball/master#egg=nova-1.2.3
# hg+[ssh]://github.com/openstack/nova/zipball/master#egg=nova-1.2.3
# svn+[proto]://github.com/openstack/nova/zipball/master#egg=nova-1.2.3
# -f lines are for index locations, and don't get used here
if re.match(r'\s*-e\s+', line):
extract = re.match(r'\s*-e\s+(.*)$', line)
line = extract.group(1)
egg = urlparse(line)
if egg.scheme:
line = re.sub(r'egg=([^&]+).*$', egg_fragment, egg.fragment)
elif re.match(r'\s*-f\s+', line):
line = None
reason = 'Index Location'
@ -183,7 +209,7 @@ def parse_dependency_links(requirements_files=None):
if re.match(r'\s*-[ef]\s+', line):
dependency_links.append(re.sub(r'\s*-[ef]\s+', '', line))
# lines that are only urls can go in unmolested
elif re.match(r'\s*(https?|git(\+(https|ssh))?):', line):
elif re.match(r'^\s*(https?|git(\+(https|ssh))?|svn|hg)\S*:', line):
dependency_links.append(line)
return dependency_links
@ -302,6 +328,7 @@ except ImportError:
def have_nose():
return _have_nose
_wsgi_text = """#PBR Generated from %(group)r
import threading
@ -404,9 +431,13 @@ def generate_script(group, entry_point, header, template):
def override_get_script_args(
dist, executable=os.path.normpath(sys.executable), is_wininst=False):
dist, executable=os.path.normpath(sys.executable)):
"""Override entrypoints console_script."""
header = easy_install.get_script_header("", executable, is_wininst)
# get_script_header() is deprecated since Setuptools 12.0
try:
header = easy_install.ScriptWriter.get_header("", executable)
except AttributeError:
header = easy_install.get_script_header("", executable)
for group, template in ENTRY_POINTS_MAP.items():
for name, ep in dist.get_entry_map(group).items():
yield (name, generate_script(group, ep, header, template))
@ -428,8 +459,12 @@ class LocalInstallScripts(install_scripts.install_scripts):
"""Intercepts console scripts entry_points."""
command_name = 'install_scripts'
def _make_wsgi_scripts_only(self, dist, executable, is_wininst):
header = easy_install.get_script_header("", executable, is_wininst)
def _make_wsgi_scripts_only(self, dist, executable):
# get_script_header() is deprecated since Setuptools 12.0
try:
header = easy_install.ScriptWriter.get_header("", executable)
except AttributeError:
header = easy_install.get_script_header("", executable)
wsgi_script_template = ENTRY_POINTS_MAP['wsgi_scripts']
for name, ep in dist.get_entry_map('wsgi_scripts').items():
content = generate_script(
@ -455,16 +490,12 @@ class LocalInstallScripts(install_scripts.install_scripts):
bs_cmd = self.get_finalized_command('build_scripts')
executable = getattr(
bs_cmd, 'executable', easy_install.sys_executable)
is_wininst = getattr(
self.get_finalized_command("bdist_wininst"), '_is_running', False
)
if 'bdist_wheel' in self.distribution.have_run:
# We're building a wheel which has no way of generating mod_wsgi
# scripts for us. Let's build them.
# NOTE(sigmavirus24): This needs to happen here because, as the
# comment below indicates, no_ep is True when building a wheel.
self._make_wsgi_scripts_only(dist, executable, is_wininst)
self._make_wsgi_scripts_only(dist, executable)
if self.no_ep:
# no_ep is True if we're installing into an .egg file or building
@ -478,7 +509,7 @@ class LocalInstallScripts(install_scripts.install_scripts):
get_script_args = easy_install.get_script_args
executable = '"%s"' % executable
for args in get_script_args(dist, executable, is_wininst):
for args in get_script_args(dist, executable):
self.write_script(*args)
@ -550,8 +581,9 @@ class LocalEggInfo(egg_info.egg_info):
else:
log.info("[pbr] Reusing existing SOURCES.txt")
self.filelist = egg_info.FileList()
for entry in open(manifest_filename, 'r').read().split('\n'):
self.filelist.append(entry)
with open(manifest_filename, 'r') as fil:
for entry in fil.read().split('\n'):
self.filelist.append(entry)
def _from_git(distribution):
@ -626,6 +658,7 @@ class LocalSDist(sdist.sdist):
self.filelist.sort()
sdist.sdist.make_distribution(self)
try:
from pbr import builddoc
_have_sphinx = True
@ -659,12 +692,14 @@ def _get_increment_kwargs(git_dir, tag):
# git log output affecting out ability to have working sem ver headers.
changelog = git._run_git_command(['log', '--pretty=%B', version_spec],
git_dir)
header_len = len('sem-ver:')
commands = [line[header_len:].strip() for line in changelog.split('\n')
if line.lower().startswith('sem-ver:')]
symbols = set()
for command in commands:
symbols.update([symbol.strip() for symbol in command.split(',')])
header = 'sem-ver:'
for line in changelog.split("\n"):
line = line.lower().strip()
if not line.lower().strip().startswith(header):
continue
new_symbols = line[len(header):].strip().split(",")
symbols.update([symbol.strip() for symbol in new_symbols])
def _handle_symbol(symbol, symbols, impact):
if symbol in symbols:
@ -791,12 +826,9 @@ def _get_version_from_pkg_metadata(package_name):
pkg_metadata = {}
for filename in pkg_metadata_filenames:
try:
pkg_metadata_file = open(filename, 'r')
except (IOError, OSError):
continue
try:
pkg_metadata = email.message_from_file(pkg_metadata_file)
except email.errors.MessageError:
with open(filename, 'r') as pkg_metadata_file:
pkg_metadata = email.message_from_file(pkg_metadata_file)
except (IOError, OSError, email.errors.MessageError):
continue
# Check to make sure we're in our own dir

View file

@ -187,7 +187,9 @@ class CapturedSubprocess(fixtures.Fixture):
self.addDetail(self.label + '-stderr', content.text_content(self.err))
self.returncode = proc.returncode
if proc.returncode:
raise AssertionError('Failed process %s' % proc.returncode)
raise AssertionError(
'Failed process args=%r, kwargs=%r, returncode=%s' % (
self.args, self.kwargs, proc.returncode))
self.addCleanup(delattr, self, 'out')
self.addCleanup(delattr, self, 'err')
self.addCleanup(delattr, self, 'returncode')
@ -200,12 +202,15 @@ def _run_cmd(args, cwd):
:param cwd: The directory to run the comamnd in.
:return: ((stdout, stderr), returncode)
"""
print('Running %s' % ' '.join(args))
p = subprocess.Popen(
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, cwd=cwd)
streams = tuple(s.decode('latin1').strip() for s in p.communicate())
for stream_content in streams:
print(stream_content)
print('STDOUT:')
print(streams[0])
print('STDERR:')
print(streams[1])
return (streams) + (p.returncode,)

View file

@ -78,7 +78,7 @@ class TestCommands(base.BaseTestCase):
stdout, stderr, return_code = self.run_pbr('freeze')
self.assertEqual(0, return_code)
pkgs = []
for l in stdout.split('\n'):
pkgs.append(l.split('==')[0].lower())
for line in stdout.split('\n'):
pkgs.append(line.split('==')[0].lower())
pkgs_sort = sorted(pkgs[:])
self.assertEqual(pkgs_sort, pkgs)

View file

@ -40,6 +40,7 @@
import glob
import os
import sys
import tarfile
import fixtures
@ -74,7 +75,7 @@ class TestCore(base.BaseTestCase):
self.run_setup('egg_info')
stdout, _, _ = self.run_setup('--keywords')
assert stdout == 'packaging,distutils,setuptools'
assert stdout == 'packaging, distutils, setuptools'
def test_setup_py_build_sphinx(self):
stdout, _, return_code = self.run_setup('build_sphinx')
@ -113,6 +114,12 @@ class TestCore(base.BaseTestCase):
def test_console_script_develop(self):
"""Test that we develop a non-pkg-resources console script."""
if sys.version_info < (3, 0):
self.skipTest(
'Fails with recent virtualenv due to '
'https://github.com/pypa/virtualenv/issues/1638'
)
if os.name == 'nt':
self.skipTest('Windows support is passthrough')

View file

@ -35,17 +35,31 @@ class FilesConfigTest(base.BaseTestCase):
])
self.useFixture(pkg_fixture)
pkg_etc = os.path.join(pkg_fixture.base, 'etc')
pkg_ansible = os.path.join(pkg_fixture.base, 'ansible',
'kolla-ansible', 'test')
dir_spcs = os.path.join(pkg_fixture.base, 'dir with space')
dir_subdir_spc = os.path.join(pkg_fixture.base, 'multi space',
'more spaces')
pkg_sub = os.path.join(pkg_etc, 'sub')
subpackage = os.path.join(
pkg_fixture.base, 'fake_package', 'subpackage')
os.makedirs(pkg_sub)
os.makedirs(subpackage)
os.makedirs(pkg_ansible)
os.makedirs(dir_spcs)
os.makedirs(dir_subdir_spc)
with open(os.path.join(pkg_etc, "foo"), 'w') as foo_file:
foo_file.write("Foo Data")
with open(os.path.join(pkg_sub, "bar"), 'w') as foo_file:
foo_file.write("Bar Data")
with open(os.path.join(pkg_ansible, "baz"), 'w') as baz_file:
baz_file.write("Baz Data")
with open(os.path.join(subpackage, "__init__.py"), 'w') as foo_file:
foo_file.write("# empty")
with open(os.path.join(dir_spcs, "file with spc"), 'w') as spc_file:
spc_file.write("# empty")
with open(os.path.join(dir_subdir_spc, "file with spc"), 'w') as file_:
file_.write("# empty")
self.useFixture(base.DiveDir(pkg_fixture.base))
@ -74,5 +88,61 @@ class FilesConfigTest(base.BaseTestCase):
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(
'\netc/pbr/ = \n etc/foo\netc/pbr/sub = \n etc/sub/bar',
"\n'etc/pbr/' = \n 'etc/foo'\n'etc/pbr/sub' = \n 'etc/sub/bar'",
config['files']['data_files'])
def test_data_files_with_spaces(self):
config = dict(
files=dict(
data_files="\n 'i like spaces' = 'dir with space'/*"
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(
"\n'i like spaces/' = \n 'dir with space/file with spc'",
config['files']['data_files'])
def test_data_files_with_spaces_subdirectories(self):
# test that we can handle whitespace in subdirectories
data_files = "\n 'one space/two space' = 'multi space/more spaces'/*"
expected = (
"\n'one space/two space/' = "
"\n 'multi space/more spaces/file with spc'")
config = dict(
files=dict(
data_files=data_files
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(expected, config['files']['data_files'])
def test_data_files_with_spaces_quoted_components(self):
# test that we can quote individual path components
data_files = (
"\n'one space'/'two space' = 'multi space'/'more spaces'/*"
)
expected = ("\n'one space/two space/' = "
"\n 'multi space/more spaces/file with spc'")
config = dict(
files=dict(
data_files=data_files
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(expected, config['files']['data_files'])
def test_data_files_globbing_source_prefix_in_directory_name(self):
# We want to test that the string, "docs", is not replaced in a
# subdirectory name, "sub-docs"
config = dict(
files=dict(
data_files="\n share/ansible = ansible/*"
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(
"\n'share/ansible/' = "
"\n'share/ansible/kolla-ansible' = "
"\n'share/ansible/kolla-ansible/test' = "
"\n 'ansible/kolla-ansible/test/baz'",
config['files']['data_files'])

View file

@ -11,7 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import configparser
except ImportError:
import ConfigParser as configparser
import os.path
import pkg_resources
import shlex
import sys
@ -77,19 +82,35 @@ class TestIntegration(base.BaseTestCase):
# We don't break these into separate tests because we'd need separate
# source dirs to isolate from side effects of running pip, and the
# overheads of setup would start to beat the benefits of parallelism.
self.useFixture(base.CapturedSubprocess(
'sync-req',
['python', 'update.py', os.path.join(REPODIR, self.short_name)],
cwd=os.path.join(REPODIR, 'requirements')))
self.useFixture(base.CapturedSubprocess(
'commit-requirements',
'git diff --quiet || git commit -amrequirements',
cwd=os.path.join(REPODIR, self.short_name), shell=True))
path = os.path.join(
self.useFixture(fixtures.TempDir()).path, 'project')
self.useFixture(base.CapturedSubprocess(
'clone',
['git', 'clone', os.path.join(REPODIR, self.short_name), path]))
path = os.path.join(REPODIR, self.short_name)
setup_cfg = os.path.join(path, 'setup.cfg')
project_name = pkg_resources.safe_name(self.short_name).lower()
# These projects should all have setup.cfg files but we'll be careful
if os.path.exists(setup_cfg):
config = configparser.ConfigParser()
config.read(setup_cfg)
if config.has_section('metadata'):
raw_name = config.get('metadata', 'name',
fallback='notapackagename')
# Technically we should really only need to use the raw
# name because all our projects should be good and use
# normalized names but they don't...
project_name = pkg_resources.safe_name(raw_name).lower()
constraints = os.path.join(REPODIR, 'requirements',
'upper-constraints.txt')
tmp_constraints = os.path.join(
self.useFixture(fixtures.TempDir()).path,
'upper-constraints.txt')
# We need to filter out the package we are installing to avoid
# conflicts with the constraints.
with open(constraints, 'r') as src:
with open(tmp_constraints, 'w') as dest:
for line in src:
constraint = line.split('===')[0]
if project_name != constraint:
dest.write(line)
pip_cmd = PIP_CMD + ['-c', tmp_constraints]
venv = self.useFixture(
test_packaging.Venv('sdist',
modules=['pip', 'wheel', PBRVERSION],
@ -105,7 +126,7 @@ class TestIntegration(base.BaseTestCase):
filename = os.path.join(
path, 'dist', os.listdir(os.path.join(path, 'dist'))[0])
self.useFixture(base.CapturedSubprocess(
'tarball', [python] + PIP_CMD + [filename]))
'tarball', [python] + pip_cmd + [filename]))
venv = self.useFixture(
test_packaging.Venv('install-git',
modules=['pip', 'wheel', PBRVERSION],
@ -113,7 +134,7 @@ class TestIntegration(base.BaseTestCase):
root = venv.path
python = venv.python
self.useFixture(base.CapturedSubprocess(
'install-git', [python] + PIP_CMD + ['git+file://' + path]))
'install-git', [python] + pip_cmd + ['git+file://' + path]))
if self.short_name == 'nova':
found = False
for _, _, filenames in os.walk(root):
@ -127,7 +148,7 @@ class TestIntegration(base.BaseTestCase):
root = venv.path
python = venv.python
self.useFixture(base.CapturedSubprocess(
'install-e', [python] + PIP_CMD + ['-e', path]))
'install-e', [python] + pip_cmd + ['-e', path]))
class TestInstallWithoutPbr(base.BaseTestCase):
@ -188,12 +209,16 @@ class TestInstallWithoutPbr(base.BaseTestCase):
class TestMarkersPip(base.BaseTestCase):
scenarios = [
('pip-1.5', {'modules': ['pip>=1.5,<1.6']}),
('pip-6.0', {'modules': ['pip>=6.0,<6.1']}),
('pip-latest', {'modules': ['pip']}),
('setuptools-EL7', {'modules': ['pip==1.4.1', 'setuptools==0.9.8']}),
('setuptools-Trusty', {'modules': ['pip==1.5', 'setuptools==2.2']}),
('setuptools-minimum', {'modules': ['pip==1.5', 'setuptools==0.7.2']}),
('setuptools-Bionic', {
'modules': ['pip==9.0.1', 'setuptools==39.0.1']}),
('setuptools-Stretch', {
'modules': ['pip==9.0.1', 'setuptools==33.1.1']}),
('setuptools-EL8', {'modules': ['pip==9.0.3', 'setuptools==39.2.0']}),
('setuptools-Buster', {
'modules': ['pip==18.1', 'setuptools==40.8.0']}),
('setuptools-Focal', {
'modules': ['pip==20.0.2', 'setuptools==45.2.0']}),
]
@testtools.skipUnless(
@ -240,25 +265,17 @@ class TestLTSSupport(base.BaseTestCase):
# These versions come from the versions installed from the 'virtualenv'
# command from the 'python-virtualenv' package.
scenarios = [
('EL7', {'modules': ['pip==1.4.1', 'setuptools==0.9.8'],
'py3support': True}), # And EPEL6
('Trusty', {'modules': ['pip==1.5', 'setuptools==2.2'],
'py3support': True}),
('Jessie', {'modules': ['pip==1.5.6', 'setuptools==5.5.1'],
'py3support': True}),
# Wheezy has pip1.1, which cannot be called with '-m pip'
# So we'll use a different version of pip here.
('WheezyPrecise', {'modules': ['pip==1.4.1', 'setuptools==0.6c11'],
'py3support': False})
('Bionic', {'modules': ['pip==9.0.1', 'setuptools==39.0.1']}),
('Stretch', {'modules': ['pip==9.0.1', 'setuptools==33.1.1']}),
('EL8', {'modules': ['pip==9.0.3', 'setuptools==39.2.0']}),
('Buster', {'modules': ['pip==18.1', 'setuptools==40.8.0']}),
('Focal', {'modules': ['pip==20.0.2', 'setuptools==45.2.0']}),
]
@testtools.skipUnless(
os.environ.get('PBR_INTEGRATION', None) == '1',
'integration tests not enabled')
def test_lts_venv_default_versions(self):
if (sys.version_info[0] == 3 and not self.py3support):
self.skipTest('This combination will not install with py3, '
'skipping test')
venv = self.useFixture(
test_packaging.Venv('setuptools', modules=self.modules))
bin_python = venv.python

View file

@ -48,7 +48,10 @@ import tempfile
import textwrap
import fixtures
import mock
try:
from unittest import mock
except ImportError:
import mock
import pkg_resources
import six
import testscenarios
@ -108,7 +111,7 @@ class GPGKeyFixture(fixtures.Fixture):
def setUp(self):
super(GPGKeyFixture, self).setUp()
tempdir = self.useFixture(fixtures.TempDir())
gnupg_version_re = re.compile('^gpg\s.*\s([\d+])\.([\d+])\.([\d+])')
gnupg_version_re = re.compile(r'^gpg\s.*\s([\d+])\.([\d+])\.([\d+])')
gnupg_version = base._run_cmd(['gpg', '--version'], tempdir.path)
for line in gnupg_version[0].split('\n'):
gnupg_version = gnupg_version_re.match(line)
@ -120,9 +123,9 @@ class GPGKeyFixture(fixtures.Fixture):
else:
if gnupg_version is None:
gnupg_version = (0, 0, 0)
config_file = tempdir.path + '/key-config'
f = open(config_file, 'wt')
try:
config_file = os.path.join(tempdir.path, 'key-config')
with open(config_file, 'wt') as f:
if gnupg_version[0] == 2 and gnupg_version[1] >= 1:
f.write("""
%no-protection
@ -135,11 +138,9 @@ class GPGKeyFixture(fixtures.Fixture):
Name-Comment: N/A
Name-Email: example@example.com
Expire-Date: 2d
Preferences: (setpref)
%commit
""")
finally:
f.close()
# Note that --quick-random (--debug-quick-random in GnuPG 2.x)
# does not have a corresponding preferences file setting and
# must be passed explicitly on the command line instead
@ -149,6 +150,7 @@ class GPGKeyFixture(fixtures.Fixture):
gnupg_random = '--debug-quick-random'
else:
gnupg_random = ''
base._run_cmd(
['gpg', '--gen-key', '--batch', gnupg_random, config_file],
tempdir.path)
@ -173,17 +175,17 @@ class Venv(fixtures.Fixture):
"""
self._reason = reason
if modules == ():
pbr = 'file://%s#egg=pbr' % PBR_ROOT
modules = ['pip', 'wheel', pbr]
modules = ['pip', 'wheel', 'build', PBR_ROOT]
self.modules = modules
if pip_cmd is None:
self.pip_cmd = ['-m', 'pip', 'install']
self.pip_cmd = ['-m', 'pip', '-v', 'install']
else:
self.pip_cmd = pip_cmd
def _setUp(self):
path = self.useFixture(fixtures.TempDir()).path
virtualenv.create_environment(path, clear=True)
virtualenv.cli_run([path])
python = os.path.join(path, 'bin', 'python')
command = [python] + self.pip_cmd + ['-U']
if self.modules and len(self.modules) > 0:
@ -293,23 +295,23 @@ class TestPackagingInGitRepoWithCommit(base.BaseTestCase):
self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read()
self.assertIn('\*', body)
self.assertIn(r'\*', body)
def test_changelog_handles_dead_links_in_commit(self):
self.repo.commit(message_content="See os_ for to_do about qemu_.")
self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read()
self.assertIn('os\_', body)
self.assertIn('to\_do', body)
self.assertIn('qemu\_', body)
self.assertIn(r'os\_', body)
self.assertIn(r'to\_do', body)
self.assertIn(r'qemu\_', body)
def test_changelog_handles_backticks(self):
self.repo.commit(message_content="Allow `openstack.org` to `work")
self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read()
self.assertIn('\`', body)
self.assertIn(r'\`', body)
def test_manifest_exclude_honoured(self):
self.run_setup('sdist', allow_fail=False)
@ -379,6 +381,12 @@ class TestPackagingWheels(base.BaseTestCase):
wheel_file.extractall(self.extracted_wheel_dir)
wheel_file.close()
def test_metadata_directory_has_pbr_json(self):
# Build the path to the scripts directory
pbr_json = os.path.join(
self.extracted_wheel_dir, 'pbr_testpackage-0.0.dist-info/pbr.json')
self.assertTrue(os.path.exists(pbr_json))
def test_data_directory_has_wsgi_scripts(self):
# Build the path to the scripts directory
scripts_dir = os.path.join(
@ -531,11 +539,13 @@ class ParseRequirementsTest(base.BaseTestCase):
tempdir = tempfile.mkdtemp()
requirements = os.path.join(tempdir, 'requirements.txt')
with open(requirements, 'w') as f:
f.write('-i https://myindex.local')
f.write(' --index-url https://myindex.local')
f.write(' --extra-index-url https://myindex.local')
f.write('-i https://myindex.local\n')
f.write(' --index-url https://myindex.local\n')
f.write(' --extra-index-url https://myindex.local\n')
f.write('--find-links https://myindex.local\n')
f.write('arequirement>=1.0\n')
result = packaging.parse_requirements([requirements])
self.assertEqual([], result)
self.assertEqual(['arequirement>=1.0'], result)
def test_nested_requirements(self):
tempdir = tempfile.mkdtemp()
@ -662,12 +672,65 @@ class TestVersions(base.BaseTestCase):
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_multi_inline_symbols_no_space(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('Sem-ver: feature,api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_multi_inline_symbols_spaced(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('Sem-ver: feature, api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_multi_inline_symbols_reversed(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('Sem-ver: api-break,feature')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_leading_space(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit(' sem-ver: api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_leading_space_multiline(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit(
(
' Some cool text\n'
' sem-ver: api-break'
)
)
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_leading_characters_symbol_not_found(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit(' ssem-ver: api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.2.4.dev1'))
def test_tagged_version_has_tag_version(self):
self.repo.commit()
self.repo.tag('1.2.3')
version = packaging._get_version_from_git('1.2.3')
self.assertEqual('1.2.3', version)
def test_tagged_version_with_semver_compliant_prerelease(self):
self.repo.commit()
self.repo.tag('1.2.3-rc2')
version = packaging._get_version_from_git()
self.assertEqual('1.2.3.0rc2', version)
def test_non_canonical_tagged_version_bump(self):
self.repo.commit()
self.repo.tag('1.4')
@ -724,6 +787,13 @@ class TestVersions(base.BaseTestCase):
version = packaging._get_version_from_git('1.2.3')
self.assertThat(version, matchers.StartsWith('1.2.3.0a2.dev1'))
def test_untagged_version_after_semver_compliant_prerelease_tag(self):
self.repo.commit()
self.repo.tag('1.2.3-rc2')
self.repo.commit()
version = packaging._get_version_from_git()
self.assertEqual('1.2.3.0rc3.dev1', version)
def test_preversion_too_low_simple(self):
# That is, the target version is either already released or not high
# enough for the semver requirements given api breaks etc.
@ -750,8 +820,10 @@ class TestVersions(base.BaseTestCase):
def test_get_kwargs_corner_cases(self):
# No tags:
git_dir = self.repo._basedir + '/.git'
get_kwargs = lambda tag: packaging._get_increment_kwargs(git_dir, tag)
def get_kwargs(tag):
git_dir = self.repo._basedir + '/.git'
return packaging._get_increment_kwargs(git_dir, tag)
def _check_combinations(tag):
self.repo.commit()
@ -903,6 +975,235 @@ class TestRequirementParsing(base.BaseTestCase):
self.assertEqual(exp_parsed, gen_parsed)
class TestPEP517Support(base.BaseTestCase):
def test_pep_517_support(self):
# Note that the current PBR PEP517 entrypoints rely on a valid
# PBR setup.py existing.
pkgs = {
'test_pep517':
{
'requirements.txt': textwrap.dedent("""\
sphinx
iso8601
"""),
# Override default setup.py to remove setup_requires.
'setup.py': textwrap.dedent("""\
#!/usr/bin/env python
import setuptools
setuptools.setup(pbr=True)
"""),
'setup.cfg': textwrap.dedent("""\
[metadata]
name = test_pep517
summary = A tiny test project
author = PBR Team
author-email = foo@example.com
home-page = https://example.com/
classifier =
Intended Audience :: Information Technology
Intended Audience :: System Administrators
License :: OSI Approved :: Apache Software License
Operating System :: POSIX :: Linux
Programming Language :: Python
Programming Language :: Python :: 2
Programming Language :: Python :: 2.7
Programming Language :: Python :: 3
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
"""),
'pyproject.toml': textwrap.dedent("""\
[build-system]
requires = ["pbr", "setuptools>=36.6.0", "wheel"]
build-backend = "pbr.build"
""")},
}
pkg_dirs = self.useFixture(CreatePackages(pkgs)).package_dirs
pkg_dir = pkg_dirs['test_pep517']
venv = self.useFixture(Venv('PEP517'))
# Test building sdists and wheels works. Note we do not use pip here
# because pip will forcefully install the latest version of PBR on
# pypi to satisfy the build-system requires. This means we can't self
# test changes using pip. Build with --no-isolation appears to avoid
# this problem.
self._run_cmd(venv.python, ('-m', 'build', '--no-isolation', '.'),
allow_fail=False, cwd=pkg_dir)
class TestRepositoryURLDependencies(base.BaseTestCase):
def setUp(self):
super(TestRepositoryURLDependencies, self).setUp()
self.requirements = os.path.join(tempfile.mkdtemp(),
'requirements.txt')
with open(self.requirements, 'w') as f:
f.write('\n'.join([
'-e git+git://git.pro-ject.org/oslo.messaging#egg=oslo.messaging-1.0.0-rc', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize-beta', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-4.0.1', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-1.0.0-alpha.beta.1', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-2.0.0-rc.1+build.123', # noqa
'-e git+git://git.project.org/Proj#egg=Proj1',
'git+https://git.project.org/Proj#egg=Proj2-0.0.1',
'-e git+ssh://git.project.org/Proj#egg=Proj3',
'svn+svn://svn.project.org/svn/Proj#egg=Proj4-0.0.2',
'-e svn+http://svn.project.org/svn/Proj/trunk@2019#egg=Proj5',
'hg+http://hg.project.org/Proj@da39a3ee5e6b#egg=Proj-0.0.3',
'-e hg+http://hg.project.org/Proj@2019#egg=Proj',
'hg+http://hg.project.org/Proj@v1.0#egg=Proj-0.0.4',
'-e hg+http://hg.project.org/Proj@special_feature#egg=Proj',
'git://foo.com/zipball#egg=foo-bar-1.2.4',
'pypi-proj1', 'pypi-proj2']))
def test_egg_fragment(self):
expected = [
'django-thumborize',
'django-thumborize-beta',
'django-thumborize2-beta',
'django-thumborize2-beta>=4.0.1',
'django-thumborize2-beta>=1.0.0-alpha.beta.1',
'django-thumborize2-beta>=1.0.0-alpha-a.b-c-long+build.1-aef.1-its-okay', # noqa
'django-thumborize2-beta>=2.0.0-rc.1+build.123',
'django-thumborize-beta>=0.0.4',
'django-thumborize-beta>=1.2.3',
'django-thumborize-beta>=10.20.30',
'django-thumborize-beta>=1.1.2-prerelease+meta',
'django-thumborize-beta>=1.1.2+meta',
'django-thumborize-beta>=1.1.2+meta-valid',
'django-thumborize-beta>=1.0.0-alpha',
'django-thumborize-beta>=1.0.0-beta',
'django-thumborize-beta>=1.0.0-alpha.beta',
'django-thumborize-beta>=1.0.0-alpha.beta.1',
'django-thumborize-beta>=1.0.0-alpha.1',
'django-thumborize-beta>=1.0.0-alpha0.valid',
'django-thumborize-beta>=1.0.0-alpha.0valid',
'django-thumborize-beta>=1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'django-thumborize-beta>=1.0.0-rc.1+build.1',
'django-thumborize-beta>=2.0.0-rc.1+build.123',
'django-thumborize-beta>=1.2.3-beta',
'django-thumborize-beta>=10.2.3-DEV-SNAPSHOT',
'django-thumborize-beta>=1.2.3-SNAPSHOT-123',
'django-thumborize-beta>=1.0.0',
'django-thumborize-beta>=2.0.0',
'django-thumborize-beta>=1.1.7',
'django-thumborize-beta>=2.0.0+build.1848',
'django-thumborize-beta>=2.0.1-alpha.1227',
'django-thumborize-beta>=1.0.0-alpha+beta',
'django-thumborize-beta>=1.2.3----RC-SNAPSHOT.12.9.1--.12+788',
'django-thumborize-beta>=1.2.3----R-S.12.9.1--.12+meta',
'django-thumborize-beta>=1.2.3----RC-SNAPSHOT.12.9.1--.12',
'django-thumborize-beta>=1.0.0+0.build.1-rc.10000aaa-kk-0.1',
'django-thumborize-beta>=999999999999999999.99999999999999.9999999999999', # noqa
'Proj1',
'Proj2>=0.0.1',
'Proj3',
'Proj4>=0.0.2',
'Proj5',
'Proj>=0.0.3',
'Proj',
'Proj>=0.0.4',
'Proj',
'foo-bar>=1.2.4',
]
tests = [
'egg=django-thumborize',
'egg=django-thumborize-beta',
'egg=django-thumborize2-beta',
'egg=django-thumborize2-beta-4.0.1',
'egg=django-thumborize2-beta-1.0.0-alpha.beta.1',
'egg=django-thumborize2-beta-1.0.0-alpha-a.b-c-long+build.1-aef.1-its-okay', # noqa
'egg=django-thumborize2-beta-2.0.0-rc.1+build.123',
'egg=django-thumborize-beta-0.0.4',
'egg=django-thumborize-beta-1.2.3',
'egg=django-thumborize-beta-10.20.30',
'egg=django-thumborize-beta-1.1.2-prerelease+meta',
'egg=django-thumborize-beta-1.1.2+meta',
'egg=django-thumborize-beta-1.1.2+meta-valid',
'egg=django-thumborize-beta-1.0.0-alpha',
'egg=django-thumborize-beta-1.0.0-beta',
'egg=django-thumborize-beta-1.0.0-alpha.beta',
'egg=django-thumborize-beta-1.0.0-alpha.beta.1',
'egg=django-thumborize-beta-1.0.0-alpha.1',
'egg=django-thumborize-beta-1.0.0-alpha0.valid',
'egg=django-thumborize-beta-1.0.0-alpha.0valid',
'egg=django-thumborize-beta-1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'egg=django-thumborize-beta-1.0.0-rc.1+build.1',
'egg=django-thumborize-beta-2.0.0-rc.1+build.123',
'egg=django-thumborize-beta-1.2.3-beta',
'egg=django-thumborize-beta-10.2.3-DEV-SNAPSHOT',
'egg=django-thumborize-beta-1.2.3-SNAPSHOT-123',
'egg=django-thumborize-beta-1.0.0',
'egg=django-thumborize-beta-2.0.0',
'egg=django-thumborize-beta-1.1.7',
'egg=django-thumborize-beta-2.0.0+build.1848',
'egg=django-thumborize-beta-2.0.1-alpha.1227',
'egg=django-thumborize-beta-1.0.0-alpha+beta',
'egg=django-thumborize-beta-1.2.3----RC-SNAPSHOT.12.9.1--.12+788', # noqa
'egg=django-thumborize-beta-1.2.3----R-S.12.9.1--.12+meta',
'egg=django-thumborize-beta-1.2.3----RC-SNAPSHOT.12.9.1--.12',
'egg=django-thumborize-beta-1.0.0+0.build.1-rc.10000aaa-kk-0.1', # noqa
'egg=django-thumborize-beta-999999999999999999.99999999999999.9999999999999', # noqa
'egg=Proj1',
'egg=Proj2-0.0.1',
'egg=Proj3',
'egg=Proj4-0.0.2',
'egg=Proj5',
'egg=Proj-0.0.3',
'egg=Proj',
'egg=Proj-0.0.4',
'egg=Proj',
'egg=foo-bar-1.2.4',
]
for index, test in enumerate(tests):
self.assertEqual(expected[index],
re.sub(r'egg=([^&]+).*$',
packaging.egg_fragment,
test))
def test_parse_repo_url_requirements(self):
result = packaging.parse_requirements([self.requirements])
self.assertEqual(['oslo.messaging>=1.0.0-rc',
'django-thumborize',
'django-thumborize-beta',
'django-thumborize2-beta',
'django-thumborize2-beta>=4.0.1',
'django-thumborize2-beta>=1.0.0-alpha.beta.1',
'django-thumborize2-beta>=1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'django-thumborize2-beta>=2.0.0-rc.1+build.123',
'Proj1', 'Proj2>=0.0.1', 'Proj3',
'Proj4>=0.0.2', 'Proj5', 'Proj>=0.0.3',
'Proj', 'Proj>=0.0.4', 'Proj',
'foo-bar>=1.2.4', 'pypi-proj1',
'pypi-proj2'], result)
def test_parse_repo_url_dependency_links(self):
result = packaging.parse_dependency_links([self.requirements])
self.assertEqual(
[
'git+git://git.pro-ject.org/oslo.messaging#egg=oslo.messaging-1.0.0-rc', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize-beta', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-4.0.1', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-1.0.0-alpha.beta.1', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-2.0.0-rc.1+build.123', # noqa
'git+git://git.project.org/Proj#egg=Proj1',
'git+https://git.project.org/Proj#egg=Proj2-0.0.1',
'git+ssh://git.project.org/Proj#egg=Proj3',
'svn+svn://svn.project.org/svn/Proj#egg=Proj4-0.0.2',
'svn+http://svn.project.org/svn/Proj/trunk@2019#egg=Proj5',
'hg+http://hg.project.org/Proj@da39a3ee5e6b#egg=Proj-0.0.3',
'hg+http://hg.project.org/Proj@2019#egg=Proj',
'hg+http://hg.project.org/Proj@v1.0#egg=Proj-0.0.4',
'hg+http://hg.project.org/Proj@special_feature#egg=Proj',
'git://foo.com/zipball#egg=foo-bar-1.2.4'], result)
def get_soabi():
soabi = None
try:

View file

@ -10,7 +10,10 @@
# License for the specific language governing permissions and limitations
# under the License.
import mock
try:
from unittest import mock
except ImportError:
import mock
from pbr import pbr_json
from pbr.tests import base

View file

@ -93,8 +93,9 @@ class SkipFileWrites(base.BaseTestCase):
option_dict=self.option_dict)
self.assertEqual(
not os.path.exists(self.filename),
(self.option_value.lower() in options.TRUE_VALUES
or self.env_value is not None))
(self.option_value.lower() in options.TRUE_VALUES or
self.env_value is not None))
_changelog_content = """7780758\x00Break parser\x00 (tag: refs/tags/1_foo.1)
04316fe\x00Make python\x00 (refs/heads/review/monty_taylor/27519)
@ -125,6 +126,7 @@ def _make_old_git_changelog_format(line):
refname = refname.replace('tag: ', '')
return '\x00'.join((sha, msg, refname))
_old_git_changelog_content = '\n'.join(
_make_old_git_changelog_format(line)
for line in _changelog_content.split('\n'))
@ -162,7 +164,7 @@ class GitLogsTest(base.BaseTestCase):
self.assertIn("------", changelog_contents)
self.assertIn("Refactor hooks file", changelog_contents)
self.assertIn(
"Bug fix: create\_stack() fails when waiting",
r"Bug fix: create\_stack() fails when waiting",
changelog_contents)
self.assertNotIn("Refactor hooks file.", changelog_contents)
self.assertNotIn("182feb3", changelog_contents)
@ -176,7 +178,7 @@ class GitLogsTest(base.BaseTestCase):
self.assertNotIn("ev)il", changelog_contents)
self.assertNotIn("e(vi)l", changelog_contents)
self.assertNotIn('Merge "', changelog_contents)
self.assertNotIn('1\_foo.1', changelog_contents)
self.assertNotIn(r'1\_foo.1', changelog_contents)
def test_generate_authors(self):
author_old = u"Foo Foo <email@foo.com>"
@ -216,9 +218,9 @@ class GitLogsTest(base.BaseTestCase):
with open(os.path.join(self.temp_path, "AUTHORS"), "r") as auth_fh:
authors = auth_fh.read()
self.assertTrue(author_old in authors)
self.assertTrue(author_new in authors)
self.assertTrue(co_author in authors)
self.assertIn(author_old, authors)
self.assertIn(author_new, authors)
self.assertIn(co_author, authors)
class _SphinxConfig(object):

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2015 Hewlett-Packard Development Company, L.P. (HP)
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
@ -13,6 +14,7 @@
# under the License.
import io
import tempfile
import textwrap
import six
@ -23,6 +25,122 @@ from pbr.tests import base
from pbr import util
def config_from_ini(ini):
config = {}
ini = textwrap.dedent(six.u(ini))
if sys.version_info >= (3, 2):
parser = configparser.ConfigParser()
parser.read_file(io.StringIO(ini))
else:
parser = configparser.SafeConfigParser()
parser.readfp(io.StringIO(ini))
for section in parser.sections():
config[section] = dict(parser.items(section))
return config
class TestBasics(base.BaseTestCase):
def test_basics(self):
self.maxDiff = None
config_text = """
[metadata]
name = foo
version = 1.0
author = John Doe
author_email = jd@example.com
maintainer = Jim Burke
maintainer_email = jb@example.com
home_page = http://example.com
summary = A foobar project.
description = Hello, world. This is a long description.
download_url = http://opendev.org/x/pbr
classifier =
Development Status :: 5 - Production/Stable
Programming Language :: Python
platform =
any
license = Apache 2.0
requires_dist =
Sphinx
requests
setup_requires_dist =
docutils
python_requires = >=3.6
provides_dist =
bax
provides_extras =
bar
obsoletes_dist =
baz
[files]
packages_root = src
packages =
foo
package_data =
"" = *.txt, *.rst
foo = *.msg
namespace_packages =
hello
data_files =
bitmaps =
bm/b1.gif
bm/b2.gif
config =
cfg/data.cfg
scripts =
scripts/hello-world.py
modules =
mod1
"""
expected = {
'name': u'foo',
'version': u'1.0',
'author': u'John Doe',
'author_email': u'jd@example.com',
'maintainer': u'Jim Burke',
'maintainer_email': u'jb@example.com',
'url': u'http://example.com',
'description': u'A foobar project.',
'long_description': u'Hello, world. This is a long description.',
'download_url': u'http://opendev.org/x/pbr',
'classifiers': [
u'Development Status :: 5 - Production/Stable',
u'Programming Language :: Python',
],
'platforms': [u'any'],
'license': u'Apache 2.0',
'install_requires': [
u'Sphinx',
u'requests',
],
'setup_requires': [u'docutils'],
'python_requires': u'>=3.6',
'provides': [u'bax'],
'provides_extras': [u'bar'],
'obsoletes': [u'baz'],
'extras_require': {},
'package_dir': {'': u'src'},
'packages': [u'foo'],
'package_data': {
'': ['*.txt,', '*.rst'],
'foo': ['*.msg'],
},
'namespace_packages': [u'hello'],
'data_files': [
('bitmaps', ['bm/b1.gif', 'bm/b2.gif']),
('config', ['cfg/data.cfg']),
],
'scripts': [u'scripts/hello-world.py'],
'py_modules': [u'mod1'],
}
config = config_from_ini(config_text)
actual = util.setup_cfg_to_setup_kwargs(config)
self.assertDictEqual(expected, actual)
class TestExtrasRequireParsingScenarios(base.BaseTestCase):
scenarios = [
@ -64,20 +182,8 @@ class TestExtrasRequireParsingScenarios(base.BaseTestCase):
{}
})]
def config_from_ini(self, ini):
config = {}
if sys.version_info >= (3, 2):
parser = configparser.ConfigParser()
else:
parser = configparser.SafeConfigParser()
ini = textwrap.dedent(six.u(ini))
parser.readfp(io.StringIO(ini))
for section in parser.sections():
config[section] = dict(parser.items(section))
return config
def test_extras_parsing(self):
config = self.config_from_ini(self.config_text)
config = config_from_ini(self.config_text)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(self.expected_extra_requires,
@ -89,3 +195,127 @@ class TestInvalidMarkers(base.BaseTestCase):
def test_invalid_marker_raises_error(self):
config = {'extras': {'test': "foo :bad_marker>'1.0'"}}
self.assertRaises(SyntaxError, util.setup_cfg_to_setup_kwargs, config)
class TestMapFieldsParsingScenarios(base.BaseTestCase):
scenarios = [
('simple_project_urls', {
'config_text': """
[metadata]
project_urls =
Bug Tracker = https://bugs.launchpad.net/pbr/
Documentation = https://docs.openstack.org/pbr/
Source Code = https://opendev.org/openstack/pbr
""", # noqa: E501
'expected_project_urls': {
'Bug Tracker': 'https://bugs.launchpad.net/pbr/',
'Documentation': 'https://docs.openstack.org/pbr/',
'Source Code': 'https://opendev.org/openstack/pbr',
},
}),
('query_parameters', {
'config_text': """
[metadata]
project_urls =
Bug Tracker = https://bugs.launchpad.net/pbr/?query=true
Documentation = https://docs.openstack.org/pbr/?foo=bar
Source Code = https://git.openstack.org/cgit/openstack-dev/pbr/commit/?id=hash
""", # noqa: E501
'expected_project_urls': {
'Bug Tracker': 'https://bugs.launchpad.net/pbr/?query=true',
'Documentation': 'https://docs.openstack.org/pbr/?foo=bar',
'Source Code': 'https://git.openstack.org/cgit/openstack-dev/pbr/commit/?id=hash', # noqa: E501
},
}),
]
def test_project_url_parsing(self):
config = config_from_ini(self.config_text)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(self.expected_project_urls, kwargs['project_urls'])
class TestKeywordsParsingScenarios(base.BaseTestCase):
scenarios = [
('keywords_list', {
'config_text': """
[metadata]
keywords =
one
two
three
""", # noqa: E501
'expected_keywords': ['one', 'two', 'three'],
},
),
('inline_keywords', {
'config_text': """
[metadata]
keywords = one, two, three
""", # noqa: E501
'expected_keywords': ['one, two, three'],
}),
]
def test_keywords_parsing(self):
config = config_from_ini(self.config_text)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(self.expected_keywords, kwargs['keywords'])
class TestProvidesExtras(base.BaseTestCase):
def test_provides_extras(self):
ini = """
[metadata]
provides_extras = foo
bar
"""
config = config_from_ini(ini)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(['foo', 'bar'], kwargs['provides_extras'])
class TestDataFilesParsing(base.BaseTestCase):
scenarios = [
('data_files', {
'config_text': """
[files]
data_files =
'i like spaces/' =
'dir with space/file with spc 2'
'dir with space/file with spc 1'
""",
'data_files': [
('i like spaces/', ['dir with space/file with spc 2',
'dir with space/file with spc 1'])
]
})]
def test_handling_of_whitespace_in_data_files(self):
config = config_from_ini(self.config_text)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(self.data_files, kwargs['data_files'])
class TestUTF8DescriptionFile(base.BaseTestCase):
def test_utf8_description_file(self):
_, path = tempfile.mkstemp()
ini_template = """
[metadata]
description_file = %s
"""
# Two \n's because pbr strips the file content and adds \n\n
# This way we can use it directly as the assert comparison
unicode_description = u'UTF8 description: é"…-ʃŋ\'\n\n'
ini = ini_template % path
with io.open(path, 'w', encoding='utf8') as f:
f.write(unicode_description)
config = config_from_ini(ini)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(unicode_description, kwargs['long_description'])

View file

@ -77,8 +77,8 @@ class TestWsgiScripts(base.BaseTestCase):
def _test_wsgi(self, cmd_name, output, extra_args=None):
cmd = os.path.join(self.temp_dir, 'bin', cmd_name)
print("Running %s -p 0" % cmd)
popen_cmd = [cmd, '-p', '0']
print("Running %s -p 0 -b 127.0.0.1" % cmd)
popen_cmd = [cmd, '-p', '0', '-b', '127.0.0.1']
if extra_args:
popen_cmd.extend(extra_args)
@ -98,7 +98,7 @@ class TestWsgiScripts(base.BaseTestCase):
stdoutdata = p.stdout.readline() # Available at ...
print(stdoutdata)
m = re.search(b'(http://[^:]+:\d+)/', stdoutdata)
m = re.search(br'(http://[^:]+:\d+)/', stdoutdata)
self.assertIsNotNone(m, "Regex failed to match on %s" % stdoutdata)
stdoutdata = p.stdout.readline() # DANGER! ...

View file

@ -12,17 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
sys.path.insert(0, os.path.abspath('../..'))
# -- General configuration ----------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = [
'sphinx.ext.autodoc',
#'sphinx.ext.intersphinx',
]
# autodoc generation is a bit aggressive and a nuisance when doing heavy
@ -49,17 +45,9 @@ add_module_names = True
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# -- Options for HTML output --------------------------------------------------
# The theme to use for HTML and HTML Help pages. Major themes that come with
# Sphinx are currently 'default' and 'sphinxdoc'.
# html_theme_path = ["."]
# html_theme = '_theme'
# html_static_path = ['static']
# Output file base name for HTML help builder.
htmlhelp_basename = '%sdoc' % project
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass
# [howto/manual]).
@ -69,6 +57,3 @@ latex_documents = [
u'%s Documentation' % project,
u'OpenStack Foundation', 'manual'),
]
# Example configuration for intersphinx: refer to the Python standard library.
#intersphinx_mapping = {'http://docs.python.org/': None}

View file

@ -53,9 +53,9 @@ except ImportError:
@contextlib.contextmanager
def open_config(filename):
if sys.version_info >= (3, 2):
cfg = configparser.ConfigParser()
cfg = configparser.ConfigParser()
else:
cfg = configparser.SafeConfigParser()
cfg = configparser.SafeConfigParser()
cfg.read(filename)
yield cfg
with open(filename, 'w') as fp:

View file

@ -62,8 +62,10 @@ except ImportError:
import logging # noqa
from collections import defaultdict
import io
import os
import re
import shlex
import sys
import traceback
@ -86,50 +88,52 @@ import pbr.hooks
# predicates in ()
_VERSION_SPEC_RE = re.compile(r'\s*(.*?)\s*\((.*)\)\s*$')
# Mappings from setup() keyword arguments to setup.cfg options;
# The values are (section, option) tuples, or simply (section,) tuples if
# the option has the same name as the setup() argument
D1_D2_SETUP_ARGS = {
"name": ("metadata",),
"version": ("metadata",),
"author": ("metadata",),
"author_email": ("metadata",),
"maintainer": ("metadata",),
"maintainer_email": ("metadata",),
"url": ("metadata", "home_page"),
"project_urls": ("metadata",),
"description": ("metadata", "summary"),
"keywords": ("metadata",),
"long_description": ("metadata", "description"),
"long_description_content_type": ("metadata", "description_content_type"),
"download_url": ("metadata",),
"classifiers": ("metadata", "classifier"),
"platforms": ("metadata", "platform"), # **
"license": ("metadata",),
# Mappings from setup.cfg options, in (section, option) form, to setup()
# keyword arguments
CFG_TO_PY_SETUP_ARGS = (
(('metadata', 'name'), 'name'),
(('metadata', 'version'), 'version'),
(('metadata', 'author'), 'author'),
(('metadata', 'author_email'), 'author_email'),
(('metadata', 'maintainer'), 'maintainer'),
(('metadata', 'maintainer_email'), 'maintainer_email'),
(('metadata', 'home_page'), 'url'),
(('metadata', 'project_urls'), 'project_urls'),
(('metadata', 'summary'), 'description'),
(('metadata', 'keywords'), 'keywords'),
(('metadata', 'description'), 'long_description'),
(
('metadata', 'description_content_type'),
'long_description_content_type',
),
(('metadata', 'download_url'), 'download_url'),
(('metadata', 'classifier'), 'classifiers'),
(('metadata', 'platform'), 'platforms'), # **
(('metadata', 'license'), 'license'),
# Use setuptools install_requires, not
# broken distutils requires
"install_requires": ("metadata", "requires_dist"),
"setup_requires": ("metadata", "setup_requires_dist"),
"python_requires": ("metadata",),
"provides": ("metadata", "provides_dist"), # **
"obsoletes": ("metadata", "obsoletes_dist"), # **
"package_dir": ("files", 'packages_root'),
"packages": ("files",),
"package_data": ("files",),
"namespace_packages": ("files",),
"data_files": ("files",),
"scripts": ("files",),
"py_modules": ("files", "modules"), # **
"cmdclass": ("global", "commands"),
(('metadata', 'requires_dist'), 'install_requires'),
(('metadata', 'setup_requires_dist'), 'setup_requires'),
(('metadata', 'python_requires'), 'python_requires'),
(('metadata', 'requires_python'), 'python_requires'),
(('metadata', 'provides_dist'), 'provides'), # **
(('metadata', 'provides_extras'), 'provides_extras'),
(('metadata', 'obsoletes_dist'), 'obsoletes'), # **
(('files', 'packages_root'), 'package_dir'),
(('files', 'packages'), 'packages'),
(('files', 'package_data'), 'package_data'),
(('files', 'namespace_packages'), 'namespace_packages'),
(('files', 'data_files'), 'data_files'),
(('files', 'scripts'), 'scripts'),
(('files', 'modules'), 'py_modules'), # **
(('global', 'commands'), 'cmdclass'),
# Not supported in distutils2, but provided for
# backwards compatibility with setuptools
"use_2to3": ("backwards_compat", "use_2to3"),
"zip_safe": ("backwards_compat", "zip_safe"),
"tests_require": ("backwards_compat", "tests_require"),
"dependency_links": ("backwards_compat",),
"include_package_data": ("backwards_compat",),
}
(('backwards_compat', 'zip_safe'), 'zip_safe'),
(('backwards_compat', 'tests_require'), 'tests_require'),
(('backwards_compat', 'dependency_links'), 'dependency_links'),
(('backwards_compat', 'include_package_data'), 'include_package_data'),
)
# setup() arguments that can have multiple values in setup.cfg
MULTI_FIELDS = ("classifiers",
@ -146,16 +150,27 @@ MULTI_FIELDS = ("classifiers",
"dependency_links",
"setup_requires",
"tests_require",
"cmdclass")
"keywords",
"cmdclass",
"provides_extras")
# setup() arguments that can have mapping values in setup.cfg
MAP_FIELDS = ("project_urls",)
# setup() arguments that contain boolean values
BOOL_FIELDS = ("use_2to3", "zip_safe", "include_package_data")
BOOL_FIELDS = ("zip_safe", "include_package_data")
CSV_FIELDS = ()
CSV_FIELDS = ("keywords",)
def shlex_split(path):
if os.name == 'nt':
# shlex cannot handle paths that contain backslashes, treating those
# as escape characters.
path = path.replace("\\", "/")
return [x.replace("/", "\\") for x in shlex.split(path)]
return shlex.split(path)
def resolve_name(name):
@ -205,10 +220,11 @@ def cfg_to_args(path='setup.cfg', script_args=()):
"""
# The method source code really starts here.
if sys.version_info >= (3, 2):
parser = configparser.ConfigParser()
if sys.version_info >= (3, 0):
parser = configparser.ConfigParser()
else:
parser = configparser.SafeConfigParser()
parser = configparser.SafeConfigParser()
if not os.path.exists(path):
raise errors.DistutilsFileError("file '%s' does not exist" %
os.path.abspath(path))
@ -297,34 +313,25 @@ def setup_cfg_to_setup_kwargs(config, script_args=()):
# parse env_markers.
all_requirements = {}
for arg in D1_D2_SETUP_ARGS:
if len(D1_D2_SETUP_ARGS[arg]) == 2:
# The distutils field name is different than distutils2's.
section, option = D1_D2_SETUP_ARGS[arg]
elif len(D1_D2_SETUP_ARGS[arg]) == 1:
# The distutils field name is the same thant distutils2's.
section = D1_D2_SETUP_ARGS[arg][0]
option = arg
for alias, arg in CFG_TO_PY_SETUP_ARGS:
section, option = alias
in_cfg_value = has_get_option(config, section, option)
if not in_cfg_value and arg == "long_description":
in_cfg_value = has_get_option(config, section, "description_file")
if in_cfg_value:
in_cfg_value = split_multiline(in_cfg_value)
value = ''
for filename in in_cfg_value:
description_file = io.open(filename, encoding='utf-8')
try:
value += description_file.read().strip() + '\n\n'
finally:
description_file.close()
in_cfg_value = value
if not in_cfg_value:
# There is no such option in the setup.cfg
if arg == "long_description":
in_cfg_value = has_get_option(config, section,
"description_file")
if in_cfg_value:
in_cfg_value = split_multiline(in_cfg_value)
value = ''
for filename in in_cfg_value:
description_file = open(filename)
try:
value += description_file.read().strip() + '\n\n'
finally:
description_file.close()
in_cfg_value = value
else:
continue
continue
if arg in CSV_FIELDS:
in_cfg_value = split_csv(in_cfg_value)
@ -333,7 +340,7 @@ def setup_cfg_to_setup_kwargs(config, script_args=()):
elif arg in MAP_FIELDS:
in_cfg_map = {}
for i in split_multiline(in_cfg_value):
k, v = i.split('=')
k, v = i.split('=', 1)
in_cfg_map[k.strip()] = v.strip()
in_cfg_value = in_cfg_map
elif arg in BOOL_FIELDS:
@ -370,26 +377,27 @@ def setup_cfg_to_setup_kwargs(config, script_args=()):
for line in in_cfg_value:
if '=' in line:
key, value = line.split('=', 1)
key, value = (key.strip(), value.strip())
key_unquoted = shlex_split(key.strip())[0]
key, value = (key_unquoted, value.strip())
if key in data_files:
# Multiple duplicates of the same package name;
# this is for backwards compatibility of the old
# format prior to d2to1 0.2.6.
prev = data_files[key]
prev.extend(value.split())
prev.extend(shlex_split(value))
else:
prev = data_files[key.strip()] = value.split()
prev = data_files[key.strip()] = shlex_split(value)
elif firstline:
raise errors.DistutilsOptionError(
'malformed package_data first line %r (misses '
'"=")' % line)
else:
prev.extend(line.strip().split())
prev.extend(shlex_split(line.strip()))
firstline = False
if arg == 'data_files':
# the data_files value is a pointlessly different structure
# from the package_data value
data_files = data_files.items()
data_files = sorted(data_files.items())
in_cfg_value = data_files
elif arg == 'cmdclass':
cmdclass = {}
@ -532,7 +540,7 @@ def get_extension_modules(config):
else:
# Backwards compatibility for old syntax; don't use this though
labels = section.split('=', 1)
labels = [l.strip() for l in labels]
labels = [label.strip() for label in labels]
if (len(labels) == 2) and (labels[0] == 'extension'):
ext_args = {}
for field in EXTENSION_FIELDS:

View file

@ -15,13 +15,24 @@
# under the License.
"""
Utilities for consuming the version from pkg_resources.
Utilities for consuming the version from importlib-metadata.
"""
import itertools
import operator
import sys
# TODO(stephenfin): Remove this once we drop support for Python < 3.8
if sys.version_info >= (3, 8):
from importlib import metadata as importlib_metadata
use_importlib = True
else:
try:
import importlib_metadata
use_importlib = True
except ImportError:
use_importlib = False
def _is_int(string):
try:
@ -323,8 +334,8 @@ class SemanticVersion(object):
version number of the component to preserve sorting. (Used for
rpm support)
"""
if ((self._prerelease_type or self._dev_count)
and pre_separator is None):
if ((self._prerelease_type or self._dev_count) and
pre_separator is None):
segments = [self.decrement().brief_string()]
pre_separator = "."
else:
@ -431,12 +442,15 @@ class VersionInfo(object):
"""Obtain a version from pkg_resources or setup-time logic if missing.
This will try to get the version of the package from the pkg_resources
This will try to get the version of the package from the
record associated with the package, and if there is no such record
importlib_metadata record associated with the package, and if there
falls back to the logic sdist would use.
is no such record falls back to the logic sdist would use.
"""
# Lazy import because pkg_resources is costly to import so defer until
# we absolutely need it.
import pkg_resources
try:
requirement = pkg_resources.Requirement.parse(self.package)
provider = pkg_resources.get_provider(requirement)
@ -447,6 +461,25 @@ class VersionInfo(object):
# installed into anything. Revert to setup-time logic.
from pbr import packaging
result_string = packaging.get_version(self.package)
return SemanticVersion.from_pip_string(result_string)
def _get_version_from_importlib_metadata(self):
"""Obtain a version from importlib or setup-time logic if missing.
This will try to get the version of the package from the
importlib_metadata record associated with the package, and if there
is no such record falls back to the logic sdist would use.
"""
try:
distribution = importlib_metadata.distribution(self.package)
result_string = distribution.version
except importlib_metadata.PackageNotFoundError:
# The most likely cause for this is running tests in a tree
# produced from a tarball where the package itself has not been
# installed into anything. Revert to setup-time logic.
from pbr import packaging
result_string = packaging.get_version(self.package)
return SemanticVersion.from_pip_string(result_string)
def release_string(self):
@ -459,7 +492,12 @@ class VersionInfo(object):
def semantic_version(self):
"""Return the SemanticVersion object for this version."""
if self._semantic is None:
self._semantic = self._get_version_from_pkg_resources()
# TODO(damami): simplify this once Python 3.8 is the oldest
# we support
if use_importlib:
self._semantic = self._get_version_from_importlib_metadata()
else:
self._semantic = self._get_version_from_pkg_resources()
return self._semantic
def version_string(self):

View file

@ -1,6 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# pylint: disable-all
from __future__ import print_function
import os
import re
@ -132,9 +133,14 @@ class SubRipShifter(object):
def run(self, args):
self.arguments = self.build_parser().parse_args(args)
if self.arguments.in_place:
self.create_backup()
self.arguments.action()
if os.path.isfile(self.arguments.file):
if self.arguments.in_place:
self.create_backup()
self.arguments.action()
else:
print('No such file', self.arguments.file)
def parse_time(self, time_string):
negative = time_string.startswith('-')

View file

@ -290,7 +290,7 @@ class SubRipFile(UserList, object):
@classmethod
def _open_unicode_file(cls, path, claimed_encoding=None):
encoding = claimed_encoding or cls._detect_encoding(path)
source_file = codecs.open(path, 'rU', encoding=encoding)
source_file = codecs.open(path, 'r', encoding=encoding)
# get rid of BOM if any
possible_bom = CODECS_BOMS.get(encoding, None)

View file

@ -16,14 +16,14 @@ from pytz.exceptions import AmbiguousTimeError
from pytz.exceptions import InvalidTimeError
from pytz.exceptions import NonExistentTimeError
from pytz.exceptions import UnknownTimeZoneError
from pytz.lazy import LazyDict, LazyList, LazySet
from pytz.lazy import LazyDict, LazyList, LazySet # noqa
from pytz.tzinfo import unpickler, BaseTzInfo
from pytz.tzfile import build_tzinfo
# The IANA (nee Olson) database is updated several times a year.
OLSON_VERSION = '2018g'
VERSION = '2018.7' # pip compatible version number.
OLSON_VERSION = '2022f'
VERSION = '2022.6' # pip compatible version number.
__version__ = VERSION
OLSEN_VERSION = OLSON_VERSION # Old releases had this misspelling
@ -34,7 +34,7 @@ __all__ = [
'NonExistentTimeError', 'UnknownTimeZoneError',
'all_timezones', 'all_timezones_set',
'common_timezones', 'common_timezones_set',
'BaseTzInfo',
'BaseTzInfo', 'FixedOffset',
]
@ -86,7 +86,7 @@ def open_resource(name):
"""
name_parts = name.lstrip('/').split('/')
for part in name_parts:
if part == os.path.pardir or os.path.sep in part:
if part == os.path.pardir or os.sep in part:
raise ValueError('Bad path segment: %r' % part)
zoneinfo_dir = os.environ.get('PYTZ_TZDATADIR', None)
if zoneinfo_dir is not None:
@ -111,6 +111,13 @@ def open_resource(name):
def resource_exists(name):
"""Return true if the given resource exists"""
try:
if os.environ.get('PYTZ_SKIPEXISTSCHECK', ''):
# In "standard" distributions, we can assume that
# all the listed timezones are present. As an
# import-speed optimization, you can set the
# PYTZ_SKIPEXISTSCHECK flag to skip checking
# for the presence of the resource file on disk.
return True
open_resource(name).close()
return True
except IOError:
@ -157,6 +164,9 @@ def timezone(zone):
Unknown
'''
if zone is None:
raise UnknownTimeZoneError(None)
if zone.upper() == 'UTC':
return utc
@ -166,9 +176,9 @@ def timezone(zone):
# All valid timezones are ASCII
raise UnknownTimeZoneError(zone)
zone = _unmunge_zone(zone)
zone = _case_insensitive_zone_lookup(_unmunge_zone(zone))
if zone not in _tzinfo_cache:
if zone in all_timezones_set:
if zone in all_timezones_set: # noqa
fp = open_resource(zone)
try:
_tzinfo_cache[zone] = build_tzinfo(zone, fp)
@ -185,6 +195,17 @@ def _unmunge_zone(zone):
return zone.replace('_plus_', '+').replace('_minus_', '-')
_all_timezones_lower_to_standard = None
def _case_insensitive_zone_lookup(zone):
"""case-insensitively matching timezone, else return zone unchanged"""
global _all_timezones_lower_to_standard
if _all_timezones_lower_to_standard is None:
_all_timezones_lower_to_standard = dict((tz.lower(), tz) for tz in all_timezones) # noqa
return _all_timezones_lower_to_standard.get(zone.lower()) or zone # noqa
ZERO = datetime.timedelta(0)
HOUR = datetime.timedelta(hours=1)
@ -249,8 +270,8 @@ def _UTC():
module global.
These examples belong in the UTC class above, but it is obscured; or in
the README.txt, but we are not depending on Python 2.4 so integrating
the README.txt examples with the unit tests is not trivial.
the README.rst, but we are not depending on Python 2.4 so integrating
the README.rst examples with the unit tests is not trivial.
>>> import datetime, pickle
>>> dt = datetime.datetime(2005, 3, 1, 14, 13, 21, tzinfo=utc)
@ -272,6 +293,8 @@ def _UTC():
False
"""
return utc
_UTC.__safe_for_unpickling__ = True
@ -282,6 +305,8 @@ def _p(*args):
by shortening the path.
"""
return unpickler(*args)
_p.__safe_for_unpickling__ = True
@ -330,7 +355,7 @@ class _CountryTimezoneDict(LazyDict):
if line.startswith('#'):
continue
code, coordinates, zone = line.split(None, 4)[:3]
if zone not in all_timezones_set:
if zone not in all_timezones_set: # noqa
continue
try:
data[code].append(zone)
@ -340,6 +365,7 @@ class _CountryTimezoneDict(LazyDict):
finally:
zone_tab.close()
country_timezones = _CountryTimezoneDict()
@ -363,6 +389,7 @@ class _CountryNameDict(LazyDict):
finally:
zone_tab.close()
country_names = _CountryNameDict()
@ -474,6 +501,7 @@ def FixedOffset(offset, _tzinfos={}):
return info
FixedOffset.__safe_for_unpickling__ = True
@ -483,6 +511,7 @@ def _test():
import pytz
return doctest.testmod(pytz)
if __name__ == '__main__':
_test()
all_timezones = \
@ -661,6 +690,7 @@ all_timezones = \
'America/North_Dakota/Beulah',
'America/North_Dakota/Center',
'America/North_Dakota/New_Salem',
'America/Nuuk',
'America/Ojinaga',
'America/Panama',
'America/Pangnirtung',
@ -787,6 +817,7 @@ all_timezones = \
'Asia/Pontianak',
'Asia/Pyongyang',
'Asia/Qatar',
'Asia/Qostanay',
'Asia/Qyzylorda',
'Asia/Rangoon',
'Asia/Riyadh',
@ -933,6 +964,7 @@ all_timezones = \
'Europe/Kaliningrad',
'Europe/Kiev',
'Europe/Kirov',
'Europe/Kyiv',
'Europe/Lisbon',
'Europe/Ljubljana',
'Europe/London',
@ -1027,6 +1059,7 @@ all_timezones = \
'Pacific/Guam',
'Pacific/Honolulu',
'Pacific/Johnston',
'Pacific/Kanton',
'Pacific/Kiritimati',
'Pacific/Kosrae',
'Pacific/Kwajalein',
@ -1187,7 +1220,6 @@ common_timezones = \
'America/Fort_Nelson',
'America/Fortaleza',
'America/Glace_Bay',
'America/Godthab',
'America/Goose_Bay',
'America/Grand_Turk',
'America/Grenada',
@ -1235,12 +1267,12 @@ common_timezones = \
'America/Montserrat',
'America/Nassau',
'America/New_York',
'America/Nipigon',
'America/Nome',
'America/Noronha',
'America/North_Dakota/Beulah',
'America/North_Dakota/Center',
'America/North_Dakota/New_Salem',
'America/Nuuk',
'America/Ojinaga',
'America/Panama',
'America/Pangnirtung',
@ -1251,7 +1283,6 @@ common_timezones = \
'America/Porto_Velho',
'America/Puerto_Rico',
'America/Punta_Arenas',
'America/Rainy_River',
'America/Rankin_Inlet',
'America/Recife',
'America/Regina',
@ -1272,7 +1303,6 @@ common_timezones = \
'America/Swift_Current',
'America/Tegucigalpa',
'America/Thule',
'America/Thunder_Bay',
'America/Tijuana',
'America/Toronto',
'America/Tortola',
@ -1351,6 +1381,7 @@ common_timezones = \
'Asia/Pontianak',
'Asia/Pyongyang',
'Asia/Qatar',
'Asia/Qostanay',
'Asia/Qyzylorda',
'Asia/Riyadh',
'Asia/Sakhalin',
@ -1388,7 +1419,6 @@ common_timezones = \
'Australia/Adelaide',
'Australia/Brisbane',
'Australia/Broken_Hill',
'Australia/Currie',
'Australia/Darwin',
'Australia/Eucla',
'Australia/Hobart',
@ -1424,8 +1454,8 @@ common_timezones = \
'Europe/Istanbul',
'Europe/Jersey',
'Europe/Kaliningrad',
'Europe/Kiev',
'Europe/Kirov',
'Europe/Kyiv',
'Europe/Lisbon',
'Europe/Ljubljana',
'Europe/London',
@ -1453,7 +1483,6 @@ common_timezones = \
'Europe/Tallinn',
'Europe/Tirane',
'Europe/Ulyanovsk',
'Europe/Uzhgorod',
'Europe/Vaduz',
'Europe/Vatican',
'Europe/Vienna',
@ -1461,7 +1490,6 @@ common_timezones = \
'Europe/Volgograd',
'Europe/Warsaw',
'Europe/Zagreb',
'Europe/Zaporozhye',
'Europe/Zurich',
'GMT',
'Indian/Antananarivo',
@ -1482,7 +1510,6 @@ common_timezones = \
'Pacific/Chuuk',
'Pacific/Easter',
'Pacific/Efate',
'Pacific/Enderbury',
'Pacific/Fakaofo',
'Pacific/Fiji',
'Pacific/Funafuti',
@ -1491,6 +1518,7 @@ common_timezones = \
'Pacific/Guadalcanal',
'Pacific/Guam',
'Pacific/Honolulu',
'Pacific/Kanton',
'Pacific/Kiritimati',
'Pacific/Kosrae',
'Pacific/Kwajalein',

View file

@ -8,7 +8,11 @@ __all__ = [
]
class UnknownTimeZoneError(KeyError):
class Error(Exception):
'''Base class for all exceptions raised by the pytz library'''
class UnknownTimeZoneError(KeyError, Error):
'''Exception raised when pytz is passed an unknown timezone.
>>> isinstance(UnknownTimeZoneError(), LookupError)
@ -20,11 +24,18 @@ class UnknownTimeZoneError(KeyError):
>>> isinstance(UnknownTimeZoneError(), KeyError)
True
And also a subclass of pytz.exceptions.Error, as are other pytz
exceptions.
>>> isinstance(UnknownTimeZoneError(), Error)
True
'''
pass
class InvalidTimeError(Exception):
class InvalidTimeError(Error):
'''Base class for invalid time exceptions.'''

View file

@ -1,4 +1,3 @@
#!/usr/bin/env python
'''
$Id: tzfile.py,v 1.8 2004/06/03 00:15:24 zenzen Exp $
'''

Some files were not shown because too many files have changed in this diff Show more