Updates vendored subliminal to 2.1.0

Updates rarfile to 3.1
Updates stevedore to 3.5.0
Updates appdirs to 1.4.4
Updates click to 8.1.3
Updates decorator to 5.1.1
Updates dogpile.cache to 1.1.8
Updates pbr to 5.11.0
Updates pysrt to 1.1.2
Updates pytz to 2022.6
Adds importlib-metadata version 3.1.1
Adds typing-extensions version 4.1.1
Adds zipp version 3.11.0
This commit is contained in:
Labrys of Knossos 2022-11-29 00:08:39 -05:00
commit f05b09f349
694 changed files with 16621 additions and 11056 deletions

Binary file not shown.

View file

@ -13,8 +13,8 @@ See <http://github.com/ActiveState/appdirs> for details and usage.
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html # - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html # - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
__version_info__ = (1, 4, 3) __version__ = "1.4.4"
__version__ = '.'.join(map(str, __version_info__)) __version_info__ = tuple(int(segment) for segment in __version__.split("."))
import sys import sys

BIN
libs/common/bin/beet.exe Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
libs/common/bin/guessit.exe Normal file

Binary file not shown.

BIN
libs/common/bin/mid3cp.exe Normal file

Binary file not shown.

Binary file not shown.

BIN
libs/common/bin/mid3v2.exe Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

BIN
libs/common/bin/pbr.exe Normal file

Binary file not shown.

BIN
libs/common/bin/srt.exe Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -1,97 +1,73 @@
# -*- coding: utf-8 -*-
""" """
click
~~~~~
Click is a simple Python module inspired by the stdlib optparse to make Click is a simple Python module inspired by the stdlib optparse to make
writing command line scripts fun. Unlike other modules, it's based writing command line scripts fun. Unlike other modules, it's based
around a simple API that does not come with too much magic and is around a simple API that does not come with too much magic and is
composable. composable.
:copyright: © 2014 by the Pallets team.
:license: BSD, see LICENSE.rst for more details.
""" """
from .core import Argument as Argument
from .core import BaseCommand as BaseCommand
from .core import Command as Command
from .core import CommandCollection as CommandCollection
from .core import Context as Context
from .core import Group as Group
from .core import MultiCommand as MultiCommand
from .core import Option as Option
from .core import Parameter as Parameter
from .decorators import argument as argument
from .decorators import command as command
from .decorators import confirmation_option as confirmation_option
from .decorators import group as group
from .decorators import help_option as help_option
from .decorators import make_pass_decorator as make_pass_decorator
from .decorators import option as option
from .decorators import pass_context as pass_context
from .decorators import pass_obj as pass_obj
from .decorators import password_option as password_option
from .decorators import version_option as version_option
from .exceptions import Abort as Abort
from .exceptions import BadArgumentUsage as BadArgumentUsage
from .exceptions import BadOptionUsage as BadOptionUsage
from .exceptions import BadParameter as BadParameter
from .exceptions import ClickException as ClickException
from .exceptions import FileError as FileError
from .exceptions import MissingParameter as MissingParameter
from .exceptions import NoSuchOption as NoSuchOption
from .exceptions import UsageError as UsageError
from .formatting import HelpFormatter as HelpFormatter
from .formatting import wrap_text as wrap_text
from .globals import get_current_context as get_current_context
from .parser import OptionParser as OptionParser
from .termui import clear as clear
from .termui import confirm as confirm
from .termui import echo_via_pager as echo_via_pager
from .termui import edit as edit
from .termui import getchar as getchar
from .termui import launch as launch
from .termui import pause as pause
from .termui import progressbar as progressbar
from .termui import prompt as prompt
from .termui import secho as secho
from .termui import style as style
from .termui import unstyle as unstyle
from .types import BOOL as BOOL
from .types import Choice as Choice
from .types import DateTime as DateTime
from .types import File as File
from .types import FLOAT as FLOAT
from .types import FloatRange as FloatRange
from .types import INT as INT
from .types import IntRange as IntRange
from .types import ParamType as ParamType
from .types import Path as Path
from .types import STRING as STRING
from .types import Tuple as Tuple
from .types import UNPROCESSED as UNPROCESSED
from .types import UUID as UUID
from .utils import echo as echo
from .utils import format_filename as format_filename
from .utils import get_app_dir as get_app_dir
from .utils import get_binary_stream as get_binary_stream
from .utils import get_text_stream as get_text_stream
from .utils import open_file as open_file
# Core classes __version__ = "8.1.3"
from .core import Context, BaseCommand, Command, MultiCommand, Group, \
CommandCollection, Parameter, Option, Argument
# Globals
from .globals import get_current_context
# Decorators
from .decorators import pass_context, pass_obj, make_pass_decorator, \
command, group, argument, option, confirmation_option, \
password_option, version_option, help_option
# Types
from .types import ParamType, File, Path, Choice, IntRange, Tuple, \
DateTime, STRING, INT, FLOAT, BOOL, UUID, UNPROCESSED, FloatRange
# Utilities
from .utils import echo, get_binary_stream, get_text_stream, open_file, \
format_filename, get_app_dir, get_os_args
# Terminal functions
from .termui import prompt, confirm, get_terminal_size, echo_via_pager, \
progressbar, clear, style, unstyle, secho, edit, launch, getchar, \
pause
# Exceptions
from .exceptions import ClickException, UsageError, BadParameter, \
FileError, Abort, NoSuchOption, BadOptionUsage, BadArgumentUsage, \
MissingParameter
# Formatting
from .formatting import HelpFormatter, wrap_text
# Parsing
from .parser import OptionParser
__all__ = [
# Core classes
'Context', 'BaseCommand', 'Command', 'MultiCommand', 'Group',
'CommandCollection', 'Parameter', 'Option', 'Argument',
# Globals
'get_current_context',
# Decorators
'pass_context', 'pass_obj', 'make_pass_decorator', 'command', 'group',
'argument', 'option', 'confirmation_option', 'password_option',
'version_option', 'help_option',
# Types
'ParamType', 'File', 'Path', 'Choice', 'IntRange', 'Tuple',
'DateTime', 'STRING', 'INT', 'FLOAT', 'BOOL', 'UUID', 'UNPROCESSED',
'FloatRange',
# Utilities
'echo', 'get_binary_stream', 'get_text_stream', 'open_file',
'format_filename', 'get_app_dir', 'get_os_args',
# Terminal functions
'prompt', 'confirm', 'get_terminal_size', 'echo_via_pager',
'progressbar', 'clear', 'style', 'unstyle', 'secho', 'edit', 'launch',
'getchar', 'pause',
# Exceptions
'ClickException', 'UsageError', 'BadParameter', 'FileError',
'Abort', 'NoSuchOption', 'BadOptionUsage', 'BadArgumentUsage',
'MissingParameter',
# Formatting
'HelpFormatter', 'wrap_text',
# Parsing
'OptionParser',
]
# Controls if click should emit the warning about the use of unicode
# literals.
disable_unicode_literals_warning = False
__version__ = '7.0'

View file

@ -1,293 +0,0 @@
import copy
import os
import re
from .utils import echo
from .parser import split_arg_string
from .core import MultiCommand, Option, Argument
from .types import Choice
try:
from collections import abc
except ImportError:
import collections as abc
WORDBREAK = '='
# Note, only BASH version 4.4 and later have the nosort option.
COMPLETION_SCRIPT_BASH = '''
%(complete_func)s() {
local IFS=$'\n'
COMPREPLY=( $( env COMP_WORDS="${COMP_WORDS[*]}" \\
COMP_CWORD=$COMP_CWORD \\
%(autocomplete_var)s=complete $1 ) )
return 0
}
%(complete_func)setup() {
local COMPLETION_OPTIONS=""
local BASH_VERSION_ARR=(${BASH_VERSION//./ })
# Only BASH version 4.4 and later have the nosort option.
if [ ${BASH_VERSION_ARR[0]} -gt 4 ] || ([ ${BASH_VERSION_ARR[0]} -eq 4 ] && [ ${BASH_VERSION_ARR[1]} -ge 4 ]); then
COMPLETION_OPTIONS="-o nosort"
fi
complete $COMPLETION_OPTIONS -F %(complete_func)s %(script_names)s
}
%(complete_func)setup
'''
COMPLETION_SCRIPT_ZSH = '''
%(complete_func)s() {
local -a completions
local -a completions_with_descriptions
local -a response
response=("${(@f)$( env COMP_WORDS=\"${words[*]}\" \\
COMP_CWORD=$((CURRENT-1)) \\
%(autocomplete_var)s=\"complete_zsh\" \\
%(script_names)s )}")
for key descr in ${(kv)response}; do
if [[ "$descr" == "_" ]]; then
completions+=("$key")
else
completions_with_descriptions+=("$key":"$descr")
fi
done
if [ -n "$completions_with_descriptions" ]; then
_describe -V unsorted completions_with_descriptions -U -Q
fi
if [ -n "$completions" ]; then
compadd -U -V unsorted -Q -a completions
fi
compstate[insert]="automenu"
}
compdef %(complete_func)s %(script_names)s
'''
_invalid_ident_char_re = re.compile(r'[^a-zA-Z0-9_]')
def get_completion_script(prog_name, complete_var, shell):
cf_name = _invalid_ident_char_re.sub('', prog_name.replace('-', '_'))
script = COMPLETION_SCRIPT_ZSH if shell == 'zsh' else COMPLETION_SCRIPT_BASH
return (script % {
'complete_func': '_%s_completion' % cf_name,
'script_names': prog_name,
'autocomplete_var': complete_var,
}).strip() + ';'
def resolve_ctx(cli, prog_name, args):
"""
Parse into a hierarchy of contexts. Contexts are connected through the parent variable.
:param cli: command definition
:param prog_name: the program that is running
:param args: full list of args
:return: the final context/command parsed
"""
ctx = cli.make_context(prog_name, args, resilient_parsing=True)
args = ctx.protected_args + ctx.args
while args:
if isinstance(ctx.command, MultiCommand):
if not ctx.command.chain:
cmd_name, cmd, args = ctx.command.resolve_command(ctx, args)
if cmd is None:
return ctx
ctx = cmd.make_context(cmd_name, args, parent=ctx,
resilient_parsing=True)
args = ctx.protected_args + ctx.args
else:
# Walk chained subcommand contexts saving the last one.
while args:
cmd_name, cmd, args = ctx.command.resolve_command(ctx, args)
if cmd is None:
return ctx
sub_ctx = cmd.make_context(cmd_name, args, parent=ctx,
allow_extra_args=True,
allow_interspersed_args=False,
resilient_parsing=True)
args = sub_ctx.args
ctx = sub_ctx
args = sub_ctx.protected_args + sub_ctx.args
else:
break
return ctx
def start_of_option(param_str):
"""
:param param_str: param_str to check
:return: whether or not this is the start of an option declaration (i.e. starts "-" or "--")
"""
return param_str and param_str[:1] == '-'
def is_incomplete_option(all_args, cmd_param):
"""
:param all_args: the full original list of args supplied
:param cmd_param: the current command paramter
:return: whether or not the last option declaration (i.e. starts "-" or "--") is incomplete and
corresponds to this cmd_param. In other words whether this cmd_param option can still accept
values
"""
if not isinstance(cmd_param, Option):
return False
if cmd_param.is_flag:
return False
last_option = None
for index, arg_str in enumerate(reversed([arg for arg in all_args if arg != WORDBREAK])):
if index + 1 > cmd_param.nargs:
break
if start_of_option(arg_str):
last_option = arg_str
return True if last_option and last_option in cmd_param.opts else False
def is_incomplete_argument(current_params, cmd_param):
"""
:param current_params: the current params and values for this argument as already entered
:param cmd_param: the current command parameter
:return: whether or not the last argument is incomplete and corresponds to this cmd_param. In
other words whether or not the this cmd_param argument can still accept values
"""
if not isinstance(cmd_param, Argument):
return False
current_param_values = current_params[cmd_param.name]
if current_param_values is None:
return True
if cmd_param.nargs == -1:
return True
if isinstance(current_param_values, abc.Iterable) \
and cmd_param.nargs > 1 and len(current_param_values) < cmd_param.nargs:
return True
return False
def get_user_autocompletions(ctx, args, incomplete, cmd_param):
"""
:param ctx: context associated with the parsed command
:param args: full list of args
:param incomplete: the incomplete text to autocomplete
:param cmd_param: command definition
:return: all the possible user-specified completions for the param
"""
results = []
if isinstance(cmd_param.type, Choice):
# Choices don't support descriptions.
results = [(c, None)
for c in cmd_param.type.choices if str(c).startswith(incomplete)]
elif cmd_param.autocompletion is not None:
dynamic_completions = cmd_param.autocompletion(ctx=ctx,
args=args,
incomplete=incomplete)
results = [c if isinstance(c, tuple) else (c, None)
for c in dynamic_completions]
return results
def get_visible_commands_starting_with(ctx, starts_with):
"""
:param ctx: context associated with the parsed command
:starts_with: string that visible commands must start with.
:return: all visible (not hidden) commands that start with starts_with.
"""
for c in ctx.command.list_commands(ctx):
if c.startswith(starts_with):
command = ctx.command.get_command(ctx, c)
if not command.hidden:
yield command
def add_subcommand_completions(ctx, incomplete, completions_out):
# Add subcommand completions.
if isinstance(ctx.command, MultiCommand):
completions_out.extend(
[(c.name, c.get_short_help_str()) for c in get_visible_commands_starting_with(ctx, incomplete)])
# Walk up the context list and add any other completion possibilities from chained commands
while ctx.parent is not None:
ctx = ctx.parent
if isinstance(ctx.command, MultiCommand) and ctx.command.chain:
remaining_commands = [c for c in get_visible_commands_starting_with(ctx, incomplete)
if c.name not in ctx.protected_args]
completions_out.extend([(c.name, c.get_short_help_str()) for c in remaining_commands])
def get_choices(cli, prog_name, args, incomplete):
"""
:param cli: command definition
:param prog_name: the program that is running
:param args: full list of args
:param incomplete: the incomplete text to autocomplete
:return: all the possible completions for the incomplete
"""
all_args = copy.deepcopy(args)
ctx = resolve_ctx(cli, prog_name, args)
if ctx is None:
return []
# In newer versions of bash long opts with '='s are partitioned, but it's easier to parse
# without the '='
if start_of_option(incomplete) and WORDBREAK in incomplete:
partition_incomplete = incomplete.partition(WORDBREAK)
all_args.append(partition_incomplete[0])
incomplete = partition_incomplete[2]
elif incomplete == WORDBREAK:
incomplete = ''
completions = []
if start_of_option(incomplete):
# completions for partial options
for param in ctx.command.params:
if isinstance(param, Option) and not param.hidden:
param_opts = [param_opt for param_opt in param.opts +
param.secondary_opts if param_opt not in all_args or param.multiple]
completions.extend([(o, param.help) for o in param_opts if o.startswith(incomplete)])
return completions
# completion for option values from user supplied values
for param in ctx.command.params:
if is_incomplete_option(all_args, param):
return get_user_autocompletions(ctx, all_args, incomplete, param)
# completion for argument values from user supplied values
for param in ctx.command.params:
if is_incomplete_argument(ctx.params, param):
return get_user_autocompletions(ctx, all_args, incomplete, param)
add_subcommand_completions(ctx, incomplete, completions)
# Sort before returning so that proper ordering can be enforced in custom types.
return sorted(completions)
def do_complete(cli, prog_name, include_descriptions):
cwords = split_arg_string(os.environ['COMP_WORDS'])
cword = int(os.environ['COMP_CWORD'])
args = cwords[1:cword]
try:
incomplete = cwords[cword]
except IndexError:
incomplete = ''
for item in get_choices(cli, prog_name, args, incomplete):
echo(item[0])
if include_descriptions:
# ZSH has trouble dealing with empty array parameters when returned from commands, so use a well defined character '_' to indicate no description is present.
echo(item[1] if item[1] else '_')
return True
def bashcomplete(cli, prog_name, complete_var, complete_instr):
if complete_instr.startswith('source'):
shell = 'zsh' if complete_instr == 'source_zsh' else 'bash'
echo(get_completion_script(prog_name, complete_var, shell))
return True
elif complete_instr == 'complete' or complete_instr == 'complete_zsh':
return do_complete(cli, prog_name, complete_instr == 'complete_zsh')
return False

File diff suppressed because it is too large Load diff

View file

@ -1,62 +1,56 @@
# -*- coding: utf-8 -*-
""" """
click._termui_impl
~~~~~~~~~~~~~~~~~~
This module contains implementations for the termui module. To keep the This module contains implementations for the termui module. To keep the
import time of Click down, some infrequently used functionality is import time of Click down, some infrequently used functionality is
placed in this module and only imported as needed. placed in this module and only imported as needed.
:copyright: © 2014 by the Pallets team.
:license: BSD, see LICENSE.rst for more details.
""" """
import contextlib
import math
import os import os
import sys import sys
import time import time
import math import typing as t
import contextlib from gettext import gettext as _
from ._compat import _default_text_stdout, range_type, PY2, isatty, \
open_stream, strip_ansi, term_len, get_best_encoding, WIN, int_types, \ from ._compat import _default_text_stdout
CYGWIN from ._compat import CYGWIN
from .utils import echo from ._compat import get_best_encoding
from ._compat import isatty
from ._compat import open_stream
from ._compat import strip_ansi
from ._compat import term_len
from ._compat import WIN
from .exceptions import ClickException from .exceptions import ClickException
from .utils import echo
V = t.TypeVar("V")
if os.name == 'nt': if os.name == "nt":
BEFORE_BAR = '\r' BEFORE_BAR = "\r"
AFTER_BAR = '\n' AFTER_BAR = "\n"
else: else:
BEFORE_BAR = '\r\033[?25l' BEFORE_BAR = "\r\033[?25l"
AFTER_BAR = '\033[?25h\n' AFTER_BAR = "\033[?25h\n"
def _length_hint(obj): class ProgressBar(t.Generic[V]):
"""Returns the length hint of an object.""" def __init__(
try: self,
return len(obj) iterable: t.Optional[t.Iterable[V]],
except (AttributeError, TypeError): length: t.Optional[int] = None,
try: fill_char: str = "#",
get_hint = type(obj).__length_hint__ empty_char: str = " ",
except AttributeError: bar_template: str = "%(bar)s",
return None info_sep: str = " ",
try: show_eta: bool = True,
hint = get_hint(obj) show_percent: t.Optional[bool] = None,
except TypeError: show_pos: bool = False,
return None item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
if hint is NotImplemented or \ label: t.Optional[str] = None,
not isinstance(hint, int_types) or \ file: t.Optional[t.TextIO] = None,
hint < 0: color: t.Optional[bool] = None,
return None update_min_steps: int = 1,
return hint width: int = 30,
) -> None:
class ProgressBar(object):
def __init__(self, iterable, length=None, fill_char='#', empty_char=' ',
bar_template='%(bar)s', info_sep=' ', show_eta=True,
show_percent=None, show_pos=False, item_show_func=None,
label=None, file=None, color=None, width=30):
self.fill_char = fill_char self.fill_char = fill_char
self.empty_char = empty_char self.empty_char = empty_char
self.bar_template = bar_template self.bar_template = bar_template
@ -65,77 +59,87 @@ class ProgressBar(object):
self.show_percent = show_percent self.show_percent = show_percent
self.show_pos = show_pos self.show_pos = show_pos
self.item_show_func = item_show_func self.item_show_func = item_show_func
self.label = label or '' self.label = label or ""
if file is None: if file is None:
file = _default_text_stdout() file = _default_text_stdout()
self.file = file self.file = file
self.color = color self.color = color
self.update_min_steps = update_min_steps
self._completed_intervals = 0
self.width = width self.width = width
self.autowidth = width == 0 self.autowidth = width == 0
if length is None: if length is None:
length = _length_hint(iterable) from operator import length_hint
length = length_hint(iterable, -1)
if length == -1:
length = None
if iterable is None: if iterable is None:
if length is None: if length is None:
raise TypeError('iterable or length is required') raise TypeError("iterable or length is required")
iterable = range_type(length) iterable = t.cast(t.Iterable[V], range(length))
self.iter = iter(iterable) self.iter = iter(iterable)
self.length = length self.length = length
self.length_known = length is not None
self.pos = 0 self.pos = 0
self.avg = [] self.avg: t.List[float] = []
self.start = self.last_eta = time.time() self.start = self.last_eta = time.time()
self.eta_known = False self.eta_known = False
self.finished = False self.finished = False
self.max_width = None self.max_width: t.Optional[int] = None
self.entered = False self.entered = False
self.current_item = None self.current_item: t.Optional[V] = None
self.is_hidden = not isatty(self.file) self.is_hidden = not isatty(self.file)
self._last_line = None self._last_line: t.Optional[str] = None
self.short_limit = 0.5
def __enter__(self): def __enter__(self) -> "ProgressBar":
self.entered = True self.entered = True
self.render_progress() self.render_progress()
return self return self
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb): # type: ignore
self.render_finish() self.render_finish()
def __iter__(self): def __iter__(self) -> t.Iterator[V]:
if not self.entered: if not self.entered:
raise RuntimeError('You need to use progress bars in a with block.') raise RuntimeError("You need to use progress bars in a with block.")
self.render_progress() self.render_progress()
return self.generator() return self.generator()
def is_fast(self): def __next__(self) -> V:
return time.time() - self.start <= self.short_limit # Iteration is defined in terms of a generator function,
# returned by iter(self); use that to define next(). This works
# because `self.iter` is an iterable consumed by that generator,
# so it is re-entry safe. Calling `next(self.generator())`
# twice works and does "what you want".
return next(iter(self))
def render_finish(self): def render_finish(self) -> None:
if self.is_hidden or self.is_fast(): if self.is_hidden:
return return
self.file.write(AFTER_BAR) self.file.write(AFTER_BAR)
self.file.flush() self.file.flush()
@property @property
def pct(self): def pct(self) -> float:
if self.finished: if self.finished:
return 1.0 return 1.0
return min(self.pos / (float(self.length) or 1), 1.0) return min(self.pos / (float(self.length or 1) or 1), 1.0)
@property @property
def time_per_iteration(self): def time_per_iteration(self) -> float:
if not self.avg: if not self.avg:
return 0.0 return 0.0
return sum(self.avg) / float(len(self.avg)) return sum(self.avg) / float(len(self.avg))
@property @property
def eta(self): def eta(self) -> float:
if self.length_known and not self.finished: if self.length is not None and not self.finished:
return self.time_per_iteration * (self.length - self.pos) return self.time_per_iteration * (self.length - self.pos)
return 0.0 return 0.0
def format_eta(self): def format_eta(self) -> str:
if self.eta_known: if self.eta_known:
t = int(self.eta) t = int(self.eta)
seconds = t % 60 seconds = t % 60
@ -145,41 +149,44 @@ class ProgressBar(object):
hours = t % 24 hours = t % 24
t //= 24 t //= 24
if t > 0: if t > 0:
days = t return f"{t}d {hours:02}:{minutes:02}:{seconds:02}"
return '%dd %02d:%02d:%02d' % (days, hours, minutes, seconds)
else: else:
return '%02d:%02d:%02d' % (hours, minutes, seconds) return f"{hours:02}:{minutes:02}:{seconds:02}"
return '' return ""
def format_pos(self): def format_pos(self) -> str:
pos = str(self.pos) pos = str(self.pos)
if self.length_known: if self.length is not None:
pos += '/%s' % self.length pos += f"/{self.length}"
return pos return pos
def format_pct(self): def format_pct(self) -> str:
return ('% 4d%%' % int(self.pct * 100))[1:] return f"{int(self.pct * 100): 4}%"[1:]
def format_bar(self): def format_bar(self) -> str:
if self.length_known: if self.length is not None:
bar_length = int(self.pct * self.width) bar_length = int(self.pct * self.width)
bar = self.fill_char * bar_length bar = self.fill_char * bar_length
bar += self.empty_char * (self.width - bar_length) bar += self.empty_char * (self.width - bar_length)
elif self.finished: elif self.finished:
bar = self.fill_char * self.width bar = self.fill_char * self.width
else: else:
bar = list(self.empty_char * (self.width or 1)) chars = list(self.empty_char * (self.width or 1))
if self.time_per_iteration != 0: if self.time_per_iteration != 0:
bar[int((math.cos(self.pos * self.time_per_iteration) chars[
/ 2.0 + 0.5) * self.width)] = self.fill_char int(
bar = ''.join(bar) (math.cos(self.pos * self.time_per_iteration) / 2.0 + 0.5)
* self.width
)
] = self.fill_char
bar = "".join(chars)
return bar return bar
def format_progress_line(self): def format_progress_line(self) -> str:
show_percent = self.show_percent show_percent = self.show_percent
info_bits = [] info_bits = []
if self.length_known and show_percent is None: if self.length is not None and show_percent is None:
show_percent = not self.show_pos show_percent = not self.show_pos
if self.show_pos: if self.show_pos:
@ -193,16 +200,25 @@ class ProgressBar(object):
if item_info is not None: if item_info is not None:
info_bits.append(item_info) info_bits.append(item_info)
return (self.bar_template % { return (
'label': self.label, self.bar_template
'bar': self.format_bar(), % {
'info': self.info_sep.join(info_bits) "label": self.label,
}).rstrip() "bar": self.format_bar(),
"info": self.info_sep.join(info_bits),
}
).rstrip()
def render_progress(self): def render_progress(self) -> None:
from .termui import get_terminal_size import shutil
if self.is_hidden: if self.is_hidden:
# Only output the label as it changes if the output is not a
# TTY. Use file=stderr if you expect to be piping stdout.
if self._last_line != self.label:
self._last_line = self.label
echo(self.label, file=self.file, color=self.color)
return return
buf = [] buf = []
@ -211,10 +227,10 @@ class ProgressBar(object):
old_width = self.width old_width = self.width
self.width = 0 self.width = 0
clutter_length = term_len(self.format_progress_line()) clutter_length = term_len(self.format_progress_line())
new_width = max(0, get_terminal_size()[0] - clutter_length) new_width = max(0, shutil.get_terminal_size().columns - clutter_length)
if new_width < old_width: if new_width < old_width:
buf.append(BEFORE_BAR) buf.append(BEFORE_BAR)
buf.append(' ' * self.max_width) buf.append(" " * self.max_width) # type: ignore
self.max_width = new_width self.max_width = new_width
self.width = new_width self.width = new_width
@ -229,18 +245,18 @@ class ProgressBar(object):
self.max_width = line_len self.max_width = line_len
buf.append(line) buf.append(line)
buf.append(' ' * (clear_width - line_len)) buf.append(" " * (clear_width - line_len))
line = ''.join(buf) line = "".join(buf)
# Render the line only if it changed. # Render the line only if it changed.
if line != self._last_line and not self.is_fast(): if line != self._last_line:
self._last_line = line self._last_line = line
echo(line, file=self.file, color=self.color, nl=False) echo(line, file=self.file, color=self.color, nl=False)
self.file.flush() self.file.flush()
def make_step(self, n_steps): def make_step(self, n_steps: int) -> None:
self.pos += n_steps self.pos += n_steps
if self.length_known and self.pos >= self.length: if self.length is not None and self.pos >= self.length:
self.finished = True self.finished = True
if (time.time() - self.last_eta) < 1.0: if (time.time() - self.last_eta) < 1.0:
@ -258,97 +274,134 @@ class ProgressBar(object):
self.avg = self.avg[-6:] + [step] self.avg = self.avg[-6:] + [step]
self.eta_known = self.length_known self.eta_known = self.length is not None
def update(self, n_steps): def update(self, n_steps: int, current_item: t.Optional[V] = None) -> None:
self.make_step(n_steps) """Update the progress bar by advancing a specified number of
self.render_progress() steps, and optionally set the ``current_item`` for this new
position.
def finish(self): :param n_steps: Number of steps to advance.
self.eta_known = 0 :param current_item: Optional item to set as ``current_item``
for the updated position.
.. versionchanged:: 8.0
Added the ``current_item`` optional parameter.
.. versionchanged:: 8.0
Only render when the number of steps meets the
``update_min_steps`` threshold.
"""
if current_item is not None:
self.current_item = current_item
self._completed_intervals += n_steps
if self._completed_intervals >= self.update_min_steps:
self.make_step(self._completed_intervals)
self.render_progress()
self._completed_intervals = 0
def finish(self) -> None:
self.eta_known = False
self.current_item = None self.current_item = None
self.finished = True self.finished = True
def generator(self): def generator(self) -> t.Iterator[V]:
""" """Return a generator which yields the items added to the bar
Returns a generator which yields the items added to the bar during during construction, and updates the progress bar *after* the
construction, and updates the progress bar *after* the yielded block yielded block returns.
returns.
""" """
# WARNING: the iterator interface for `ProgressBar` relies on
# this and only works because this is a simple generator which
# doesn't create or manage additional state. If this function
# changes, the impact should be evaluated both against
# `iter(bar)` and `next(bar)`. `next()` in particular may call
# `self.generator()` repeatedly, and this must remain safe in
# order for that interface to work.
if not self.entered: if not self.entered:
raise RuntimeError('You need to use progress bars in a with block.') raise RuntimeError("You need to use progress bars in a with block.")
if self.is_hidden: if self.is_hidden:
for rv in self.iter: yield from self.iter
yield rv
else: else:
for rv in self.iter: for rv in self.iter:
self.current_item = rv self.current_item = rv
# This allows show_item_func to be updated before the
# item is processed. Only trigger at the beginning of
# the update interval.
if self._completed_intervals == 0:
self.render_progress()
yield rv yield rv
self.update(1) self.update(1)
self.finish() self.finish()
self.render_progress() self.render_progress()
def pager(generator, color=None): def pager(generator: t.Iterable[str], color: t.Optional[bool] = None) -> None:
"""Decide what method to use for paging through text.""" """Decide what method to use for paging through text."""
stdout = _default_text_stdout() stdout = _default_text_stdout()
if not isatty(sys.stdin) or not isatty(stdout): if not isatty(sys.stdin) or not isatty(stdout):
return _nullpager(stdout, generator, color) return _nullpager(stdout, generator, color)
pager_cmd = (os.environ.get('PAGER', None) or '').strip() pager_cmd = (os.environ.get("PAGER", None) or "").strip()
if pager_cmd: if pager_cmd:
if WIN: if WIN:
return _tempfilepager(generator, pager_cmd, color) return _tempfilepager(generator, pager_cmd, color)
return _pipepager(generator, pager_cmd, color) return _pipepager(generator, pager_cmd, color)
if os.environ.get('TERM') in ('dumb', 'emacs'): if os.environ.get("TERM") in ("dumb", "emacs"):
return _nullpager(stdout, generator, color) return _nullpager(stdout, generator, color)
if WIN or sys.platform.startswith('os2'): if WIN or sys.platform.startswith("os2"):
return _tempfilepager(generator, 'more <', color) return _tempfilepager(generator, "more <", color)
if hasattr(os, 'system') and os.system('(less) 2>/dev/null') == 0: if hasattr(os, "system") and os.system("(less) 2>/dev/null") == 0:
return _pipepager(generator, 'less', color) return _pipepager(generator, "less", color)
import tempfile import tempfile
fd, filename = tempfile.mkstemp() fd, filename = tempfile.mkstemp()
os.close(fd) os.close(fd)
try: try:
if hasattr(os, 'system') and os.system('more "%s"' % filename) == 0: if hasattr(os, "system") and os.system(f'more "{filename}"') == 0:
return _pipepager(generator, 'more', color) return _pipepager(generator, "more", color)
return _nullpager(stdout, generator, color) return _nullpager(stdout, generator, color)
finally: finally:
os.unlink(filename) os.unlink(filename)
def _pipepager(generator, cmd, color): def _pipepager(generator: t.Iterable[str], cmd: str, color: t.Optional[bool]) -> None:
"""Page through text by feeding it to another program. Invoking a """Page through text by feeding it to another program. Invoking a
pager through this might support colors. pager through this might support colors.
""" """
import subprocess import subprocess
env = dict(os.environ) env = dict(os.environ)
# If we're piping to less we might support colors under the # If we're piping to less we might support colors under the
# condition that # condition that
cmd_detail = cmd.rsplit('/', 1)[-1].split() cmd_detail = cmd.rsplit("/", 1)[-1].split()
if color is None and cmd_detail[0] == 'less': if color is None and cmd_detail[0] == "less":
less_flags = os.environ.get('LESS', '') + ' '.join(cmd_detail[1:]) less_flags = f"{os.environ.get('LESS', '')}{' '.join(cmd_detail[1:])}"
if not less_flags: if not less_flags:
env['LESS'] = '-R' env["LESS"] = "-R"
color = True color = True
elif 'r' in less_flags or 'R' in less_flags: elif "r" in less_flags or "R" in less_flags:
color = True color = True
c = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, c = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, env=env)
env=env) stdin = t.cast(t.BinaryIO, c.stdin)
encoding = get_best_encoding(c.stdin) encoding = get_best_encoding(stdin)
try: try:
for text in generator: for text in generator:
if not color: if not color:
text = strip_ansi(text) text = strip_ansi(text)
c.stdin.write(text.encode(encoding, 'replace')) stdin.write(text.encode(encoding, "replace"))
except (IOError, KeyboardInterrupt): except (OSError, KeyboardInterrupt):
pass pass
else: else:
c.stdin.close() stdin.close()
# Less doesn't respect ^C, but catches it for its own UI purposes (aborting # Less doesn't respect ^C, but catches it for its own UI purposes (aborting
# search or other commands inside less). # search or other commands inside less).
@ -367,24 +420,30 @@ def _pipepager(generator, cmd, color):
break break
def _tempfilepager(generator, cmd, color): def _tempfilepager(
generator: t.Iterable[str], cmd: str, color: t.Optional[bool]
) -> None:
"""Page through text by invoking a program on a temporary file.""" """Page through text by invoking a program on a temporary file."""
import tempfile import tempfile
filename = tempfile.mktemp()
fd, filename = tempfile.mkstemp()
# TODO: This never terminates if the passed generator never terminates. # TODO: This never terminates if the passed generator never terminates.
text = "".join(generator) text = "".join(generator)
if not color: if not color:
text = strip_ansi(text) text = strip_ansi(text)
encoding = get_best_encoding(sys.stdout) encoding = get_best_encoding(sys.stdout)
with open_stream(filename, 'wb')[0] as f: with open_stream(filename, "wb")[0] as f:
f.write(text.encode(encoding)) f.write(text.encode(encoding))
try: try:
os.system(cmd + ' "' + filename + '"') os.system(f'{cmd} "{filename}"')
finally: finally:
os.close(fd)
os.unlink(filename) os.unlink(filename)
def _nullpager(stream, generator, color): def _nullpager(
stream: t.TextIO, generator: t.Iterable[str], color: t.Optional[bool]
) -> None:
"""Simply print unformatted text. This is the ultimate fallback.""" """Simply print unformatted text. This is the ultimate fallback."""
for text in generator: for text in generator:
if not color: if not color:
@ -392,159 +451,184 @@ def _nullpager(stream, generator, color):
stream.write(text) stream.write(text)
class Editor(object): class Editor:
def __init__(
def __init__(self, editor=None, env=None, require_save=True, self,
extension='.txt'): editor: t.Optional[str] = None,
env: t.Optional[t.Mapping[str, str]] = None,
require_save: bool = True,
extension: str = ".txt",
) -> None:
self.editor = editor self.editor = editor
self.env = env self.env = env
self.require_save = require_save self.require_save = require_save
self.extension = extension self.extension = extension
def get_editor(self): def get_editor(self) -> str:
if self.editor is not None: if self.editor is not None:
return self.editor return self.editor
for key in 'VISUAL', 'EDITOR': for key in "VISUAL", "EDITOR":
rv = os.environ.get(key) rv = os.environ.get(key)
if rv: if rv:
return rv return rv
if WIN: if WIN:
return 'notepad' return "notepad"
for editor in 'vim', 'nano': for editor in "sensible-editor", "vim", "nano":
if os.system('which %s >/dev/null 2>&1' % editor) == 0: if os.system(f"which {editor} >/dev/null 2>&1") == 0:
return editor return editor
return 'vi' return "vi"
def edit_file(self, filename): def edit_file(self, filename: str) -> None:
import subprocess import subprocess
editor = self.get_editor() editor = self.get_editor()
environ: t.Optional[t.Dict[str, str]] = None
if self.env: if self.env:
environ = os.environ.copy() environ = os.environ.copy()
environ.update(self.env) environ.update(self.env)
else:
environ = None
try: try:
c = subprocess.Popen('%s "%s"' % (editor, filename), c = subprocess.Popen(f'{editor} "{filename}"', env=environ, shell=True)
env=environ, shell=True)
exit_code = c.wait() exit_code = c.wait()
if exit_code != 0: if exit_code != 0:
raise ClickException('%s: Editing failed!' % editor) raise ClickException(
_("{editor}: Editing failed").format(editor=editor)
)
except OSError as e: except OSError as e:
raise ClickException('%s: Editing failed: %s' % (editor, e)) raise ClickException(
_("{editor}: Editing failed: {e}").format(editor=editor, e=e)
) from e
def edit(self, text): def edit(self, text: t.Optional[t.AnyStr]) -> t.Optional[t.AnyStr]:
import tempfile import tempfile
text = text or '' if not text:
if text and not text.endswith('\n'): data = b""
text += '\n' elif isinstance(text, (bytes, bytearray)):
data = text
else:
if text and not text.endswith("\n"):
text += "\n"
fd, name = tempfile.mkstemp(prefix='editor-', suffix=self.extension)
try:
if WIN: if WIN:
encoding = 'utf-8-sig' data = text.replace("\n", "\r\n").encode("utf-8-sig")
text = text.replace('\n', '\r\n')
else: else:
encoding = 'utf-8' data = text.encode("utf-8")
text = text.encode(encoding)
f = os.fdopen(fd, 'wb') fd, name = tempfile.mkstemp(prefix="editor-", suffix=self.extension)
f.write(text) f: t.BinaryIO
f.close()
try:
with os.fdopen(fd, "wb") as f:
f.write(data)
# If the filesystem resolution is 1 second, like Mac OS
# 10.12 Extended, or 2 seconds, like FAT32, and the editor
# closes very fast, require_save can fail. Set the modified
# time to be 2 seconds in the past to work around this.
os.utime(name, (os.path.getatime(name), os.path.getmtime(name) - 2))
# Depending on the resolution, the exact value might not be
# recorded, so get the new recorded value.
timestamp = os.path.getmtime(name) timestamp = os.path.getmtime(name)
self.edit_file(name) self.edit_file(name)
if self.require_save \ if self.require_save and os.path.getmtime(name) == timestamp:
and os.path.getmtime(name) == timestamp:
return None return None
f = open(name, 'rb') with open(name, "rb") as f:
try:
rv = f.read() rv = f.read()
finally:
f.close() if isinstance(text, (bytes, bytearray)):
return rv.decode('utf-8-sig').replace('\r\n', '\n') return rv
return rv.decode("utf-8-sig").replace("\r\n", "\n") # type: ignore
finally: finally:
os.unlink(name) os.unlink(name)
def open_url(url, wait=False, locate=False): def open_url(url: str, wait: bool = False, locate: bool = False) -> int:
import subprocess import subprocess
def _unquote_file(url): def _unquote_file(url: str) -> str:
try: from urllib.parse import unquote
import urllib
except ImportError: if url.startswith("file://"):
import urllib url = unquote(url[7:])
if url.startswith('file://'):
url = urllib.unquote(url[7:])
return url return url
if sys.platform == 'darwin': if sys.platform == "darwin":
args = ['open'] args = ["open"]
if wait: if wait:
args.append('-W') args.append("-W")
if locate: if locate:
args.append('-R') args.append("-R")
args.append(_unquote_file(url)) args.append(_unquote_file(url))
null = open('/dev/null', 'w') null = open("/dev/null", "w")
try: try:
return subprocess.Popen(args, stderr=null).wait() return subprocess.Popen(args, stderr=null).wait()
finally: finally:
null.close() null.close()
elif WIN: elif WIN:
if locate: if locate:
url = _unquote_file(url) url = _unquote_file(url.replace('"', ""))
args = 'explorer /select,"%s"' % _unquote_file( args = f'explorer /select,"{url}"'
url.replace('"', ''))
else: else:
args = 'start %s "" "%s"' % ( url = url.replace('"', "")
wait and '/WAIT' or '', url.replace('"', '')) wait_str = "/WAIT" if wait else ""
args = f'start {wait_str} "" "{url}"'
return os.system(args) return os.system(args)
elif CYGWIN: elif CYGWIN:
if locate: if locate:
url = _unquote_file(url) url = os.path.dirname(_unquote_file(url).replace('"', ""))
args = 'cygstart "%s"' % (os.path.dirname(url).replace('"', '')) args = f'cygstart "{url}"'
else: else:
args = 'cygstart %s "%s"' % ( url = url.replace('"', "")
wait and '-w' or '', url.replace('"', '')) wait_str = "-w" if wait else ""
args = f'cygstart {wait_str} "{url}"'
return os.system(args) return os.system(args)
try: try:
if locate: if locate:
url = os.path.dirname(_unquote_file(url)) or '.' url = os.path.dirname(_unquote_file(url)) or "."
else: else:
url = _unquote_file(url) url = _unquote_file(url)
c = subprocess.Popen(['xdg-open', url]) c = subprocess.Popen(["xdg-open", url])
if wait: if wait:
return c.wait() return c.wait()
return 0 return 0
except OSError: except OSError:
if url.startswith(('http://', 'https://')) and not locate and not wait: if url.startswith(("http://", "https://")) and not locate and not wait:
import webbrowser import webbrowser
webbrowser.open(url) webbrowser.open(url)
return 0 return 0
return 1 return 1
def _translate_ch_to_exc(ch): def _translate_ch_to_exc(ch: str) -> t.Optional[BaseException]:
if ch == u'\x03': if ch == "\x03":
raise KeyboardInterrupt() raise KeyboardInterrupt()
if ch == u'\x04' and not WIN: # Unix-like, Ctrl+D
if ch == "\x04" and not WIN: # Unix-like, Ctrl+D
raise EOFError() raise EOFError()
if ch == u'\x1a' and WIN: # Windows, Ctrl+Z
if ch == "\x1a" and WIN: # Windows, Ctrl+Z
raise EOFError() raise EOFError()
return None
if WIN: if WIN:
import msvcrt import msvcrt
@contextlib.contextmanager @contextlib.contextmanager
def raw_terminal(): def raw_terminal() -> t.Iterator[int]:
yield yield -1
def getchar(echo): def getchar(echo: bool) -> str:
# The function `getch` will return a bytes object corresponding to # The function `getch` will return a bytes object corresponding to
# the pressed character. Since Windows 10 build 1803, it will also # the pressed character. Since Windows 10 build 1803, it will also
# return \x00 when called a second time after pressing a regular key. # return \x00 when called a second time after pressing a regular key.
@ -574,48 +658,60 @@ if WIN:
# #
# Anyway, Click doesn't claim to do this Right(tm), and using `getwch` # Anyway, Click doesn't claim to do this Right(tm), and using `getwch`
# is doing the right thing in more situations than with `getch`. # is doing the right thing in more situations than with `getch`.
func: t.Callable[[], str]
if echo: if echo:
func = msvcrt.getwche func = msvcrt.getwche # type: ignore
else: else:
func = msvcrt.getwch func = msvcrt.getwch # type: ignore
rv = func() rv = func()
if rv in (u'\x00', u'\xe0'):
if rv in ("\x00", "\xe0"):
# \x00 and \xe0 are control characters that indicate special key, # \x00 and \xe0 are control characters that indicate special key,
# see above. # see above.
rv += func() rv += func()
_translate_ch_to_exc(rv) _translate_ch_to_exc(rv)
return rv return rv
else: else:
import tty import tty
import termios import termios
@contextlib.contextmanager @contextlib.contextmanager
def raw_terminal(): def raw_terminal() -> t.Iterator[int]:
f: t.Optional[t.TextIO]
fd: int
if not isatty(sys.stdin): if not isatty(sys.stdin):
f = open('/dev/tty') f = open("/dev/tty")
fd = f.fileno() fd = f.fileno()
else: else:
fd = sys.stdin.fileno() fd = sys.stdin.fileno()
f = None f = None
try: try:
old_settings = termios.tcgetattr(fd) old_settings = termios.tcgetattr(fd)
try: try:
tty.setraw(fd) tty.setraw(fd)
yield fd yield fd
finally: finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
sys.stdout.flush() sys.stdout.flush()
if f is not None: if f is not None:
f.close() f.close()
except termios.error: except termios.error:
pass pass
def getchar(echo): def getchar(echo: bool) -> str:
with raw_terminal() as fd: with raw_terminal() as fd:
ch = os.read(fd, 32) ch = os.read(fd, 32).decode(get_best_encoding(sys.stdin), "replace")
ch = ch.decode(get_best_encoding(sys.stdin), 'replace')
if echo and isatty(sys.stdout): if echo and isatty(sys.stdout):
sys.stdout.write(ch) sys.stdout.write(ch)
_translate_ch_to_exc(ch) _translate_ch_to_exc(ch)
return ch return ch

View file

@ -1,10 +1,16 @@
import textwrap import textwrap
import typing as t
from contextlib import contextmanager from contextlib import contextmanager
class TextWrapper(textwrap.TextWrapper): class TextWrapper(textwrap.TextWrapper):
def _handle_long_word(
def _handle_long_word(self, reversed_chunks, cur_line, cur_len, width): self,
reversed_chunks: t.List[str],
cur_line: t.List[str],
cur_len: int,
width: int,
) -> None:
space_left = max(width - cur_len, 1) space_left = max(width - cur_len, 1)
if self.break_long_words: if self.break_long_words:
@ -17,22 +23,27 @@ class TextWrapper(textwrap.TextWrapper):
cur_line.append(reversed_chunks.pop()) cur_line.append(reversed_chunks.pop())
@contextmanager @contextmanager
def extra_indent(self, indent): def extra_indent(self, indent: str) -> t.Iterator[None]:
old_initial_indent = self.initial_indent old_initial_indent = self.initial_indent
old_subsequent_indent = self.subsequent_indent old_subsequent_indent = self.subsequent_indent
self.initial_indent += indent self.initial_indent += indent
self.subsequent_indent += indent self.subsequent_indent += indent
try: try:
yield yield
finally: finally:
self.initial_indent = old_initial_indent self.initial_indent = old_initial_indent
self.subsequent_indent = old_subsequent_indent self.subsequent_indent = old_subsequent_indent
def indent_only(self, text): def indent_only(self, text: str) -> str:
rv = [] rv = []
for idx, line in enumerate(text.splitlines()): for idx, line in enumerate(text.splitlines()):
indent = self.initial_indent indent = self.initial_indent
if idx > 0: if idx > 0:
indent = self.subsequent_indent indent = self.subsequent_indent
rv.append(indent + line)
return '\n'.join(rv) rv.append(f"{indent}{line}")
return "\n".join(rv)

View file

@ -1,125 +0,0 @@
import os
import sys
import codecs
from ._compat import PY2
# If someone wants to vendor click, we want to ensure the
# correct package is discovered. Ideally we could use a
# relative import here but unfortunately Python does not
# support that.
click = sys.modules[__name__.rsplit('.', 1)[0]]
def _find_unicode_literals_frame():
import __future__
if not hasattr(sys, '_getframe'): # not all Python implementations have it
return 0
frm = sys._getframe(1)
idx = 1
while frm is not None:
if frm.f_globals.get('__name__', '').startswith('click.'):
frm = frm.f_back
idx += 1
elif frm.f_code.co_flags & __future__.unicode_literals.compiler_flag:
return idx
else:
break
return 0
def _check_for_unicode_literals():
if not __debug__:
return
if not PY2 or click.disable_unicode_literals_warning:
return
bad_frame = _find_unicode_literals_frame()
if bad_frame <= 0:
return
from warnings import warn
warn(Warning('Click detected the use of the unicode_literals '
'__future__ import. This is heavily discouraged '
'because it can introduce subtle bugs in your '
'code. You should instead use explicit u"" literals '
'for your unicode strings. For more information see '
'https://click.palletsprojects.com/python3/'),
stacklevel=bad_frame)
def _verify_python3_env():
"""Ensures that the environment is good for unicode on Python 3."""
if PY2:
return
try:
import locale
fs_enc = codecs.lookup(locale.getpreferredencoding()).name
except Exception:
fs_enc = 'ascii'
if fs_enc != 'ascii':
return
extra = ''
if os.name == 'posix':
import subprocess
try:
rv = subprocess.Popen(['locale', '-a'], stdout=subprocess.PIPE,
stderr=subprocess.PIPE).communicate()[0]
except OSError:
rv = b''
good_locales = set()
has_c_utf8 = False
# Make sure we're operating on text here.
if isinstance(rv, bytes):
rv = rv.decode('ascii', 'replace')
for line in rv.splitlines():
locale = line.strip()
if locale.lower().endswith(('.utf-8', '.utf8')):
good_locales.add(locale)
if locale.lower() in ('c.utf8', 'c.utf-8'):
has_c_utf8 = True
extra += '\n\n'
if not good_locales:
extra += (
'Additional information: on this system no suitable UTF-8\n'
'locales were discovered. This most likely requires resolving\n'
'by reconfiguring the locale system.'
)
elif has_c_utf8:
extra += (
'This system supports the C.UTF-8 locale which is recommended.\n'
'You might be able to resolve your issue by exporting the\n'
'following environment variables:\n\n'
' export LC_ALL=C.UTF-8\n'
' export LANG=C.UTF-8'
)
else:
extra += (
'This system lists a couple of UTF-8 supporting locales that\n'
'you can pick from. The following suitable locales were\n'
'discovered: %s'
) % ', '.join(sorted(good_locales))
bad_locale = None
for locale in os.environ.get('LC_ALL'), os.environ.get('LANG'):
if locale and locale.lower().endswith(('.utf-8', '.utf8')):
bad_locale = locale
if locale is not None:
break
if bad_locale is not None:
extra += (
'\n\nClick discovered that you exported a UTF-8 locale\n'
'but the locale system could not pick up from it because\n'
'it does not exist. The exported locale is "%s" but it\n'
'is not supported'
) % bad_locale
raise RuntimeError(
'Click will abort further execution because Python 3 was'
' configured to use ASCII as encoding for the environment.'
' Consult https://click.palletsprojects.com/en/7.x/python3/ for'
' mitigation steps.' + extra
)

View file

@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# This module is based on the excellent work by Adam Bartoš who # This module is based on the excellent work by Adam Bartoš who
# provided a lot of what went into the implementation here in # provided a lot of what went into the implementation here in
# the discussion to issue1602 in the Python bug tracker. # the discussion to issue1602 in the Python bug tracker.
@ -6,26 +5,32 @@
# There are some general differences in regards to how this works # There are some general differences in regards to how this works
# compared to the original patches as we do not need to patch # compared to the original patches as we do not need to patch
# the entire interpreter but just work in our little world of # the entire interpreter but just work in our little world of
# echo and prmopt. # echo and prompt.
import io import io
import os
import sys import sys
import zlib
import time import time
import ctypes import typing as t
import msvcrt from ctypes import byref
from ._compat import _NonClosingTextIOWrapper, text_type, PY2 from ctypes import c_char
from ctypes import byref, POINTER, c_int, c_char, c_char_p, \ from ctypes import c_char_p
c_void_p, py_object, c_ssize_t, c_ulong, windll, WINFUNCTYPE from ctypes import c_int
try: from ctypes import c_ssize_t
from ctypes import pythonapi from ctypes import c_ulong
PyObject_GetBuffer = pythonapi.PyObject_GetBuffer from ctypes import c_void_p
PyBuffer_Release = pythonapi.PyBuffer_Release from ctypes import POINTER
except ImportError: from ctypes import py_object
pythonapi = None from ctypes import Structure
from ctypes.wintypes import LPWSTR, LPCWSTR from ctypes.wintypes import DWORD
from ctypes.wintypes import HANDLE
from ctypes.wintypes import LPCWSTR
from ctypes.wintypes import LPWSTR
from ._compat import _NonClosingTextIOWrapper
assert sys.platform == "win32"
import msvcrt # noqa: E402
from ctypes import windll # noqa: E402
from ctypes import WINFUNCTYPE # noqa: E402
c_ssize_p = POINTER(c_ssize_t) c_ssize_p = POINTER(c_ssize_t)
@ -33,19 +38,18 @@ kernel32 = windll.kernel32
GetStdHandle = kernel32.GetStdHandle GetStdHandle = kernel32.GetStdHandle
ReadConsoleW = kernel32.ReadConsoleW ReadConsoleW = kernel32.ReadConsoleW
WriteConsoleW = kernel32.WriteConsoleW WriteConsoleW = kernel32.WriteConsoleW
GetConsoleMode = kernel32.GetConsoleMode
GetLastError = kernel32.GetLastError GetLastError = kernel32.GetLastError
GetCommandLineW = WINFUNCTYPE(LPWSTR)( GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32))
('GetCommandLineW', windll.kernel32)) CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
CommandLineToArgvW = WINFUNCTYPE( ("CommandLineToArgvW", windll.shell32)
POINTER(LPWSTR), LPCWSTR, POINTER(c_int))( )
('CommandLineToArgvW', windll.shell32)) LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32))
STDIN_HANDLE = GetStdHandle(-10) STDIN_HANDLE = GetStdHandle(-10)
STDOUT_HANDLE = GetStdHandle(-11) STDOUT_HANDLE = GetStdHandle(-11)
STDERR_HANDLE = GetStdHandle(-12) STDERR_HANDLE = GetStdHandle(-12)
PyBUF_SIMPLE = 0 PyBUF_SIMPLE = 0
PyBUF_WRITABLE = 1 PyBUF_WRITABLE = 1
@ -57,38 +61,40 @@ STDIN_FILENO = 0
STDOUT_FILENO = 1 STDOUT_FILENO = 1
STDERR_FILENO = 2 STDERR_FILENO = 2
EOF = b'\x1a' EOF = b"\x1a"
MAX_BYTES_WRITTEN = 32767 MAX_BYTES_WRITTEN = 32767
try:
class Py_buffer(ctypes.Structure): from ctypes import pythonapi
_fields_ = [ except ImportError:
('buf', c_void_p), # On PyPy we cannot get buffers so our ability to operate here is
('obj', py_object), # severely limited.
('len', c_ssize_t),
('itemsize', c_ssize_t),
('readonly', c_int),
('ndim', c_int),
('format', c_char_p),
('shape', c_ssize_p),
('strides', c_ssize_p),
('suboffsets', c_ssize_p),
('internal', c_void_p)
]
if PY2:
_fields_.insert(-1, ('smalltable', c_ssize_t * 2))
# On PyPy we cannot get buffers so our ability to operate here is
# serverly limited.
if pythonapi is None:
get_buffer = None get_buffer = None
else: else:
class Py_buffer(Structure):
_fields_ = [
("buf", c_void_p),
("obj", py_object),
("len", c_ssize_t),
("itemsize", c_ssize_t),
("readonly", c_int),
("ndim", c_int),
("format", c_char_p),
("shape", c_ssize_p),
("strides", c_ssize_p),
("suboffsets", c_ssize_p),
("internal", c_void_p),
]
PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
PyBuffer_Release = pythonapi.PyBuffer_Release
def get_buffer(obj, writable=False): def get_buffer(obj, writable=False):
buf = Py_buffer() buf = Py_buffer()
flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
PyObject_GetBuffer(py_object(obj), byref(buf), flags) PyObject_GetBuffer(py_object(obj), byref(buf), flags)
try: try:
buffer_type = c_char * buf.len buffer_type = c_char * buf.len
return buffer_type.from_address(buf.buf) return buffer_type.from_address(buf.buf)
@ -97,17 +103,15 @@ else:
class _WindowsConsoleRawIOBase(io.RawIOBase): class _WindowsConsoleRawIOBase(io.RawIOBase):
def __init__(self, handle): def __init__(self, handle):
self.handle = handle self.handle = handle
def isatty(self): def isatty(self):
io.RawIOBase.isatty(self) super().isatty()
return True return True
class _WindowsConsoleReader(_WindowsConsoleRawIOBase): class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
def readable(self): def readable(self):
return True return True
@ -116,20 +120,26 @@ class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
if not bytes_to_be_read: if not bytes_to_be_read:
return 0 return 0
elif bytes_to_be_read % 2: elif bytes_to_be_read % 2:
raise ValueError('cannot read odd number of bytes from ' raise ValueError(
'UTF-16-LE encoded console') "cannot read odd number of bytes from UTF-16-LE encoded console"
)
buffer = get_buffer(b, writable=True) buffer = get_buffer(b, writable=True)
code_units_to_be_read = bytes_to_be_read // 2 code_units_to_be_read = bytes_to_be_read // 2
code_units_read = c_ulong() code_units_read = c_ulong()
rv = ReadConsoleW(self.handle, buffer, code_units_to_be_read, rv = ReadConsoleW(
byref(code_units_read), None) HANDLE(self.handle),
buffer,
code_units_to_be_read,
byref(code_units_read),
None,
)
if GetLastError() == ERROR_OPERATION_ABORTED: if GetLastError() == ERROR_OPERATION_ABORTED:
# wait for KeyboardInterrupt # wait for KeyboardInterrupt
time.sleep(0.1) time.sleep(0.1)
if not rv: if not rv:
raise OSError('Windows error: %s' % GetLastError()) raise OSError(f"Windows error: {GetLastError()}")
if buffer[0] == EOF: if buffer[0] == EOF:
return 0 return 0
@ -137,27 +147,30 @@ class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
class _WindowsConsoleWriter(_WindowsConsoleRawIOBase): class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
def writable(self): def writable(self):
return True return True
@staticmethod @staticmethod
def _get_error_message(errno): def _get_error_message(errno):
if errno == ERROR_SUCCESS: if errno == ERROR_SUCCESS:
return 'ERROR_SUCCESS' return "ERROR_SUCCESS"
elif errno == ERROR_NOT_ENOUGH_MEMORY: elif errno == ERROR_NOT_ENOUGH_MEMORY:
return 'ERROR_NOT_ENOUGH_MEMORY' return "ERROR_NOT_ENOUGH_MEMORY"
return 'Windows error %s' % errno return f"Windows error {errno}"
def write(self, b): def write(self, b):
bytes_to_be_written = len(b) bytes_to_be_written = len(b)
buf = get_buffer(b) buf = get_buffer(b)
code_units_to_be_written = min(bytes_to_be_written, code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
MAX_BYTES_WRITTEN) // 2
code_units_written = c_ulong() code_units_written = c_ulong()
WriteConsoleW(self.handle, buf, code_units_to_be_written, WriteConsoleW(
byref(code_units_written), None) HANDLE(self.handle),
buf,
code_units_to_be_written,
byref(code_units_written),
None,
)
bytes_written = 2 * code_units_written.value bytes_written = 2 * code_units_written.value
if bytes_written == 0 and bytes_to_be_written > 0: if bytes_written == 0 and bytes_to_be_written > 0:
@ -165,18 +178,17 @@ class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
return bytes_written return bytes_written
class ConsoleStream(object): class ConsoleStream:
def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None:
def __init__(self, text_stream, byte_stream):
self._text_stream = text_stream self._text_stream = text_stream
self.buffer = byte_stream self.buffer = byte_stream
@property @property
def name(self): def name(self) -> str:
return self.buffer.name return self.buffer.name
def write(self, x): def write(self, x: t.AnyStr) -> int:
if isinstance(x, text_type): if isinstance(x, str):
return self._text_stream.write(x) return self._text_stream.write(x)
try: try:
self.flush() self.flush()
@ -184,124 +196,84 @@ class ConsoleStream(object):
pass pass
return self.buffer.write(x) return self.buffer.write(x)
def writelines(self, lines): def writelines(self, lines: t.Iterable[t.AnyStr]) -> None:
for line in lines: for line in lines:
self.write(line) self.write(line)
def __getattr__(self, name): def __getattr__(self, name: str) -> t.Any:
return getattr(self._text_stream, name) return getattr(self._text_stream, name)
def isatty(self): def isatty(self) -> bool:
return self.buffer.isatty() return self.buffer.isatty()
def __repr__(self): def __repr__(self):
return '<ConsoleStream name=%r encoding=%r>' % ( return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>"
self.name,
self.encoding,
)
class WindowsChunkedWriter(object): def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO:
"""
Wraps a stream (such as stdout), acting as a transparent proxy for all
attribute access apart from method 'write()' which we wrap to write in
limited chunks due to a Windows limitation on binary console streams.
"""
def __init__(self, wrapped):
# double-underscore everything to prevent clashes with names of
# attributes on the wrapped stream object.
self.__wrapped = wrapped
def __getattr__(self, name):
return getattr(self.__wrapped, name)
def write(self, text):
total_to_write = len(text)
written = 0
while written < total_to_write:
to_write = min(total_to_write - written, MAX_BYTES_WRITTEN)
self.__wrapped.write(text[written:written+to_write])
written += to_write
_wrapped_std_streams = set()
def _wrap_std_stream(name):
# Python 2 & Windows 7 and below
if PY2 and sys.getwindowsversion()[:2] <= (6, 1) and name not in _wrapped_std_streams:
setattr(sys, name, WindowsChunkedWriter(getattr(sys, name)))
_wrapped_std_streams.add(name)
def _get_text_stdin(buffer_stream):
text_stream = _NonClosingTextIOWrapper( text_stream = _NonClosingTextIOWrapper(
io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)), io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
'utf-16-le', 'strict', line_buffering=True) "utf-16-le",
return ConsoleStream(text_stream, buffer_stream) "strict",
line_buffering=True,
)
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
def _get_text_stdout(buffer_stream): def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO:
text_stream = _NonClosingTextIOWrapper( text_stream = _NonClosingTextIOWrapper(
io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)), io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
'utf-16-le', 'strict', line_buffering=True) "utf-16-le",
return ConsoleStream(text_stream, buffer_stream) "strict",
line_buffering=True,
)
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
def _get_text_stderr(buffer_stream): def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO:
text_stream = _NonClosingTextIOWrapper( text_stream = _NonClosingTextIOWrapper(
io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)), io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
'utf-16-le', 'strict', line_buffering=True) "utf-16-le",
return ConsoleStream(text_stream, buffer_stream) "strict",
line_buffering=True,
)
return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
if PY2: _stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = {
def _hash_py_argv():
return zlib.crc32('\x00'.join(sys.argv[1:]))
_initial_argv_hash = _hash_py_argv()
def _get_windows_argv():
argc = c_int(0)
argv_unicode = CommandLineToArgvW(GetCommandLineW(), byref(argc))
argv = [argv_unicode[i] for i in range(0, argc.value)]
if not hasattr(sys, 'frozen'):
argv = argv[1:]
while len(argv) > 0:
arg = argv[0]
if not arg.startswith('-') or arg == '-':
break
argv = argv[1:]
if arg.startswith(('-c', '-m')):
break
return argv[1:]
_stream_factories = {
0: _get_text_stdin, 0: _get_text_stdin,
1: _get_text_stdout, 1: _get_text_stdout,
2: _get_text_stderr, 2: _get_text_stderr,
} }
def _get_windows_console_stream(f, encoding, errors): def _is_console(f: t.TextIO) -> bool:
if get_buffer is not None and \ if not hasattr(f, "fileno"):
encoding in ('utf-16-le', None) \ return False
and errors in ('strict', None) and \
hasattr(f, 'isatty') and f.isatty(): try:
fileno = f.fileno()
except (OSError, io.UnsupportedOperation):
return False
handle = msvcrt.get_osfhandle(fileno)
return bool(GetConsoleMode(handle, byref(DWORD())))
def _get_windows_console_stream(
f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str]
) -> t.Optional[t.TextIO]:
if (
get_buffer is not None
and encoding in {"utf-16-le", None}
and errors in {"strict", None}
and _is_console(f)
):
func = _stream_factories.get(f.fileno()) func = _stream_factories.get(f.fileno())
if func is not None: if func is not None:
if not PY2: b = getattr(f, "buffer", None)
f = getattr(f, 'buffer', None)
if f is None: if b is None:
return None return None
else:
# If we are on Python 2 we need to set the stream that we return func(b)
# deal with to binary mode as otherwise the exercise if a
# bit moot. The same problems apply as for
# get_binary_stdin and friends from _compat.
msvcrt.setmode(f.fileno(), os.O_BINARY)
return func(f)

File diff suppressed because it is too large Load diff

View file

@ -1,34 +1,48 @@
import sys
import inspect import inspect
import types
import typing as t
from functools import update_wrapper from functools import update_wrapper
from gettext import gettext as _
from ._compat import iteritems from .core import Argument
from ._unicodefun import _check_for_unicode_literals from .core import Command
from .utils import echo from .core import Context
from .core import Group
from .core import Option
from .core import Parameter
from .globals import get_current_context from .globals import get_current_context
from .utils import echo
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
FC = t.TypeVar("FC", bound=t.Union[t.Callable[..., t.Any], Command])
def pass_context(f): def pass_context(f: F) -> F:
"""Marks a callback as wanting to receive the current context """Marks a callback as wanting to receive the current context
object as first argument. object as first argument.
""" """
def new_func(*args, **kwargs):
def new_func(*args, **kwargs): # type: ignore
return f(get_current_context(), *args, **kwargs) return f(get_current_context(), *args, **kwargs)
return update_wrapper(new_func, f)
return update_wrapper(t.cast(F, new_func), f)
def pass_obj(f): def pass_obj(f: F) -> F:
"""Similar to :func:`pass_context`, but only pass the object on the """Similar to :func:`pass_context`, but only pass the object on the
context onwards (:attr:`Context.obj`). This is useful if that object context onwards (:attr:`Context.obj`). This is useful if that object
represents the state of a nested system. represents the state of a nested system.
""" """
def new_func(*args, **kwargs):
def new_func(*args, **kwargs): # type: ignore
return f(get_current_context().obj, *args, **kwargs) return f(get_current_context().obj, *args, **kwargs)
return update_wrapper(new_func, f)
return update_wrapper(t.cast(F, new_func), f)
def make_pass_decorator(object_type, ensure=False): def make_pass_decorator(
object_type: t.Type, ensure: bool = False
) -> "t.Callable[[F], F]":
"""Given an object type this creates a decorator that will work """Given an object type this creates a decorator that will work
similar to :func:`pass_obj` but instead of passing the object of the similar to :func:`pass_obj` but instead of passing the object of the
current context, it will find the innermost context of type current context, it will find the innermost context of type
@ -50,55 +64,106 @@ def make_pass_decorator(object_type, ensure=False):
:param ensure: if set to `True`, a new object will be created and :param ensure: if set to `True`, a new object will be created and
remembered on the context if it's not there yet. remembered on the context if it's not there yet.
""" """
def decorator(f):
def new_func(*args, **kwargs): def decorator(f: F) -> F:
def new_func(*args, **kwargs): # type: ignore
ctx = get_current_context() ctx = get_current_context()
if ensure: if ensure:
obj = ctx.ensure_object(object_type) obj = ctx.ensure_object(object_type)
else: else:
obj = ctx.find_object(object_type) obj = ctx.find_object(object_type)
if obj is None: if obj is None:
raise RuntimeError('Managed to invoke callback without a ' raise RuntimeError(
'context object of type %r existing' "Managed to invoke callback without a context"
% object_type.__name__) f" object of type {object_type.__name__!r}"
" existing."
)
return ctx.invoke(f, obj, *args, **kwargs) return ctx.invoke(f, obj, *args, **kwargs)
return update_wrapper(new_func, f)
return update_wrapper(t.cast(F, new_func), f)
return decorator return decorator
def _make_command(f, name, attrs, cls): def pass_meta_key(
if isinstance(f, Command): key: str, *, doc_description: t.Optional[str] = None
raise TypeError('Attempted to convert a callback into a ' ) -> "t.Callable[[F], F]":
'command twice.') """Create a decorator that passes a key from
try: :attr:`click.Context.meta` as the first argument to the decorated
params = f.__click_params__ function.
params.reverse()
del f.__click_params__ :param key: Key in ``Context.meta`` to pass.
except AttributeError: :param doc_description: Description of the object being passed,
params = [] inserted into the decorator's docstring. Defaults to "the 'key'
help = attrs.get('help') key from Context.meta".
if help is None:
help = inspect.getdoc(f) .. versionadded:: 8.0
if isinstance(help, bytes): """
help = help.decode('utf-8')
else: def decorator(f: F) -> F:
help = inspect.cleandoc(help) def new_func(*args, **kwargs): # type: ignore
attrs['help'] = help ctx = get_current_context()
_check_for_unicode_literals() obj = ctx.meta[key]
return cls(name=name or f.__name__.lower().replace('_', '-'), return ctx.invoke(f, obj, *args, **kwargs)
callback=f, params=params, **attrs)
return update_wrapper(t.cast(F, new_func), f)
if doc_description is None:
doc_description = f"the {key!r} key from :attr:`click.Context.meta`"
decorator.__doc__ = (
f"Decorator that passes {doc_description} as the first argument"
" to the decorated function."
)
return decorator
def command(name=None, cls=None, **attrs): CmdType = t.TypeVar("CmdType", bound=Command)
@t.overload
def command(
__func: t.Callable[..., t.Any],
) -> Command:
...
@t.overload
def command(
name: t.Optional[str] = None,
**attrs: t.Any,
) -> t.Callable[..., Command]:
...
@t.overload
def command(
name: t.Optional[str] = None,
cls: t.Type[CmdType] = ...,
**attrs: t.Any,
) -> t.Callable[..., CmdType]:
...
def command(
name: t.Union[str, t.Callable[..., t.Any], None] = None,
cls: t.Optional[t.Type[Command]] = None,
**attrs: t.Any,
) -> t.Union[Command, t.Callable[..., Command]]:
r"""Creates a new :class:`Command` and uses the decorated function as r"""Creates a new :class:`Command` and uses the decorated function as
callback. This will also automatically attach all decorated callback. This will also automatically attach all decorated
:func:`option`\s and :func:`argument`\s as parameters to the command. :func:`option`\s and :func:`argument`\s as parameters to the command.
The name of the command defaults to the name of the function. If you The name of the command defaults to the name of the function with
want to change that, you can pass the intended name as the first underscores replaced by dashes. If you want to change that, you can
argument. pass the intended name as the first argument.
All keyword arguments are forwarded to the underlying command class. All keyword arguments are forwarded to the underlying command class.
For the ``params`` argument, any decorated params are appended to
the end of the list.
Once decorated the function turns into a :class:`Command` instance Once decorated the function turns into a :class:`Command` instance
that can be invoked as a command line utility or be attached to a that can be invoked as a command line utility or be attached to a
@ -108,35 +173,105 @@ def command(name=None, cls=None, **attrs):
name with underscores replaced by dashes. name with underscores replaced by dashes.
:param cls: the command class to instantiate. This defaults to :param cls: the command class to instantiate. This defaults to
:class:`Command`. :class:`Command`.
.. versionchanged:: 8.1
This decorator can be applied without parentheses.
.. versionchanged:: 8.1
The ``params`` argument can be used. Decorated params are
appended to the end of the list.
""" """
func: t.Optional[t.Callable[..., t.Any]] = None
if callable(name):
func = name
name = None
assert cls is None, "Use 'command(cls=cls)(callable)' to specify a class."
assert not attrs, "Use 'command(**kwargs)(callable)' to provide arguments."
if cls is None: if cls is None:
cls = Command cls = Command
def decorator(f):
cmd = _make_command(f, name, attrs, cls) def decorator(f: t.Callable[..., t.Any]) -> Command:
if isinstance(f, Command):
raise TypeError("Attempted to convert a callback into a command twice.")
attr_params = attrs.pop("params", None)
params = attr_params if attr_params is not None else []
try:
decorator_params = f.__click_params__ # type: ignore
except AttributeError:
pass
else:
del f.__click_params__ # type: ignore
params.extend(reversed(decorator_params))
if attrs.get("help") is None:
attrs["help"] = f.__doc__
cmd = cls( # type: ignore[misc]
name=name or f.__name__.lower().replace("_", "-"), # type: ignore[arg-type]
callback=f,
params=params,
**attrs,
)
cmd.__doc__ = f.__doc__ cmd.__doc__ = f.__doc__
return cmd return cmd
if func is not None:
return decorator(func)
return decorator return decorator
def group(name=None, **attrs): @t.overload
def group(
__func: t.Callable[..., t.Any],
) -> Group:
...
@t.overload
def group(
name: t.Optional[str] = None,
**attrs: t.Any,
) -> t.Callable[[F], Group]:
...
def group(
name: t.Union[str, t.Callable[..., t.Any], None] = None, **attrs: t.Any
) -> t.Union[Group, t.Callable[[F], Group]]:
"""Creates a new :class:`Group` with a function as callback. This """Creates a new :class:`Group` with a function as callback. This
works otherwise the same as :func:`command` just that the `cls` works otherwise the same as :func:`command` just that the `cls`
parameter is set to :class:`Group`. parameter is set to :class:`Group`.
.. versionchanged:: 8.1
This decorator can be applied without parentheses.
""" """
attrs.setdefault('cls', Group) if attrs.get("cls") is None:
return command(name, **attrs) attrs["cls"] = Group
if callable(name):
grp: t.Callable[[F], Group] = t.cast(Group, command(**attrs))
return grp(name)
return t.cast(Group, command(name, **attrs))
def _param_memo(f, param): def _param_memo(f: FC, param: Parameter) -> None:
if isinstance(f, Command): if isinstance(f, Command):
f.params.append(param) f.params.append(param)
else: else:
if not hasattr(f, '__click_params__'): if not hasattr(f, "__click_params__"):
f.__click_params__ = [] f.__click_params__ = [] # type: ignore
f.__click_params__.append(param)
f.__click_params__.append(param) # type: ignore
def argument(*param_decls, **attrs): def argument(*param_decls: str, **attrs: t.Any) -> t.Callable[[FC], FC]:
"""Attaches an argument to the command. All positional arguments are """Attaches an argument to the command. All positional arguments are
passed as parameter declarations to :class:`Argument`; all keyword passed as parameter declarations to :class:`Argument`; all keyword
arguments are forwarded unchanged (except ``cls``). arguments are forwarded unchanged (except ``cls``).
@ -146,14 +281,16 @@ def argument(*param_decls, **attrs):
:param cls: the argument class to instantiate. This defaults to :param cls: the argument class to instantiate. This defaults to
:class:`Argument`. :class:`Argument`.
""" """
def decorator(f):
ArgumentClass = attrs.pop('cls', Argument) def decorator(f: FC) -> FC:
ArgumentClass = attrs.pop("cls", None) or Argument
_param_memo(f, ArgumentClass(param_decls, **attrs)) _param_memo(f, ArgumentClass(param_decls, **attrs))
return f return f
return decorator return decorator
def option(*param_decls, **attrs): def option(*param_decls: str, **attrs: t.Any) -> t.Callable[[FC], FC]:
"""Attaches an option to the command. All positional arguments are """Attaches an option to the command. All positional arguments are
passed as parameter declarations to :class:`Option`; all keyword passed as parameter declarations to :class:`Option`; all keyword
arguments are forwarded unchanged (except ``cls``). arguments are forwarded unchanged (except ``cls``).
@ -163,149 +300,198 @@ def option(*param_decls, **attrs):
:param cls: the option class to instantiate. This defaults to :param cls: the option class to instantiate. This defaults to
:class:`Option`. :class:`Option`.
""" """
def decorator(f):
def decorator(f: FC) -> FC:
# Issue 926, copy attrs, so pre-defined options can re-use the same cls= # Issue 926, copy attrs, so pre-defined options can re-use the same cls=
option_attrs = attrs.copy() option_attrs = attrs.copy()
OptionClass = option_attrs.pop("cls", None) or Option
if 'help' in option_attrs:
option_attrs['help'] = inspect.cleandoc(option_attrs['help'])
OptionClass = option_attrs.pop('cls', Option)
_param_memo(f, OptionClass(param_decls, **option_attrs)) _param_memo(f, OptionClass(param_decls, **option_attrs))
return f return f
return decorator return decorator
def confirmation_option(*param_decls, **attrs): def confirmation_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
"""Shortcut for confirmation prompts that can be ignored by passing """Add a ``--yes`` option which shows a prompt before continuing if
``--yes`` as parameter. not passed. If the prompt is declined, the program will exit.
This is equivalent to decorating a function with :func:`option` with :param param_decls: One or more option names. Defaults to the single
the following parameters:: value ``"--yes"``.
:param kwargs: Extra arguments are passed to :func:`option`.
def callback(ctx, param, value):
if not value:
ctx.abort()
@click.command()
@click.option('--yes', is_flag=True, callback=callback,
expose_value=False, prompt='Do you want to continue?')
def dropdb():
pass
""" """
def decorator(f):
def callback(ctx, param, value): def callback(ctx: Context, param: Parameter, value: bool) -> None:
if not value: if not value:
ctx.abort() ctx.abort()
attrs.setdefault('is_flag', True)
attrs.setdefault('callback', callback) if not param_decls:
attrs.setdefault('expose_value', False) param_decls = ("--yes",)
attrs.setdefault('prompt', 'Do you want to continue?')
attrs.setdefault('help', 'Confirm the action without prompting.') kwargs.setdefault("is_flag", True)
return option(*(param_decls or ('--yes',)), **attrs)(f) kwargs.setdefault("callback", callback)
return decorator kwargs.setdefault("expose_value", False)
kwargs.setdefault("prompt", "Do you want to continue?")
kwargs.setdefault("help", "Confirm the action without prompting.")
return option(*param_decls, **kwargs)
def password_option(*param_decls, **attrs): def password_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
"""Shortcut for password prompts. """Add a ``--password`` option which prompts for a password, hiding
input and asking to enter the value again for confirmation.
This is equivalent to decorating a function with :func:`option` with :param param_decls: One or more option names. Defaults to the single
the following parameters:: value ``"--password"``.
:param kwargs: Extra arguments are passed to :func:`option`.
@click.command()
@click.option('--password', prompt=True, confirmation_prompt=True,
hide_input=True)
def changeadmin(password):
pass
""" """
def decorator(f): if not param_decls:
attrs.setdefault('prompt', True) param_decls = ("--password",)
attrs.setdefault('confirmation_prompt', True)
attrs.setdefault('hide_input', True) kwargs.setdefault("prompt", True)
return option(*(param_decls or ('--password',)), **attrs)(f) kwargs.setdefault("confirmation_prompt", True)
return decorator kwargs.setdefault("hide_input", True)
return option(*param_decls, **kwargs)
def version_option(version=None, *param_decls, **attrs): def version_option(
"""Adds a ``--version`` option which immediately ends the program version: t.Optional[str] = None,
printing out the version number. This is implemented as an eager *param_decls: str,
option that prints the version and exits the program in the callback. package_name: t.Optional[str] = None,
prog_name: t.Optional[str] = None,
message: t.Optional[str] = None,
**kwargs: t.Any,
) -> t.Callable[[FC], FC]:
"""Add a ``--version`` option which immediately prints the version
number and exits the program.
:param version: the version number to show. If not provided Click If ``version`` is not provided, Click will try to detect it using
attempts an auto discovery via setuptools. :func:`importlib.metadata.version` to get the version for the
:param prog_name: the name of the program (defaults to autodetection) ``package_name``. On Python < 3.8, the ``importlib_metadata``
:param message: custom message to show instead of the default backport must be installed.
(``'%(prog)s, version %(version)s'``)
:param others: everything else is forwarded to :func:`option`. If ``package_name`` is not provided, Click will try to detect it by
inspecting the stack frames. This will be used to detect the
version, so it must match the name of the installed package.
:param version: The version number to show. If not provided, Click
will try to detect it.
:param param_decls: One or more option names. Defaults to the single
value ``"--version"``.
:param package_name: The package name to detect the version from. If
not provided, Click will try to detect it.
:param prog_name: The name of the CLI to show in the message. If not
provided, it will be detected from the command.
:param message: The message to show. The values ``%(prog)s``,
``%(package)s``, and ``%(version)s`` are available. Defaults to
``"%(prog)s, version %(version)s"``.
:param kwargs: Extra arguments are passed to :func:`option`.
:raise RuntimeError: ``version`` could not be detected.
.. versionchanged:: 8.0
Add the ``package_name`` parameter, and the ``%(package)s``
value for messages.
.. versionchanged:: 8.0
Use :mod:`importlib.metadata` instead of ``pkg_resources``. The
version is detected based on the package name, not the entry
point name. The Python package name must match the installed
package name, or be passed with ``package_name=``.
""" """
if version is None: if message is None:
if hasattr(sys, '_getframe'): message = _("%(prog)s, version %(version)s")
module = sys._getframe(1).f_globals.get('__name__')
else:
module = ''
def decorator(f): if version is None and package_name is None:
prog_name = attrs.pop('prog_name', None) frame = inspect.currentframe()
message = attrs.pop('message', '%(prog)s, version %(version)s') f_back = frame.f_back if frame is not None else None
f_globals = f_back.f_globals if f_back is not None else None
# break reference cycle
# https://docs.python.org/3/library/inspect.html#the-interpreter-stack
del frame
def callback(ctx, param, value): if f_globals is not None:
if not value or ctx.resilient_parsing: package_name = f_globals.get("__name__")
return
prog = prog_name
if prog is None:
prog = ctx.find_root().info_name
ver = version
if ver is None:
try:
import pkg_resources
except ImportError:
pass
else:
for dist in pkg_resources.working_set:
scripts = dist.get_entry_map().get('console_scripts') or {}
for script_name, entry_point in iteritems(scripts):
if entry_point.module_name == module:
ver = dist.version
break
if ver is None:
raise RuntimeError('Could not determine version')
echo(message % {
'prog': prog,
'version': ver,
}, color=ctx.color)
ctx.exit()
attrs.setdefault('is_flag', True) if package_name == "__main__":
attrs.setdefault('expose_value', False) package_name = f_globals.get("__package__")
attrs.setdefault('is_eager', True)
attrs.setdefault('help', 'Show the version and exit.') if package_name:
attrs['callback'] = callback package_name = package_name.partition(".")[0]
return option(*(param_decls or ('--version',)), **attrs)(f)
return decorator def callback(ctx: Context, param: Parameter, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
nonlocal prog_name
nonlocal version
if prog_name is None:
prog_name = ctx.find_root().info_name
if version is None and package_name is not None:
metadata: t.Optional[types.ModuleType]
try:
from importlib import metadata # type: ignore
except ImportError:
# Python < 3.8
import importlib_metadata as metadata # type: ignore
try:
version = metadata.version(package_name) # type: ignore
except metadata.PackageNotFoundError: # type: ignore
raise RuntimeError(
f"{package_name!r} is not installed. Try passing"
" 'package_name' instead."
) from None
if version is None:
raise RuntimeError(
f"Could not determine the version for {package_name!r} automatically."
)
echo(
t.cast(str, message)
% {"prog": prog_name, "package": package_name, "version": version},
color=ctx.color,
)
ctx.exit()
if not param_decls:
param_decls = ("--version",)
kwargs.setdefault("is_flag", True)
kwargs.setdefault("expose_value", False)
kwargs.setdefault("is_eager", True)
kwargs.setdefault("help", _("Show the version and exit."))
kwargs["callback"] = callback
return option(*param_decls, **kwargs)
def help_option(*param_decls, **attrs): def help_option(*param_decls: str, **kwargs: t.Any) -> t.Callable[[FC], FC]:
"""Adds a ``--help`` option which immediately ends the program """Add a ``--help`` option which immediately prints the help page
printing out the help page. This is usually unnecessary to add as and exits the program.
this is added by default to all commands unless suppressed.
Like :func:`version_option`, this is implemented as eager option that This is usually unnecessary, as the ``--help`` option is added to
prints in the callback and exits. each command automatically unless ``add_help_option=False`` is
passed.
All arguments are forwarded to :func:`option`. :param param_decls: One or more option names. Defaults to the single
value ``"--help"``.
:param kwargs: Extra arguments are passed to :func:`option`.
""" """
def decorator(f):
def callback(ctx, param, value):
if value and not ctx.resilient_parsing:
echo(ctx.get_help(), color=ctx.color)
ctx.exit()
attrs.setdefault('is_flag', True)
attrs.setdefault('expose_value', False)
attrs.setdefault('help', 'Show this message and exit.')
attrs.setdefault('is_eager', True)
attrs['callback'] = callback
return option(*(param_decls or ('--help',)), **attrs)(f)
return decorator
def callback(ctx: Context, param: Parameter, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
# Circular dependencies between core and decorators echo(ctx.get_help(), color=ctx.color)
from .core import Command, Group, Argument, Option ctx.exit()
if not param_decls:
param_decls = ("--help",)
kwargs.setdefault("is_flag", True)
kwargs.setdefault("expose_value", False)
kwargs.setdefault("is_eager", True)
kwargs.setdefault("help", _("Show this message and exit."))
kwargs["callback"] = callback
return option(*param_decls, **kwargs)

View file

@ -1,43 +1,46 @@
from ._compat import PY2, filename_to_ui, get_text_stderr import os
import typing as t
from gettext import gettext as _
from gettext import ngettext
from ._compat import get_text_stderr
from .utils import echo from .utils import echo
if t.TYPE_CHECKING:
from .core import Context
from .core import Parameter
def _join_param_hints(
param_hint: t.Optional[t.Union[t.Sequence[str], str]]
) -> t.Optional[str]:
if param_hint is not None and not isinstance(param_hint, str):
return " / ".join(repr(x) for x in param_hint)
def _join_param_hints(param_hint):
if isinstance(param_hint, (tuple, list)):
return ' / '.join('"%s"' % x for x in param_hint)
return param_hint return param_hint
class ClickException(Exception): class ClickException(Exception):
"""An exception that Click can handle and show to the user.""" """An exception that Click can handle and show to the user."""
#: The exit code for this exception #: The exit code for this exception.
exit_code = 1 exit_code = 1
def __init__(self, message): def __init__(self, message: str) -> None:
ctor_msg = message super().__init__(message)
if PY2:
if ctor_msg is not None:
ctor_msg = ctor_msg.encode('utf-8')
Exception.__init__(self, ctor_msg)
self.message = message self.message = message
def format_message(self): def format_message(self) -> str:
return self.message return self.message
def __str__(self): def __str__(self) -> str:
return self.message return self.message
if PY2: def show(self, file: t.Optional[t.IO] = None) -> None:
__unicode__ = __str__
def __str__(self):
return self.message.encode('utf-8')
def show(self, file=None):
if file is None: if file is None:
file = get_text_stderr() file = get_text_stderr()
echo('Error: %s' % self.format_message(), file=file)
echo(_("Error: {message}").format(message=self.format_message()), file=file)
class UsageError(ClickException): class UsageError(ClickException):
@ -48,26 +51,35 @@ class UsageError(ClickException):
:param ctx: optionally the context that caused this error. Click will :param ctx: optionally the context that caused this error. Click will
fill in the context automatically in some situations. fill in the context automatically in some situations.
""" """
exit_code = 2 exit_code = 2
def __init__(self, message, ctx=None): def __init__(self, message: str, ctx: t.Optional["Context"] = None) -> None:
ClickException.__init__(self, message) super().__init__(message)
self.ctx = ctx self.ctx = ctx
self.cmd = self.ctx and self.ctx.command or None self.cmd = self.ctx.command if self.ctx else None
def show(self, file=None): def show(self, file: t.Optional[t.IO] = None) -> None:
if file is None: if file is None:
file = get_text_stderr() file = get_text_stderr()
color = None color = None
hint = '' hint = ""
if (self.cmd is not None and if (
self.cmd.get_help_option(self.ctx) is not None): self.ctx is not None
hint = ('Try "%s %s" for help.\n' and self.ctx.command.get_help_option(self.ctx) is not None
% (self.ctx.command_path, self.ctx.help_option_names[0])) ):
hint = _("Try '{command} {option}' for help.").format(
command=self.ctx.command_path, option=self.ctx.help_option_names[0]
)
hint = f"{hint}\n"
if self.ctx is not None: if self.ctx is not None:
color = self.ctx.color color = self.ctx.color
echo(self.ctx.get_usage() + '\n%s' % hint, file=file, color=color) echo(f"{self.ctx.get_usage()}\n{hint}", file=file, color=color)
echo('Error: %s' % self.format_message(), file=file, color=color) echo(
_("Error: {message}").format(message=self.format_message()),
file=file,
color=color,
)
class BadParameter(UsageError): class BadParameter(UsageError):
@ -88,22 +100,28 @@ class BadParameter(UsageError):
each item is quoted and separated. each item is quoted and separated.
""" """
def __init__(self, message, ctx=None, param=None, def __init__(
param_hint=None): self,
UsageError.__init__(self, message, ctx) message: str,
ctx: t.Optional["Context"] = None,
param: t.Optional["Parameter"] = None,
param_hint: t.Optional[str] = None,
) -> None:
super().__init__(message, ctx)
self.param = param self.param = param
self.param_hint = param_hint self.param_hint = param_hint
def format_message(self): def format_message(self) -> str:
if self.param_hint is not None: if self.param_hint is not None:
param_hint = self.param_hint param_hint = self.param_hint
elif self.param is not None: elif self.param is not None:
param_hint = self.param.get_error_hint(self.ctx) param_hint = self.param.get_error_hint(self.ctx) # type: ignore
else: else:
return 'Invalid value: %s' % self.message return _("Invalid value: {message}").format(message=self.message)
param_hint = _join_param_hints(param_hint)
return 'Invalid value for %s: %s' % (param_hint, self.message) return _("Invalid value for {param_hint}: {message}").format(
param_hint=_join_param_hints(param_hint), message=self.message
)
class MissingParameter(BadParameter): class MissingParameter(BadParameter):
@ -118,19 +136,27 @@ class MissingParameter(BadParameter):
``'option'`` or ``'argument'``. ``'option'`` or ``'argument'``.
""" """
def __init__(self, message=None, ctx=None, param=None, def __init__(
param_hint=None, param_type=None): self,
BadParameter.__init__(self, message, ctx, param, param_hint) message: t.Optional[str] = None,
ctx: t.Optional["Context"] = None,
param: t.Optional["Parameter"] = None,
param_hint: t.Optional[str] = None,
param_type: t.Optional[str] = None,
) -> None:
super().__init__(message or "", ctx, param, param_hint)
self.param_type = param_type self.param_type = param_type
def format_message(self): def format_message(self) -> str:
if self.param_hint is not None: if self.param_hint is not None:
param_hint = self.param_hint param_hint: t.Optional[str] = self.param_hint
elif self.param is not None: elif self.param is not None:
param_hint = self.param.get_error_hint(self.ctx) param_hint = self.param.get_error_hint(self.ctx) # type: ignore
else: else:
param_hint = None param_hint = None
param_hint = _join_param_hints(param_hint) param_hint = _join_param_hints(param_hint)
param_hint = f" {param_hint}" if param_hint else ""
param_type = self.param_type param_type = self.param_type
if param_type is None and self.param is not None: if param_type is None and self.param is not None:
@ -141,16 +167,30 @@ class MissingParameter(BadParameter):
msg_extra = self.param.type.get_missing_message(self.param) msg_extra = self.param.type.get_missing_message(self.param)
if msg_extra: if msg_extra:
if msg: if msg:
msg += '. ' + msg_extra msg += f". {msg_extra}"
else: else:
msg = msg_extra msg = msg_extra
return 'Missing %s%s%s%s' % ( msg = f" {msg}" if msg else ""
param_type,
param_hint and ' %s' % param_hint or '', # Translate param_type for known types.
msg and '. ' or '.', if param_type == "argument":
msg or '', missing = _("Missing argument")
) elif param_type == "option":
missing = _("Missing option")
elif param_type == "parameter":
missing = _("Missing parameter")
else:
missing = _("Missing {param_type}").format(param_type=param_type)
return f"{missing}{param_hint}.{msg}"
def __str__(self) -> str:
if not self.message:
param_name = self.param.name if self.param else None
return _("Missing parameter: {param_name}").format(param_name=param_name)
else:
return self.message
class NoSuchOption(UsageError): class NoSuchOption(UsageError):
@ -160,23 +200,31 @@ class NoSuchOption(UsageError):
.. versionadded:: 4.0 .. versionadded:: 4.0
""" """
def __init__(self, option_name, message=None, possibilities=None, def __init__(
ctx=None): self,
option_name: str,
message: t.Optional[str] = None,
possibilities: t.Optional[t.Sequence[str]] = None,
ctx: t.Optional["Context"] = None,
) -> None:
if message is None: if message is None:
message = 'no such option: %s' % option_name message = _("No such option: {name}").format(name=option_name)
UsageError.__init__(self, message, ctx)
super().__init__(message, ctx)
self.option_name = option_name self.option_name = option_name
self.possibilities = possibilities self.possibilities = possibilities
def format_message(self): def format_message(self) -> str:
bits = [self.message] if not self.possibilities:
if self.possibilities: return self.message
if len(self.possibilities) == 1:
bits.append('Did you mean %s?' % self.possibilities[0]) possibility_str = ", ".join(sorted(self.possibilities))
else: suggest = ngettext(
possibilities = sorted(self.possibilities) "Did you mean {possibility}?",
bits.append('(Possible options: %s)' % ', '.join(possibilities)) "(Possible options: {possibilities})",
return ' '.join(bits) len(self.possibilities),
).format(possibility=possibility_str, possibilities=possibility_str)
return f"{self.message} {suggest}"
class BadOptionUsage(UsageError): class BadOptionUsage(UsageError):
@ -189,8 +237,10 @@ class BadOptionUsage(UsageError):
:param option_name: the name of the option being used incorrectly. :param option_name: the name of the option being used incorrectly.
""" """
def __init__(self, option_name, message, ctx=None): def __init__(
UsageError.__init__(self, message, ctx) self, option_name: str, message: str, ctx: t.Optional["Context"] = None
) -> None:
super().__init__(message, ctx)
self.option_name = option_name self.option_name = option_name
@ -202,23 +252,22 @@ class BadArgumentUsage(UsageError):
.. versionadded:: 6.0 .. versionadded:: 6.0
""" """
def __init__(self, message, ctx=None):
UsageError.__init__(self, message, ctx)
class FileError(ClickException): class FileError(ClickException):
"""Raised if a file cannot be opened.""" """Raised if a file cannot be opened."""
def __init__(self, filename, hint=None): def __init__(self, filename: str, hint: t.Optional[str] = None) -> None:
ui_filename = filename_to_ui(filename)
if hint is None: if hint is None:
hint = 'unknown error' hint = _("unknown error")
ClickException.__init__(self, hint)
self.ui_filename = ui_filename super().__init__(hint)
self.ui_filename = os.fsdecode(filename)
self.filename = filename self.filename = filename
def format_message(self): def format_message(self) -> str:
return 'Could not open file %s: %s' % (self.ui_filename, self.message) return _("Could not open file {filename!r}: {message}").format(
filename=self.ui_filename, message=self.message
)
class Abort(RuntimeError): class Abort(RuntimeError):
@ -231,5 +280,8 @@ class Exit(RuntimeError):
:param code: the status code to exit with. :param code: the status code to exit with.
""" """
def __init__(self, code=0):
__slots__ = ("exit_code",)
def __init__(self, code: int = 0) -> None:
self.exit_code = code self.exit_code = code

View file

@ -1,29 +1,38 @@
import typing as t
from contextlib import contextmanager from contextlib import contextmanager
from .termui import get_terminal_size from gettext import gettext as _
from .parser import split_opt
from ._compat import term_len
from ._compat import term_len
from .parser import split_opt
# Can force a width. This is used by the test system # Can force a width. This is used by the test system
FORCED_WIDTH = None FORCED_WIDTH: t.Optional[int] = None
def measure_table(rows): def measure_table(rows: t.Iterable[t.Tuple[str, str]]) -> t.Tuple[int, ...]:
widths = {} widths: t.Dict[int, int] = {}
for row in rows: for row in rows:
for idx, col in enumerate(row): for idx, col in enumerate(row):
widths[idx] = max(widths.get(idx, 0), term_len(col)) widths[idx] = max(widths.get(idx, 0), term_len(col))
return tuple(y for x, y in sorted(widths.items())) return tuple(y for x, y in sorted(widths.items()))
def iter_rows(rows, col_count): def iter_rows(
rows: t.Iterable[t.Tuple[str, str]], col_count: int
) -> t.Iterator[t.Tuple[str, ...]]:
for row in rows: for row in rows:
row = tuple(row) yield row + ("",) * (col_count - len(row))
yield row + ('',) * (col_count - len(row))
def wrap_text(text, width=78, initial_indent='', subsequent_indent='', def wrap_text(
preserve_paragraphs=False): text: str,
width: int = 78,
initial_indent: str = "",
subsequent_indent: str = "",
preserve_paragraphs: bool = False,
) -> str:
"""A helper function that intelligently wraps text. By default, it """A helper function that intelligently wraps text. By default, it
assumes that it operates on a single paragraph of text but if the assumes that it operates on a single paragraph of text but if the
`preserve_paragraphs` parameter is provided it will intelligently `preserve_paragraphs` parameter is provided it will intelligently
@ -43,24 +52,28 @@ def wrap_text(text, width=78, initial_indent='', subsequent_indent='',
intelligently handle paragraphs. intelligently handle paragraphs.
""" """
from ._textwrap import TextWrapper from ._textwrap import TextWrapper
text = text.expandtabs() text = text.expandtabs()
wrapper = TextWrapper(width, initial_indent=initial_indent, wrapper = TextWrapper(
subsequent_indent=subsequent_indent, width,
replace_whitespace=False) initial_indent=initial_indent,
subsequent_indent=subsequent_indent,
replace_whitespace=False,
)
if not preserve_paragraphs: if not preserve_paragraphs:
return wrapper.fill(text) return wrapper.fill(text)
p = [] p: t.List[t.Tuple[int, bool, str]] = []
buf = [] buf: t.List[str] = []
indent = None indent = None
def _flush_par(): def _flush_par() -> None:
if not buf: if not buf:
return return
if buf[0].strip() == '\b': if buf[0].strip() == "\b":
p.append((indent or 0, True, '\n'.join(buf[1:]))) p.append((indent or 0, True, "\n".join(buf[1:])))
else: else:
p.append((indent or 0, False, ' '.join(buf))) p.append((indent or 0, False, " ".join(buf)))
del buf[:] del buf[:]
for line in text.splitlines(): for line in text.splitlines():
@ -77,16 +90,16 @@ def wrap_text(text, width=78, initial_indent='', subsequent_indent='',
rv = [] rv = []
for indent, raw, text in p: for indent, raw, text in p:
with wrapper.extra_indent(' ' * indent): with wrapper.extra_indent(" " * indent):
if raw: if raw:
rv.append(wrapper.indent_only(text)) rv.append(wrapper.indent_only(text))
else: else:
rv.append(wrapper.fill(text)) rv.append(wrapper.fill(text))
return '\n\n'.join(rv) return "\n\n".join(rv)
class HelpFormatter(object): class HelpFormatter:
"""This class helps with formatting text-based help pages. It's """This class helps with formatting text-based help pages. It's
usually just needed for very special internal cases, but it's also usually just needed for very special internal cases, but it's also
exposed so that developers can write their own fancy outputs. exposed so that developers can write their own fancy outputs.
@ -98,79 +111,108 @@ class HelpFormatter(object):
width clamped to a maximum of 78. width clamped to a maximum of 78.
""" """
def __init__(self, indent_increment=2, width=None, max_width=None): def __init__(
self,
indent_increment: int = 2,
width: t.Optional[int] = None,
max_width: t.Optional[int] = None,
) -> None:
import shutil
self.indent_increment = indent_increment self.indent_increment = indent_increment
if max_width is None: if max_width is None:
max_width = 80 max_width = 80
if width is None: if width is None:
width = FORCED_WIDTH width = FORCED_WIDTH
if width is None: if width is None:
width = max(min(get_terminal_size()[0], max_width) - 2, 50) width = max(min(shutil.get_terminal_size().columns, max_width) - 2, 50)
self.width = width self.width = width
self.current_indent = 0 self.current_indent = 0
self.buffer = [] self.buffer: t.List[str] = []
def write(self, string): def write(self, string: str) -> None:
"""Writes a unicode string into the internal buffer.""" """Writes a unicode string into the internal buffer."""
self.buffer.append(string) self.buffer.append(string)
def indent(self): def indent(self) -> None:
"""Increases the indentation.""" """Increases the indentation."""
self.current_indent += self.indent_increment self.current_indent += self.indent_increment
def dedent(self): def dedent(self) -> None:
"""Decreases the indentation.""" """Decreases the indentation."""
self.current_indent -= self.indent_increment self.current_indent -= self.indent_increment
def write_usage(self, prog, args='', prefix='Usage: '): def write_usage(
self, prog: str, args: str = "", prefix: t.Optional[str] = None
) -> None:
"""Writes a usage line into the buffer. """Writes a usage line into the buffer.
:param prog: the program name. :param prog: the program name.
:param args: whitespace separated list of arguments. :param args: whitespace separated list of arguments.
:param prefix: the prefix for the first line. :param prefix: The prefix for the first line. Defaults to
``"Usage: "``.
""" """
usage_prefix = '%*s%s ' % (self.current_indent, prefix, prog) if prefix is None:
prefix = f"{_('Usage:')} "
usage_prefix = f"{prefix:>{self.current_indent}}{prog} "
text_width = self.width - self.current_indent text_width = self.width - self.current_indent
if text_width >= (term_len(usage_prefix) + 20): if text_width >= (term_len(usage_prefix) + 20):
# The arguments will fit to the right of the prefix. # The arguments will fit to the right of the prefix.
indent = ' ' * term_len(usage_prefix) indent = " " * term_len(usage_prefix)
self.write(wrap_text(args, text_width, self.write(
initial_indent=usage_prefix, wrap_text(
subsequent_indent=indent)) args,
text_width,
initial_indent=usage_prefix,
subsequent_indent=indent,
)
)
else: else:
# The prefix is too long, put the arguments on the next line. # The prefix is too long, put the arguments on the next line.
self.write(usage_prefix) self.write(usage_prefix)
self.write('\n') self.write("\n")
indent = ' ' * (max(self.current_indent, term_len(prefix)) + 4) indent = " " * (max(self.current_indent, term_len(prefix)) + 4)
self.write(wrap_text(args, text_width, self.write(
initial_indent=indent, wrap_text(
subsequent_indent=indent)) args, text_width, initial_indent=indent, subsequent_indent=indent
)
)
self.write('\n') self.write("\n")
def write_heading(self, heading): def write_heading(self, heading: str) -> None:
"""Writes a heading into the buffer.""" """Writes a heading into the buffer."""
self.write('%*s%s:\n' % (self.current_indent, '', heading)) self.write(f"{'':>{self.current_indent}}{heading}:\n")
def write_paragraph(self): def write_paragraph(self) -> None:
"""Writes a paragraph into the buffer.""" """Writes a paragraph into the buffer."""
if self.buffer: if self.buffer:
self.write('\n') self.write("\n")
def write_text(self, text): def write_text(self, text: str) -> None:
"""Writes re-indented text into the buffer. This rewraps and """Writes re-indented text into the buffer. This rewraps and
preserves paragraphs. preserves paragraphs.
""" """
text_width = max(self.width - self.current_indent, 11) indent = " " * self.current_indent
indent = ' ' * self.current_indent self.write(
self.write(wrap_text(text, text_width, wrap_text(
initial_indent=indent, text,
subsequent_indent=indent, self.width,
preserve_paragraphs=True)) initial_indent=indent,
self.write('\n') subsequent_indent=indent,
preserve_paragraphs=True,
)
)
self.write("\n")
def write_dl(self, rows, col_max=30, col_spacing=2): def write_dl(
self,
rows: t.Sequence[t.Tuple[str, str]],
col_max: int = 30,
col_spacing: int = 2,
) -> None:
"""Writes a definition list into the buffer. This is how options """Writes a definition list into the buffer. This is how options
and commands are usually formatted. and commands are usually formatted.
@ -182,33 +224,35 @@ class HelpFormatter(object):
rows = list(rows) rows = list(rows)
widths = measure_table(rows) widths = measure_table(rows)
if len(widths) != 2: if len(widths) != 2:
raise TypeError('Expected two columns for definition list') raise TypeError("Expected two columns for definition list")
first_col = min(widths[0], col_max) + col_spacing first_col = min(widths[0], col_max) + col_spacing
for first, second in iter_rows(rows, len(widths)): for first, second in iter_rows(rows, len(widths)):
self.write('%*s%s' % (self.current_indent, '', first)) self.write(f"{'':>{self.current_indent}}{first}")
if not second: if not second:
self.write('\n') self.write("\n")
continue continue
if term_len(first) <= first_col - col_spacing: if term_len(first) <= first_col - col_spacing:
self.write(' ' * (first_col - term_len(first))) self.write(" " * (first_col - term_len(first)))
else: else:
self.write('\n') self.write("\n")
self.write(' ' * (first_col + self.current_indent)) self.write(" " * (first_col + self.current_indent))
text_width = max(self.width - first_col - 2, 10) text_width = max(self.width - first_col - 2, 10)
lines = iter(wrap_text(second, text_width).splitlines()) wrapped_text = wrap_text(second, text_width, preserve_paragraphs=True)
lines = wrapped_text.splitlines()
if lines: if lines:
self.write(next(lines) + '\n') self.write(f"{lines[0]}\n")
for line in lines:
self.write('%*s%s\n' % ( for line in lines[1:]:
first_col + self.current_indent, '', line)) self.write(f"{'':>{first_col + self.current_indent}}{line}\n")
else: else:
self.write('\n') self.write("\n")
@contextmanager @contextmanager
def section(self, name): def section(self, name: str) -> t.Iterator[None]:
"""Helpful context manager that writes a paragraph, a heading, """Helpful context manager that writes a paragraph, a heading,
and the indents. and the indents.
@ -223,7 +267,7 @@ class HelpFormatter(object):
self.dedent() self.dedent()
@contextmanager @contextmanager
def indentation(self): def indentation(self) -> t.Iterator[None]:
"""A context manager that increases the indentation.""" """A context manager that increases the indentation."""
self.indent() self.indent()
try: try:
@ -231,12 +275,12 @@ class HelpFormatter(object):
finally: finally:
self.dedent() self.dedent()
def getvalue(self): def getvalue(self) -> str:
"""Returns the buffer contents.""" """Returns the buffer contents."""
return ''.join(self.buffer) return "".join(self.buffer)
def join_options(options): def join_options(options: t.Sequence[str]) -> t.Tuple[str, bool]:
"""Given a list of option strings this joins them in the most appropriate """Given a list of option strings this joins them in the most appropriate
way and returns them in the form ``(formatted_string, way and returns them in the form ``(formatted_string,
any_prefix_is_slash)`` where the second item in the tuple is a flag that any_prefix_is_slash)`` where the second item in the tuple is a flag that
@ -244,13 +288,14 @@ def join_options(options):
""" """
rv = [] rv = []
any_prefix_is_slash = False any_prefix_is_slash = False
for opt in options: for opt in options:
prefix = split_opt(opt)[0] prefix = split_opt(opt)[0]
if prefix == '/':
if prefix == "/":
any_prefix_is_slash = True any_prefix_is_slash = True
rv.append((len(prefix), opt)) rv.append((len(prefix), opt))
rv.sort(key=lambda x: x[0]) rv.sort(key=lambda x: x[0])
return ", ".join(x[1] for x in rv), any_prefix_is_slash
rv = ', '.join(x[1] for x in rv)
return rv, any_prefix_is_slash

View file

@ -1,10 +1,24 @@
import typing as t
from threading import local from threading import local
if t.TYPE_CHECKING:
import typing_extensions as te
from .core import Context
_local = local() _local = local()
def get_current_context(silent=False): @t.overload
def get_current_context(silent: "te.Literal[False]" = False) -> "Context":
...
@t.overload
def get_current_context(silent: bool = ...) -> t.Optional["Context"]:
...
def get_current_context(silent: bool = False) -> t.Optional["Context"]:
"""Returns the current click context. This can be used as a way to """Returns the current click context. This can be used as a way to
access the current context object from anywhere. This is a more implicit access the current context object from anywhere. This is a more implicit
alternative to the :func:`pass_context` decorator. This function is alternative to the :func:`pass_context` decorator. This function is
@ -15,34 +29,40 @@ def get_current_context(silent=False):
.. versionadded:: 5.0 .. versionadded:: 5.0
:param silent: is set to `True` the return value is `None` if no context :param silent: if set to `True` the return value is `None` if no context
is available. The default behavior is to raise a is available. The default behavior is to raise a
:exc:`RuntimeError`. :exc:`RuntimeError`.
""" """
try: try:
return getattr(_local, 'stack')[-1] return t.cast("Context", _local.stack[-1])
except (AttributeError, IndexError): except (AttributeError, IndexError) as e:
if not silent: if not silent:
raise RuntimeError('There is no active click context.') raise RuntimeError("There is no active click context.") from e
return None
def push_context(ctx): def push_context(ctx: "Context") -> None:
"""Pushes a new context to the current stack.""" """Pushes a new context to the current stack."""
_local.__dict__.setdefault('stack', []).append(ctx) _local.__dict__.setdefault("stack", []).append(ctx)
def pop_context(): def pop_context() -> None:
"""Removes the top level from the stack.""" """Removes the top level from the stack."""
_local.stack.pop() _local.stack.pop()
def resolve_color_default(color=None): def resolve_color_default(color: t.Optional[bool] = None) -> t.Optional[bool]:
""""Internal helper to get the default value of the color flag. If a """Internal helper to get the default value of the color flag. If a
value is passed it's returned unchanged, otherwise it's looked up from value is passed it's returned unchanged, otherwise it's looked up from
the current context. the current context.
""" """
if color is not None: if color is not None:
return color return color
ctx = get_current_context(silent=True) ctx = get_current_context(silent=True)
if ctx is not None: if ctx is not None:
return ctx.color return ctx.color
return None

View file

@ -1,8 +1,4 @@
# -*- coding: utf-8 -*-
""" """
click.parser
~~~~~~~~~~~~
This module started out as largely a copy paste from the stdlib's This module started out as largely a copy paste from the stdlib's
optparse module with the features removed that we do not need from optparse module with the features removed that we do not need from
optparse because we implement them in Click on a higher level (for optparse because we implement them in Click on a higher level (for
@ -14,15 +10,45 @@ The reason this is a different module and not optparse from the stdlib
is that there are differences in 2.x and 3.x about the error messages is that there are differences in 2.x and 3.x about the error messages
generated and optparse in the stdlib uses gettext for no good reason generated and optparse in the stdlib uses gettext for no good reason
and might cause us issues. and might cause us issues.
Click uses parts of optparse written by Gregory P. Ward and maintained
by the Python Software Foundation. This is limited to code in parser.py.
Copyright 2001-2006 Gregory P. Ward. All rights reserved.
Copyright 2002-2006 Python Software Foundation. All rights reserved.
""" """
# This code uses parts of optparse written by Gregory P. Ward and
import re # maintained by the Python Software Foundation.
# Copyright 2001-2006 Gregory P. Ward
# Copyright 2002-2006 Python Software Foundation
import typing as t
from collections import deque from collections import deque
from .exceptions import UsageError, NoSuchOption, BadOptionUsage, \ from gettext import gettext as _
BadArgumentUsage from gettext import ngettext
from .exceptions import BadArgumentUsage
from .exceptions import BadOptionUsage
from .exceptions import NoSuchOption
from .exceptions import UsageError
if t.TYPE_CHECKING:
import typing_extensions as te
from .core import Argument as CoreArgument
from .core import Context
from .core import Option as CoreOption
from .core import Parameter as CoreParameter
V = t.TypeVar("V")
# Sentinel value that indicates an option was passed as a flag without a
# value but is not a flag option. Option.consume_value uses this to
# prompt or use the flag_value.
_flag_needs_value = object()
def _unpack_args(args, nargs_spec): def _unpack_args(
args: t.Sequence[str], nargs_spec: t.Sequence[int]
) -> t.Tuple[t.Sequence[t.Union[str, t.Sequence[t.Optional[str]], None]], t.List[str]]:
"""Given an iterable of arguments and an iterable of nargs specifications, """Given an iterable of arguments and an iterable of nargs specifications,
it returns a tuple with all the unpacked arguments at the first index it returns a tuple with all the unpacked arguments at the first index
and all remaining arguments as the second. and all remaining arguments as the second.
@ -34,10 +60,10 @@ def _unpack_args(args, nargs_spec):
""" """
args = deque(args) args = deque(args)
nargs_spec = deque(nargs_spec) nargs_spec = deque(nargs_spec)
rv = [] rv: t.List[t.Union[str, t.Tuple[t.Optional[str], ...], None]] = []
spos = None spos: t.Optional[int] = None
def _fetch(c): def _fetch(c: "te.Deque[V]") -> t.Optional[V]:
try: try:
if spos is None: if spos is None:
return c.popleft() return c.popleft()
@ -48,18 +74,25 @@ def _unpack_args(args, nargs_spec):
while nargs_spec: while nargs_spec:
nargs = _fetch(nargs_spec) nargs = _fetch(nargs_spec)
if nargs is None:
continue
if nargs == 1: if nargs == 1:
rv.append(_fetch(args)) rv.append(_fetch(args))
elif nargs > 1: elif nargs > 1:
x = [_fetch(args) for _ in range(nargs)] x = [_fetch(args) for _ in range(nargs)]
# If we're reversed, we're pulling in the arguments in reverse, # If we're reversed, we're pulling in the arguments in reverse,
# so we need to turn them around. # so we need to turn them around.
if spos is not None: if spos is not None:
x.reverse() x.reverse()
rv.append(tuple(x)) rv.append(tuple(x))
elif nargs < 0: elif nargs < 0:
if spos is not None: if spos is not None:
raise TypeError('Cannot have two nargs < 0') raise TypeError("Cannot have two nargs < 0")
spos = len(rv) spos = len(rv)
rv.append(None) rv.append(None)
@ -68,54 +101,71 @@ def _unpack_args(args, nargs_spec):
if spos is not None: if spos is not None:
rv[spos] = tuple(args) rv[spos] = tuple(args)
args = [] args = []
rv[spos + 1:] = reversed(rv[spos + 1:]) rv[spos + 1 :] = reversed(rv[spos + 1 :])
return tuple(rv), list(args) return tuple(rv), list(args)
def _error_opt_args(nargs, opt): def split_opt(opt: str) -> t.Tuple[str, str]:
if nargs == 1:
raise BadOptionUsage(opt, '%s option requires an argument' % opt)
raise BadOptionUsage(opt, '%s option requires %d arguments' % (opt, nargs))
def split_opt(opt):
first = opt[:1] first = opt[:1]
if first.isalnum(): if first.isalnum():
return '', opt return "", opt
if opt[1:2] == first: if opt[1:2] == first:
return opt[:2], opt[2:] return opt[:2], opt[2:]
return first, opt[1:] return first, opt[1:]
def normalize_opt(opt, ctx): def normalize_opt(opt: str, ctx: t.Optional["Context"]) -> str:
if ctx is None or ctx.token_normalize_func is None: if ctx is None or ctx.token_normalize_func is None:
return opt return opt
prefix, opt = split_opt(opt) prefix, opt = split_opt(opt)
return prefix + ctx.token_normalize_func(opt) return f"{prefix}{ctx.token_normalize_func(opt)}"
def split_arg_string(string): def split_arg_string(string: str) -> t.List[str]:
"""Given an argument string this attempts to split it into small parts.""" """Split an argument string as with :func:`shlex.split`, but don't
rv = [] fail if the string is incomplete. Ignores a missing closing quote or
for match in re.finditer(r"('([^'\\]*(?:\\.[^'\\]*)*)'" incomplete escape sequence and uses the partial token as-is.
r'|"([^"\\]*(?:\\.[^"\\]*)*)"'
r'|\S+)\s*', string, re.S): .. code-block:: python
arg = match.group().strip()
if arg[:1] == arg[-1:] and arg[:1] in '"\'': split_arg_string("example 'my file")
arg = arg[1:-1].encode('ascii', 'backslashreplace') \ ["example", "my file"]
.decode('unicode-escape')
try: split_arg_string("example my\\")
arg = type(string)(arg) ["example", "my"]
except UnicodeError:
pass :param string: String to split.
rv.append(arg) """
return rv import shlex
lex = shlex.shlex(string, posix=True)
lex.whitespace_split = True
lex.commenters = ""
out = []
try:
for token in lex:
out.append(token)
except ValueError:
# Raised when end-of-string is reached in an invalid state. Use
# the partial token as-is. The quote or escape character is in
# lex.state, not lex.token.
out.append(lex.token)
return out
class Option(object): class Option:
def __init__(
def __init__(self, opts, dest, action=None, nargs=1, const=None, obj=None): self,
obj: "CoreOption",
opts: t.Sequence[str],
dest: t.Optional[str],
action: t.Optional[str] = None,
nargs: int = 1,
const: t.Optional[t.Any] = None,
):
self._short_opts = [] self._short_opts = []
self._long_opts = [] self._long_opts = []
self.prefixes = set() self.prefixes = set()
@ -123,8 +173,7 @@ class Option(object):
for opt in opts: for opt in opts:
prefix, value = split_opt(opt) prefix, value = split_opt(opt)
if not prefix: if not prefix:
raise ValueError('Invalid start character for option (%s)' raise ValueError(f"Invalid start character for option ({opt})")
% opt)
self.prefixes.add(prefix[0]) self.prefixes.add(prefix[0])
if len(prefix) == 1 and len(value) == 1: if len(prefix) == 1 and len(value) == 1:
self._short_opts.append(opt) self._short_opts.append(opt)
@ -133,7 +182,7 @@ class Option(object):
self.prefixes.add(prefix) self.prefixes.add(prefix)
if action is None: if action is None:
action = 'store' action = "store"
self.dest = dest self.dest = dest
self.action = action self.action = action
@ -142,54 +191,66 @@ class Option(object):
self.obj = obj self.obj = obj
@property @property
def takes_value(self): def takes_value(self) -> bool:
return self.action in ('store', 'append') return self.action in ("store", "append")
def process(self, value, state): def process(self, value: str, state: "ParsingState") -> None:
if self.action == 'store': if self.action == "store":
state.opts[self.dest] = value state.opts[self.dest] = value # type: ignore
elif self.action == 'store_const': elif self.action == "store_const":
state.opts[self.dest] = self.const state.opts[self.dest] = self.const # type: ignore
elif self.action == 'append': elif self.action == "append":
state.opts.setdefault(self.dest, []).append(value) state.opts.setdefault(self.dest, []).append(value) # type: ignore
elif self.action == 'append_const': elif self.action == "append_const":
state.opts.setdefault(self.dest, []).append(self.const) state.opts.setdefault(self.dest, []).append(self.const) # type: ignore
elif self.action == 'count': elif self.action == "count":
state.opts[self.dest] = state.opts.get(self.dest, 0) + 1 state.opts[self.dest] = state.opts.get(self.dest, 0) + 1 # type: ignore
else: else:
raise ValueError('unknown action %r' % self.action) raise ValueError(f"unknown action '{self.action}'")
state.order.append(self.obj) state.order.append(self.obj)
class Argument(object): class Argument:
def __init__(self, obj: "CoreArgument", dest: t.Optional[str], nargs: int = 1):
def __init__(self, dest, nargs=1, obj=None):
self.dest = dest self.dest = dest
self.nargs = nargs self.nargs = nargs
self.obj = obj self.obj = obj
def process(self, value, state): def process(
self,
value: t.Union[t.Optional[str], t.Sequence[t.Optional[str]]],
state: "ParsingState",
) -> None:
if self.nargs > 1: if self.nargs > 1:
assert value is not None
holes = sum(1 for x in value if x is None) holes = sum(1 for x in value if x is None)
if holes == len(value): if holes == len(value):
value = None value = None
elif holes != 0: elif holes != 0:
raise BadArgumentUsage('argument %s takes %d values' raise BadArgumentUsage(
% (self.dest, self.nargs)) _("Argument {name!r} takes {nargs} values.").format(
state.opts[self.dest] = value name=self.dest, nargs=self.nargs
)
)
if self.nargs == -1 and self.obj.envvar is not None and value == ():
# Replace empty tuple with None so that a value from the
# environment may be tried.
value = None
state.opts[self.dest] = value # type: ignore
state.order.append(self.obj) state.order.append(self.obj)
class ParsingState(object): class ParsingState:
def __init__(self, rargs: t.List[str]) -> None:
def __init__(self, rargs): self.opts: t.Dict[str, t.Any] = {}
self.opts = {} self.largs: t.List[str] = []
self.largs = []
self.rargs = rargs self.rargs = rargs
self.order = [] self.order: t.List["CoreParameter"] = []
class OptionParser(object): class OptionParser:
"""The option parser is an internal class that is ultimately used to """The option parser is an internal class that is ultimately used to
parse options and arguments. It's modelled after optparse and brings parse options and arguments. It's modelled after optparse and brings
a similar but vastly simplified API. It should generally not be used a similar but vastly simplified API. It should generally not be used
@ -203,7 +264,7 @@ class OptionParser(object):
should go with. should go with.
""" """
def __init__(self, ctx=None): def __init__(self, ctx: t.Optional["Context"] = None) -> None:
#: The :class:`~click.Context` for this parser. This might be #: The :class:`~click.Context` for this parser. This might be
#: `None` for some advanced use cases. #: `None` for some advanced use cases.
self.ctx = ctx self.ctx = ctx
@ -217,46 +278,54 @@ class OptionParser(object):
#: second mode where it will ignore it and continue processing #: second mode where it will ignore it and continue processing
#: after shifting all the unknown options into the resulting args. #: after shifting all the unknown options into the resulting args.
self.ignore_unknown_options = False self.ignore_unknown_options = False
if ctx is not None: if ctx is not None:
self.allow_interspersed_args = ctx.allow_interspersed_args self.allow_interspersed_args = ctx.allow_interspersed_args
self.ignore_unknown_options = ctx.ignore_unknown_options self.ignore_unknown_options = ctx.ignore_unknown_options
self._short_opt = {}
self._long_opt = {}
self._opt_prefixes = set(['-', '--'])
self._args = []
def add_option(self, opts, dest, action=None, nargs=1, const=None, self._short_opt: t.Dict[str, Option] = {}
obj=None): self._long_opt: t.Dict[str, Option] = {}
self._opt_prefixes = {"-", "--"}
self._args: t.List[Argument] = []
def add_option(
self,
obj: "CoreOption",
opts: t.Sequence[str],
dest: t.Optional[str],
action: t.Optional[str] = None,
nargs: int = 1,
const: t.Optional[t.Any] = None,
) -> None:
"""Adds a new option named `dest` to the parser. The destination """Adds a new option named `dest` to the parser. The destination
is not inferred (unlike with optparse) and needs to be explicitly is not inferred (unlike with optparse) and needs to be explicitly
provided. Action can be any of ``store``, ``store_const``, provided. Action can be any of ``store``, ``store_const``,
``append``, ``appnd_const`` or ``count``. ``append``, ``append_const`` or ``count``.
The `obj` can be used to identify the option in the order list The `obj` can be used to identify the option in the order list
that is returned from the parser. that is returned from the parser.
""" """
if obj is None:
obj = dest
opts = [normalize_opt(opt, self.ctx) for opt in opts] opts = [normalize_opt(opt, self.ctx) for opt in opts]
option = Option(opts, dest, action=action, nargs=nargs, option = Option(obj, opts, dest, action=action, nargs=nargs, const=const)
const=const, obj=obj)
self._opt_prefixes.update(option.prefixes) self._opt_prefixes.update(option.prefixes)
for opt in option._short_opts: for opt in option._short_opts:
self._short_opt[opt] = option self._short_opt[opt] = option
for opt in option._long_opts: for opt in option._long_opts:
self._long_opt[opt] = option self._long_opt[opt] = option
def add_argument(self, dest, nargs=1, obj=None): def add_argument(
self, obj: "CoreArgument", dest: t.Optional[str], nargs: int = 1
) -> None:
"""Adds a positional argument named `dest` to the parser. """Adds a positional argument named `dest` to the parser.
The `obj` can be used to identify the option in the order list The `obj` can be used to identify the option in the order list
that is returned from the parser. that is returned from the parser.
""" """
if obj is None: self._args.append(Argument(obj, dest=dest, nargs=nargs))
obj = dest
self._args.append(Argument(dest=dest, nargs=nargs, obj=obj))
def parse_args(self, args): def parse_args(
self, args: t.List[str]
) -> t.Tuple[t.Dict[str, t.Any], t.List[str], t.List["CoreParameter"]]:
"""Parses positional arguments and returns ``(values, args, order)`` """Parses positional arguments and returns ``(values, args, order)``
for the parsed options and arguments as well as the leftover for the parsed options and arguments as well as the leftover
arguments if there are any. The order is a list of objects as they arguments if there are any. The order is a list of objects as they
@ -272,9 +341,10 @@ class OptionParser(object):
raise raise
return state.opts, state.largs, state.order return state.opts, state.largs, state.order
def _process_args_for_args(self, state): def _process_args_for_args(self, state: ParsingState) -> None:
pargs, args = _unpack_args(state.largs + state.rargs, pargs, args = _unpack_args(
[x.nargs for x in self._args]) state.largs + state.rargs, [x.nargs for x in self._args]
)
for idx, arg in enumerate(self._args): for idx, arg in enumerate(self._args):
arg.process(pargs[idx], state) arg.process(pargs[idx], state)
@ -282,13 +352,13 @@ class OptionParser(object):
state.largs = args state.largs = args
state.rargs = [] state.rargs = []
def _process_args_for_options(self, state): def _process_args_for_options(self, state: ParsingState) -> None:
while state.rargs: while state.rargs:
arg = state.rargs.pop(0) arg = state.rargs.pop(0)
arglen = len(arg) arglen = len(arg)
# Double dashes always handled explicitly regardless of what # Double dashes always handled explicitly regardless of what
# prefixes are valid. # prefixes are valid.
if arg == '--': if arg == "--":
return return
elif arg[:1] in self._opt_prefixes and arglen > 1: elif arg[:1] in self._opt_prefixes and arglen > 1:
self._process_opts(arg, state) self._process_opts(arg, state)
@ -318,10 +388,13 @@ class OptionParser(object):
# *empty* -- still a subset of [arg0, ..., arg(i-1)], but # *empty* -- still a subset of [arg0, ..., arg(i-1)], but
# not a very interesting subset! # not a very interesting subset!
def _match_long_opt(self, opt, explicit_value, state): def _match_long_opt(
self, opt: str, explicit_value: t.Optional[str], state: ParsingState
) -> None:
if opt not in self._long_opt: if opt not in self._long_opt:
possibilities = [word for word in self._long_opt from difflib import get_close_matches
if word.startswith(opt)]
possibilities = get_close_matches(opt, self._long_opt)
raise NoSuchOption(opt, possibilities=possibilities, ctx=self.ctx) raise NoSuchOption(opt, possibilities=possibilities, ctx=self.ctx)
option = self._long_opt[opt] option = self._long_opt[opt]
@ -333,31 +406,26 @@ class OptionParser(object):
if explicit_value is not None: if explicit_value is not None:
state.rargs.insert(0, explicit_value) state.rargs.insert(0, explicit_value)
nargs = option.nargs value = self._get_value_from_state(opt, option, state)
if len(state.rargs) < nargs:
_error_opt_args(nargs, opt)
elif nargs == 1:
value = state.rargs.pop(0)
else:
value = tuple(state.rargs[:nargs])
del state.rargs[:nargs]
elif explicit_value is not None: elif explicit_value is not None:
raise BadOptionUsage(opt, '%s option does not take a value' % opt) raise BadOptionUsage(
opt, _("Option {name!r} does not take a value.").format(name=opt)
)
else: else:
value = None value = None
option.process(value, state) option.process(value, state)
def _match_short_opt(self, arg, state): def _match_short_opt(self, arg: str, state: ParsingState) -> None:
stop = False stop = False
i = 1 i = 1
prefix = arg[0] prefix = arg[0]
unknown_options = [] unknown_options = []
for ch in arg[1:]: for ch in arg[1:]:
opt = normalize_opt(prefix + ch, self.ctx) opt = normalize_opt(f"{prefix}{ch}", self.ctx)
option = self._short_opt.get(opt) option = self._short_opt.get(opt)
i += 1 i += 1
@ -373,14 +441,7 @@ class OptionParser(object):
state.rargs.insert(0, arg[i:]) state.rargs.insert(0, arg[i:])
stop = True stop = True
nargs = option.nargs value = self._get_value_from_state(opt, option, state)
if len(state.rargs) < nargs:
_error_opt_args(nargs, opt)
elif nargs == 1:
value = state.rargs.pop(0)
else:
value = tuple(state.rargs[:nargs])
del state.rargs[:nargs]
else: else:
value = None value = None
@ -395,15 +456,53 @@ class OptionParser(object):
# to the state as new larg. This way there is basic combinatorics # to the state as new larg. This way there is basic combinatorics
# that can be achieved while still ignoring unknown arguments. # that can be achieved while still ignoring unknown arguments.
if self.ignore_unknown_options and unknown_options: if self.ignore_unknown_options and unknown_options:
state.largs.append(prefix + ''.join(unknown_options)) state.largs.append(f"{prefix}{''.join(unknown_options)}")
def _process_opts(self, arg, state): def _get_value_from_state(
self, option_name: str, option: Option, state: ParsingState
) -> t.Any:
nargs = option.nargs
if len(state.rargs) < nargs:
if option.obj._flag_needs_value:
# Option allows omitting the value.
value = _flag_needs_value
else:
raise BadOptionUsage(
option_name,
ngettext(
"Option {name!r} requires an argument.",
"Option {name!r} requires {nargs} arguments.",
nargs,
).format(name=option_name, nargs=nargs),
)
elif nargs == 1:
next_rarg = state.rargs[0]
if (
option.obj._flag_needs_value
and isinstance(next_rarg, str)
and next_rarg[:1] in self._opt_prefixes
and len(next_rarg) > 1
):
# The next arg looks like the start of an option, don't
# use it as the value if omitting the value is allowed.
value = _flag_needs_value
else:
value = state.rargs.pop(0)
else:
value = tuple(state.rargs[:nargs])
del state.rargs[:nargs]
return value
def _process_opts(self, arg: str, state: ParsingState) -> None:
explicit_value = None explicit_value = None
# Long option handling happens in two parts. The first part is # Long option handling happens in two parts. The first part is
# supporting explicitly attached values. In any case, we will try # supporting explicitly attached values. In any case, we will try
# to long match the option first. # to long match the option first.
if '=' in arg: if "=" in arg:
long_opt, explicit_value = arg.split('=', 1) long_opt, explicit_value = arg.split("=", 1)
else: else:
long_opt = arg long_opt = arg
norm_long_opt = normalize_opt(long_opt, self.ctx) norm_long_opt = normalize_opt(long_opt, self.ctx)
@ -421,7 +520,10 @@ class OptionParser(object):
# short option code and will instead raise the no option # short option code and will instead raise the no option
# error. # error.
if arg[:2] not in self._opt_prefixes: if arg[:2] not in self._opt_prefixes:
return self._match_short_opt(arg, state) self._match_short_opt(arg, state)
return
if not self.ignore_unknown_options: if not self.ignore_unknown_options:
raise raise
state.largs.append(arg) state.largs.append(arg)

View file

View file

@ -0,0 +1,580 @@
import os
import re
import typing as t
from gettext import gettext as _
from .core import Argument
from .core import BaseCommand
from .core import Context
from .core import MultiCommand
from .core import Option
from .core import Parameter
from .core import ParameterSource
from .parser import split_arg_string
from .utils import echo
def shell_complete(
cli: BaseCommand,
ctx_args: t.Dict[str, t.Any],
prog_name: str,
complete_var: str,
instruction: str,
) -> int:
"""Perform shell completion for the given CLI program.
:param cli: Command being called.
:param ctx_args: Extra arguments to pass to
``cli.make_context``.
:param prog_name: Name of the executable in the shell.
:param complete_var: Name of the environment variable that holds
the completion instruction.
:param instruction: Value of ``complete_var`` with the completion
instruction and shell, in the form ``instruction_shell``.
:return: Status code to exit with.
"""
shell, _, instruction = instruction.partition("_")
comp_cls = get_completion_class(shell)
if comp_cls is None:
return 1
comp = comp_cls(cli, ctx_args, prog_name, complete_var)
if instruction == "source":
echo(comp.source())
return 0
if instruction == "complete":
echo(comp.complete())
return 0
return 1
class CompletionItem:
"""Represents a completion value and metadata about the value. The
default metadata is ``type`` to indicate special shell handling,
and ``help`` if a shell supports showing a help string next to the
value.
Arbitrary parameters can be passed when creating the object, and
accessed using ``item.attr``. If an attribute wasn't passed,
accessing it returns ``None``.
:param value: The completion suggestion.
:param type: Tells the shell script to provide special completion
support for the type. Click uses ``"dir"`` and ``"file"``.
:param help: String shown next to the value if supported.
:param kwargs: Arbitrary metadata. The built-in implementations
don't use this, but custom type completions paired with custom
shell support could use it.
"""
__slots__ = ("value", "type", "help", "_info")
def __init__(
self,
value: t.Any,
type: str = "plain",
help: t.Optional[str] = None,
**kwargs: t.Any,
) -> None:
self.value = value
self.type = type
self.help = help
self._info = kwargs
def __getattr__(self, name: str) -> t.Any:
return self._info.get(name)
# Only Bash >= 4.4 has the nosort option.
_SOURCE_BASH = """\
%(complete_func)s() {
local IFS=$'\\n'
local response
response=$(env COMP_WORDS="${COMP_WORDS[*]}" COMP_CWORD=$COMP_CWORD \
%(complete_var)s=bash_complete $1)
for completion in $response; do
IFS=',' read type value <<< "$completion"
if [[ $type == 'dir' ]]; then
COMPREPLY=()
compopt -o dirnames
elif [[ $type == 'file' ]]; then
COMPREPLY=()
compopt -o default
elif [[ $type == 'plain' ]]; then
COMPREPLY+=($value)
fi
done
return 0
}
%(complete_func)s_setup() {
complete -o nosort -F %(complete_func)s %(prog_name)s
}
%(complete_func)s_setup;
"""
_SOURCE_ZSH = """\
#compdef %(prog_name)s
%(complete_func)s() {
local -a completions
local -a completions_with_descriptions
local -a response
(( ! $+commands[%(prog_name)s] )) && return 1
response=("${(@f)$(env COMP_WORDS="${words[*]}" COMP_CWORD=$((CURRENT-1)) \
%(complete_var)s=zsh_complete %(prog_name)s)}")
for type key descr in ${response}; do
if [[ "$type" == "plain" ]]; then
if [[ "$descr" == "_" ]]; then
completions+=("$key")
else
completions_with_descriptions+=("$key":"$descr")
fi
elif [[ "$type" == "dir" ]]; then
_path_files -/
elif [[ "$type" == "file" ]]; then
_path_files -f
fi
done
if [ -n "$completions_with_descriptions" ]; then
_describe -V unsorted completions_with_descriptions -U
fi
if [ -n "$completions" ]; then
compadd -U -V unsorted -a completions
fi
}
compdef %(complete_func)s %(prog_name)s;
"""
_SOURCE_FISH = """\
function %(complete_func)s;
set -l response;
for value in (env %(complete_var)s=fish_complete COMP_WORDS=(commandline -cp) \
COMP_CWORD=(commandline -t) %(prog_name)s);
set response $response $value;
end;
for completion in $response;
set -l metadata (string split "," $completion);
if test $metadata[1] = "dir";
__fish_complete_directories $metadata[2];
else if test $metadata[1] = "file";
__fish_complete_path $metadata[2];
else if test $metadata[1] = "plain";
echo $metadata[2];
end;
end;
end;
complete --no-files --command %(prog_name)s --arguments \
"(%(complete_func)s)";
"""
class ShellComplete:
"""Base class for providing shell completion support. A subclass for
a given shell will override attributes and methods to implement the
completion instructions (``source`` and ``complete``).
:param cli: Command being called.
:param prog_name: Name of the executable in the shell.
:param complete_var: Name of the environment variable that holds
the completion instruction.
.. versionadded:: 8.0
"""
name: t.ClassVar[str]
"""Name to register the shell as with :func:`add_completion_class`.
This is used in completion instructions (``{name}_source`` and
``{name}_complete``).
"""
source_template: t.ClassVar[str]
"""Completion script template formatted by :meth:`source`. This must
be provided by subclasses.
"""
def __init__(
self,
cli: BaseCommand,
ctx_args: t.Dict[str, t.Any],
prog_name: str,
complete_var: str,
) -> None:
self.cli = cli
self.ctx_args = ctx_args
self.prog_name = prog_name
self.complete_var = complete_var
@property
def func_name(self) -> str:
"""The name of the shell function defined by the completion
script.
"""
safe_name = re.sub(r"\W*", "", self.prog_name.replace("-", "_"), re.ASCII)
return f"_{safe_name}_completion"
def source_vars(self) -> t.Dict[str, t.Any]:
"""Vars for formatting :attr:`source_template`.
By default this provides ``complete_func``, ``complete_var``,
and ``prog_name``.
"""
return {
"complete_func": self.func_name,
"complete_var": self.complete_var,
"prog_name": self.prog_name,
}
def source(self) -> str:
"""Produce the shell script that defines the completion
function. By default this ``%``-style formats
:attr:`source_template` with the dict returned by
:meth:`source_vars`.
"""
return self.source_template % self.source_vars()
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
"""Use the env vars defined by the shell script to return a
tuple of ``args, incomplete``. This must be implemented by
subclasses.
"""
raise NotImplementedError
def get_completions(
self, args: t.List[str], incomplete: str
) -> t.List[CompletionItem]:
"""Determine the context and last complete command or parameter
from the complete args. Call that object's ``shell_complete``
method to get the completions for the incomplete value.
:param args: List of complete args before the incomplete value.
:param incomplete: Value being completed. May be empty.
"""
ctx = _resolve_context(self.cli, self.ctx_args, self.prog_name, args)
obj, incomplete = _resolve_incomplete(ctx, args, incomplete)
return obj.shell_complete(ctx, incomplete)
def format_completion(self, item: CompletionItem) -> str:
"""Format a completion item into the form recognized by the
shell script. This must be implemented by subclasses.
:param item: Completion item to format.
"""
raise NotImplementedError
def complete(self) -> str:
"""Produce the completion data to send back to the shell.
By default this calls :meth:`get_completion_args`, gets the
completions, then calls :meth:`format_completion` for each
completion.
"""
args, incomplete = self.get_completion_args()
completions = self.get_completions(args, incomplete)
out = [self.format_completion(item) for item in completions]
return "\n".join(out)
class BashComplete(ShellComplete):
"""Shell completion for Bash."""
name = "bash"
source_template = _SOURCE_BASH
def _check_version(self) -> None:
import subprocess
output = subprocess.run(
["bash", "-c", "echo ${BASH_VERSION}"], stdout=subprocess.PIPE
)
match = re.search(r"^(\d+)\.(\d+)\.\d+", output.stdout.decode())
if match is not None:
major, minor = match.groups()
if major < "4" or major == "4" and minor < "4":
raise RuntimeError(
_(
"Shell completion is not supported for Bash"
" versions older than 4.4."
)
)
else:
raise RuntimeError(
_("Couldn't detect Bash version, shell completion is not supported.")
)
def source(self) -> str:
self._check_version()
return super().source()
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
cwords = split_arg_string(os.environ["COMP_WORDS"])
cword = int(os.environ["COMP_CWORD"])
args = cwords[1:cword]
try:
incomplete = cwords[cword]
except IndexError:
incomplete = ""
return args, incomplete
def format_completion(self, item: CompletionItem) -> str:
return f"{item.type},{item.value}"
class ZshComplete(ShellComplete):
"""Shell completion for Zsh."""
name = "zsh"
source_template = _SOURCE_ZSH
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
cwords = split_arg_string(os.environ["COMP_WORDS"])
cword = int(os.environ["COMP_CWORD"])
args = cwords[1:cword]
try:
incomplete = cwords[cword]
except IndexError:
incomplete = ""
return args, incomplete
def format_completion(self, item: CompletionItem) -> str:
return f"{item.type}\n{item.value}\n{item.help if item.help else '_'}"
class FishComplete(ShellComplete):
"""Shell completion for Fish."""
name = "fish"
source_template = _SOURCE_FISH
def get_completion_args(self) -> t.Tuple[t.List[str], str]:
cwords = split_arg_string(os.environ["COMP_WORDS"])
incomplete = os.environ["COMP_CWORD"]
args = cwords[1:]
# Fish stores the partial word in both COMP_WORDS and
# COMP_CWORD, remove it from complete args.
if incomplete and args and args[-1] == incomplete:
args.pop()
return args, incomplete
def format_completion(self, item: CompletionItem) -> str:
if item.help:
return f"{item.type},{item.value}\t{item.help}"
return f"{item.type},{item.value}"
_available_shells: t.Dict[str, t.Type[ShellComplete]] = {
"bash": BashComplete,
"fish": FishComplete,
"zsh": ZshComplete,
}
def add_completion_class(
cls: t.Type[ShellComplete], name: t.Optional[str] = None
) -> None:
"""Register a :class:`ShellComplete` subclass under the given name.
The name will be provided by the completion instruction environment
variable during completion.
:param cls: The completion class that will handle completion for the
shell.
:param name: Name to register the class under. Defaults to the
class's ``name`` attribute.
"""
if name is None:
name = cls.name
_available_shells[name] = cls
def get_completion_class(shell: str) -> t.Optional[t.Type[ShellComplete]]:
"""Look up a registered :class:`ShellComplete` subclass by the name
provided by the completion instruction environment variable. If the
name isn't registered, returns ``None``.
:param shell: Name the class is registered under.
"""
return _available_shells.get(shell)
def _is_incomplete_argument(ctx: Context, param: Parameter) -> bool:
"""Determine if the given parameter is an argument that can still
accept values.
:param ctx: Invocation context for the command represented by the
parsed complete args.
:param param: Argument object being checked.
"""
if not isinstance(param, Argument):
return False
assert param.name is not None
value = ctx.params[param.name]
return (
param.nargs == -1
or ctx.get_parameter_source(param.name) is not ParameterSource.COMMANDLINE
or (
param.nargs > 1
and isinstance(value, (tuple, list))
and len(value) < param.nargs
)
)
def _start_of_option(ctx: Context, value: str) -> bool:
"""Check if the value looks like the start of an option."""
if not value:
return False
c = value[0]
return c in ctx._opt_prefixes
def _is_incomplete_option(ctx: Context, args: t.List[str], param: Parameter) -> bool:
"""Determine if the given parameter is an option that needs a value.
:param args: List of complete args before the incomplete value.
:param param: Option object being checked.
"""
if not isinstance(param, Option):
return False
if param.is_flag or param.count:
return False
last_option = None
for index, arg in enumerate(reversed(args)):
if index + 1 > param.nargs:
break
if _start_of_option(ctx, arg):
last_option = arg
return last_option is not None and last_option in param.opts
def _resolve_context(
cli: BaseCommand, ctx_args: t.Dict[str, t.Any], prog_name: str, args: t.List[str]
) -> Context:
"""Produce the context hierarchy starting with the command and
traversing the complete arguments. This only follows the commands,
it doesn't trigger input prompts or callbacks.
:param cli: Command being called.
:param prog_name: Name of the executable in the shell.
:param args: List of complete args before the incomplete value.
"""
ctx_args["resilient_parsing"] = True
ctx = cli.make_context(prog_name, args.copy(), **ctx_args)
args = ctx.protected_args + ctx.args
while args:
command = ctx.command
if isinstance(command, MultiCommand):
if not command.chain:
name, cmd, args = command.resolve_command(ctx, args)
if cmd is None:
return ctx
ctx = cmd.make_context(name, args, parent=ctx, resilient_parsing=True)
args = ctx.protected_args + ctx.args
else:
while args:
name, cmd, args = command.resolve_command(ctx, args)
if cmd is None:
return ctx
sub_ctx = cmd.make_context(
name,
args,
parent=ctx,
allow_extra_args=True,
allow_interspersed_args=False,
resilient_parsing=True,
)
args = sub_ctx.args
ctx = sub_ctx
args = [*sub_ctx.protected_args, *sub_ctx.args]
else:
break
return ctx
def _resolve_incomplete(
ctx: Context, args: t.List[str], incomplete: str
) -> t.Tuple[t.Union[BaseCommand, Parameter], str]:
"""Find the Click object that will handle the completion of the
incomplete value. Return the object and the incomplete value.
:param ctx: Invocation context for the command represented by
the parsed complete args.
:param args: List of complete args before the incomplete value.
:param incomplete: Value being completed. May be empty.
"""
# Different shells treat an "=" between a long option name and
# value differently. Might keep the value joined, return the "="
# as a separate item, or return the split name and value. Always
# split and discard the "=" to make completion easier.
if incomplete == "=":
incomplete = ""
elif "=" in incomplete and _start_of_option(ctx, incomplete):
name, _, incomplete = incomplete.partition("=")
args.append(name)
# The "--" marker tells Click to stop treating values as options
# even if they start with the option character. If it hasn't been
# given and the incomplete arg looks like an option, the current
# command will provide option name completions.
if "--" not in args and _start_of_option(ctx, incomplete):
return ctx.command, incomplete
params = ctx.command.get_params(ctx)
# If the last complete arg is an option name with an incomplete
# value, the option will provide value completions.
for param in params:
if _is_incomplete_option(ctx, args, param):
return param, incomplete
# It's not an option name or value. The first argument without a
# parsed value will provide value completions.
for param in params:
if _is_incomplete_argument(ctx, param):
return param, incomplete
# There were no unparsed arguments, the command may be a group that
# will provide command name completions.
return ctx.command, incomplete

View file

@ -1,81 +1,109 @@
import inspect
import io
import itertools
import os import os
import sys import sys
import struct import typing as t
import inspect from gettext import gettext as _
import itertools
from ._compat import raw_input, text_type, string_types, \ from ._compat import isatty
isatty, strip_ansi, get_winterm_size, DEFAULT_COLUMNS, WIN from ._compat import strip_ansi
from .utils import echo from ._compat import WIN
from .exceptions import Abort, UsageError from .exceptions import Abort
from .types import convert_type, Choice, Path from .exceptions import UsageError
from .globals import resolve_color_default from .globals import resolve_color_default
from .types import Choice
from .types import convert_type
from .types import ParamType
from .utils import echo
from .utils import LazyFile
if t.TYPE_CHECKING:
from ._termui_impl import ProgressBar
V = t.TypeVar("V")
# The prompt functions to use. The doc tools currently override these # The prompt functions to use. The doc tools currently override these
# functions to customize how they work. # functions to customize how they work.
visible_prompt_func = raw_input visible_prompt_func: t.Callable[[str], str] = input
_ansi_colors = { _ansi_colors = {
'black': 30, "black": 30,
'red': 31, "red": 31,
'green': 32, "green": 32,
'yellow': 33, "yellow": 33,
'blue': 34, "blue": 34,
'magenta': 35, "magenta": 35,
'cyan': 36, "cyan": 36,
'white': 37, "white": 37,
'reset': 39, "reset": 39,
'bright_black': 90, "bright_black": 90,
'bright_red': 91, "bright_red": 91,
'bright_green': 92, "bright_green": 92,
'bright_yellow': 93, "bright_yellow": 93,
'bright_blue': 94, "bright_blue": 94,
'bright_magenta': 95, "bright_magenta": 95,
'bright_cyan': 96, "bright_cyan": 96,
'bright_white': 97, "bright_white": 97,
} }
_ansi_reset_all = '\033[0m' _ansi_reset_all = "\033[0m"
def hidden_prompt_func(prompt): def hidden_prompt_func(prompt: str) -> str:
import getpass import getpass
return getpass.getpass(prompt) return getpass.getpass(prompt)
def _build_prompt(text, suffix, show_default=False, default=None, show_choices=True, type=None): def _build_prompt(
text: str,
suffix: str,
show_default: bool = False,
default: t.Optional[t.Any] = None,
show_choices: bool = True,
type: t.Optional[ParamType] = None,
) -> str:
prompt = text prompt = text
if type is not None and show_choices and isinstance(type, Choice): if type is not None and show_choices and isinstance(type, Choice):
prompt += ' (' + ", ".join(map(str, type.choices)) + ')' prompt += f" ({', '.join(map(str, type.choices))})"
if default is not None and show_default: if default is not None and show_default:
prompt = '%s [%s]' % (prompt, default) prompt = f"{prompt} [{_format_default(default)}]"
return prompt + suffix return f"{prompt}{suffix}"
def prompt(text, default=None, hide_input=False, confirmation_prompt=False, def _format_default(default: t.Any) -> t.Any:
type=None, value_proc=None, prompt_suffix=': ', show_default=True, if isinstance(default, (io.IOBase, LazyFile)) and hasattr(default, "name"):
err=False, show_choices=True): return default.name # type: ignore
return default
def prompt(
text: str,
default: t.Optional[t.Any] = None,
hide_input: bool = False,
confirmation_prompt: t.Union[bool, str] = False,
type: t.Optional[t.Union[ParamType, t.Any]] = None,
value_proc: t.Optional[t.Callable[[str], t.Any]] = None,
prompt_suffix: str = ": ",
show_default: bool = True,
err: bool = False,
show_choices: bool = True,
) -> t.Any:
"""Prompts a user for input. This is a convenience function that can """Prompts a user for input. This is a convenience function that can
be used to prompt a user for input later. be used to prompt a user for input later.
If the user aborts the input by sending a interrupt signal, this If the user aborts the input by sending an interrupt signal, this
function will catch it and raise a :exc:`Abort` exception. function will catch it and raise a :exc:`Abort` exception.
.. versionadded:: 7.0
Added the show_choices parameter.
.. versionadded:: 6.0
Added unicode support for cmd.exe on Windows.
.. versionadded:: 4.0
Added the `err` parameter.
:param text: the text to show for the prompt. :param text: the text to show for the prompt.
:param default: the default value to use if no input happens. If this :param default: the default value to use if no input happens. If this
is not given it will prompt until it's aborted. is not given it will prompt until it's aborted.
:param hide_input: if this is set to true then the input value will :param hide_input: if this is set to true then the input value will
be hidden. be hidden.
:param confirmation_prompt: asks for confirmation for the value. :param confirmation_prompt: Prompt a second time to confirm the
value. Can be set to a string instead of ``True`` to customize
the message.
:param type: the type to use to check the value against. :param type: the type to use to check the value against.
:param value_proc: if this parameter is provided it's a function that :param value_proc: if this parameter is provided it's a function that
is invoked instead of the type conversion to is invoked instead of the type conversion to
@ -88,93 +116,133 @@ def prompt(text, default=None, hide_input=False, confirmation_prompt=False,
For example if type is a Choice of either day or week, For example if type is a Choice of either day or week,
show_choices is true and text is "Group by" then the show_choices is true and text is "Group by" then the
prompt will be "Group by (day, week): ". prompt will be "Group by (day, week): ".
"""
result = None
def prompt_func(text): .. versionadded:: 8.0
f = hide_input and hidden_prompt_func or visible_prompt_func ``confirmation_prompt`` can be a custom string.
.. versionadded:: 7.0
Added the ``show_choices`` parameter.
.. versionadded:: 6.0
Added unicode support for cmd.exe on Windows.
.. versionadded:: 4.0
Added the `err` parameter.
"""
def prompt_func(text: str) -> str:
f = hidden_prompt_func if hide_input else visible_prompt_func
try: try:
# Write the prompt separately so that we get nice # Write the prompt separately so that we get nice
# coloring through colorama on Windows # coloring through colorama on Windows
echo(text, nl=False, err=err) echo(text.rstrip(" "), nl=False, err=err)
return f('') # Echo a space to stdout to work around an issue where
# readline causes backspace to clear the whole line.
return f(" ")
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
# getpass doesn't print a newline if the user aborts input with ^C. # getpass doesn't print a newline if the user aborts input with ^C.
# Allegedly this behavior is inherited from getpass(3). # Allegedly this behavior is inherited from getpass(3).
# A doc bug has been filed at https://bugs.python.org/issue24711 # A doc bug has been filed at https://bugs.python.org/issue24711
if hide_input: if hide_input:
echo(None, err=err) echo(None, err=err)
raise Abort() raise Abort() from None
if value_proc is None: if value_proc is None:
value_proc = convert_type(type, default) value_proc = convert_type(type, default)
prompt = _build_prompt(text, prompt_suffix, show_default, default, show_choices, type) prompt = _build_prompt(
text, prompt_suffix, show_default, default, show_choices, type
)
while 1: if confirmation_prompt:
while 1: if confirmation_prompt is True:
confirmation_prompt = _("Repeat for confirmation")
confirmation_prompt = _build_prompt(confirmation_prompt, prompt_suffix)
while True:
while True:
value = prompt_func(prompt) value = prompt_func(prompt)
if value: if value:
break break
elif default is not None: elif default is not None:
if isinstance(value_proc, Path): value = default
# validate Path default value(exists, dir_okay etc.) break
value = default
break
return default
try: try:
result = value_proc(value) result = value_proc(value)
except UsageError as e: except UsageError as e:
echo('Error: %s' % e.message, err=err) if hide_input:
echo(_("Error: The value you entered was invalid."), err=err)
else:
echo(_("Error: {e.message}").format(e=e), err=err) # noqa: B306
continue continue
if not confirmation_prompt: if not confirmation_prompt:
return result return result
while 1: while True:
value2 = prompt_func('Repeat for confirmation: ') value2 = prompt_func(confirmation_prompt)
if value2: is_empty = not value and not value2
if value2 or is_empty:
break break
if value == value2: if value == value2:
return result return result
echo('Error: the two entered values do not match', err=err) echo(_("Error: The two entered values do not match."), err=err)
def confirm(text, default=False, abort=False, prompt_suffix=': ', def confirm(
show_default=True, err=False): text: str,
default: t.Optional[bool] = False,
abort: bool = False,
prompt_suffix: str = ": ",
show_default: bool = True,
err: bool = False,
) -> bool:
"""Prompts for confirmation (yes/no question). """Prompts for confirmation (yes/no question).
If the user aborts the input by sending a interrupt signal this If the user aborts the input by sending a interrupt signal this
function will catch it and raise a :exc:`Abort` exception. function will catch it and raise a :exc:`Abort` exception.
.. versionadded:: 4.0
Added the `err` parameter.
:param text: the question to ask. :param text: the question to ask.
:param default: the default for the prompt. :param default: The default value to use when no input is given. If
``None``, repeat until input is given.
:param abort: if this is set to `True` a negative answer aborts the :param abort: if this is set to `True` a negative answer aborts the
exception by raising :exc:`Abort`. exception by raising :exc:`Abort`.
:param prompt_suffix: a suffix that should be added to the prompt. :param prompt_suffix: a suffix that should be added to the prompt.
:param show_default: shows or hides the default value in the prompt. :param show_default: shows or hides the default value in the prompt.
:param err: if set to true the file defaults to ``stderr`` instead of :param err: if set to true the file defaults to ``stderr`` instead of
``stdout``, the same as with echo. ``stdout``, the same as with echo.
.. versionchanged:: 8.0
Repeat until input is given if ``default`` is ``None``.
.. versionadded:: 4.0
Added the ``err`` parameter.
""" """
prompt = _build_prompt(text, prompt_suffix, show_default, prompt = _build_prompt(
default and 'Y/n' or 'y/N') text,
while 1: prompt_suffix,
show_default,
"y/n" if default is None else ("Y/n" if default else "y/N"),
)
while True:
try: try:
# Write the prompt separately so that we get nice # Write the prompt separately so that we get nice
# coloring through colorama on Windows # coloring through colorama on Windows
echo(prompt, nl=False, err=err) echo(prompt.rstrip(" "), nl=False, err=err)
value = visible_prompt_func('').lower().strip() # Echo a space to stdout to work around an issue where
# readline causes backspace to clear the whole line.
value = visible_prompt_func(" ").lower().strip()
except (KeyboardInterrupt, EOFError): except (KeyboardInterrupt, EOFError):
raise Abort() raise Abort() from None
if value in ('y', 'yes'): if value in ("y", "yes"):
rv = True rv = True
elif value in ('n', 'no'): elif value in ("n", "no"):
rv = False rv = False
elif value == '': elif default is not None and value == "":
rv = default rv = default
else: else:
echo('Error: invalid input', err=err) echo(_("Error: invalid input"), err=err)
continue continue
break break
if abort and not rv: if abort and not rv:
@ -182,54 +250,10 @@ def confirm(text, default=False, abort=False, prompt_suffix=': ',
return rv return rv
def get_terminal_size(): def echo_via_pager(
"""Returns the current size of the terminal as tuple in the form text_or_generator: t.Union[t.Iterable[str], t.Callable[[], t.Iterable[str]], str],
``(width, height)`` in columns and rows. color: t.Optional[bool] = None,
""" ) -> None:
# If shutil has get_terminal_size() (Python 3.3 and later) use that
if sys.version_info >= (3, 3):
import shutil
shutil_get_terminal_size = getattr(shutil, 'get_terminal_size', None)
if shutil_get_terminal_size:
sz = shutil_get_terminal_size()
return sz.columns, sz.lines
# We provide a sensible default for get_winterm_size() when being invoked
# inside a subprocess. Without this, it would not provide a useful input.
if get_winterm_size is not None:
size = get_winterm_size()
if size == (0, 0):
return (79, 24)
else:
return size
def ioctl_gwinsz(fd):
try:
import fcntl
import termios
cr = struct.unpack(
'hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234'))
except Exception:
return
return cr
cr = ioctl_gwinsz(0) or ioctl_gwinsz(1) or ioctl_gwinsz(2)
if not cr:
try:
fd = os.open(os.ctermid(), os.O_RDONLY)
try:
cr = ioctl_gwinsz(fd)
finally:
os.close(fd)
except Exception:
pass
if not cr or not cr[0] or not cr[1]:
cr = (os.environ.get('LINES', 25),
os.environ.get('COLUMNS', DEFAULT_COLUMNS))
return int(cr[1]), int(cr[0])
def echo_via_pager(text_or_generator, color=None):
"""This function takes a text and shows it via an environment specific """This function takes a text and shows it via an environment specific
pager on stdout. pager on stdout.
@ -244,25 +268,37 @@ def echo_via_pager(text_or_generator, color=None):
color = resolve_color_default(color) color = resolve_color_default(color)
if inspect.isgeneratorfunction(text_or_generator): if inspect.isgeneratorfunction(text_or_generator):
i = text_or_generator() i = t.cast(t.Callable[[], t.Iterable[str]], text_or_generator)()
elif isinstance(text_or_generator, string_types): elif isinstance(text_or_generator, str):
i = [text_or_generator] i = [text_or_generator]
else: else:
i = iter(text_or_generator) i = iter(t.cast(t.Iterable[str], text_or_generator))
# convert every element of i to a text type if necessary # convert every element of i to a text type if necessary
text_generator = (el if isinstance(el, string_types) else text_type(el) text_generator = (el if isinstance(el, str) else str(el) for el in i)
for el in i)
from ._termui_impl import pager from ._termui_impl import pager
return pager(itertools.chain(text_generator, "\n"), color) return pager(itertools.chain(text_generator, "\n"), color)
def progressbar(iterable=None, length=None, label=None, show_eta=True, def progressbar(
show_percent=None, show_pos=False, iterable: t.Optional[t.Iterable[V]] = None,
item_show_func=None, fill_char='#', empty_char='-', length: t.Optional[int] = None,
bar_template='%(label)s [%(bar)s] %(info)s', label: t.Optional[str] = None,
info_sep=' ', width=36, file=None, color=None): show_eta: bool = True,
show_percent: t.Optional[bool] = None,
show_pos: bool = False,
item_show_func: t.Optional[t.Callable[[t.Optional[V]], t.Optional[str]]] = None,
fill_char: str = "#",
empty_char: str = "-",
bar_template: str = "%(label)s [%(bar)s] %(info)s",
info_sep: str = " ",
width: int = 36,
file: t.Optional[t.TextIO] = None,
color: t.Optional[bool] = None,
update_min_steps: int = 1,
) -> "ProgressBar[V]":
"""This function creates an iterable context manager that can be used """This function creates an iterable context manager that can be used
to iterate over something while showing a progress bar. It will to iterate over something while showing a progress bar. It will
either iterate over the `iterable` or `length` items (that are counted either iterate over the `iterable` or `length` items (that are counted
@ -272,11 +308,17 @@ def progressbar(iterable=None, length=None, label=None, show_eta=True,
will not be rendered if the file is not a terminal. will not be rendered if the file is not a terminal.
The context manager creates the progress bar. When the context The context manager creates the progress bar. When the context
manager is entered the progress bar is already displayed. With every manager is entered the progress bar is already created. With every
iteration over the progress bar, the iterable passed to the bar is iteration over the progress bar, the iterable passed to the bar is
advanced and the bar is updated. When the context manager exits, advanced and the bar is updated. When the context manager exits,
a newline is printed and the progress bar is finalized on screen. a newline is printed and the progress bar is finalized on screen.
Note: The progress bar is currently designed for use cases where the
total progress can be expected to take at least several seconds.
Because of this, the ProgressBar class object won't display
progress that is considered too fast, and progress where the time
between steps is less than a second.
No printing must happen or the progress bar will be unintentionally No printing must happen or the progress bar will be unintentionally
destroyed. destroyed.
@ -296,11 +338,19 @@ def progressbar(iterable=None, length=None, label=None, show_eta=True,
process_chunk(chunk) process_chunk(chunk)
bar.update(chunks.bytes) bar.update(chunks.bytes)
.. versionadded:: 2.0 The ``update()`` method also takes an optional value specifying the
``current_item`` at the new position. This is useful when used
together with ``item_show_func`` to customize the output for each
manual step::
.. versionadded:: 4.0 with click.progressbar(
Added the `color` parameter. Added a `update` method to the length=total_size,
progressbar object. label='Unzipping archive',
item_show_func=lambda a: a.filename
) as bar:
for archive in zip_file:
archive.extract()
bar.update(archive.size, archive)
:param iterable: an iterable to iterate over. If not provided the length :param iterable: an iterable to iterate over. If not provided the length
is required. is required.
@ -319,10 +369,10 @@ def progressbar(iterable=None, length=None, label=None, show_eta=True,
`False` if not. `False` if not.
:param show_pos: enables or disables the absolute position display. The :param show_pos: enables or disables the absolute position display. The
default is `False`. default is `False`.
:param item_show_func: a function called with the current item which :param item_show_func: A function called with the current item which
can return a string to show the current item can return a string to show next to the progress bar. If the
next to the progress bar. Note that the current function returns ``None`` nothing is shown. The current item can
item can be `None`! be ``None``, such as when entering and exiting the bar.
:param fill_char: the character to use to show the filled part of the :param fill_char: the character to use to show the filled part of the
progress bar. progress bar.
:param empty_char: the character to use to show the non-filled part of :param empty_char: the character to use to show the non-filled part of
@ -334,24 +384,57 @@ def progressbar(iterable=None, length=None, label=None, show_eta=True,
:param info_sep: the separator between multiple info items (eta etc.) :param info_sep: the separator between multiple info items (eta etc.)
:param width: the width of the progress bar in characters, 0 means full :param width: the width of the progress bar in characters, 0 means full
terminal width terminal width
:param file: the file to write to. If this is not a terminal then :param file: The file to write to. If this is not a terminal then
only the label is printed. only the label is printed.
:param color: controls if the terminal supports ANSI colors or not. The :param color: controls if the terminal supports ANSI colors or not. The
default is autodetection. This is only needed if ANSI default is autodetection. This is only needed if ANSI
codes are included anywhere in the progress bar output codes are included anywhere in the progress bar output
which is not the case by default. which is not the case by default.
:param update_min_steps: Render only when this many updates have
completed. This allows tuning for very fast iterators.
.. versionchanged:: 8.0
Output is shown even if execution time is less than 0.5 seconds.
.. versionchanged:: 8.0
``item_show_func`` shows the current item, not the previous one.
.. versionchanged:: 8.0
Labels are echoed if the output is not a TTY. Reverts a change
in 7.0 that removed all output.
.. versionadded:: 8.0
Added the ``update_min_steps`` parameter.
.. versionchanged:: 4.0
Added the ``color`` parameter. Added the ``update`` method to
the object.
.. versionadded:: 2.0
""" """
from ._termui_impl import ProgressBar from ._termui_impl import ProgressBar
color = resolve_color_default(color) color = resolve_color_default(color)
return ProgressBar(iterable=iterable, length=length, show_eta=show_eta, return ProgressBar(
show_percent=show_percent, show_pos=show_pos, iterable=iterable,
item_show_func=item_show_func, fill_char=fill_char, length=length,
empty_char=empty_char, bar_template=bar_template, show_eta=show_eta,
info_sep=info_sep, file=file, label=label, show_percent=show_percent,
width=width, color=color) show_pos=show_pos,
item_show_func=item_show_func,
fill_char=fill_char,
empty_char=empty_char,
bar_template=bar_template,
info_sep=info_sep,
file=file,
label=label,
width=width,
color=color,
update_min_steps=update_min_steps,
)
def clear(): def clear() -> None:
"""Clears the terminal screen. This will have the effect of clearing """Clears the terminal screen. This will have the effect of clearing
the whole visible space of the terminal and moving the cursor to the the whole visible space of the terminal and moving the cursor to the
top left. This does not do anything if not connected to a terminal. top left. This does not do anything if not connected to a terminal.
@ -360,17 +443,39 @@ def clear():
""" """
if not isatty(sys.stdout): if not isatty(sys.stdout):
return return
# If we're on Windows and we don't have colorama available, then we
# clear the screen by shelling out. Otherwise we can use an escape
# sequence.
if WIN: if WIN:
os.system('cls') os.system("cls")
else: else:
sys.stdout.write('\033[2J\033[1;1H') sys.stdout.write("\033[2J\033[1;1H")
def style(text, fg=None, bg=None, bold=None, dim=None, underline=None, def _interpret_color(
blink=None, reverse=None, reset=True): color: t.Union[int, t.Tuple[int, int, int], str], offset: int = 0
) -> str:
if isinstance(color, int):
return f"{38 + offset};5;{color:d}"
if isinstance(color, (tuple, list)):
r, g, b = color
return f"{38 + offset};2;{r:d};{g:d};{b:d}"
return str(_ansi_colors[color] + offset)
def style(
text: t.Any,
fg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
bg: t.Optional[t.Union[int, t.Tuple[int, int, int], str]] = None,
bold: t.Optional[bool] = None,
dim: t.Optional[bool] = None,
underline: t.Optional[bool] = None,
overline: t.Optional[bool] = None,
italic: t.Optional[bool] = None,
blink: t.Optional[bool] = None,
reverse: t.Optional[bool] = None,
strikethrough: t.Optional[bool] = None,
reset: bool = True,
) -> str:
"""Styles a text with ANSI styles and returns the new string. By """Styles a text with ANSI styles and returns the new string. By
default the styling is self contained which means that at the end default the styling is self contained which means that at the end
of the string a reset code is issued. This can be prevented by of the string a reset code is issued. This can be prevented by
@ -381,6 +486,7 @@ def style(text, fg=None, bg=None, bold=None, dim=None, underline=None,
click.echo(click.style('Hello World!', fg='green')) click.echo(click.style('Hello World!', fg='green'))
click.echo(click.style('ATTENTION!', blink=True)) click.echo(click.style('ATTENTION!', blink=True))
click.echo(click.style('Some things', reverse=True, fg='cyan')) click.echo(click.style('Some things', reverse=True, fg='cyan'))
click.echo(click.style('More colors', fg=(255, 12, 128), bg=117))
Supported color names: Supported color names:
@ -402,10 +508,15 @@ def style(text, fg=None, bg=None, bold=None, dim=None, underline=None,
* ``bright_white`` * ``bright_white``
* ``reset`` (reset the color code only) * ``reset`` (reset the color code only)
.. versionadded:: 2.0 If the terminal supports it, color may also be specified as:
.. versionadded:: 7.0 - An integer in the interval [0, 255]. The terminal must support
Added support for bright colors. 8-bit/256-color mode.
- An RGB tuple of three integers in [0, 255]. The terminal must
support 24-bit/true-color mode.
See https://en.wikipedia.org/wiki/ANSI_color and
https://gist.github.com/XVilka/8346728 for more information.
:param text: the string to style with ansi codes. :param text: the string to style with ansi codes.
:param fg: if provided this will become the foreground color. :param fg: if provided this will become the foreground color.
@ -414,42 +525,73 @@ def style(text, fg=None, bg=None, bold=None, dim=None, underline=None,
:param dim: if provided this will enable or disable dim mode. This is :param dim: if provided this will enable or disable dim mode. This is
badly supported. badly supported.
:param underline: if provided this will enable or disable underline. :param underline: if provided this will enable or disable underline.
:param overline: if provided this will enable or disable overline.
:param italic: if provided this will enable or disable italic.
:param blink: if provided this will enable or disable blinking. :param blink: if provided this will enable or disable blinking.
:param reverse: if provided this will enable or disable inverse :param reverse: if provided this will enable or disable inverse
rendering (foreground becomes background and the rendering (foreground becomes background and the
other way round). other way round).
:param strikethrough: if provided this will enable or disable
striking through text.
:param reset: by default a reset-all code is added at the end of the :param reset: by default a reset-all code is added at the end of the
string which means that styles do not carry over. This string which means that styles do not carry over. This
can be disabled to compose styles. can be disabled to compose styles.
.. versionchanged:: 8.0
A non-string ``message`` is converted to a string.
.. versionchanged:: 8.0
Added support for 256 and RGB color codes.
.. versionchanged:: 8.0
Added the ``strikethrough``, ``italic``, and ``overline``
parameters.
.. versionchanged:: 7.0
Added support for bright colors.
.. versionadded:: 2.0
""" """
if not isinstance(text, str):
text = str(text)
bits = [] bits = []
if fg: if fg:
try: try:
bits.append('\033[%dm' % (_ansi_colors[fg])) bits.append(f"\033[{_interpret_color(fg)}m")
except KeyError: except KeyError:
raise TypeError('Unknown color %r' % fg) raise TypeError(f"Unknown color {fg!r}") from None
if bg: if bg:
try: try:
bits.append('\033[%dm' % (_ansi_colors[bg] + 10)) bits.append(f"\033[{_interpret_color(bg, 10)}m")
except KeyError: except KeyError:
raise TypeError('Unknown color %r' % bg) raise TypeError(f"Unknown color {bg!r}") from None
if bold is not None: if bold is not None:
bits.append('\033[%dm' % (1 if bold else 22)) bits.append(f"\033[{1 if bold else 22}m")
if dim is not None: if dim is not None:
bits.append('\033[%dm' % (2 if dim else 22)) bits.append(f"\033[{2 if dim else 22}m")
if underline is not None: if underline is not None:
bits.append('\033[%dm' % (4 if underline else 24)) bits.append(f"\033[{4 if underline else 24}m")
if overline is not None:
bits.append(f"\033[{53 if overline else 55}m")
if italic is not None:
bits.append(f"\033[{3 if italic else 23}m")
if blink is not None: if blink is not None:
bits.append('\033[%dm' % (5 if blink else 25)) bits.append(f"\033[{5 if blink else 25}m")
if reverse is not None: if reverse is not None:
bits.append('\033[%dm' % (7 if reverse else 27)) bits.append(f"\033[{7 if reverse else 27}m")
if strikethrough is not None:
bits.append(f"\033[{9 if strikethrough else 29}m")
bits.append(text) bits.append(text)
if reset: if reset:
bits.append(_ansi_reset_all) bits.append(_ansi_reset_all)
return ''.join(bits) return "".join(bits)
def unstyle(text): def unstyle(text: str) -> str:
"""Removes ANSI styling information from a string. Usually it's not """Removes ANSI styling information from a string. Usually it's not
necessary to use this function as Click's echo function will necessary to use this function as Click's echo function will
automatically remove styling if necessary. automatically remove styling if necessary.
@ -461,7 +603,14 @@ def unstyle(text):
return strip_ansi(text) return strip_ansi(text)
def secho(message=None, file=None, nl=True, err=False, color=None, **styles): def secho(
message: t.Optional[t.Any] = None,
file: t.Optional[t.IO[t.AnyStr]] = None,
nl: bool = True,
err: bool = False,
color: t.Optional[bool] = None,
**styles: t.Any,
) -> None:
"""This function combines :func:`echo` and :func:`style` into one """This function combines :func:`echo` and :func:`style` into one
call. As such the following two calls are the same:: call. As such the following two calls are the same::
@ -471,15 +620,31 @@ def secho(message=None, file=None, nl=True, err=False, color=None, **styles):
All keyword arguments are forwarded to the underlying functions All keyword arguments are forwarded to the underlying functions
depending on which one they go with. depending on which one they go with.
Non-string types will be converted to :class:`str`. However,
:class:`bytes` are passed directly to :meth:`echo` without applying
style. If you want to style bytes that represent text, call
:meth:`bytes.decode` first.
.. versionchanged:: 8.0
A non-string ``message`` is converted to a string. Bytes are
passed through without style applied.
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
if message is not None: if message is not None and not isinstance(message, (bytes, bytearray)):
message = style(message, **styles) message = style(message, **styles)
return echo(message, file=file, nl=nl, err=err, color=color) return echo(message, file=file, nl=nl, err=err, color=color)
def edit(text=None, editor=None, env=None, require_save=True, def edit(
extension='.txt', filename=None): text: t.Optional[t.AnyStr] = None,
editor: t.Optional[str] = None,
env: t.Optional[t.Mapping[str, str]] = None,
require_save: bool = True,
extension: str = ".txt",
filename: t.Optional[str] = None,
) -> t.Optional[t.AnyStr]:
r"""Edits the given text in the defined editor. If an editor is given r"""Edits the given text in the defined editor. If an editor is given
(should be the full path to the executable but the regular operating (should be the full path to the executable but the regular operating
system search path is used for finding the executable) it overrides system search path is used for finding the executable) it overrides
@ -508,14 +673,17 @@ def edit(text=None, editor=None, env=None, require_save=True,
file as an indirection in that case. file as an indirection in that case.
""" """
from ._termui_impl import Editor from ._termui_impl import Editor
editor = Editor(editor=editor, env=env, require_save=require_save,
extension=extension) ed = Editor(editor=editor, env=env, require_save=require_save, extension=extension)
if filename is None: if filename is None:
return editor.edit(text) return ed.edit(text)
editor.edit_file(filename)
ed.edit_file(filename)
return None
def launch(url, wait=False, locate=False): def launch(url: str, wait: bool = False, locate: bool = False) -> int:
"""This function launches the given URL (or filename) in the default """This function launches the given URL (or filename) in the default
viewer application for this file type. If this is an executable, it viewer application for this file type. If this is an executable, it
might launch the executable in a new session. The return value is might launch the executable in a new session. The return value is
@ -530,7 +698,9 @@ def launch(url, wait=False, locate=False):
.. versionadded:: 2.0 .. versionadded:: 2.0
:param url: URL or filename of the thing to launch. :param url: URL or filename of the thing to launch.
:param wait: waits for the program to stop. :param wait: Wait for the program to exit before returning. This
only works if the launched program blocks. In particular,
``xdg-open`` on Linux does not block.
:param locate: if this is set to `True` then instead of launching the :param locate: if this is set to `True` then instead of launching the
application associated with the URL it will attempt to application associated with the URL it will attempt to
launch a file manager with the file located. This launch a file manager with the file located. This
@ -538,15 +708,16 @@ def launch(url, wait=False, locate=False):
the filesystem. the filesystem.
""" """
from ._termui_impl import open_url from ._termui_impl import open_url
return open_url(url, wait=wait, locate=locate) return open_url(url, wait=wait, locate=locate)
# If this is provided, getchar() calls into this instead. This is used # If this is provided, getchar() calls into this instead. This is used
# for unittesting purposes. # for unittesting purposes.
_getchar = None _getchar: t.Optional[t.Callable[[bool], str]] = None
def getchar(echo=False): def getchar(echo: bool = False) -> str:
"""Fetches a single character from the terminal and returns it. This """Fetches a single character from the terminal and returns it. This
will always return a unicode character and under certain rare will always return a unicode character and under certain rare
circumstances this might return more than one character. The circumstances this might return more than one character. The
@ -566,18 +737,23 @@ def getchar(echo=False):
:param echo: if set to `True`, the character read will also show up on :param echo: if set to `True`, the character read will also show up on
the terminal. The default is to not show it. the terminal. The default is to not show it.
""" """
f = _getchar global _getchar
if f is None:
if _getchar is None:
from ._termui_impl import getchar as f from ._termui_impl import getchar as f
return f(echo)
_getchar = f
return _getchar(echo)
def raw_terminal(): def raw_terminal() -> t.ContextManager[int]:
from ._termui_impl import raw_terminal as f from ._termui_impl import raw_terminal as f
return f() return f()
def pause(info='Press any key to continue ...', err=False): def pause(info: t.Optional[str] = None, err: bool = False) -> None:
"""This command stops execution and waits for the user to press any """This command stops execution and waits for the user to press any
key to continue. This is similar to the Windows batch "pause" key to continue. This is similar to the Windows batch "pause"
command. If the program is not run through a terminal, this command command. If the program is not run through a terminal, this command
@ -588,12 +764,17 @@ def pause(info='Press any key to continue ...', err=False):
.. versionadded:: 4.0 .. versionadded:: 4.0
Added the `err` parameter. Added the `err` parameter.
:param info: the info string to print before pausing. :param info: The message to print before pausing. Defaults to
``"Press any key to continue..."``.
:param err: if set to message goes to ``stderr`` instead of :param err: if set to message goes to ``stderr`` instead of
``stdout``, the same as with echo. ``stdout``, the same as with echo.
""" """
if not isatty(sys.stdin) or not isatty(sys.stdout): if not isatty(sys.stdin) or not isatty(sys.stdout):
return return
if info is None:
info = _("Press any key to continue...")
try: try:
if info: if info:
echo(info, nl=False, err=err) echo(info, nl=False, err=err)

View file

@ -1,86 +1,128 @@
import os
import sys
import shutil
import tempfile
import contextlib import contextlib
import io
import os
import shlex import shlex
import shutil
import sys
import tempfile
import typing as t
from types import TracebackType
from ._compat import iteritems, PY2, string_types from . import formatting
from . import termui
from . import utils
from ._compat import _find_binary_reader
if t.TYPE_CHECKING:
from .core import BaseCommand
# If someone wants to vendor click, we want to ensure the class EchoingStdin:
# correct package is discovered. Ideally we could use a def __init__(self, input: t.BinaryIO, output: t.BinaryIO) -> None:
# relative import here but unfortunately Python does not
# support that.
clickpkg = sys.modules[__name__.rsplit('.', 1)[0]]
if PY2:
from cStringIO import StringIO
else:
import io
from ._compat import _find_binary_reader
class EchoingStdin(object):
def __init__(self, input, output):
self._input = input self._input = input
self._output = output self._output = output
self._paused = False
def __getattr__(self, x): def __getattr__(self, x: str) -> t.Any:
return getattr(self._input, x) return getattr(self._input, x)
def _echo(self, rv): def _echo(self, rv: bytes) -> bytes:
self._output.write(rv) if not self._paused:
self._output.write(rv)
return rv return rv
def read(self, n=-1): def read(self, n: int = -1) -> bytes:
return self._echo(self._input.read(n)) return self._echo(self._input.read(n))
def readline(self, n=-1): def read1(self, n: int = -1) -> bytes:
return self._echo(self._input.read1(n)) # type: ignore
def readline(self, n: int = -1) -> bytes:
return self._echo(self._input.readline(n)) return self._echo(self._input.readline(n))
def readlines(self): def readlines(self) -> t.List[bytes]:
return [self._echo(x) for x in self._input.readlines()] return [self._echo(x) for x in self._input.readlines()]
def __iter__(self): def __iter__(self) -> t.Iterator[bytes]:
return iter(self._echo(x) for x in self._input) return iter(self._echo(x) for x in self._input)
def __repr__(self): def __repr__(self) -> str:
return repr(self._input) return repr(self._input)
def make_input_stream(input, charset): @contextlib.contextmanager
def _pause_echo(stream: t.Optional[EchoingStdin]) -> t.Iterator[None]:
if stream is None:
yield
else:
stream._paused = True
yield
stream._paused = False
class _NamedTextIOWrapper(io.TextIOWrapper):
def __init__(
self, buffer: t.BinaryIO, name: str, mode: str, **kwargs: t.Any
) -> None:
super().__init__(buffer, **kwargs)
self._name = name
self._mode = mode
@property
def name(self) -> str:
return self._name
@property
def mode(self) -> str:
return self._mode
def make_input_stream(
input: t.Optional[t.Union[str, bytes, t.IO]], charset: str
) -> t.BinaryIO:
# Is already an input stream. # Is already an input stream.
if hasattr(input, 'read'): if hasattr(input, "read"):
if PY2: rv = _find_binary_reader(t.cast(t.IO, input))
return input
rv = _find_binary_reader(input)
if rv is not None: if rv is not None:
return rv return rv
raise TypeError('Could not find binary reader for input stream.')
raise TypeError("Could not find binary reader for input stream.")
if input is None: if input is None:
input = b'' input = b""
elif not isinstance(input, bytes): elif isinstance(input, str):
input = input.encode(charset) input = input.encode(charset)
if PY2:
return StringIO(input) return io.BytesIO(t.cast(bytes, input))
return io.BytesIO(input)
class Result(object): class Result:
"""Holds the captured result of an invoked CLI script.""" """Holds the captured result of an invoked CLI script."""
def __init__(self, runner, stdout_bytes, stderr_bytes, exit_code, def __init__(
exception, exc_info=None): self,
runner: "CliRunner",
stdout_bytes: bytes,
stderr_bytes: t.Optional[bytes],
return_value: t.Any,
exit_code: int,
exception: t.Optional[BaseException],
exc_info: t.Optional[
t.Tuple[t.Type[BaseException], BaseException, TracebackType]
] = None,
):
#: The runner that created the result #: The runner that created the result
self.runner = runner self.runner = runner
#: The standard output as bytes. #: The standard output as bytes.
self.stdout_bytes = stdout_bytes self.stdout_bytes = stdout_bytes
#: The standard error as bytes, or False(y) if not available #: The standard error as bytes, or None if not available
self.stderr_bytes = stderr_bytes self.stderr_bytes = stderr_bytes
#: The value returned from the invoked command.
#:
#: .. versionadded:: 8.0
self.return_value = return_value
#: The exit code as integer. #: The exit code as integer.
self.exit_code = exit_code self.exit_code = exit_code
#: The exception that happened if one did. #: The exception that happened if one did.
@ -89,41 +131,38 @@ class Result(object):
self.exc_info = exc_info self.exc_info = exc_info
@property @property
def output(self): def output(self) -> str:
"""The (standard) output as unicode string.""" """The (standard) output as unicode string."""
return self.stdout return self.stdout
@property @property
def stdout(self): def stdout(self) -> str:
"""The standard output as unicode string.""" """The standard output as unicode string."""
return self.stdout_bytes.decode(self.runner.charset, 'replace') \ return self.stdout_bytes.decode(self.runner.charset, "replace").replace(
.replace('\r\n', '\n') "\r\n", "\n"
@property
def stderr(self):
"""The standard error as unicode string."""
if not self.stderr_bytes:
raise ValueError("stderr not separately captured")
return self.stderr_bytes.decode(self.runner.charset, 'replace') \
.replace('\r\n', '\n')
def __repr__(self):
return '<%s %s>' % (
type(self).__name__,
self.exception and repr(self.exception) or 'okay',
) )
@property
def stderr(self) -> str:
"""The standard error as unicode string."""
if self.stderr_bytes is None:
raise ValueError("stderr not separately captured")
return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
"\r\n", "\n"
)
class CliRunner(object): def __repr__(self) -> str:
exc_str = repr(self.exception) if self.exception else "okay"
return f"<{type(self).__name__} {exc_str}>"
class CliRunner:
"""The CLI runner provides functionality to invoke a Click command line """The CLI runner provides functionality to invoke a Click command line
script for unittesting purposes in a isolated environment. This only script for unittesting purposes in a isolated environment. This only
works in single-threaded systems without any concurrency as it changes the works in single-threaded systems without any concurrency as it changes the
global interpreter state. global interpreter state.
:param charset: the character set for the input and output data. This is :param charset: the character set for the input and output data.
UTF-8 by default and should not be changed currently as
the reporting to Click only works in Python 2 properly.
:param env: a dictionary with environment variables for overriding. :param env: a dictionary with environment variables for overriding.
:param echo_stdin: if this is set to `True`, then reading from stdin writes :param echo_stdin: if this is set to `True`, then reading from stdin writes
to stdout. This is useful for showing examples in to stdout. This is useful for showing examples in
@ -136,23 +175,28 @@ class CliRunner(object):
independently independently
""" """
def __init__(self, charset=None, env=None, echo_stdin=False, def __init__(
mix_stderr=True): self,
if charset is None: charset: str = "utf-8",
charset = 'utf-8' env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
echo_stdin: bool = False,
mix_stderr: bool = True,
) -> None:
self.charset = charset self.charset = charset
self.env = env or {} self.env = env or {}
self.echo_stdin = echo_stdin self.echo_stdin = echo_stdin
self.mix_stderr = mix_stderr self.mix_stderr = mix_stderr
def get_default_prog_name(self, cli): def get_default_prog_name(self, cli: "BaseCommand") -> str:
"""Given a command object it will return the default program name """Given a command object it will return the default program name
for it. The default is the `name` attribute or ``"root"`` if not for it. The default is the `name` attribute or ``"root"`` if not
set. set.
""" """
return cli.name or 'root' return cli.name or "root"
def make_env(self, overrides=None): def make_env(
self, overrides: t.Optional[t.Mapping[str, t.Optional[str]]] = None
) -> t.Mapping[str, t.Optional[str]]:
"""Returns the environment overrides for invoking a script.""" """Returns the environment overrides for invoking a script."""
rv = dict(self.env) rv = dict(self.env)
if overrides: if overrides:
@ -160,7 +204,12 @@ class CliRunner(object):
return rv return rv
@contextlib.contextmanager @contextlib.contextmanager
def isolation(self, input=None, env=None, color=False): def isolation(
self,
input: t.Optional[t.Union[str, bytes, t.IO]] = None,
env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
color: bool = False,
) -> t.Iterator[t.Tuple[io.BytesIO, t.Optional[io.BytesIO]]]:
"""A context manager that sets up the isolation for invoking of a """A context manager that sets up the isolation for invoking of a
command line tool. This sets up stdin with the given input data command line tool. This sets up stdin with the given input data
and `os.environ` with the overrides from the given dictionary. and `os.environ` with the overrides from the given dictionary.
@ -169,87 +218,107 @@ class CliRunner(object):
This is automatically done in the :meth:`invoke` method. This is automatically done in the :meth:`invoke` method.
.. versionadded:: 4.0
The ``color`` parameter was added.
:param input: the input stream to put into sys.stdin. :param input: the input stream to put into sys.stdin.
:param env: the environment overrides as dictionary. :param env: the environment overrides as dictionary.
:param color: whether the output should contain color codes. The :param color: whether the output should contain color codes. The
application can still override this explicitly. application can still override this explicitly.
.. versionchanged:: 8.0
``stderr`` is opened with ``errors="backslashreplace"``
instead of the default ``"strict"``.
.. versionchanged:: 4.0
Added the ``color`` parameter.
""" """
input = make_input_stream(input, self.charset) bytes_input = make_input_stream(input, self.charset)
echo_input = None
old_stdin = sys.stdin old_stdin = sys.stdin
old_stdout = sys.stdout old_stdout = sys.stdout
old_stderr = sys.stderr old_stderr = sys.stderr
old_forced_width = clickpkg.formatting.FORCED_WIDTH old_forced_width = formatting.FORCED_WIDTH
clickpkg.formatting.FORCED_WIDTH = 80 formatting.FORCED_WIDTH = 80
env = self.make_env(env) env = self.make_env(env)
if PY2: bytes_output = io.BytesIO()
bytes_output = StringIO()
if self.echo_stdin:
input = EchoingStdin(input, bytes_output)
sys.stdout = bytes_output
if not self.mix_stderr:
bytes_error = StringIO()
sys.stderr = bytes_error
else:
bytes_output = io.BytesIO()
if self.echo_stdin:
input = EchoingStdin(input, bytes_output)
input = io.TextIOWrapper(input, encoding=self.charset)
sys.stdout = io.TextIOWrapper(
bytes_output, encoding=self.charset)
if not self.mix_stderr:
bytes_error = io.BytesIO()
sys.stderr = io.TextIOWrapper(
bytes_error, encoding=self.charset)
if self.echo_stdin:
bytes_input = echo_input = t.cast(
t.BinaryIO, EchoingStdin(bytes_input, bytes_output)
)
sys.stdin = text_input = _NamedTextIOWrapper(
bytes_input, encoding=self.charset, name="<stdin>", mode="r"
)
if self.echo_stdin:
# Force unbuffered reads, otherwise TextIOWrapper reads a
# large chunk which is echoed early.
text_input._CHUNK_SIZE = 1 # type: ignore
sys.stdout = _NamedTextIOWrapper(
bytes_output, encoding=self.charset, name="<stdout>", mode="w"
)
bytes_error = None
if self.mix_stderr: if self.mix_stderr:
sys.stderr = sys.stdout sys.stderr = sys.stdout
else:
bytes_error = io.BytesIO()
sys.stderr = _NamedTextIOWrapper(
bytes_error,
encoding=self.charset,
name="<stderr>",
mode="w",
errors="backslashreplace",
)
sys.stdin = input @_pause_echo(echo_input) # type: ignore
def visible_input(prompt: t.Optional[str] = None) -> str:
def visible_input(prompt=None): sys.stdout.write(prompt or "")
sys.stdout.write(prompt or '') val = text_input.readline().rstrip("\r\n")
val = input.readline().rstrip('\r\n') sys.stdout.write(f"{val}\n")
sys.stdout.write(val + '\n')
sys.stdout.flush() sys.stdout.flush()
return val return val
def hidden_input(prompt=None): @_pause_echo(echo_input) # type: ignore
sys.stdout.write((prompt or '') + '\n') def hidden_input(prompt: t.Optional[str] = None) -> str:
sys.stdout.write(f"{prompt or ''}\n")
sys.stdout.flush() sys.stdout.flush()
return input.readline().rstrip('\r\n') return text_input.readline().rstrip("\r\n")
def _getchar(echo): @_pause_echo(echo_input) # type: ignore
def _getchar(echo: bool) -> str:
char = sys.stdin.read(1) char = sys.stdin.read(1)
if echo: if echo:
sys.stdout.write(char) sys.stdout.write(char)
sys.stdout.flush()
sys.stdout.flush()
return char return char
default_color = color default_color = color
def should_strip_ansi(stream=None, color=None): def should_strip_ansi(
stream: t.Optional[t.IO] = None, color: t.Optional[bool] = None
) -> bool:
if color is None: if color is None:
return not default_color return not default_color
return not color return not color
old_visible_prompt_func = clickpkg.termui.visible_prompt_func old_visible_prompt_func = termui.visible_prompt_func
old_hidden_prompt_func = clickpkg.termui.hidden_prompt_func old_hidden_prompt_func = termui.hidden_prompt_func
old__getchar_func = clickpkg.termui._getchar old__getchar_func = termui._getchar
old_should_strip_ansi = clickpkg.utils.should_strip_ansi old_should_strip_ansi = utils.should_strip_ansi # type: ignore
clickpkg.termui.visible_prompt_func = visible_input termui.visible_prompt_func = visible_input
clickpkg.termui.hidden_prompt_func = hidden_input termui.hidden_prompt_func = hidden_input
clickpkg.termui._getchar = _getchar termui._getchar = _getchar
clickpkg.utils.should_strip_ansi = should_strip_ansi utils.should_strip_ansi = should_strip_ansi # type: ignore
old_env = {} old_env = {}
try: try:
for key, value in iteritems(env): for key, value in env.items():
old_env[key] = os.environ.get(key) old_env[key] = os.environ.get(key)
if value is None: if value is None:
try: try:
@ -258,9 +327,9 @@ class CliRunner(object):
pass pass
else: else:
os.environ[key] = value os.environ[key] = value
yield (bytes_output, not self.mix_stderr and bytes_error) yield (bytes_output, bytes_error)
finally: finally:
for key, value in iteritems(old_env): for key, value in old_env.items():
if value is None: if value is None:
try: try:
del os.environ[key] del os.environ[key]
@ -271,14 +340,22 @@ class CliRunner(object):
sys.stdout = old_stdout sys.stdout = old_stdout
sys.stderr = old_stderr sys.stderr = old_stderr
sys.stdin = old_stdin sys.stdin = old_stdin
clickpkg.termui.visible_prompt_func = old_visible_prompt_func termui.visible_prompt_func = old_visible_prompt_func
clickpkg.termui.hidden_prompt_func = old_hidden_prompt_func termui.hidden_prompt_func = old_hidden_prompt_func
clickpkg.termui._getchar = old__getchar_func termui._getchar = old__getchar_func
clickpkg.utils.should_strip_ansi = old_should_strip_ansi utils.should_strip_ansi = old_should_strip_ansi # type: ignore
clickpkg.formatting.FORCED_WIDTH = old_forced_width formatting.FORCED_WIDTH = old_forced_width
def invoke(self, cli, args=None, input=None, env=None, def invoke(
catch_exceptions=True, color=False, mix_stderr=False, **extra): self,
cli: "BaseCommand",
args: t.Optional[t.Union[str, t.Sequence[str]]] = None,
input: t.Optional[t.Union[str, bytes, t.IO]] = None,
env: t.Optional[t.Mapping[str, t.Optional[str]]] = None,
catch_exceptions: bool = True,
color: bool = False,
**extra: t.Any,
) -> Result:
"""Invokes a command in an isolated environment. The arguments are """Invokes a command in an isolated environment. The arguments are
forwarded directly to the command line script, the `extra` keyword forwarded directly to the command line script, the `extra` keyword
arguments are passed to the :meth:`~clickpkg.Command.main` function of arguments are passed to the :meth:`~clickpkg.Command.main` function of
@ -286,16 +363,6 @@ class CliRunner(object):
This returns a :class:`Result` object. This returns a :class:`Result` object.
.. versionadded:: 3.0
The ``catch_exceptions`` parameter was added.
.. versionchanged:: 3.0
The result object now has an `exc_info` attribute with the
traceback if available.
.. versionadded:: 4.0
The ``color`` parameter was added.
:param cli: the command to invoke :param cli: the command to invoke
:param args: the arguments to invoke. It may be given as an iterable :param args: the arguments to invoke. It may be given as an iterable
or a string. When given as string it will be interpreted or a string. When given as string it will be interpreted
@ -308,13 +375,28 @@ class CliRunner(object):
:param extra: the keyword arguments to pass to :meth:`main`. :param extra: the keyword arguments to pass to :meth:`main`.
:param color: whether the output should contain color codes. The :param color: whether the output should contain color codes. The
application can still override this explicitly. application can still override this explicitly.
.. versionchanged:: 8.0
The result object has the ``return_value`` attribute with
the value returned from the invoked command.
.. versionchanged:: 4.0
Added the ``color`` parameter.
.. versionchanged:: 3.0
Added the ``catch_exceptions`` parameter.
.. versionchanged:: 3.0
The result object has the ``exc_info`` attribute with the
traceback if available.
""" """
exc_info = None exc_info = None
with self.isolation(input=input, env=env, color=color) as outstreams: with self.isolation(input=input, env=env, color=color) as outstreams:
exception = None return_value = None
exception: t.Optional[BaseException] = None
exit_code = 0 exit_code = 0
if isinstance(args, string_types): if isinstance(args, str):
args = shlex.split(args) args = shlex.split(args)
try: try:
@ -323,20 +405,23 @@ class CliRunner(object):
prog_name = self.get_default_prog_name(cli) prog_name = self.get_default_prog_name(cli)
try: try:
cli.main(args=args or (), prog_name=prog_name, **extra) return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
except SystemExit as e: except SystemExit as e:
exc_info = sys.exc_info() exc_info = sys.exc_info()
exit_code = e.code e_code = t.cast(t.Optional[t.Union[int, t.Any]], e.code)
if exit_code is None:
exit_code = 0
if exit_code != 0: if e_code is None:
e_code = 0
if e_code != 0:
exception = e exception = e
if not isinstance(exit_code, int): if not isinstance(e_code, int):
sys.stdout.write(str(exit_code)) sys.stdout.write(str(e_code))
sys.stdout.write('\n') sys.stdout.write("\n")
exit_code = 1 e_code = 1
exit_code = e_code
except Exception as e: except Exception as e:
if not catch_exceptions: if not catch_exceptions:
@ -347,28 +432,48 @@ class CliRunner(object):
finally: finally:
sys.stdout.flush() sys.stdout.flush()
stdout = outstreams[0].getvalue() stdout = outstreams[0].getvalue()
stderr = outstreams[1] and outstreams[1].getvalue() if self.mix_stderr:
stderr = None
else:
stderr = outstreams[1].getvalue() # type: ignore
return Result(runner=self, return Result(
stdout_bytes=stdout, runner=self,
stderr_bytes=stderr, stdout_bytes=stdout,
exit_code=exit_code, stderr_bytes=stderr,
exception=exception, return_value=return_value,
exc_info=exc_info) exit_code=exit_code,
exception=exception,
exc_info=exc_info, # type: ignore
)
@contextlib.contextmanager @contextlib.contextmanager
def isolated_filesystem(self): def isolated_filesystem(
"""A context manager that creates a temporary folder and changes self, temp_dir: t.Optional[t.Union[str, os.PathLike]] = None
the current working directory to it for isolated filesystem tests. ) -> t.Iterator[str]:
"""A context manager that creates a temporary directory and
changes the current working directory to it. This isolates tests
that affect the contents of the CWD to prevent them from
interfering with each other.
:param temp_dir: Create the temporary directory under this
directory. If given, the created directory is not removed
when exiting.
.. versionchanged:: 8.0
Added the ``temp_dir`` parameter.
""" """
cwd = os.getcwd() cwd = os.getcwd()
t = tempfile.mkdtemp() dt = tempfile.mkdtemp(dir=temp_dir) # type: ignore[type-var]
os.chdir(t) os.chdir(dt)
try: try:
yield t yield t.cast(str, dt)
finally: finally:
os.chdir(cwd) os.chdir(cwd)
try:
shutil.rmtree(t) if temp_dir is None:
except (OSError, IOError): try:
pass shutil.rmtree(dt)
except OSError: # noqa: B014
pass

File diff suppressed because it is too large Load diff

View file

@ -1,92 +1,131 @@
import os import os
import re
import sys import sys
import typing as t
from functools import update_wrapper
from types import ModuleType
from ._compat import _default_text_stderr
from ._compat import _default_text_stdout
from ._compat import _find_binary_writer
from ._compat import auto_wrap_for_ansi
from ._compat import binary_streams
from ._compat import get_filesystem_encoding
from ._compat import open_stream
from ._compat import should_strip_ansi
from ._compat import strip_ansi
from ._compat import text_streams
from ._compat import WIN
from .globals import resolve_color_default from .globals import resolve_color_default
from ._compat import text_type, open_stream, get_filesystem_encoding, \ if t.TYPE_CHECKING:
get_streerror, string_types, PY2, binary_streams, text_streams, \ import typing_extensions as te
filename_to_ui, auto_wrap_for_ansi, strip_ansi, should_strip_ansi, \
_default_text_stdout, _default_text_stderr, is_bytes, WIN
if not PY2: F = t.TypeVar("F", bound=t.Callable[..., t.Any])
from ._compat import _find_binary_writer
elif WIN:
from ._winconsole import _get_windows_argv, \
_hash_py_argv, _initial_argv_hash
echo_native_types = string_types + (bytes, bytearray) def _posixify(name: str) -> str:
return "-".join(name.split()).lower()
def _posixify(name): def safecall(func: F) -> F:
return '-'.join(name.split()).lower()
def safecall(func):
"""Wraps a function so that it swallows exceptions.""" """Wraps a function so that it swallows exceptions."""
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs): # type: ignore
try: try:
return func(*args, **kwargs) return func(*args, **kwargs)
except Exception: except Exception:
pass pass
return wrapper
return update_wrapper(t.cast(F, wrapper), func)
def make_str(value): def make_str(value: t.Any) -> str:
"""Converts a value into a valid string.""" """Converts a value into a valid string."""
if isinstance(value, bytes): if isinstance(value, bytes):
try: try:
return value.decode(get_filesystem_encoding()) return value.decode(get_filesystem_encoding())
except UnicodeError: except UnicodeError:
return value.decode('utf-8', 'replace') return value.decode("utf-8", "replace")
return text_type(value) return str(value)
def make_default_short_help(help, max_length=45): def make_default_short_help(help: str, max_length: int = 45) -> str:
"""Return a condensed version of help string.""" """Returns a condensed version of help string."""
# Consider only the first paragraph.
paragraph_end = help.find("\n\n")
if paragraph_end != -1:
help = help[:paragraph_end]
# Collapse newlines, tabs, and spaces.
words = help.split() words = help.split()
if not words:
return ""
# The first paragraph started with a "no rewrap" marker, ignore it.
if words[0] == "\b":
words = words[1:]
total_length = 0 total_length = 0
result = [] last_index = len(words) - 1
done = False
for word in words: for i, word in enumerate(words):
if word[-1:] == '.': total_length += len(word) + (i > 0)
done = True
new_length = result and 1 + len(word) or len(word) if total_length > max_length: # too long, truncate
if total_length + new_length > max_length:
result.append('...')
done = True
else:
if result:
result.append(' ')
result.append(word)
if done:
break break
total_length += new_length
return ''.join(result) if word[-1] == ".": # sentence end, truncate without "..."
return " ".join(words[: i + 1])
if total_length == max_length and i != last_index:
break # not at sentence end, truncate with "..."
else:
return " ".join(words) # no truncation needed
# Account for the length of the suffix.
total_length += len("...")
# remove words until the length is short enough
while i > 0:
total_length -= len(words[i]) + (i > 0)
if total_length <= max_length:
break
i -= 1
return " ".join(words[:i]) + "..."
class LazyFile(object): class LazyFile:
"""A lazy file works like a regular file but it does not fully open """A lazy file works like a regular file but it does not fully open
the file but it does perform some basic checks early to see if the the file but it does perform some basic checks early to see if the
filename parameter does make sense. This is useful for safely opening filename parameter does make sense. This is useful for safely opening
files for writing. files for writing.
""" """
def __init__(self, filename, mode='r', encoding=None, errors='strict', def __init__(
atomic=False): self,
filename: str,
mode: str = "r",
encoding: t.Optional[str] = None,
errors: t.Optional[str] = "strict",
atomic: bool = False,
):
self.name = filename self.name = filename
self.mode = mode self.mode = mode
self.encoding = encoding self.encoding = encoding
self.errors = errors self.errors = errors
self.atomic = atomic self.atomic = atomic
self._f: t.Optional[t.IO]
if filename == '-': if filename == "-":
self._f, self.should_close = open_stream(filename, mode, self._f, self.should_close = open_stream(filename, mode, encoding, errors)
encoding, errors)
else: else:
if 'r' in mode: if "r" in mode:
# Open and close the file in case we're opening it for # Open and close the file in case we're opening it for
# reading so that we can catch at least some errors in # reading so that we can catch at least some errors in
# some cases early. # some cases early.
@ -94,15 +133,15 @@ class LazyFile(object):
self._f = None self._f = None
self.should_close = True self.should_close = True
def __getattr__(self, name): def __getattr__(self, name: str) -> t.Any:
return getattr(self.open(), name) return getattr(self.open(), name)
def __repr__(self): def __repr__(self) -> str:
if self._f is not None: if self._f is not None:
return repr(self._f) return repr(self._f)
return '<unopened file %r %s>' % (self.name, self.mode) return f"<unopened file '{self.name}' {self.mode}>"
def open(self): def open(self) -> t.IO:
"""Opens the file if it's not yet open. This call might fail with """Opens the file if it's not yet open. This call might fail with
a :exc:`FileError`. Not handling this error will produce an error a :exc:`FileError`. Not handling this error will produce an error
that Click shows. that Click shows.
@ -110,106 +149,103 @@ class LazyFile(object):
if self._f is not None: if self._f is not None:
return self._f return self._f
try: try:
rv, self.should_close = open_stream(self.name, self.mode, rv, self.should_close = open_stream(
self.encoding, self.name, self.mode, self.encoding, self.errors, atomic=self.atomic
self.errors, )
atomic=self.atomic) except OSError as e: # noqa: E402
except (IOError, OSError) as e:
from .exceptions import FileError from .exceptions import FileError
raise FileError(self.name, hint=get_streerror(e))
raise FileError(self.name, hint=e.strerror) from e
self._f = rv self._f = rv
return rv return rv
def close(self): def close(self) -> None:
"""Closes the underlying file, no matter what.""" """Closes the underlying file, no matter what."""
if self._f is not None: if self._f is not None:
self._f.close() self._f.close()
def close_intelligently(self): def close_intelligently(self) -> None:
"""This function only closes the file if it was opened by the lazy """This function only closes the file if it was opened by the lazy
file wrapper. For instance this will never close stdin. file wrapper. For instance this will never close stdin.
""" """
if self.should_close: if self.should_close:
self.close() self.close()
def __enter__(self): def __enter__(self) -> "LazyFile":
return self return self
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb): # type: ignore
self.close_intelligently() self.close_intelligently()
def __iter__(self): def __iter__(self) -> t.Iterator[t.AnyStr]:
self.open() self.open()
return iter(self._f) return iter(self._f) # type: ignore
class KeepOpenFile(object): class KeepOpenFile:
def __init__(self, file: t.IO) -> None:
def __init__(self, file):
self._file = file self._file = file
def __getattr__(self, name): def __getattr__(self, name: str) -> t.Any:
return getattr(self._file, name) return getattr(self._file, name)
def __enter__(self): def __enter__(self) -> "KeepOpenFile":
return self return self
def __exit__(self, exc_type, exc_value, tb): def __exit__(self, exc_type, exc_value, tb): # type: ignore
pass pass
def __repr__(self): def __repr__(self) -> str:
return repr(self._file) return repr(self._file)
def __iter__(self): def __iter__(self) -> t.Iterator[t.AnyStr]:
return iter(self._file) return iter(self._file)
def echo(message=None, file=None, nl=True, err=False, color=None): def echo(
"""Prints a message plus a newline to the given file or stdout. On message: t.Optional[t.Any] = None,
first sight, this looks like the print function, but it has improved file: t.Optional[t.IO[t.Any]] = None,
support for handling Unicode and binary data that does not fail no nl: bool = True,
matter how badly configured the system is. err: bool = False,
color: t.Optional[bool] = None,
) -> None:
"""Print a message and newline to stdout or a file. This should be
used instead of :func:`print` because it provides better support
for different data, files, and environments.
Primarily it means that you can print binary data as well as Unicode Compared to :func:`print`, this does the following:
data on both 2.x and 3.x to the given file in the most appropriate way
possible. This is a very carefree function in that it will try its
best to not fail. As of Click 6.0 this includes support for unicode
output on the Windows console.
In addition to that, if `colorama`_ is installed, the echo function will - Ensures that the output encoding is not misconfigured on Linux.
also support clever handling of ANSI codes. Essentially it will then - Supports Unicode in the Windows console.
do the following: - Supports writing to binary outputs, and supports writing bytes
to text outputs.
- Supports colors and styles on Windows.
- Removes ANSI color and style codes if the output does not look
like an interactive terminal.
- Always flushes the output.
- add transparent handling of ANSI color codes on Windows. :param message: The string or bytes to output. Other objects are
- hide ANSI codes automatically if the destination file is not a converted to strings.
terminal. :param file: The file to write to. Defaults to ``stdout``.
:param err: Write to ``stderr`` instead of ``stdout``.
.. _colorama: https://pypi.org/project/colorama/ :param nl: Print a newline after the message. Enabled by default.
:param color: Force showing or hiding colors and other styles. By
default Click will remove color if the output does not look like
an interactive terminal.
.. versionchanged:: 6.0 .. versionchanged:: 6.0
As of Click 6.0 the echo function will properly support unicode Support Unicode output on the Windows console. Click does not
output on the windows console. Not that click does not modify modify ``sys.stdout``, so ``sys.stdout.write()`` and ``print()``
the interpreter in any way which means that `sys.stdout` or the will still not support Unicode.
print statement or function will still not provide unicode support.
.. versionchanged:: 2.0
Starting with version 2.0 of Click, the echo function will work
with colorama if it's installed.
.. versionadded:: 3.0
The `err` parameter was added.
.. versionchanged:: 4.0 .. versionchanged:: 4.0
Added the `color` flag. Added the ``color`` parameter.
:param message: the message to print .. versionadded:: 3.0
:param file: the file to write to (defaults to ``stdout``) Added the ``err`` parameter.
:param err: if set to true the file defaults to ``stderr`` instead of
``stdout``. This is faster and easier than calling .. versionchanged:: 2.0
:func:`get_text_stderr` yourself. Support colors on Windows if colorama is installed.
:param nl: if set to `True` (the default) a newline is printed afterwards.
:param color: controls if the terminal supports ANSI colors or not. The
default is autodetection.
""" """
if file is None: if file is None:
if err: if err:
@ -218,70 +254,73 @@ def echo(message=None, file=None, nl=True, err=False, color=None):
file = _default_text_stdout() file = _default_text_stdout()
# Convert non bytes/text into the native string type. # Convert non bytes/text into the native string type.
if message is not None and not isinstance(message, echo_native_types): if message is not None and not isinstance(message, (str, bytes, bytearray)):
message = text_type(message) out: t.Optional[t.Union[str, bytes]] = str(message)
else:
out = message
if nl: if nl:
message = message or u'' out = out or ""
if isinstance(message, text_type): if isinstance(out, str):
message += u'\n' out += "\n"
else: else:
message += b'\n' out += b"\n"
# If there is a message, and we're in Python 3, and the value looks if not out:
# like bytes, we manually need to find the binary stream and write the file.flush()
# message in there. This is done separately so that most stream return
# types will work as you would expect. Eg: you can write to StringIO
# for other cases. # If there is a message and the value looks like bytes, we manually
if message and not PY2 and is_bytes(message): # need to find the binary stream and write the message in there.
# This is done separately so that most stream types will work as you
# would expect. Eg: you can write to StringIO for other cases.
if isinstance(out, (bytes, bytearray)):
binary_file = _find_binary_writer(file) binary_file = _find_binary_writer(file)
if binary_file is not None: if binary_file is not None:
file.flush() file.flush()
binary_file.write(message) binary_file.write(out)
binary_file.flush() binary_file.flush()
return return
# ANSI-style support. If there is no message or we are dealing with # ANSI style code support. For no message or bytes, nothing happens.
# bytes nothing is happening. If we are connected to a file we want # When outputting to a file instead of a terminal, strip codes.
# to strip colors. If we are on windows we either wrap the stream else:
# to strip the color or we use the colorama support to translate the
# ansi codes to API calls.
if message and not is_bytes(message):
color = resolve_color_default(color) color = resolve_color_default(color)
if should_strip_ansi(file, color): if should_strip_ansi(file, color):
message = strip_ansi(message) out = strip_ansi(out)
elif WIN: elif WIN:
if auto_wrap_for_ansi is not None: if auto_wrap_for_ansi is not None:
file = auto_wrap_for_ansi(file) file = auto_wrap_for_ansi(file) # type: ignore
elif not color: elif not color:
message = strip_ansi(message) out = strip_ansi(out)
if message: file.write(out) # type: ignore
file.write(message)
file.flush() file.flush()
def get_binary_stream(name): def get_binary_stream(name: "te.Literal['stdin', 'stdout', 'stderr']") -> t.BinaryIO:
"""Returns a system stream for byte processing. This essentially """Returns a system stream for byte processing.
returns the stream from the sys module with the given name but it
solves some compatibility issues between different Python versions.
Primarily this function is necessary for getting binary streams on
Python 3.
:param name: the name of the stream to open. Valid names are ``'stdin'``, :param name: the name of the stream to open. Valid names are ``'stdin'``,
``'stdout'`` and ``'stderr'`` ``'stdout'`` and ``'stderr'``
""" """
opener = binary_streams.get(name) opener = binary_streams.get(name)
if opener is None: if opener is None:
raise TypeError('Unknown standard stream %r' % name) raise TypeError(f"Unknown standard stream '{name}'")
return opener() return opener()
def get_text_stream(name, encoding=None, errors='strict'): def get_text_stream(
name: "te.Literal['stdin', 'stdout', 'stderr']",
encoding: t.Optional[str] = None,
errors: t.Optional[str] = "strict",
) -> t.TextIO:
"""Returns a system stream for text processing. This usually returns """Returns a system stream for text processing. This usually returns
a wrapped stream around a binary stream returned from a wrapped stream around a binary stream returned from
:func:`get_binary_stream` but it also can take shortcuts on Python 3 :func:`get_binary_stream` but it also can take shortcuts for already
for already correctly configured streams. correctly configured streams.
:param name: the name of the stream to open. Valid names are ``'stdin'``, :param name: the name of the stream to open. Valid names are ``'stdin'``,
``'stdout'`` and ``'stderr'`` ``'stdout'`` and ``'stderr'``
@ -290,65 +329,60 @@ def get_text_stream(name, encoding=None, errors='strict'):
""" """
opener = text_streams.get(name) opener = text_streams.get(name)
if opener is None: if opener is None:
raise TypeError('Unknown standard stream %r' % name) raise TypeError(f"Unknown standard stream '{name}'")
return opener(encoding, errors) return opener(encoding, errors)
def open_file(filename, mode='r', encoding=None, errors='strict', def open_file(
lazy=False, atomic=False): filename: str,
"""This is similar to how the :class:`File` works but for manual mode: str = "r",
usage. Files are opened non lazy by default. This can open regular encoding: t.Optional[str] = None,
files as well as stdin/stdout if ``'-'`` is passed. errors: t.Optional[str] = "strict",
lazy: bool = False,
atomic: bool = False,
) -> t.IO:
"""Open a file, with extra behavior to handle ``'-'`` to indicate
a standard stream, lazy open on write, and atomic write. Similar to
the behavior of the :class:`~click.File` param type.
If stdin/stdout is returned the stream is wrapped so that the context If ``'-'`` is given to open ``stdout`` or ``stdin``, the stream is
manager will not close the stream accidentally. This makes it possible wrapped so that using it in a context manager will not close it.
to always use the function like this without having to worry to This makes it possible to use the function without accidentally
accidentally close a standard stream:: closing a standard stream:
.. code-block:: python
with open_file(filename) as f: with open_file(filename) as f:
... ...
.. versionadded:: 3.0 :param filename: The name of the file to open, or ``'-'`` for
``stdin``/``stdout``.
:param mode: The mode in which to open the file.
:param encoding: The encoding to decode or encode a file opened in
text mode.
:param errors: The error handling mode.
:param lazy: Wait to open the file until it is accessed. For read
mode, the file is temporarily opened to raise access errors
early, then closed until it is read again.
:param atomic: Write to a temporary file and replace the given file
on close.
:param filename: the name of the file to open (or ``'-'`` for stdin/stdout). .. versionadded:: 3.0
:param mode: the mode in which to open the file.
:param encoding: the encoding to use.
:param errors: the error handling for this file.
:param lazy: can be flipped to true to open the file lazily.
:param atomic: in atomic mode writes go into a temporary file and it's
moved on close.
""" """
if lazy: if lazy:
return LazyFile(filename, mode, encoding, errors, atomic=atomic) return t.cast(t.IO, LazyFile(filename, mode, encoding, errors, atomic=atomic))
f, should_close = open_stream(filename, mode, encoding, errors,
atomic=atomic) f, should_close = open_stream(filename, mode, encoding, errors, atomic=atomic)
if not should_close: if not should_close:
f = KeepOpenFile(f) f = t.cast(t.IO, KeepOpenFile(f))
return f return f
def get_os_args(): def format_filename(
"""This returns the argument part of sys.argv in the most appropriate filename: t.Union[str, bytes, os.PathLike], shorten: bool = False
form for processing. What this means is that this return value is in ) -> str:
a format that works for Click to process but does not necessarily
correspond well to what's actually standard for the interpreter.
On most environments the return value is ``sys.argv[:1]`` unchanged.
However if you are on Windows and running Python 2 the return value
will actually be a list of unicode strings instead because the
default behavior on that platform otherwise will not be able to
carry all possible values that sys.argv can have.
.. versionadded:: 6.0
"""
# We can only extract the unicode argv if sys.argv has not been
# changed since the startup of the application.
if PY2 and WIN and _initial_argv_hash == _hash_py_argv():
return _get_windows_argv()
return sys.argv[1:]
def format_filename(filename, shorten=False):
"""Formats a filename for user display. The main purpose of this """Formats a filename for user display. The main purpose of this
function is to ensure that the filename can be displayed at all. This function is to ensure that the filename can be displayed at all. This
will decode the filename to unicode if necessary in a way that it will will decode the filename to unicode if necessary in a way that it will
@ -362,10 +396,11 @@ def format_filename(filename, shorten=False):
""" """
if shorten: if shorten:
filename = os.path.basename(filename) filename = os.path.basename(filename)
return filename_to_ui(filename)
return os.fsdecode(filename)
def get_app_dir(app_name, roaming=True, force_posix=False): def get_app_dir(app_name: str, roaming: bool = True, force_posix: bool = False) -> str:
r"""Returns the config folder for the application. The default behavior r"""Returns the config folder for the application. The default behavior
is to return whatever is most appropriate for the operating system. is to return whatever is most appropriate for the operating system.
@ -380,13 +415,9 @@ def get_app_dir(app_name, roaming=True, force_posix=False):
``~/.config/foo-bar`` ``~/.config/foo-bar``
Unix (POSIX): Unix (POSIX):
``~/.foo-bar`` ``~/.foo-bar``
Win XP (roaming): Windows (roaming):
``C:\Documents and Settings\<user>\Local Settings\Application Data\Foo Bar``
Win XP (not roaming):
``C:\Documents and Settings\<user>\Application Data\Foo Bar``
Win 7 (roaming):
``C:\Users\<user>\AppData\Roaming\Foo Bar`` ``C:\Users\<user>\AppData\Roaming\Foo Bar``
Win 7 (not roaming): Windows (not roaming):
``C:\Users\<user>\AppData\Local\Foo Bar`` ``C:\Users\<user>\AppData\Local\Foo Bar``
.. versionadded:: 2.0 .. versionadded:: 2.0
@ -401,22 +432,24 @@ def get_app_dir(app_name, roaming=True, force_posix=False):
application support folder. application support folder.
""" """
if WIN: if WIN:
key = roaming and 'APPDATA' or 'LOCALAPPDATA' key = "APPDATA" if roaming else "LOCALAPPDATA"
folder = os.environ.get(key) folder = os.environ.get(key)
if folder is None: if folder is None:
folder = os.path.expanduser('~') folder = os.path.expanduser("~")
return os.path.join(folder, app_name) return os.path.join(folder, app_name)
if force_posix: if force_posix:
return os.path.join(os.path.expanduser('~/.' + _posixify(app_name))) return os.path.join(os.path.expanduser(f"~/.{_posixify(app_name)}"))
if sys.platform == 'darwin': if sys.platform == "darwin":
return os.path.join(os.path.expanduser( return os.path.join(
'~/Library/Application Support'), app_name) os.path.expanduser("~/Library/Application Support"), app_name
)
return os.path.join( return os.path.join(
os.environ.get('XDG_CONFIG_HOME', os.path.expanduser('~/.config')), os.environ.get("XDG_CONFIG_HOME", os.path.expanduser("~/.config")),
_posixify(app_name)) _posixify(app_name),
)
class PacifyFlushWrapper(object): class PacifyFlushWrapper:
"""This wrapper is used to catch and suppress BrokenPipeErrors resulting """This wrapper is used to catch and suppress BrokenPipeErrors resulting
from ``.flush()`` being called on broken pipe during the shutdown/final-GC from ``.flush()`` being called on broken pipe during the shutdown/final-GC
of the Python interpreter. Notably ``.flush()`` is always called on of the Python interpreter. Notably ``.flush()`` is always called on
@ -425,16 +458,123 @@ class PacifyFlushWrapper(object):
pipe, all calls and attributes are proxied. pipe, all calls and attributes are proxied.
""" """
def __init__(self, wrapped): def __init__(self, wrapped: t.IO) -> None:
self.wrapped = wrapped self.wrapped = wrapped
def flush(self): def flush(self) -> None:
try: try:
self.wrapped.flush() self.wrapped.flush()
except IOError as e: except OSError as e:
import errno import errno
if e.errno != errno.EPIPE: if e.errno != errno.EPIPE:
raise raise
def __getattr__(self, attr): def __getattr__(self, attr: str) -> t.Any:
return getattr(self.wrapped, attr) return getattr(self.wrapped, attr)
def _detect_program_name(
path: t.Optional[str] = None, _main: t.Optional[ModuleType] = None
) -> str:
"""Determine the command used to run the program, for use in help
text. If a file or entry point was executed, the file name is
returned. If ``python -m`` was used to execute a module or package,
``python -m name`` is returned.
This doesn't try to be too precise, the goal is to give a concise
name for help text. Files are only shown as their name without the
path. ``python`` is only shown for modules, and the full path to
``sys.executable`` is not shown.
:param path: The Python file being executed. Python puts this in
``sys.argv[0]``, which is used by default.
:param _main: The ``__main__`` module. This should only be passed
during internal testing.
.. versionadded:: 8.0
Based on command args detection in the Werkzeug reloader.
:meta private:
"""
if _main is None:
_main = sys.modules["__main__"]
if not path:
path = sys.argv[0]
# The value of __package__ indicates how Python was called. It may
# not exist if a setuptools script is installed as an egg. It may be
# set incorrectly for entry points created with pip on Windows.
if getattr(_main, "__package__", None) is None or (
os.name == "nt"
and _main.__package__ == ""
and not os.path.exists(path)
and os.path.exists(f"{path}.exe")
):
# Executed a file, like "python app.py".
return os.path.basename(path)
# Executed a module, like "python -m example".
# Rewritten by Python from "-m script" to "/path/to/script.py".
# Need to look at main module to determine how it was executed.
py_module = t.cast(str, _main.__package__)
name = os.path.splitext(os.path.basename(path))[0]
# A submodule like "example.cli".
if name != "__main__":
py_module = f"{py_module}.{name}"
return f"python -m {py_module.lstrip('.')}"
def _expand_args(
args: t.Iterable[str],
*,
user: bool = True,
env: bool = True,
glob_recursive: bool = True,
) -> t.List[str]:
"""Simulate Unix shell expansion with Python functions.
See :func:`glob.glob`, :func:`os.path.expanduser`, and
:func:`os.path.expandvars`.
This is intended for use on Windows, where the shell does not do any
expansion. It may not exactly match what a Unix shell would do.
:param args: List of command line arguments to expand.
:param user: Expand user home directory.
:param env: Expand environment variables.
:param glob_recursive: ``**`` matches directories recursively.
.. versionchanged:: 8.1
Invalid glob patterns are treated as empty expansions rather
than raising an error.
.. versionadded:: 8.0
:meta private:
"""
from glob import glob
out = []
for arg in args:
if user:
arg = os.path.expanduser(arg)
if env:
arg = os.path.expandvars(arg)
try:
matches = glob(arg, recursive=glob_recursive)
except re.error:
matches = []
if not matches:
out.append(arg)
else:
out.extend(matches)
return out

File diff suppressed because it is too large Load diff

View file

@ -1,2 +0,0 @@
"""Project version"""
__version__ = '5.1.0'

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
# ######################### LICENSE ############################ # # ######################### LICENSE ############################ #
# Copyright (c) 2005-2018, Michele Simionato # Copyright (c) 2005-2021, Michele Simionato
# All rights reserved. # All rights reserved.
# Redistribution and use in source and binary forms, with or without # Redistribution and use in source and binary forms, with or without
@ -28,49 +28,26 @@
# DAMAGE. # DAMAGE.
""" """
Decorator module, see http://pypi.python.org/pypi/decorator Decorator module, see
https://github.com/micheles/decorator/blob/master/docs/documentation.md
for the documentation. for the documentation.
""" """
from __future__ import print_function
import re import re
import sys import sys
import inspect import inspect
import operator import operator
import itertools import itertools
import collections from contextlib import _GeneratorContextManager
from inspect import getfullargspec, iscoroutinefunction, isgeneratorfunction
__version__ = '4.3.0'
if sys.version >= '3':
from inspect import getfullargspec
def get_init(cls):
return cls.__init__
else:
FullArgSpec = collections.namedtuple(
'FullArgSpec', 'args varargs varkw defaults '
'kwonlyargs kwonlydefaults annotations')
def getfullargspec(f):
"A quick and dirty replacement for getfullargspec for Python 2.X"
return FullArgSpec._make(inspect.getargspec(f) + ([], None, {}))
def get_init(cls):
return cls.__init__.__func__
try:
iscoroutinefunction = inspect.iscoroutinefunction
except AttributeError:
# let's assume there are no coroutine functions in old Python
def iscoroutinefunction(f):
return False
__version__ = '5.1.1'
DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(') DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(')
POS = inspect.Parameter.POSITIONAL_OR_KEYWORD
EMPTY = inspect.Parameter.empty
# basic functionality # this is not used anymore in the core, but kept for backward compatibility
class FunctionMaker(object): class FunctionMaker(object):
""" """
An object with the ability to create functions with a given signature. An object with the ability to create functions with a given signature.
@ -94,7 +71,7 @@ class FunctionMaker(object):
self.name = '_lambda_' self.name = '_lambda_'
self.doc = func.__doc__ self.doc = func.__doc__
self.module = func.__module__ self.module = func.__module__
if inspect.isfunction(func): if inspect.isroutine(func):
argspec = getfullargspec(func) argspec = getfullargspec(func)
self.annotations = getattr(func, '__annotations__', {}) self.annotations = getattr(func, '__annotations__', {})
for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs',
@ -137,7 +114,9 @@ class FunctionMaker(object):
raise TypeError('You are decorating a non function: %s' % func) raise TypeError('You are decorating a non function: %s' % func)
def update(self, func, **kw): def update(self, func, **kw):
"Update the signature of func with the data in self" """
Update the signature of func with the data in self
"""
func.__name__ = self.name func.__name__ = self.name
func.__doc__ = getattr(self, 'doc', None) func.__doc__ = getattr(self, 'doc', None)
func.__dict__ = getattr(self, 'dict', {}) func.__dict__ = getattr(self, 'dict', {})
@ -154,7 +133,9 @@ class FunctionMaker(object):
func.__dict__.update(kw) func.__dict__.update(kw)
def make(self, src_templ, evaldict=None, addsource=False, **attrs): def make(self, src_templ, evaldict=None, addsource=False, **attrs):
"Make a new function from a given template and update the signature" """
Make a new function from a given template and update the signature
"""
src = src_templ % vars(self) # expand name and signature src = src_templ % vars(self) # expand name and signature
evaldict = evaldict or {} evaldict = evaldict or {}
mo = DEF.search(src) mo = DEF.search(src)
@ -173,7 +154,7 @@ class FunctionMaker(object):
# Ensure each generated function has a unique filename for profilers # Ensure each generated function has a unique filename for profilers
# (such as cProfile) that depend on the tuple of (<filename>, # (such as cProfile) that depend on the tuple of (<filename>,
# <definition line>, <function name>) being unique. # <definition line>, <function name>) being unique.
filename = '<decorator-gen-%d>' % (next(self._compile_count),) filename = '<decorator-gen-%d>' % next(self._compile_count)
try: try:
code = compile(src, filename, 'single') code = compile(src, filename, 'single')
exec(code, evaldict) exec(code, evaldict)
@ -215,90 +196,128 @@ class FunctionMaker(object):
return self.make(body, evaldict, addsource, **attrs) return self.make(body, evaldict, addsource, **attrs)
def decorate(func, caller, extras=()): def fix(args, kwargs, sig):
""" """
decorate(func, caller) decorates a function using a caller. Fix args and kwargs to be consistent with the signature
""" """
evaldict = dict(_call_=caller, _func_=func) ba = sig.bind(*args, **kwargs)
es = '' ba.apply_defaults() # needed for test_dan_schult
for i, extra in enumerate(extras): return ba.args, ba.kwargs
ex = '_e%d_' % i
evaldict[ex] = extra
es += ex + ', ' def decorate(func, caller, extras=(), kwsyntax=False):
fun = FunctionMaker.create( """
func, "return _call_(_func_, %s%%(shortsignature)s)" % es, Decorates a function/generator/coroutine using a caller.
evaldict, __wrapped__=func) If kwsyntax is True calling the decorated functions with keyword
if hasattr(func, '__qualname__'): syntax will pass the named arguments inside the ``kw`` dictionary,
fun.__qualname__ = func.__qualname__ even if such argument are positional, similarly to what functools.wraps
does. By default kwsyntax is False and the the arguments are untouched.
"""
sig = inspect.signature(func)
if iscoroutinefunction(caller):
async def fun(*args, **kw):
if not kwsyntax:
args, kw = fix(args, kw, sig)
return await caller(func, *(extras + args), **kw)
elif isgeneratorfunction(caller):
def fun(*args, **kw):
if not kwsyntax:
args, kw = fix(args, kw, sig)
for res in caller(func, *(extras + args), **kw):
yield res
else:
def fun(*args, **kw):
if not kwsyntax:
args, kw = fix(args, kw, sig)
return caller(func, *(extras + args), **kw)
fun.__name__ = func.__name__
fun.__doc__ = func.__doc__
fun.__wrapped__ = func
fun.__signature__ = sig
fun.__qualname__ = func.__qualname__
# builtin functions like defaultdict.__setitem__ lack many attributes
try:
fun.__defaults__ = func.__defaults__
except AttributeError:
pass
try:
fun.__kwdefaults__ = func.__kwdefaults__
except AttributeError:
pass
try:
fun.__annotations__ = func.__annotations__
except AttributeError:
pass
try:
fun.__module__ = func.__module__
except AttributeError:
pass
try:
fun.__dict__.update(func.__dict__)
except AttributeError:
pass
return fun return fun
def decorator(caller, _func=None): def decoratorx(caller):
"""decorator(caller) converts a caller function into a decorator""" """
A version of "decorator" implemented via "exec" and not via the
Signature object. Use this if you are want to preserve the `.__code__`
object properties (https://github.com/micheles/decorator/issues/129).
"""
def dec(func):
return FunctionMaker.create(
func,
"return _call_(_func_, %(shortsignature)s)",
dict(_call_=caller, _func_=func),
__wrapped__=func, __qualname__=func.__qualname__)
return dec
def decorator(caller, _func=None, kwsyntax=False):
"""
decorator(caller) converts a caller function into a decorator
"""
if _func is not None: # return a decorated function if _func is not None: # return a decorated function
# this is obsolete behavior; you should use decorate instead # this is obsolete behavior; you should use decorate instead
return decorate(_func, caller) return decorate(_func, caller, (), kwsyntax)
# else return a decorator function # else return a decorator function
defaultargs, defaults = '', () sig = inspect.signature(caller)
if inspect.isclass(caller): dec_params = [p for p in sig.parameters.values() if p.kind is POS]
name = caller.__name__.lower()
doc = 'decorator(%s) converts functions/generators into ' \ def dec(func=None, *args, **kw):
'factories of %s objects' % (caller.__name__, caller.__name__) na = len(args) + 1
elif inspect.isfunction(caller): extras = args + tuple(kw.get(p.name, p.default)
if caller.__name__ == '<lambda>': for p in dec_params[na:]
name = '_lambda_' if p.default is not EMPTY)
if func is None:
return lambda func: decorate(func, caller, extras, kwsyntax)
else: else:
name = caller.__name__ return decorate(func, caller, extras, kwsyntax)
doc = caller.__doc__ dec.__signature__ = sig.replace(parameters=dec_params)
nargs = caller.__code__.co_argcount dec.__name__ = caller.__name__
ndefs = len(caller.__defaults__ or ()) dec.__doc__ = caller.__doc__
defaultargs = ', '.join(caller.__code__.co_varnames[nargs-ndefs:nargs]) dec.__wrapped__ = caller
if defaultargs: dec.__qualname__ = caller.__qualname__
defaultargs += ',' dec.__kwdefaults__ = getattr(caller, '__kwdefaults__', None)
defaults = caller.__defaults__ dec.__dict__.update(caller.__dict__)
else: # assume caller is an object with a __call__ method
name = caller.__class__.__name__.lower()
doc = caller.__call__.__doc__
evaldict = dict(_call=caller, _decorate_=decorate)
dec = FunctionMaker.create(
'%s(%s func)' % (name, defaultargs),
'if func is None: return lambda func: _decorate_(func, _call, (%s))\n'
'return _decorate_(func, _call, (%s))' % (defaultargs, defaultargs),
evaldict, doc=doc, module=caller.__module__, __wrapped__=caller)
if defaults:
dec.__defaults__ = defaults + (None,)
return dec return dec
# ####################### contextmanager ####################### # # ####################### contextmanager ####################### #
try: # Python >= 3.2
from contextlib import _GeneratorContextManager
except ImportError: # Python >= 2.5
from contextlib import GeneratorContextManager as _GeneratorContextManager
class ContextManager(_GeneratorContextManager): class ContextManager(_GeneratorContextManager):
def __init__(self, g, *a, **k):
_GeneratorContextManager.__init__(self, g, a, k)
def __call__(self, func): def __call__(self, func):
"""Context manager decorator""" def caller(f, *a, **k):
return FunctionMaker.create( with self.__class__(self.func, *self.args, **self.kwds):
func, "with _self_: return _func_(%(shortsignature)s)", return f(*a, **k)
dict(_self_=self, _func_=func), __wrapped__=func) return decorate(func, caller)
init = getfullargspec(_GeneratorContextManager.__init__)
n_args = len(init.args)
if n_args == 2 and not init.varargs: # (self, genobj) Python 2.7
def __init__(self, g, *a, **k):
return _GeneratorContextManager.__init__(self, g(*a, **k))
ContextManager.__init__ = __init__
elif n_args == 2 and init.varargs: # (self, gen, *a, **k) Python 3.4
pass
elif n_args == 4: # (self, gen, args, kwds) Python 3.5
def __init__(self, g, *a, **k):
return _GeneratorContextManager.__init__(self, g, a, k)
ContextManager.__init__ = __init__
_contextmanager = decorator(ContextManager) _contextmanager = decorator(ContextManager)

View file

@ -1,4 +1,4 @@
__version__ = '0.7.1' __version__ = "1.1.8"
from .lock import Lock # noqa from .lock import Lock # noqa
from .lock import NeedRegenerationException # noqa from .lock import NeedRegenerationException # noqa

View file

@ -1,4 +1,6 @@
from .region import CacheRegion, register_backend, make_region # noqa from .region import CacheRegion # noqa
from .region import make_region # noqa
from .region import register_backend # noqa
from .. import __version__ # noqa
# backwards compat # backwards compat
from .. import __version__ # noqa

View file

@ -1,14 +1,22 @@
import operator import abc
from ..util.compat import py3k import pickle
from typing import Any
from typing import Callable
from typing import cast
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Union
class NoValue(object): class NoValue:
"""Describe a missing cache value. """Describe a missing cache value.
The :attr:`.NO_VALUE` module global The :data:`.NO_VALUE` constant should be used.
should be used.
""" """
@property @property
def payload(self): def payload(self):
return self return self
@ -18,49 +26,125 @@ class NoValue(object):
fill another cache key. fill another cache key.
""" """
return '<dogpile.cache.api.NoValue object>' return "<dogpile.cache.api.NoValue object>"
if py3k: def __bool__(self): # pragma NO COVERAGE
def __bool__(self): # pragma NO COVERAGE return False
return False
else:
def __nonzero__(self): # pragma NO COVERAGE
return False
NO_VALUE = NoValue() NO_VALUE = NoValue()
"""Value returned from ``get()`` that describes """Value returned from ``get()`` that describes
a key not present.""" a key not present."""
MetaDataType = Mapping[str, Any]
class CachedValue(tuple):
KeyType = str
"""A cache key."""
ValuePayload = Any
"""An object to be placed in the cache against a key."""
KeyManglerType = Callable[[KeyType], KeyType]
Serializer = Callable[[ValuePayload], bytes]
Deserializer = Callable[[bytes], ValuePayload]
class CacheMutex(abc.ABC):
"""Describes a mutexing object with acquire and release methods.
This is an abstract base class; any object that has acquire/release
methods may be used.
.. versionadded:: 1.1
.. seealso::
:meth:`.CacheBackend.get_mutex` - the backend method that optionally
returns this locking object.
"""
@abc.abstractmethod
def acquire(self, wait: bool = True) -> bool:
"""Acquire the mutex.
:param wait: if True, block until available, else return True/False
immediately.
:return: True if the lock succeeded.
"""
raise NotImplementedError()
@abc.abstractmethod
def release(self) -> None:
"""Release the mutex."""
raise NotImplementedError()
@abc.abstractmethod
def locked(self) -> bool:
"""Check if the mutex was acquired.
:return: true if the lock is acquired.
.. versionadded:: 1.1.2
"""
raise NotImplementedError()
@classmethod
def __subclasshook__(cls, C):
return hasattr(C, "acquire") and hasattr(C, "release")
class CachedValue(NamedTuple):
"""Represent a value stored in the cache. """Represent a value stored in the cache.
:class:`.CachedValue` is a two-tuple of :class:`.CachedValue` is a two-tuple of
``(payload, metadata)``, where ``metadata`` ``(payload, metadata)``, where ``metadata``
is dogpile.cache's tracking information ( is dogpile.cache's tracking information (
currently the creation time). The metadata currently the creation time).
and tuple structure is pickleable, if
the backend requires serialization.
""" """
payload = property(operator.itemgetter(0))
"""Named accessor for the payload."""
metadata = property(operator.itemgetter(1)) payload: ValuePayload
"""Named accessor for the dogpile.cache metadata dictionary."""
def __new__(cls, payload, metadata): metadata: MetaDataType
return tuple.__new__(cls, (payload, metadata))
def __reduce__(self):
return CachedValue, (self.payload, self.metadata)
class CacheBackend(object): CacheReturnType = Union[CachedValue, NoValue]
"""Base class for backend implementations.""" """The non-serialized form of what may be returned from a backend
get method.
key_mangler = None """
SerializedReturnType = Union[bytes, NoValue]
"""the serialized form of what may be returned from a backend get method."""
BackendFormatted = Union[CacheReturnType, SerializedReturnType]
"""Describes the type returned from the :meth:`.CacheBackend.get` method."""
BackendSetType = Union[CachedValue, bytes]
"""Describes the value argument passed to the :meth:`.CacheBackend.set`
method."""
BackendArguments = Mapping[str, Any]
class CacheBackend:
"""Base class for backend implementations.
Backends which set and get Python object values should subclass this
backend. For backends in which the value that's stored is ultimately
a stream of bytes, the :class:`.BytesBackend` should be used.
"""
key_mangler: Optional[Callable[[KeyType], KeyType]] = None
"""Key mangling function. """Key mangling function.
May be None, or otherwise declared May be None, or otherwise declared
@ -68,7 +152,23 @@ class CacheBackend(object):
""" """
def __init__(self, arguments): # pragma NO COVERAGE serializer: Union[None, Serializer] = None
"""Serializer function that will be used by default if not overridden
by the region.
.. versionadded:: 1.1
"""
deserializer: Union[None, Deserializer] = None
"""deserializer function that will be used by default if not overridden
by the region.
.. versionadded:: 1.1
"""
def __init__(self, arguments: BackendArguments): # pragma NO COVERAGE
"""Construct a new :class:`.CacheBackend`. """Construct a new :class:`.CacheBackend`.
Subclasses should override this to Subclasses should override this to
@ -91,10 +191,10 @@ class CacheBackend(object):
) )
) )
def has_lock_timeout(self): def has_lock_timeout(self) -> bool:
return False return False
def get_mutex(self, key): def get_mutex(self, key: KeyType) -> Optional[CacheMutex]:
"""Return an optional mutexing object for the given key. """Return an optional mutexing object for the given key.
This object need only provide an ``acquire()`` This object need only provide an ``acquire()``
@ -127,48 +227,141 @@ class CacheBackend(object):
""" """
return None return None
def get(self, key): # pragma NO COVERAGE def get(self, key: KeyType) -> BackendFormatted: # pragma NO COVERAGE
"""Retrieve a value from the cache. """Retrieve an optionally serialized value from the cache.
The returned value should be an instance of :param key: String key that was passed to the :meth:`.CacheRegion.get`
:class:`.CachedValue`, or ``NO_VALUE`` if method, which will also be processed by the "key mangling" function
not present. if one was present.
:return: the Python object that corresponds to
what was established via the :meth:`.CacheBackend.set` method,
or the :data:`.NO_VALUE` constant if not present.
If a serializer is in use, this method will only be called if the
:meth:`.CacheBackend.get_serialized` method is not overridden.
""" """
raise NotImplementedError() raise NotImplementedError()
def get_multi(self, keys): # pragma NO COVERAGE def get_multi(
"""Retrieve multiple values from the cache. self, keys: Sequence[KeyType]
) -> Sequence[BackendFormatted]: # pragma NO COVERAGE
"""Retrieve multiple optionally serialized values from the cache.
The returned value should be a list, corresponding :param keys: sequence of string keys that was passed to the
to the list of keys given. :meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.
:return a list of values as would be returned
individually via the :meth:`.CacheBackend.get` method, corresponding
to the list of keys given.
If a serializer is in use, this method will only be called if the
:meth:`.CacheBackend.get_serialized_multi` method is not overridden.
.. versionadded:: 0.5.0 .. versionadded:: 0.5.0
""" """
raise NotImplementedError() raise NotImplementedError()
def set(self, key, value): # pragma NO COVERAGE def get_serialized(self, key: KeyType) -> SerializedReturnType:
"""Set a value in the cache. """Retrieve a serialized value from the cache.
The key will be whatever was passed :param key: String key that was passed to the :meth:`.CacheRegion.get`
to the registry, processed by the method, which will also be processed by the "key mangling" function
"key mangling" function, if any. if one was present.
The value will always be an instance
of :class:`.CachedValue`. :return: a bytes object, or :data:`.NO_VALUE`
constant if not present.
The default implementation of this method for :class:`.CacheBackend`
returns the value of the :meth:`.CacheBackend.get` method.
.. versionadded:: 1.1
.. seealso::
:class:`.BytesBackend`
"""
return cast(SerializedReturnType, self.get(key))
def get_serialized_multi(
self, keys: Sequence[KeyType]
) -> Sequence[SerializedReturnType]: # pragma NO COVERAGE
"""Retrieve multiple serialized values from the cache.
:param keys: sequence of string keys that was passed to the
:meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.
:return: list of bytes objects
The default implementation of this method for :class:`.CacheBackend`
returns the value of the :meth:`.CacheBackend.get_multi` method.
.. versionadded:: 1.1
.. seealso::
:class:`.BytesBackend`
"""
return cast(Sequence[SerializedReturnType], self.get_multi(keys))
def set(
self, key: KeyType, value: BackendSetType
) -> None: # pragma NO COVERAGE
"""Set an optionally serialized value in the cache.
:param key: String key that was passed to the :meth:`.CacheRegion.set`
method, which will also be processed by the "key mangling" function
if one was present.
:param value: The optionally serialized :class:`.CachedValue` object.
May be an instance of :class:`.CachedValue` or a bytes object
depending on if a serializer is in use with the region and if the
:meth:`.CacheBackend.set_serialized` method is not overridden.
.. seealso::
:meth:`.CacheBackend.set_serialized`
""" """
raise NotImplementedError() raise NotImplementedError()
def set_multi(self, mapping): # pragma NO COVERAGE def set_serialized(
self, key: KeyType, value: bytes
) -> None: # pragma NO COVERAGE
"""Set a serialized value in the cache.
:param key: String key that was passed to the :meth:`.CacheRegion.set`
method, which will also be processed by the "key mangling" function
if one was present.
:param value: a bytes object to be stored.
The default implementation of this method for :class:`.CacheBackend`
calls upon the :meth:`.CacheBackend.set` method.
.. versionadded:: 1.1
.. seealso::
:class:`.BytesBackend`
"""
self.set(key, value)
def set_multi(
self, mapping: Mapping[KeyType, BackendSetType]
) -> None: # pragma NO COVERAGE
"""Set multiple values in the cache. """Set multiple values in the cache.
``mapping`` is a dict in which :param mapping: a dict in which the key will be whatever was passed to
the key will be whatever was passed the :meth:`.CacheRegion.set_multi` method, processed by the "key
to the registry, processed by the mangling" function, if any.
"key mangling" function, if any.
The value will always be an instance
of :class:`.CachedValue`.
When implementing a new :class:`.CacheBackend` or cutomizing via When implementing a new :class:`.CacheBackend` or cutomizing via
:class:`.ProxyBackend`, be aware that when this method is invoked by :class:`.ProxyBackend`, be aware that when this method is invoked by
@ -178,17 +371,52 @@ class CacheBackend(object):
-- that will have the undesirable effect of modifying the returned -- that will have the undesirable effect of modifying the returned
values as well. values as well.
If a serializer is in use, this method will only be called if the
:meth:`.CacheBackend.set_serialized_multi` method is not overridden.
.. versionadded:: 0.5.0 .. versionadded:: 0.5.0
""" """
raise NotImplementedError() raise NotImplementedError()
def delete(self, key): # pragma NO COVERAGE def set_serialized_multi(
self, mapping: Mapping[KeyType, bytes]
) -> None: # pragma NO COVERAGE
"""Set multiple serialized values in the cache.
:param mapping: a dict in which the key will be whatever was passed to
the :meth:`.CacheRegion.set_multi` method, processed by the "key
mangling" function, if any.
When implementing a new :class:`.CacheBackend` or cutomizing via
:class:`.ProxyBackend`, be aware that when this method is invoked by
:meth:`.Region.get_or_create_multi`, the ``mapping`` values are the
same ones returned to the upstream caller. If the subclass alters the
values in any way, it must not do so 'in-place' on the ``mapping`` dict
-- that will have the undesirable effect of modifying the returned
values as well.
.. versionadded:: 1.1
The default implementation of this method for :class:`.CacheBackend`
calls upon the :meth:`.CacheBackend.set_multi` method.
.. seealso::
:class:`.BytesBackend`
"""
self.set_multi(mapping)
def delete(self, key: KeyType) -> None: # pragma NO COVERAGE
"""Delete a value from the cache. """Delete a value from the cache.
The key will be whatever was passed :param key: String key that was passed to the
to the registry, processed by the :meth:`.CacheRegion.delete`
"key mangling" function, if any. method, which will also be processed by the "key mangling" function
if one was present.
The behavior here should be idempotent, The behavior here should be idempotent,
that is, can be called any number of times that is, can be called any number of times
@ -197,12 +425,14 @@ class CacheBackend(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def delete_multi(self, keys): # pragma NO COVERAGE def delete_multi(
self, keys: Sequence[KeyType]
) -> None: # pragma NO COVERAGE
"""Delete multiple values from the cache. """Delete multiple values from the cache.
The key will be whatever was passed :param keys: sequence of string keys that was passed to the
to the registry, processed by the :meth:`.CacheRegion.delete_multi` method, which will also be processed
"key mangling" function, if any. by the "key mangling" function if one was present.
The behavior here should be idempotent, The behavior here should be idempotent,
that is, can be called any number of times that is, can be called any number of times
@ -213,3 +443,95 @@ class CacheBackend(object):
""" """
raise NotImplementedError() raise NotImplementedError()
class DefaultSerialization:
serializer: Union[None, Serializer] = staticmethod( # type: ignore
pickle.dumps
)
deserializer: Union[None, Deserializer] = staticmethod( # type: ignore
pickle.loads
)
class BytesBackend(DefaultSerialization, CacheBackend):
"""A cache backend that receives and returns series of bytes.
This backend only supports the "serialized" form of values; subclasses
should implement :meth:`.BytesBackend.get_serialized`,
:meth:`.BytesBackend.get_serialized_multi`,
:meth:`.BytesBackend.set_serialized`,
:meth:`.BytesBackend.set_serialized_multi`.
.. versionadded:: 1.1
"""
def get_serialized(self, key: KeyType) -> SerializedReturnType:
"""Retrieve a serialized value from the cache.
:param key: String key that was passed to the :meth:`.CacheRegion.get`
method, which will also be processed by the "key mangling" function
if one was present.
:return: a bytes object, or :data:`.NO_VALUE`
constant if not present.
.. versionadded:: 1.1
"""
raise NotImplementedError()
def get_serialized_multi(
self, keys: Sequence[KeyType]
) -> Sequence[SerializedReturnType]: # pragma NO COVERAGE
"""Retrieve multiple serialized values from the cache.
:param keys: sequence of string keys that was passed to the
:meth:`.CacheRegion.get_multi` method, which will also be processed
by the "key mangling" function if one was present.
:return: list of bytes objects
.. versionadded:: 1.1
"""
raise NotImplementedError()
def set_serialized(
self, key: KeyType, value: bytes
) -> None: # pragma NO COVERAGE
"""Set a serialized value in the cache.
:param key: String key that was passed to the :meth:`.CacheRegion.set`
method, which will also be processed by the "key mangling" function
if one was present.
:param value: a bytes object to be stored.
.. versionadded:: 1.1
"""
raise NotImplementedError()
def set_serialized_multi(
self, mapping: Mapping[KeyType, bytes]
) -> None: # pragma NO COVERAGE
"""Set multiple serialized values in the cache.
:param mapping: a dict in which the key will be whatever was passed to
the :meth:`.CacheRegion.set_multi` method, processed by the "key
mangling" function, if any.
When implementing a new :class:`.CacheBackend` or cutomizing via
:class:`.ProxyBackend`, be aware that when this method is invoked by
:meth:`.Region.get_or_create_multi`, the ``mapping`` values are the
same ones returned to the upstream caller. If the subclass alters the
values in any way, it must not do so 'in-place' on the ``mapping`` dict
-- that will have the undesirable effect of modifying the returned
values as well.
.. versionadded:: 1.1
"""
raise NotImplementedError()

View file

@ -1,22 +1,47 @@
from dogpile.cache.region import register_backend from ...util import PluginLoader
_backend_loader = PluginLoader("dogpile.cache")
register_backend = _backend_loader.register
register_backend( register_backend(
"dogpile.cache.null", "dogpile.cache.backends.null", "NullBackend") "dogpile.cache.null", "dogpile.cache.backends.null", "NullBackend"
)
register_backend( register_backend(
"dogpile.cache.dbm", "dogpile.cache.backends.file", "DBMBackend") "dogpile.cache.dbm", "dogpile.cache.backends.file", "DBMBackend"
)
register_backend( register_backend(
"dogpile.cache.pylibmc", "dogpile.cache.backends.memcached", "dogpile.cache.pylibmc",
"PylibmcBackend") "dogpile.cache.backends.memcached",
"PylibmcBackend",
)
register_backend( register_backend(
"dogpile.cache.bmemcached", "dogpile.cache.backends.memcached", "dogpile.cache.bmemcached",
"BMemcachedBackend") "dogpile.cache.backends.memcached",
"BMemcachedBackend",
)
register_backend( register_backend(
"dogpile.cache.memcached", "dogpile.cache.backends.memcached", "dogpile.cache.memcached",
"MemcachedBackend") "dogpile.cache.backends.memcached",
"MemcachedBackend",
)
register_backend( register_backend(
"dogpile.cache.memory", "dogpile.cache.backends.memory", "MemoryBackend") "dogpile.cache.pymemcache",
"dogpile.cache.backends.memcached",
"PyMemcacheBackend",
)
register_backend( register_backend(
"dogpile.cache.memory_pickle", "dogpile.cache.backends.memory", "dogpile.cache.memory", "dogpile.cache.backends.memory", "MemoryBackend"
"MemoryPickleBackend") )
register_backend( register_backend(
"dogpile.cache.redis", "dogpile.cache.backends.redis", "RedisBackend") "dogpile.cache.memory_pickle",
"dogpile.cache.backends.memory",
"MemoryPickleBackend",
)
register_backend(
"dogpile.cache.redis", "dogpile.cache.backends.redis", "RedisBackend"
)
register_backend(
"dogpile.cache.redis_sentinel",
"dogpile.cache.backends.redis",
"RedisSentinelBackend",
)

View file

@ -7,16 +7,20 @@ Provides backends that deal with local filesystem access.
""" """
from __future__ import with_statement from __future__ import with_statement
from ..api import CacheBackend, NO_VALUE
from contextlib import contextmanager from contextlib import contextmanager
from ...util import compat import dbm
from ... import util
import os import os
import threading
__all__ = 'DBMBackend', 'FileLock', 'AbstractFileLock' from ..api import BytesBackend
from ..api import NO_VALUE
from ... import util
__all__ = ["DBMBackend", "FileLock", "AbstractFileLock"]
class DBMBackend(CacheBackend): class DBMBackend(BytesBackend):
"""A file-backend using a dbm file to store keys. """A file-backend using a dbm file to store keys.
Basic usage:: Basic usage::
@ -134,28 +138,25 @@ class DBMBackend(CacheBackend):
""" """
def __init__(self, arguments): def __init__(self, arguments):
self.filename = os.path.abspath( self.filename = os.path.abspath(
os.path.normpath(arguments['filename']) os.path.normpath(arguments["filename"])
) )
dir_, filename = os.path.split(self.filename) dir_, filename = os.path.split(self.filename)
self.lock_factory = arguments.get("lock_factory", FileLock) self.lock_factory = arguments.get("lock_factory", FileLock)
self._rw_lock = self._init_lock( self._rw_lock = self._init_lock(
arguments.get('rw_lockfile'), arguments.get("rw_lockfile"), ".rw.lock", dir_, filename
".rw.lock", dir_, filename) )
self._dogpile_lock = self._init_lock( self._dogpile_lock = self._init_lock(
arguments.get('dogpile_lockfile'), arguments.get("dogpile_lockfile"),
".dogpile.lock", ".dogpile.lock",
dir_, filename, dir_,
util.KeyReentrantMutex.factory) filename,
util.KeyReentrantMutex.factory,
)
# TODO: make this configurable
if compat.py3k:
import dbm
else:
import anydbm as dbm
self.dbmmodule = dbm
self._init_dbm_file() self._init_dbm_file()
def _init_lock(self, argument, suffix, basedir, basefile, wrapper=None): def _init_lock(self, argument, suffix, basedir, basefile, wrapper=None):
@ -163,9 +164,8 @@ class DBMBackend(CacheBackend):
lock = self.lock_factory(os.path.join(basedir, basefile + suffix)) lock = self.lock_factory(os.path.join(basedir, basefile + suffix))
elif argument is not False: elif argument is not False:
lock = self.lock_factory( lock = self.lock_factory(
os.path.abspath( os.path.abspath(os.path.normpath(argument))
os.path.normpath(argument) )
))
else: else:
return None return None
if wrapper: if wrapper:
@ -175,12 +175,12 @@ class DBMBackend(CacheBackend):
def _init_dbm_file(self): def _init_dbm_file(self):
exists = os.access(self.filename, os.F_OK) exists = os.access(self.filename, os.F_OK)
if not exists: if not exists:
for ext in ('db', 'dat', 'pag', 'dir'): for ext in ("db", "dat", "pag", "dir"):
if os.access(self.filename + os.extsep + ext, os.F_OK): if os.access(self.filename + os.extsep + ext, os.F_OK):
exists = True exists = True
break break
if not exists: if not exists:
fh = self.dbmmodule.open(self.filename, 'c') fh = dbm.open(self.filename, "c")
fh.close() fh.close()
def get_mutex(self, key): def get_mutex(self, key):
@ -210,57 +210,50 @@ class DBMBackend(CacheBackend):
@contextmanager @contextmanager
def _dbm_file(self, write): def _dbm_file(self, write):
with self._use_rw_lock(write): with self._use_rw_lock(write):
dbm = self.dbmmodule.open( with dbm.open(self.filename, "w" if write else "r") as dbm_obj:
self.filename, yield dbm_obj
"w" if write else "r")
yield dbm
dbm.close()
def get(self, key): def get_serialized(self, key):
with self._dbm_file(False) as dbm: with self._dbm_file(False) as dbm_obj:
if hasattr(dbm, 'get'): if hasattr(dbm_obj, "get"):
value = dbm.get(key, NO_VALUE) value = dbm_obj.get(key, NO_VALUE)
else: else:
# gdbm objects lack a .get method # gdbm objects lack a .get method
try: try:
value = dbm[key] value = dbm_obj[key]
except KeyError: except KeyError:
value = NO_VALUE value = NO_VALUE
if value is not NO_VALUE:
value = compat.pickle.loads(value)
return value return value
def get_multi(self, keys): def get_serialized_multi(self, keys):
return [self.get(key) for key in keys] return [self.get_serialized(key) for key in keys]
def set(self, key, value): def set_serialized(self, key, value):
with self._dbm_file(True) as dbm: with self._dbm_file(True) as dbm_obj:
dbm[key] = compat.pickle.dumps(value, dbm_obj[key] = value
compat.pickle.HIGHEST_PROTOCOL)
def set_multi(self, mapping): def set_serialized_multi(self, mapping):
with self._dbm_file(True) as dbm: with self._dbm_file(True) as dbm_obj:
for key, value in mapping.items(): for key, value in mapping.items():
dbm[key] = compat.pickle.dumps(value, dbm_obj[key] = value
compat.pickle.HIGHEST_PROTOCOL)
def delete(self, key): def delete(self, key):
with self._dbm_file(True) as dbm: with self._dbm_file(True) as dbm_obj:
try: try:
del dbm[key] del dbm_obj[key]
except KeyError: except KeyError:
pass pass
def delete_multi(self, keys): def delete_multi(self, keys):
with self._dbm_file(True) as dbm: with self._dbm_file(True) as dbm_obj:
for key in keys: for key in keys:
try: try:
del dbm[key] del dbm_obj[key]
except KeyError: except KeyError:
pass pass
class AbstractFileLock(object): class AbstractFileLock:
"""Coordinate read/write access to a file. """Coordinate read/write access to a file.
typically is a file-based lock but doesn't necessarily have to be. typically is a file-based lock but doesn't necessarily have to be.
@ -392,17 +385,18 @@ class FileLock(AbstractFileLock):
""" """
def __init__(self, filename): def __init__(self, filename):
self._filedescriptor = compat.threading.local() self._filedescriptor = threading.local()
self.filename = filename self.filename = filename
@util.memoized_property @util.memoized_property
def _module(self): def _module(self):
import fcntl import fcntl
return fcntl return fcntl
@property @property
def is_open(self): def is_open(self):
return hasattr(self._filedescriptor, 'fileno') return hasattr(self._filedescriptor, "fileno")
def acquire_read_lock(self, wait): def acquire_read_lock(self, wait):
return self._acquire(wait, os.O_RDONLY, self._module.LOCK_SH) return self._acquire(wait, os.O_RDONLY, self._module.LOCK_SH)

View file

@ -6,23 +6,43 @@ Provides backends for talking to `memcached <http://memcached.org>`_.
""" """
from ..api import CacheBackend, NO_VALUE
from ...util import compat
from ... import util
import random import random
import threading
import time import time
import typing
from typing import Any
from typing import Mapping
import warnings
__all__ = 'GenericMemcachedBackend', 'MemcachedBackend',\ from ..api import CacheBackend
'PylibmcBackend', 'BMemcachedBackend', 'MemcachedLock' from ..api import NO_VALUE
from ... import util
if typing.TYPE_CHECKING:
import bmemcached
import memcache
import pylibmc
import pymemcache
else:
# delayed import
bmemcached = None # noqa F811
memcache = None # noqa F811
pylibmc = None # noqa F811
pymemcache = None # noqa F811
__all__ = (
"GenericMemcachedBackend",
"MemcachedBackend",
"PylibmcBackend",
"PyMemcacheBackend",
"BMemcachedBackend",
"MemcachedLock",
)
class MemcachedLock(object): class MemcachedLock(object):
"""Simple distributed lock using memcached. """Simple distributed lock using memcached."""
This is an adaptation of the lock featured at
http://amix.dk/blog/post/19386
"""
def __init__(self, client_fn, key, timeout=0): def __init__(self, client_fn, key, timeout=0):
self.client_fn = client_fn self.client_fn = client_fn
@ -38,11 +58,15 @@ class MemcachedLock(object):
elif not wait: elif not wait:
return False return False
else: else:
sleep_time = (((i + 1) * random.random()) + 2 ** i) / 2.5 sleep_time = (((i + 1) * random.random()) + 2**i) / 2.5
time.sleep(sleep_time) time.sleep(sleep_time)
if i < 15: if i < 15:
i += 1 i += 1
def locked(self):
client = self.client_fn()
return client.get(self.key) is not None
def release(self): def release(self):
client = self.client_fn() client = self.client_fn()
client.delete(self.key) client.delete(self.key)
@ -100,10 +124,17 @@ class GenericMemcachedBackend(CacheBackend):
""" """
set_arguments = {} set_arguments: Mapping[str, Any] = {}
"""Additional arguments which will be passed """Additional arguments which will be passed
to the :meth:`set` method.""" to the :meth:`set` method."""
# No need to override serializer, as all the memcached libraries
# handles that themselves. Still, we support customizing the
# serializer/deserializer to use better default pickle protocol
# or completely different serialization mechanism
serializer = None
deserializer = None
def __init__(self, arguments): def __init__(self, arguments):
self._imports() self._imports()
# using a plain threading.local here. threading.local # using a plain threading.local here. threading.local
@ -111,11 +142,10 @@ class GenericMemcachedBackend(CacheBackend):
# so the idea is that this is superior to pylibmc's # so the idea is that this is superior to pylibmc's
# own ThreadMappedPool which doesn't handle this # own ThreadMappedPool which doesn't handle this
# automatically. # automatically.
self.url = util.to_list(arguments['url']) self.url = util.to_list(arguments["url"])
self.distributed_lock = arguments.get('distributed_lock', False) self.distributed_lock = arguments.get("distributed_lock", False)
self.lock_timeout = arguments.get('lock_timeout', 0) self.lock_timeout = arguments.get("lock_timeout", 0)
self.memcached_expire_time = arguments.get( self.memcached_expire_time = arguments.get("memcached_expire_time", 0)
'memcached_expire_time', 0)
def has_lock_timeout(self): def has_lock_timeout(self):
return self.lock_timeout != 0 return self.lock_timeout != 0
@ -132,7 +162,7 @@ class GenericMemcachedBackend(CacheBackend):
def _clients(self): def _clients(self):
backend = self backend = self
class ClientPool(compat.threading.local): class ClientPool(threading.local):
def __init__(self): def __init__(self):
self.memcached = backend._create_client() self.memcached = backend._create_client()
@ -152,8 +182,9 @@ class GenericMemcachedBackend(CacheBackend):
def get_mutex(self, key): def get_mutex(self, key):
if self.distributed_lock: if self.distributed_lock:
return MemcachedLock(lambda: self.client, key, return MemcachedLock(
timeout=self.lock_timeout) lambda: self.client, key, timeout=self.lock_timeout
)
else: else:
return None return None
@ -166,23 +197,18 @@ class GenericMemcachedBackend(CacheBackend):
def get_multi(self, keys): def get_multi(self, keys):
values = self.client.get_multi(keys) values = self.client.get_multi(keys)
return [ return [
NO_VALUE if key not in values NO_VALUE if val is None else val
else values[key] for key in keys for val in [values.get(key, NO_VALUE) for key in keys]
] ]
def set(self, key, value): def set(self, key, value):
self.client.set( self.client.set(key, value, **self.set_arguments)
key,
value,
**self.set_arguments
)
def set_multi(self, mapping): def set_multi(self, mapping):
self.client.set_multi( mapping = {key: value for key, value in mapping.items()}
mapping, self.client.set_multi(mapping, **self.set_arguments)
**self.set_arguments
)
def delete(self, key): def delete(self, key):
self.client.delete(key) self.client.delete(key)
@ -191,24 +217,23 @@ class GenericMemcachedBackend(CacheBackend):
self.client.delete_multi(keys) self.client.delete_multi(keys)
class MemcacheArgs(object): class MemcacheArgs(GenericMemcachedBackend):
"""Mixin which provides support for the 'time' argument to set(), """Mixin which provides support for the 'time' argument to set(),
'min_compress_len' to other methods. 'min_compress_len' to other methods.
""" """
def __init__(self, arguments): def __init__(self, arguments):
self.min_compress_len = arguments.get('min_compress_len', 0) self.min_compress_len = arguments.get("min_compress_len", 0)
self.set_arguments = {} self.set_arguments = {}
if "memcached_expire_time" in arguments: if "memcached_expire_time" in arguments:
self.set_arguments["time"] = arguments["memcached_expire_time"] self.set_arguments["time"] = arguments["memcached_expire_time"]
if "min_compress_len" in arguments: if "min_compress_len" in arguments:
self.set_arguments["min_compress_len"] = \ self.set_arguments["min_compress_len"] = arguments[
arguments["min_compress_len"] "min_compress_len"
]
super(MemcacheArgs, self).__init__(arguments) super(MemcacheArgs, self).__init__(arguments)
pylibmc = None
class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend): class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend):
"""A backend for the """A backend for the
@ -245,8 +270,8 @@ class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend):
""" """
def __init__(self, arguments): def __init__(self, arguments):
self.binary = arguments.get('binary', False) self.binary = arguments.get("binary", False)
self.behaviors = arguments.get('behaviors', {}) self.behaviors = arguments.get("behaviors", {})
super(PylibmcBackend, self).__init__(arguments) super(PylibmcBackend, self).__init__(arguments)
def _imports(self): def _imports(self):
@ -255,13 +280,9 @@ class PylibmcBackend(MemcacheArgs, GenericMemcachedBackend):
def _create_client(self): def _create_client(self):
return pylibmc.Client( return pylibmc.Client(
self.url, self.url, binary=self.binary, behaviors=self.behaviors
binary=self.binary,
behaviors=self.behaviors
) )
memcache = None
class MemcachedBackend(MemcacheArgs, GenericMemcachedBackend): class MemcachedBackend(MemcacheArgs, GenericMemcachedBackend):
"""A backend using the standard """A backend using the standard
@ -281,16 +302,39 @@ class MemcachedBackend(MemcacheArgs, GenericMemcachedBackend):
} }
) )
:param dead_retry: Number of seconds memcached server is considered dead
before it is tried again. Will be passed to ``memcache.Client``
as the ``dead_retry`` parameter.
.. versionchanged:: 1.1.8 Moved the ``dead_retry`` argument which was
erroneously added to "set_parameters" to
be part of the Memcached connection arguments.
:param socket_timeout: Timeout in seconds for every call to a server.
Will be passed to ``memcache.Client`` as the ``socket_timeout``
parameter.
.. versionchanged:: 1.1.8 Moved the ``socket_timeout`` argument which
was erroneously added to "set_parameters"
to be part of the Memcached connection arguments.
""" """
def __init__(self, arguments):
self.dead_retry = arguments.get("dead_retry", 30)
self.socket_timeout = arguments.get("socket_timeout", 3)
super(MemcachedBackend, self).__init__(arguments)
def _imports(self): def _imports(self):
global memcache global memcache
import memcache # noqa import memcache # noqa
def _create_client(self): def _create_client(self):
return memcache.Client(self.url) return memcache.Client(
self.url,
dead_retry=self.dead_retry,
bmemcached = None socket_timeout=self.socket_timeout,
)
class BMemcachedBackend(GenericMemcachedBackend): class BMemcachedBackend(GenericMemcachedBackend):
@ -299,9 +343,11 @@ class BMemcachedBackend(GenericMemcachedBackend):
python-binary-memcached>`_ python-binary-memcached>`_
memcached client. memcached client.
This is a pure Python memcached client which This is a pure Python memcached client which includes
includes the ability to authenticate with a memcached security features like SASL and SSL/TLS.
server using SASL.
SASL is a standard for adding authentication mechanisms
to protocols in a way that is protocol independent.
A typical configuration using username/password:: A typical configuration using username/password::
@ -317,6 +363,25 @@ class BMemcachedBackend(GenericMemcachedBackend):
} }
) )
A typical configuration using tls_context::
import ssl
from dogpile.cache import make_region
ctx = ssl.create_default_context(cafile="/path/to/my-ca.pem")
region = make_region().configure(
'dogpile.cache.bmemcached',
expiration_time = 3600,
arguments = {
'url':["127.0.0.1"],
'tls_context':ctx,
}
)
For advanced ways to configure TLS creating a more complex
tls_context visit https://docs.python.org/3/library/ssl.html
Arguments which can be passed to the ``arguments`` Arguments which can be passed to the ``arguments``
dictionary include: dictionary include:
@ -324,11 +389,17 @@ class BMemcachedBackend(GenericMemcachedBackend):
SASL authentication. SASL authentication.
:param password: optional password, will be used for :param password: optional password, will be used for
SASL authentication. SASL authentication.
:param tls_context: optional TLS context, will be used for
TLS connections.
.. versionadded:: 1.0.2
""" """
def __init__(self, arguments): def __init__(self, arguments):
self.username = arguments.get('username', None) self.username = arguments.get("username", None)
self.password = arguments.get('password', None) self.password = arguments.get("password", None)
self.tls_context = arguments.get("tls_context", None)
super(BMemcachedBackend, self).__init__(arguments) super(BMemcachedBackend, self).__init__(arguments)
def _imports(self): def _imports(self):
@ -345,7 +416,8 @@ class BMemcachedBackend(GenericMemcachedBackend):
def add(self, key, value, timeout=0): def add(self, key, value, timeout=0):
try: try:
return super(RepairBMemcachedAPI, self).add( return super(RepairBMemcachedAPI, self).add(
key, value, timeout) key, value, timeout
)
except ValueError: except ValueError:
return False return False
@ -355,10 +427,213 @@ class BMemcachedBackend(GenericMemcachedBackend):
return self.Client( return self.Client(
self.url, self.url,
username=self.username, username=self.username,
password=self.password password=self.password,
tls_context=self.tls_context,
) )
def delete_multi(self, keys): def delete_multi(self, keys):
"""python-binary-memcached api does not implements delete_multi""" """python-binary-memcached api does not implements delete_multi"""
for key in keys: for key in keys:
self.delete(key) self.delete(key)
class PyMemcacheBackend(GenericMemcachedBackend):
"""A backend for the
`pymemcache <https://github.com/pinterest/pymemcache>`_
memcached client.
A comprehensive, fast, pure Python memcached client
.. versionadded:: 1.1.2
pymemcache supports the following features:
* Complete implementation of the memcached text protocol.
* Configurable timeouts for socket connect and send/recv calls.
* Access to the "noreply" flag, which can significantly increase
the speed of writes.
* Flexible, simple approach to serialization and deserialization.
* The (optional) ability to treat network and memcached errors as
cache misses.
dogpile.cache uses the ``HashClient`` from pymemcache in order to reduce
API differences when compared to other memcached client drivers.
This allows the user to provide a single server or a list of memcached
servers.
Arguments which can be passed to the ``arguments``
dictionary include:
:param tls_context: optional TLS context, will be used for
TLS connections.
A typical configuration using tls_context::
import ssl
from dogpile.cache import make_region
ctx = ssl.create_default_context(cafile="/path/to/my-ca.pem")
region = make_region().configure(
'dogpile.cache.pymemcache',
expiration_time = 3600,
arguments = {
'url':["127.0.0.1"],
'tls_context':ctx,
}
)
.. seealso::
`<https://docs.python.org/3/library/ssl.html>`_ - additional TLS
documentation.
:param serde: optional "serde". Defaults to
``pymemcache.serde.pickle_serde``.
:param default_noreply: defaults to False. When set to True this flag
enables the pymemcache "noreply" feature. See the pymemcache
documentation for further details.
:param socket_keepalive: optional socket keepalive, will be used for
TCP keepalive configuration. Use of this parameter requires pymemcache
3.5.0 or greater. This parameter
accepts a
`pymemcache.client.base.KeepAliveOpts
<https://pymemcache.readthedocs.io/en/latest/apidoc/pymemcache.client.base.html#pymemcache.client.base.KeepaliveOpts>`_
object.
A typical configuration using ``socket_keepalive``::
from pymemcache import KeepaliveOpts
from dogpile.cache import make_region
# Using the default keepalive configuration
socket_keepalive = KeepaliveOpts()
region = make_region().configure(
'dogpile.cache.pymemcache',
expiration_time = 3600,
arguments = {
'url':["127.0.0.1"],
'socket_keepalive': socket_keepalive
}
)
.. versionadded:: 1.1.4 - added support for ``socket_keepalive``.
:param enable_retry_client: optional flag to enable retry client
mechanisms to handle failure. Defaults to False. When set to ``True``,
the :paramref:`.PyMemcacheBackend.retry_attempts` parameter must also
be set, along with optional parameters
:paramref:`.PyMemcacheBackend.retry_delay`.
:paramref:`.PyMemcacheBackend.retry_for`,
:paramref:`.PyMemcacheBackend.do_not_retry_for`.
.. seealso::
`<https://pymemcache.readthedocs.io/en/latest/getting_started.html#using-the-built-in-retrying-mechanism>`_ -
in the pymemcache documentation
.. versionadded:: 1.1.4
:param retry_attempts: how many times to attempt an action with
pymemcache's retrying wrapper before failing. Must be 1 or above.
Defaults to None.
.. versionadded:: 1.1.4
:param retry_delay: optional int|float, how many seconds to sleep between
each attempt. Used by the retry wrapper. Defaults to None.
.. versionadded:: 1.1.4
:param retry_for: optional None|tuple|set|list, what exceptions to
allow retries for. Will allow retries for all exceptions if None.
Example: ``(MemcacheClientError, MemcacheUnexpectedCloseError)``
Accepts any class that is a subclass of Exception. Defaults to None.
.. versionadded:: 1.1.4
:param do_not_retry_for: optional None|tuple|set|list, what
exceptions should be retried. Will not block retries for any Exception if
None. Example: ``(IOError, MemcacheIllegalInputError)``
Accepts any class that is a subclass of Exception. Defaults to None.
.. versionadded:: 1.1.4
:param hashclient_retry_attempts: Amount of times a client should be tried
before it is marked dead and removed from the pool in the HashClient's
internal mechanisms.
.. versionadded:: 1.1.5
:param hashclient_retry_timeout: Time in seconds that should pass between
retry attempts in the HashClient's internal mechanisms.
.. versionadded:: 1.1.5
:param dead_timeout: Time in seconds before attempting to add a node
back in the pool in the HashClient's internal mechanisms.
.. versionadded:: 1.1.5
""" # noqa E501
def __init__(self, arguments):
super().__init__(arguments)
self.serde = arguments.get("serde", pymemcache.serde.pickle_serde)
self.default_noreply = arguments.get("default_noreply", False)
self.tls_context = arguments.get("tls_context", None)
self.socket_keepalive = arguments.get("socket_keepalive", None)
self.enable_retry_client = arguments.get("enable_retry_client", False)
self.retry_attempts = arguments.get("retry_attempts", None)
self.retry_delay = arguments.get("retry_delay", None)
self.retry_for = arguments.get("retry_for", None)
self.do_not_retry_for = arguments.get("do_not_retry_for", None)
self.hashclient_retry_attempts = arguments.get(
"hashclient_retry_attempts", 2
)
self.hashclient_retry_timeout = arguments.get(
"hashclient_retry_timeout", 1
)
self.dead_timeout = arguments.get("hashclient_dead_timeout", 60)
if (
self.retry_delay is not None
or self.retry_attempts is not None
or self.retry_for is not None
or self.do_not_retry_for is not None
) and not self.enable_retry_client:
warnings.warn(
"enable_retry_client is not set; retry options "
"will be ignored"
)
def _imports(self):
global pymemcache
import pymemcache
def _create_client(self):
_kwargs = {
"serde": self.serde,
"default_noreply": self.default_noreply,
"tls_context": self.tls_context,
"retry_attempts": self.hashclient_retry_attempts,
"retry_timeout": self.hashclient_retry_timeout,
"dead_timeout": self.dead_timeout,
}
if self.socket_keepalive is not None:
_kwargs.update({"socket_keepalive": self.socket_keepalive})
client = pymemcache.client.hash.HashClient(self.url, **_kwargs)
if self.enable_retry_client:
return pymemcache.client.retrying.RetryingClient(
client,
attempts=self.retry_attempts,
retry_delay=self.retry_delay,
retry_for=self.retry_for,
do_not_retry_for=self.do_not_retry_for,
)
return client

View file

@ -10,8 +10,10 @@ places the value as given into the dictionary.
""" """
from ..api import CacheBackend, NO_VALUE
from ...util.compat import pickle from ..api import CacheBackend
from ..api import DefaultSerialization
from ..api import NO_VALUE
class MemoryBackend(CacheBackend): class MemoryBackend(CacheBackend):
@ -47,39 +49,21 @@ class MemoryBackend(CacheBackend):
""" """
pickle_values = False
def __init__(self, arguments): def __init__(self, arguments):
self._cache = arguments.pop("cache_dict", {}) self._cache = arguments.pop("cache_dict", {})
def get(self, key): def get(self, key):
value = self._cache.get(key, NO_VALUE) return self._cache.get(key, NO_VALUE)
if value is not NO_VALUE and self.pickle_values:
value = pickle.loads(value)
return value
def get_multi(self, keys): def get_multi(self, keys):
ret = [ return [self._cache.get(key, NO_VALUE) for key in keys]
self._cache.get(key, NO_VALUE)
for key in keys]
if self.pickle_values:
ret = [
pickle.loads(value)
if value is not NO_VALUE else value
for value in ret
]
return ret
def set(self, key, value): def set(self, key, value):
if self.pickle_values:
value = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
self._cache[key] = value self._cache[key] = value
def set_multi(self, mapping): def set_multi(self, mapping):
pickle_values = self.pickle_values
for key, value in mapping.items(): for key, value in mapping.items():
if pickle_values:
value = pickle.dumps(value, pickle.HIGHEST_PROTOCOL)
self._cache[key] = value self._cache[key] = value
def delete(self, key): def delete(self, key):
@ -90,7 +74,7 @@ class MemoryBackend(CacheBackend):
self._cache.pop(key, None) self._cache.pop(key, None)
class MemoryPickleBackend(MemoryBackend): class MemoryPickleBackend(DefaultSerialization, MemoryBackend):
"""A backend that uses a plain dictionary, but serializes objects on """A backend that uses a plain dictionary, but serializes objects on
:meth:`.MemoryBackend.set` and deserializes :meth:`.MemoryBackend.get`. :meth:`.MemoryBackend.set` and deserializes :meth:`.MemoryBackend.get`.
@ -121,4 +105,3 @@ class MemoryPickleBackend(MemoryBackend):
.. versionadded:: 0.5.3 .. versionadded:: 0.5.3
""" """
pickle_values = True

View file

@ -10,10 +10,11 @@ caching for a region that is otherwise used normally.
""" """
from ..api import CacheBackend, NO_VALUE from ..api import CacheBackend
from ..api import NO_VALUE
__all__ = ['NullBackend'] __all__ = ["NullBackend"]
class NullLock(object): class NullLock(object):
@ -23,6 +24,9 @@ class NullLock(object):
def release(self): def release(self):
pass pass
def locked(self):
return False
class NullBackend(CacheBackend): class NullBackend(CacheBackend):
"""A "null" backend that effectively disables all cache operations. """A "null" backend that effectively disables all cache operations.

View file

@ -7,16 +7,24 @@ Provides backends for talking to `Redis <http://redis.io>`_.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from ..api import CacheBackend, NO_VALUE
from ...util.compat import pickle, u
redis = None import typing
import warnings
__all__ = 'RedisBackend', from ..api import BytesBackend
from ..api import NO_VALUE
if typing.TYPE_CHECKING:
import redis
else:
# delayed import
redis = None # noqa F811
__all__ = ("RedisBackend", "RedisSentinelBackend")
class RedisBackend(CacheBackend): class RedisBackend(BytesBackend):
"""A `Redis <http://redis.io/>`_ backend, using the r"""A `Redis <http://redis.io/>`_ backend, using the
`redis-py <http://pypi.python.org/pypi/redis/>`_ backend. `redis-py <http://pypi.python.org/pypi/redis/>`_ backend.
Example configuration:: Example configuration::
@ -30,23 +38,21 @@ class RedisBackend(CacheBackend):
'port': 6379, 'port': 6379,
'db': 0, 'db': 0,
'redis_expiration_time': 60*60*2, # 2 hours 'redis_expiration_time': 60*60*2, # 2 hours
'distributed_lock': True 'distributed_lock': True,
'thread_local_lock': False
} }
) )
Arguments accepted in the arguments dictionary: Arguments accepted in the arguments dictionary:
:param url: string. If provided, will override separate host/port/db :param url: string. If provided, will override separate host/port/db
params. The format is that accepted by ``StrictRedis.from_url()``. params. The format is that accepted by ``StrictRedis.from_url()``.
.. versionadded:: 0.4.1
:param host: string, default is ``localhost``. :param host: string, default is ``localhost``.
:param password: string, default is no password. :param password: string, default is no password.
.. versionadded:: 0.4.1
:param port: integer, default is ``6379``. :param port: integer, default is ``6379``.
:param db: integer, default is ``0``. :param db: integer, default is ``0``.
@ -56,57 +62,66 @@ class RedisBackend(CacheBackend):
cache expiration. By default no expiration is set. cache expiration. By default no expiration is set.
:param distributed_lock: boolean, when True, will use a :param distributed_lock: boolean, when True, will use a
redis-lock as the dogpile lock. redis-lock as the dogpile lock. Use this when multiple processes will be
Use this when multiple talking to the same redis instance. When left at False, dogpile will
processes will be talking to the same redis instance. coordinate on a regular threading mutex.
When left at False, dogpile will coordinate on a regular
threading mutex.
:param lock_timeout: integer, number of seconds after acquiring a lock that :param lock_timeout: integer, number of seconds after acquiring a lock that
Redis should expire it. This argument is only valid when Redis should expire it. This argument is only valid when
``distributed_lock`` is ``True``. ``distributed_lock`` is ``True``.
.. versionadded:: 0.5.0
:param socket_timeout: float, seconds for socket timeout. :param socket_timeout: float, seconds for socket timeout.
Default is None (no timeout). Default is None (no timeout).
.. versionadded:: 0.5.4
:param lock_sleep: integer, number of seconds to sleep when failed to :param lock_sleep: integer, number of seconds to sleep when failed to
acquire a lock. This argument is only valid when acquire a lock. This argument is only valid when
``distributed_lock`` is ``True``. ``distributed_lock`` is ``True``.
.. versionadded:: 0.5.0
:param connection_pool: ``redis.ConnectionPool`` object. If provided, :param connection_pool: ``redis.ConnectionPool`` object. If provided,
this object supersedes other connection arguments passed to the this object supersedes other connection arguments passed to the
``redis.StrictRedis`` instance, including url and/or host as well as ``redis.StrictRedis`` instance, including url and/or host as well as
socket_timeout, and will be passed to ``redis.StrictRedis`` as the socket_timeout, and will be passed to ``redis.StrictRedis`` as the
source of connectivity. source of connectivity.
.. versionadded:: 0.5.4 :param thread_local_lock: bool, whether a thread-local Redis lock object
should be used. This is the default, but is not compatible with
asynchronous runners, as they run in a different thread than the one
used to create the lock.
:param connection_kwargs: dict, additional keyword arguments are passed
along to the
``StrictRedis.from_url()`` method or ``StrictRedis()`` constructor
directly, including parameters like ``ssl``, ``ssl_certfile``,
``charset``, etc.
.. versionadded:: 1.1.6 Added ``connection_kwargs`` parameter.
""" """
def __init__(self, arguments): def __init__(self, arguments):
arguments = arguments.copy() arguments = arguments.copy()
self._imports() self._imports()
self.url = arguments.pop('url', None) self.url = arguments.pop("url", None)
self.host = arguments.pop('host', 'localhost') self.host = arguments.pop("host", "localhost")
self.password = arguments.pop('password', None) self.password = arguments.pop("password", None)
self.port = arguments.pop('port', 6379) self.port = arguments.pop("port", 6379)
self.db = arguments.pop('db', 0) self.db = arguments.pop("db", 0)
self.distributed_lock = arguments.get('distributed_lock', False) self.distributed_lock = arguments.pop("distributed_lock", False)
self.socket_timeout = arguments.pop('socket_timeout', None) self.socket_timeout = arguments.pop("socket_timeout", None)
self.lock_timeout = arguments.pop("lock_timeout", None)
self.lock_sleep = arguments.pop("lock_sleep", 0.1)
self.thread_local_lock = arguments.pop("thread_local_lock", True)
self.connection_kwargs = arguments.pop("connection_kwargs", {})
self.lock_timeout = arguments.get('lock_timeout', None) if self.distributed_lock and self.thread_local_lock:
self.lock_sleep = arguments.get('lock_sleep', 0.1) warnings.warn(
"The Redis backend thread_local_lock parameter should be "
"set to False when distributed_lock is True"
)
self.redis_expiration_time = arguments.pop('redis_expiration_time', 0) self.redis_expiration_time = arguments.pop("redis_expiration_time", 0)
self.connection_pool = arguments.get('connection_pool', None) self.connection_pool = arguments.pop("connection_pool", None)
self.client = self._create_client() self._create_client()
def _imports(self): def _imports(self):
# defer imports until backend is used # defer imports until backend is used
@ -118,66 +133,207 @@ class RedisBackend(CacheBackend):
# the connection pool already has all other connection # the connection pool already has all other connection
# options present within, so here we disregard socket_timeout # options present within, so here we disregard socket_timeout
# and others. # and others.
return redis.StrictRedis(connection_pool=self.connection_pool) self.writer_client = redis.StrictRedis(
connection_pool=self.connection_pool
args = {}
if self.socket_timeout:
args['socket_timeout'] = self.socket_timeout
if self.url is not None:
args.update(url=self.url)
return redis.StrictRedis.from_url(**args)
else:
args.update(
host=self.host, password=self.password,
port=self.port, db=self.db
) )
return redis.StrictRedis(**args) self.reader_client = self.writer_client
else:
args = {}
args.update(self.connection_kwargs)
if self.socket_timeout:
args["socket_timeout"] = self.socket_timeout
if self.url is not None:
args.update(url=self.url)
self.writer_client = redis.StrictRedis.from_url(**args)
self.reader_client = self.writer_client
else:
args.update(
host=self.host,
password=self.password,
port=self.port,
db=self.db,
)
self.writer_client = redis.StrictRedis(**args)
self.reader_client = self.writer_client
def get_mutex(self, key): def get_mutex(self, key):
if self.distributed_lock: if self.distributed_lock:
return self.client.lock(u('_lock{0}').format(key), return _RedisLockWrapper(
self.lock_timeout, self.lock_sleep) self.writer_client.lock(
"_lock{0}".format(key),
timeout=self.lock_timeout,
sleep=self.lock_sleep,
thread_local=self.thread_local_lock,
)
)
else: else:
return None return None
def get(self, key): def get_serialized(self, key):
value = self.client.get(key) value = self.reader_client.get(key)
if value is None: if value is None:
return NO_VALUE return NO_VALUE
return pickle.loads(value) return value
def get_multi(self, keys): def get_serialized_multi(self, keys):
if not keys: if not keys:
return [] return []
values = self.client.mget(keys) values = self.reader_client.mget(keys)
return [ return [v if v is not None else NO_VALUE for v in values]
pickle.loads(v) if v is not None else NO_VALUE
for v in values]
def set(self, key, value): def set_serialized(self, key, value):
if self.redis_expiration_time: if self.redis_expiration_time:
self.client.setex(key, self.redis_expiration_time, self.writer_client.setex(key, self.redis_expiration_time, value)
pickle.dumps(value, pickle.HIGHEST_PROTOCOL))
else: else:
self.client.set(key, pickle.dumps(value, pickle.HIGHEST_PROTOCOL)) self.writer_client.set(key, value)
def set_multi(self, mapping):
mapping = dict(
(k, pickle.dumps(v, pickle.HIGHEST_PROTOCOL))
for k, v in mapping.items()
)
def set_serialized_multi(self, mapping):
if not self.redis_expiration_time: if not self.redis_expiration_time:
self.client.mset(mapping) self.writer_client.mset(mapping)
else: else:
pipe = self.client.pipeline() pipe = self.writer_client.pipeline()
for key, value in mapping.items(): for key, value in mapping.items():
pipe.setex(key, self.redis_expiration_time, value) pipe.setex(key, self.redis_expiration_time, value)
pipe.execute() pipe.execute()
def delete(self, key): def delete(self, key):
self.client.delete(key) self.writer_client.delete(key)
def delete_multi(self, keys): def delete_multi(self, keys):
self.client.delete(*keys) self.writer_client.delete(*keys)
class _RedisLockWrapper:
__slots__ = ("mutex", "__weakref__")
def __init__(self, mutex: typing.Any):
self.mutex = mutex
def acquire(self, wait: bool = True) -> typing.Any:
return self.mutex.acquire(blocking=wait)
def release(self) -> typing.Any:
return self.mutex.release()
def locked(self) -> bool:
return self.mutex.locked() # type: ignore
class RedisSentinelBackend(RedisBackend):
"""A `Redis <http://redis.io/>`_ backend, using the
`redis-py <http://pypi.python.org/pypi/redis/>`_ backend.
It will use the Sentinel of a Redis cluster.
.. versionadded:: 1.0.0
Example configuration::
from dogpile.cache import make_region
region = make_region().configure(
'dogpile.cache.redis_sentinel',
arguments = {
'sentinels': [
['redis_sentinel_1', 26379],
['redis_sentinel_2', 26379]
],
'db': 0,
'redis_expiration_time': 60*60*2, # 2 hours
'distributed_lock': True,
'thread_local_lock': False
}
)
Arguments accepted in the arguments dictionary:
:param db: integer, default is ``0``.
:param redis_expiration_time: integer, number of seconds after setting
a value that Redis should expire it. This should be larger than dogpile's
cache expiration. By default no expiration is set.
:param distributed_lock: boolean, when True, will use a
redis-lock as the dogpile lock. Use this when multiple processes will be
talking to the same redis instance. When False, dogpile will
coordinate on a regular threading mutex, Default is True.
:param lock_timeout: integer, number of seconds after acquiring a lock that
Redis should expire it. This argument is only valid when
``distributed_lock`` is ``True``.
:param socket_timeout: float, seconds for socket timeout.
Default is None (no timeout).
:param sentinels: is a list of sentinel nodes. Each node is represented by
a pair (hostname, port).
Default is None (not in sentinel mode).
:param service_name: str, the service name.
Default is 'mymaster'.
:param sentinel_kwargs: is a dictionary of connection arguments used when
connecting to sentinel instances. Any argument that can be passed to
a normal Redis connection can be specified here.
Default is {}.
:param connection_kwargs: dict, additional keyword arguments are passed
along to the
``StrictRedis.from_url()`` method or ``StrictRedis()`` constructor
directly, including parameters like ``ssl``, ``ssl_certfile``,
``charset``, etc.
:param lock_sleep: integer, number of seconds to sleep when failed to
acquire a lock. This argument is only valid when
``distributed_lock`` is ``True``.
:param thread_local_lock: bool, whether a thread-local Redis lock object
should be used. This is the default, but is not compatible with
asynchronous runners, as they run in a different thread than the one
used to create the lock.
"""
def __init__(self, arguments):
arguments = arguments.copy()
self.sentinels = arguments.pop("sentinels", None)
self.service_name = arguments.pop("service_name", "mymaster")
self.sentinel_kwargs = arguments.pop("sentinel_kwargs", {})
super().__init__(
arguments={
"distributed_lock": True,
"thread_local_lock": False,
**arguments,
}
)
def _imports(self):
# defer imports until backend is used
global redis
import redis.sentinel # noqa
def _create_client(self):
sentinel_kwargs = {}
sentinel_kwargs.update(self.sentinel_kwargs)
sentinel_kwargs.setdefault("password", self.password)
connection_kwargs = {}
connection_kwargs.update(self.connection_kwargs)
connection_kwargs.setdefault("password", self.password)
if self.db is not None:
connection_kwargs.setdefault("db", self.db)
sentinel_kwargs.setdefault("db", self.db)
if self.socket_timeout is not None:
connection_kwargs.setdefault("socket_timeout", self.socket_timeout)
sentinel = redis.sentinel.Sentinel(
self.sentinels,
sentinel_kwargs=sentinel_kwargs,
**connection_kwargs,
)
self.writer_client = sentinel.master_for(self.service_name)
self.reader_client = sentinel.slave_for(self.service_name)

View file

@ -51,20 +51,22 @@ class MakoPlugin(CacheImpl):
def __init__(self, cache): def __init__(self, cache):
super(MakoPlugin, self).__init__(cache) super(MakoPlugin, self).__init__(cache)
try: try:
self.regions = self.cache.template.cache_args['regions'] self.regions = self.cache.template.cache_args["regions"]
except KeyError: except KeyError:
raise KeyError( raise KeyError(
"'cache_regions' argument is required on the " "'cache_regions' argument is required on the "
"Mako Lookup or Template object for usage " "Mako Lookup or Template object for usage "
"with the dogpile.cache plugin.") "with the dogpile.cache plugin."
)
def _get_region(self, **kw): def _get_region(self, **kw):
try: try:
region = kw['region'] region = kw["region"]
except KeyError: except KeyError:
raise KeyError( raise KeyError(
"'cache_region' argument must be specified with 'cache=True'" "'cache_region' argument must be specified with 'cache=True'"
"within templates for usage with the dogpile.cache plugin.") "within templates for usage with the dogpile.cache plugin."
)
try: try:
return self.regions[region] return self.regions[region]
except KeyError: except KeyError:
@ -73,8 +75,8 @@ class MakoPlugin(CacheImpl):
def get_and_replace(self, key, creation_function, **kw): def get_and_replace(self, key, creation_function, **kw):
expiration_time = kw.pop("timeout", None) expiration_time = kw.pop("timeout", None)
return self._get_region(**kw).get_or_create( return self._get_region(**kw).get_or_create(
key, creation_function, key, creation_function, expiration_time=expiration_time
expiration_time=expiration_time) )
def get_or_create(self, key, creation_function, **kw): def get_or_create(self, key, creation_function, **kw):
return self.get_and_replace(key, creation_function, **kw) return self.get_and_replace(key, creation_function, **kw)

View file

@ -10,7 +10,16 @@ base backend.
""" """
from typing import Mapping
from typing import Optional
from typing import Sequence
from .api import BackendFormatted
from .api import BackendSetType
from .api import CacheBackend from .api import CacheBackend
from .api import CacheMutex
from .api import KeyType
from .api import SerializedReturnType
class ProxyBackend(CacheBackend): class ProxyBackend(CacheBackend):
@ -55,17 +64,17 @@ class ProxyBackend(CacheBackend):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *arg, **kw):
self.proxied = None pass
def wrap(self, backend): def wrap(self, backend: CacheBackend) -> "ProxyBackend":
''' Take a backend as an argument and setup the self.proxied property. """Take a backend as an argument and setup the self.proxied property.
Return an object that be used as a backend by a :class:`.CacheRegion` Return an object that be used as a backend by a :class:`.CacheRegion`
object. object.
''' """
assert( assert isinstance(backend, CacheBackend) or isinstance(
isinstance(backend, CacheBackend) or backend, ProxyBackend
isinstance(backend, ProxyBackend)) )
self.proxied = backend self.proxied = backend
return self return self
@ -73,23 +82,37 @@ class ProxyBackend(CacheBackend):
# Delegate any functions that are not already overridden to # Delegate any functions that are not already overridden to
# the proxies backend # the proxies backend
# #
def get(self, key): def get(self, key: KeyType) -> BackendFormatted:
return self.proxied.get(key) return self.proxied.get(key)
def set(self, key, value): def set(self, key: KeyType, value: BackendSetType) -> None:
self.proxied.set(key, value) self.proxied.set(key, value)
def delete(self, key): def delete(self, key: KeyType) -> None:
self.proxied.delete(key) self.proxied.delete(key)
def get_multi(self, keys): def get_multi(self, keys: Sequence[KeyType]) -> Sequence[BackendFormatted]:
return self.proxied.get_multi(keys) return self.proxied.get_multi(keys)
def set_multi(self, mapping): def set_multi(self, mapping: Mapping[KeyType, BackendSetType]) -> None:
self.proxied.set_multi(mapping) self.proxied.set_multi(mapping)
def delete_multi(self, keys): def delete_multi(self, keys: Sequence[KeyType]) -> None:
self.proxied.delete_multi(keys) self.proxied.delete_multi(keys)
def get_mutex(self, key): def get_mutex(self, key: KeyType) -> Optional[CacheMutex]:
return self.proxied.get_mutex(key) return self.proxied.get_mutex(key)
def get_serialized(self, key: KeyType) -> SerializedReturnType:
return self.proxied.get_serialized(key)
def get_serialized_multi(
self, keys: Sequence[KeyType]
) -> Sequence[SerializedReturnType]:
return self.proxied.get_serialized_multi(keys)
def set_serialized(self, key: KeyType, value: bytes) -> None:
self.proxied.set_serialized(key, value)
def set_serialized_multi(self, mapping: Mapping[KeyType, bytes]) -> None:
self.proxied.set_serialized_multi(mapping)

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,10 @@
from hashlib import sha1 from hashlib import sha1
from ..util import compat from ..util import compat
from ..util import langhelpers from ..util import langhelpers
def function_key_generator(namespace, fn, to_str=compat.string_type): def function_key_generator(namespace, fn, to_str=str):
"""Return a function that generates a string """Return a function that generates a string
key, based on a given function as well as key, based on a given function as well as
arguments to the returned function itself. arguments to the returned function itself.
@ -23,47 +24,51 @@ def function_key_generator(namespace, fn, to_str=compat.string_type):
""" """
if namespace is None: if namespace is None:
namespace = '%s:%s' % (fn.__module__, fn.__name__) namespace = "%s:%s" % (fn.__module__, fn.__name__)
else: else:
namespace = '%s:%s|%s' % (fn.__module__, fn.__name__, namespace) namespace = "%s:%s|%s" % (fn.__module__, fn.__name__, namespace)
args = compat.inspect_getargspec(fn) args = compat.inspect_getargspec(fn)
has_self = args[0] and args[0][0] in ('self', 'cls') has_self = args[0] and args[0][0] in ("self", "cls")
def generate_key(*args, **kw): def generate_key(*args, **kw):
if kw: if kw:
raise ValueError( raise ValueError(
"dogpile.cache's default key creation " "dogpile.cache's default key creation "
"function does not accept keyword arguments.") "function does not accept keyword arguments."
)
if has_self: if has_self:
args = args[1:] args = args[1:]
return namespace + "|" + " ".join(map(to_str, args)) return namespace + "|" + " ".join(map(to_str, args))
return generate_key return generate_key
def function_multi_key_generator(namespace, fn, to_str=compat.string_type): def function_multi_key_generator(namespace, fn, to_str=str):
if namespace is None: if namespace is None:
namespace = '%s:%s' % (fn.__module__, fn.__name__) namespace = "%s:%s" % (fn.__module__, fn.__name__)
else: else:
namespace = '%s:%s|%s' % (fn.__module__, fn.__name__, namespace) namespace = "%s:%s|%s" % (fn.__module__, fn.__name__, namespace)
args = compat.inspect_getargspec(fn) args = compat.inspect_getargspec(fn)
has_self = args[0] and args[0][0] in ('self', 'cls') has_self = args[0] and args[0][0] in ("self", "cls")
def generate_keys(*args, **kw): def generate_keys(*args, **kw):
if kw: if kw:
raise ValueError( raise ValueError(
"dogpile.cache's default key creation " "dogpile.cache's default key creation "
"function does not accept keyword arguments.") "function does not accept keyword arguments."
)
if has_self: if has_self:
args = args[1:] args = args[1:]
return [namespace + "|" + key for key in map(to_str, args)] return [namespace + "|" + key for key in map(to_str, args)]
return generate_keys return generate_keys
def kwarg_function_key_generator(namespace, fn, to_str=compat.string_type): def kwarg_function_key_generator(namespace, fn, to_str=str):
"""Return a function that generates a string """Return a function that generates a string
key, based on a given function as well as key, based on a given function as well as
arguments to the returned function itself. arguments to the returned function itself.
@ -83,9 +88,9 @@ def kwarg_function_key_generator(namespace, fn, to_str=compat.string_type):
""" """
if namespace is None: if namespace is None:
namespace = '%s:%s' % (fn.__module__, fn.__name__) namespace = "%s:%s" % (fn.__module__, fn.__name__)
else: else:
namespace = '%s:%s|%s' % (fn.__module__, fn.__name__, namespace) namespace = "%s:%s|%s" % (fn.__module__, fn.__name__, namespace)
argspec = compat.inspect_getargspec(fn) argspec = compat.inspect_getargspec(fn)
default_list = list(argspec.defaults or []) default_list = list(argspec.defaults or [])
@ -94,32 +99,41 @@ def kwarg_function_key_generator(namespace, fn, to_str=compat.string_type):
# enumerate() # enumerate()
default_list.reverse() default_list.reverse()
# use idx*-1 to create the correct right-lookup index. # use idx*-1 to create the correct right-lookup index.
args_with_defaults = dict((argspec.args[(idx*-1)], default) args_with_defaults = dict(
for idx, default in enumerate(default_list, 1)) (argspec.args[(idx * -1)], default)
if argspec.args and argspec.args[0] in ('self', 'cls'): for idx, default in enumerate(default_list, 1)
)
if argspec.args and argspec.args[0] in ("self", "cls"):
arg_index_start = 1 arg_index_start = 1
else: else:
arg_index_start = 0 arg_index_start = 0
def generate_key(*args, **kwargs): def generate_key(*args, **kwargs):
as_kwargs = dict( as_kwargs = dict(
[(argspec.args[idx], arg) [
for idx, arg in enumerate(args[arg_index_start:], (argspec.args[idx], arg)
arg_index_start)]) for idx, arg in enumerate(
args[arg_index_start:], arg_index_start
)
]
)
as_kwargs.update(kwargs) as_kwargs.update(kwargs)
for arg, val in args_with_defaults.items(): for arg, val in args_with_defaults.items():
if arg not in as_kwargs: if arg not in as_kwargs:
as_kwargs[arg] = val as_kwargs[arg] = val
argument_values = [as_kwargs[key] argument_values = [as_kwargs[key] for key in sorted(as_kwargs.keys())]
for key in sorted(as_kwargs.keys())] return namespace + "|" + " ".join(map(to_str, argument_values))
return namespace + '|' + " ".join(map(to_str, argument_values))
return generate_key return generate_key
def sha1_mangle_key(key): def sha1_mangle_key(key):
"""a SHA1 key mangler.""" """a SHA1 key mangler."""
if isinstance(key, str):
key = key.encode("utf-8")
return sha1(key).hexdigest() return sha1(key).hexdigest()
@ -128,13 +142,16 @@ def length_conditional_mangler(length, mangler):
past a certain threshold. past a certain threshold.
""" """
def mangle(key): def mangle(key):
if len(key) >= length: if len(key) >= length:
return mangler(key) return mangler(key)
else: else:
return key return key
return mangle return mangle
# in the 0.6 release these functions were moved to the dogpile.util namespace. # in the 0.6 release these functions were moved to the dogpile.util namespace.
# They are linked here to maintain compatibility with older versions. # They are linked here to maintain compatibility with older versions.
@ -143,3 +160,30 @@ KeyReentrantMutex = langhelpers.KeyReentrantMutex
memoized_property = langhelpers.memoized_property memoized_property = langhelpers.memoized_property
PluginLoader = langhelpers.PluginLoader PluginLoader = langhelpers.PluginLoader
to_list = langhelpers.to_list to_list = langhelpers.to_list
class repr_obj:
__slots__ = ("value", "max_chars")
def __init__(self, value, max_chars=300):
self.value = value
self.max_chars = max_chars
def __eq__(self, other):
return other.value == self.value
def __repr__(self):
rep = repr(self.value)
lenrep = len(rep)
if lenrep > self.max_chars:
segment_length = self.max_chars // 2
rep = (
rep[0:segment_length]
+ (
" ... (%d characters truncated) ... "
% (lenrep - self.max_chars)
)
+ rep[-segment_length:]
)
return rep

View file

@ -8,10 +8,10 @@ dogpile.core installation is present.
""" """
from .util import nameregistry # noqa from . import __version__ # noqa
from .util import readwrite_lock # noqa
from .util.readwrite_lock import ReadWriteMutex # noqa
from .util.nameregistry import NameRegistry # noqa
from .lock import Lock # noqa from .lock import Lock # noqa
from .lock import NeedRegenerationException # noqa from .lock import NeedRegenerationException # noqa
from . import __version__ # noqa from .util import nameregistry # noqa
from .util import readwrite_lock # noqa
from .util.nameregistry import NameRegistry # noqa
from .util.readwrite_lock import ReadWriteMutex # noqa

View file

@ -1,5 +1,5 @@
import time
import logging import logging
import time
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -11,10 +11,11 @@ class NeedRegenerationException(Exception):
""" """
NOT_REGENERATED = object() NOT_REGENERATED = object()
class Lock(object): class Lock:
"""Dogpile lock class. """Dogpile lock class.
Provides an interface around an arbitrary mutex Provides an interface around an arbitrary mutex
@ -70,8 +71,8 @@ class Lock(object):
value is available.""" value is available."""
return not self._has_value(createdtime) or ( return not self._has_value(createdtime) or (
self.expiretime is not None and self.expiretime is not None
time.time() - createdtime > self.expiretime and time.time() - createdtime > self.expiretime
) )
def _has_value(self, createdtime): def _has_value(self, createdtime):
@ -109,7 +110,8 @@ class Lock(object):
raise Exception( raise Exception(
"Generation function should " "Generation function should "
"have just been called by a concurrent " "have just been called by a concurrent "
"thread.") "thread."
)
else: else:
return value return value
@ -122,9 +124,7 @@ class Lock(object):
if self._has_value(createdtime): if self._has_value(createdtime):
has_value = True has_value = True
if not self.mutex.acquire(False): if not self.mutex.acquire(False):
log.debug( log.debug("creation function in progress elsewhere, returning")
"creation function in progress "
"elsewhere, returning")
return NOT_REGENERATED return NOT_REGENERATED
else: else:
has_value = False has_value = False
@ -173,8 +173,7 @@ class Lock(object):
# there's no value at all, and we have to create it synchronously # there's no value at all, and we have to create it synchronously
log.debug( log.debug(
"Calling creation function for %s value", "Calling creation function for %s value",
"not-yet-present" if not has_value else "not-yet-present" if not has_value else "previously expired",
"previously expired"
) )
return self.creator() return self.creator()
finally: finally:
@ -185,5 +184,5 @@ class Lock(object):
def __enter__(self): def __enter__(self):
return self._enter() return self._enter()
def __exit__(self, type, value, traceback): def __exit__(self, type_, value, traceback):
pass pass

View file

@ -1,4 +1,7 @@
from .langhelpers import coerce_string_conf # noqa
from .langhelpers import KeyReentrantMutex # noqa
from .langhelpers import memoized_property # noqa
from .langhelpers import PluginLoader # noqa
from .langhelpers import to_list # noqa
from .nameregistry import NameRegistry # noqa from .nameregistry import NameRegistry # noqa
from .readwrite_lock import ReadWriteMutex # noqa from .readwrite_lock import ReadWriteMutex # noqa
from .langhelpers import PluginLoader, memoized_property, \
coerce_string_conf, to_list, KeyReentrantMutex # noqa

View file

@ -1,87 +1,72 @@
import sys import collections
import inspect
py2k = sys.version_info < (3, 0)
py3k = sys.version_info >= (3, 0)
py32 = sys.version_info >= (3, 2)
py27 = sys.version_info >= (2, 7)
jython = sys.platform.startswith('java')
win32 = sys.platform.startswith('win')
try:
import threading
except ImportError:
import dummy_threading as threading # noqa
if py3k: # pragma: no cover FullArgSpec = collections.namedtuple(
string_types = str, "FullArgSpec",
text_type = str [
string_type = str "args",
"varargs",
"varkw",
"defaults",
"kwonlyargs",
"kwonlydefaults",
"annotations",
],
)
if py32: ArgSpec = collections.namedtuple(
callable = callable "ArgSpec", ["args", "varargs", "keywords", "defaults"]
else: )
def callable(fn):
return hasattr(fn, '__call__')
def u(s):
return s
def ue(s):
return s
import configparser
import io
import _thread as thread
else:
string_types = basestring,
text_type = unicode
string_type = str
def u(s):
return unicode(s, "utf-8")
def ue(s):
return unicode(s, "unicode_escape")
import ConfigParser as configparser # noqa
import StringIO as io # noqa
callable = callable # noqa
import thread # noqa
if py3k: def inspect_getfullargspec(func):
import collections """Fully vendored version of getfullargspec from Python 3.3.
ArgSpec = collections.namedtuple(
"ArgSpec",
["args", "varargs", "keywords", "defaults"])
from inspect import getfullargspec as inspect_getfullargspec This version is more performant than the one which appeared in
later Python 3 versions.
def inspect_getargspec(func): """
return ArgSpec(
*inspect_getfullargspec(func)[0:4]
)
else:
from inspect import getargspec as inspect_getargspec # noqa
if py3k or jython: # if a Signature is already present, as is the case with newer
import pickle # "decorator" package, defer back to built in
else: if hasattr(func, "__signature__"):
import cPickle as pickle # noqa return inspect.getfullargspec(func)
if py3k: if inspect.ismethod(func):
def read_config_file(config, fileobj): func = func.__func__
return config.read_file(fileobj) if not inspect.isfunction(func):
else: raise TypeError("{!r} is not a Python function".format(func))
def read_config_file(config, fileobj):
return config.readfp(fileobj) co = func.__code__
if not inspect.iscode(co):
raise TypeError("{!r} is not a code object".format(co))
nargs = co.co_argcount
names = co.co_varnames
nkwargs = co.co_kwonlyargcount
args = list(names[:nargs])
kwonlyargs = list(names[nargs : nargs + nkwargs])
nargs += nkwargs
varargs = None
if co.co_flags & inspect.CO_VARARGS:
varargs = co.co_varnames[nargs]
nargs = nargs + 1
varkw = None
if co.co_flags & inspect.CO_VARKEYWORDS:
varkw = co.co_varnames[nargs]
return FullArgSpec(
args,
varargs,
varkw,
func.__defaults__,
kwonlyargs,
func.__kwdefaults__,
func.__annotations__,
)
def timedelta_total_seconds(td): def inspect_getargspec(func):
if py27: return ArgSpec(*inspect_getfullargspec(func)[0:4])
return td.total_seconds()
else:
return (td.microseconds + (
td.seconds + td.days * 24 * 3600) * 1e6) / 1e6

View file

@ -1,44 +1,54 @@
import re import abc
import collections import collections
from . import compat import re
import threading
from typing import MutableMapping
from typing import MutableSet
import stevedore
def coerce_string_conf(d): def coerce_string_conf(d):
result = {} result = {}
for k, v in d.items(): for k, v in d.items():
if not isinstance(v, compat.string_types): if not isinstance(v, str):
result[k] = v result[k] = v
continue continue
v = v.strip() v = v.strip()
if re.match(r'^[-+]?\d+$', v): if re.match(r"^[-+]?\d+$", v):
result[k] = int(v) result[k] = int(v)
elif re.match(r'^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?$', v): elif re.match(r"^[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?$", v):
result[k] = float(v) result[k] = float(v)
elif v.lower() in ('false', 'true'): elif v.lower() in ("false", "true"):
result[k] = v.lower() == 'true' result[k] = v.lower() == "true"
elif v == 'None': elif v == "None":
result[k] = None result[k] = None
else: else:
result[k] = v result[k] = v
return result return result
class PluginLoader(object): class PluginLoader:
def __init__(self, group): def __init__(self, group):
self.group = group self.group = group
self.impls = {} self.impls = {} # loaded plugins
self._mgr = None # lazily defined stevedore manager
self._unloaded = {} # plugins registered but not loaded
def load(self, name): def load(self, name):
if name in self._unloaded:
self.impls[name] = self._unloaded[name]()
return self.impls[name]
if name in self.impls: if name in self.impls:
return self.impls[name]() return self.impls[name]
else: # pragma NO COVERAGE else: # pragma NO COVERAGE
import pkg_resources if self._mgr is None:
for impl in pkg_resources.iter_entry_points( self._mgr = stevedore.ExtensionManager(self.group)
self.group, name): try:
self.impls[name] = impl.load self.impls[name] = self._mgr[name].plugin
return impl.load() return self.impls[name]
else: except KeyError:
raise self.NotFound( raise self.NotFound(
"Can't load plugin %s %s" % (self.group, name) "Can't load plugin %s %s" % (self.group, name)
) )
@ -47,14 +57,16 @@ class PluginLoader(object):
def load(): def load():
mod = __import__(modulepath, fromlist=[objname]) mod = __import__(modulepath, fromlist=[objname])
return getattr(mod, objname) return getattr(mod, objname)
self.impls[name] = load
self._unloaded[name] = load
class NotFound(Exception): class NotFound(Exception):
"""The specified plugin could not be found.""" """The specified plugin could not be found."""
class memoized_property(object): class memoized_property:
"""A read-only @property that is only evaluated once.""" """A read-only @property that is only evaluated once."""
def __init__(self, fget, doc=None): def __init__(self, fget, doc=None):
self.fget = fget self.fget = fget
self.__doc__ = doc or fget.__doc__ self.__doc__ = doc or fget.__doc__
@ -77,9 +89,23 @@ def to_list(x, default=None):
return x return x
class KeyReentrantMutex(object): class Mutex(abc.ABC):
@abc.abstractmethod
def acquire(self, wait: bool = True) -> bool:
raise NotImplementedError()
def __init__(self, key, mutex, keys): @abc.abstractmethod
def release(self) -> None:
raise NotImplementedError()
class KeyReentrantMutex:
def __init__(
self,
key: str,
mutex: Mutex,
keys: MutableMapping[int, MutableSet[str]],
):
self.key = key self.key = key
self.mutex = mutex self.mutex = mutex
self.keys = keys self.keys = keys
@ -89,17 +115,19 @@ class KeyReentrantMutex(object):
# this collection holds zero or one # this collection holds zero or one
# thread idents as the key; a set of # thread idents as the key; a set of
# keynames held as the value. # keynames held as the value.
keystore = collections.defaultdict(set) keystore: MutableMapping[
int, MutableSet[str]
] = collections.defaultdict(set)
def fac(key): def fac(key):
return KeyReentrantMutex(key, mutex, keystore) return KeyReentrantMutex(key, mutex, keystore)
return fac return fac
def acquire(self, wait=True): def acquire(self, wait=True):
current_thread = compat.threading.current_thread().ident current_thread = threading.get_ident()
keys = self.keys.get(current_thread) keys = self.keys.get(current_thread)
if keys is not None and \ if keys is not None and self.key not in keys:
self.key not in keys:
# current lockholder, new key. add it in # current lockholder, new key. add it in
keys.add(self.key) keys.add(self.key)
return True return True
@ -111,7 +139,7 @@ class KeyReentrantMutex(object):
return False return False
def release(self): def release(self):
current_thread = compat.threading.current_thread().ident current_thread = threading.get_ident()
keys = self.keys.get(current_thread) keys = self.keys.get(current_thread)
assert keys is not None, "this thread didn't do the acquire" assert keys is not None, "this thread didn't do the acquire"
assert self.key in keys, "No acquire held for key '%s'" % self.key assert self.key in keys, "No acquire held for key '%s'" % self.key
@ -121,3 +149,10 @@ class KeyReentrantMutex(object):
# the thread ident and unlock. # the thread ident and unlock.
del self.keys[current_thread] del self.keys[current_thread]
self.mutex.release() self.mutex.release()
def locked(self):
current_thread = threading.get_ident()
keys = self.keys.get(current_thread)
if keys is None:
return False
return self.key in keys

View file

@ -1,4 +1,7 @@
from .compat import threading import threading
from typing import Any
from typing import Callable
from typing import MutableMapping
import weakref import weakref
@ -37,19 +40,16 @@ class NameRegistry(object):
method. method.
""" """
_locks = weakref.WeakValueDictionary()
_mutex = threading.RLock() _mutex = threading.RLock()
def __init__(self, creator): def __init__(self, creator: Callable[..., Any]):
"""Create a new :class:`.NameRegistry`. """Create a new :class:`.NameRegistry`."""
self._values: MutableMapping[str, Any] = weakref.WeakValueDictionary()
"""
self._values = weakref.WeakValueDictionary()
self._mutex = threading.RLock() self._mutex = threading.RLock()
self.creator = creator self.creator = creator
def get(self, identifier, *args, **kw): def get(self, identifier: str, *args: Any, **kw: Any) -> Any:
r"""Get and possibly create the value. r"""Get and possibly create the value.
:param identifier: Hash key for the value. :param identifier: Hash key for the value.
@ -68,7 +68,7 @@ class NameRegistry(object):
except KeyError: except KeyError:
return self._sync_get(identifier, *args, **kw) return self._sync_get(identifier, *args, **kw)
def _sync_get(self, identifier, *args, **kw): def _sync_get(self, identifier: str, *args: Any, **kw: Any) -> Any:
self._mutex.acquire() self._mutex.acquire()
try: try:
try: try:
@ -76,11 +76,13 @@ class NameRegistry(object):
return self._values[identifier] return self._values[identifier]
else: else:
self._values[identifier] = value = self.creator( self._values[identifier] = value = self.creator(
identifier, *args, **kw) identifier, *args, **kw
)
return value return value
except KeyError: except KeyError:
self._values[identifier] = value = self.creator( self._values[identifier] = value = self.creator(
identifier, *args, **kw) identifier, *args, **kw
)
return value return value
finally: finally:
self._mutex.release() self._mutex.release()

View file

@ -1,6 +1,6 @@
from .compat import threading
import logging import logging
import threading
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -62,13 +62,15 @@ class ReadWriteMutex(object):
# check if we are the last asynchronous reader thread # check if we are the last asynchronous reader thread
# out the door. # out the door.
if self.async_ == 0: if self.async_ == 0:
# yes. so if a sync operation is waiting, notifyAll to wake # yes. so if a sync operation is waiting, notify_all to wake
# it up # it up
if self.current_sync_operation is not None: if self.current_sync_operation is not None:
self.condition.notifyAll() self.condition.notify_all()
elif self.async_ < 0: elif self.async_ < 0:
raise LockError("Synchronizer error - too many " raise LockError(
"release_read_locks called") "Synchronizer error - too many "
"release_read_locks called"
)
log.debug("%s released read lock", self) log.debug("%s released read lock", self)
finally: finally:
self.condition.release() self.condition.release()
@ -93,7 +95,7 @@ class ReadWriteMutex(object):
# establish ourselves as the current sync # establish ourselves as the current sync
# this indicates to other read/write operations # this indicates to other read/write operations
# that they should wait until this is None again # that they should wait until this is None again
self.current_sync_operation = threading.currentThread() self.current_sync_operation = threading.current_thread()
# now wait again for asyncs to finish # now wait again for asyncs to finish
if self.async_ > 0: if self.async_ > 0:
@ -115,16 +117,18 @@ class ReadWriteMutex(object):
"""Release the 'write' lock.""" """Release the 'write' lock."""
self.condition.acquire() self.condition.acquire()
try: try:
if self.current_sync_operation is not threading.currentThread(): if self.current_sync_operation is not threading.current_thread():
raise LockError("Synchronizer error - current thread doesn't " raise LockError(
"have the write lock") "Synchronizer error - current thread doesn't "
"have the write lock"
)
# reset the current sync operation so # reset the current sync operation so
# another can get it # another can get it
self.current_sync_operation = None self.current_sync_operation = None
# tell everyone to get ready # tell everyone to get ready
self.condition.notifyAll() self.condition.notify_all()
log.debug("%s released write lock", self) log.debug("%s released write lock", self)
finally: finally:

View file

@ -0,0 +1,631 @@
import io
import os
import re
import abc
import csv
import sys
import zipp
import email
import pathlib
import operator
import functools
import itertools
import posixpath
import collections
from ._compat import (
NullFinder,
PyPy_repr,
install,
)
from configparser import ConfigParser
from contextlib import suppress
from importlib import import_module
from importlib.abc import MetaPathFinder
from itertools import starmap
__all__ = [
'Distribution',
'DistributionFinder',
'PackageNotFoundError',
'distribution',
'distributions',
'entry_points',
'files',
'metadata',
'requires',
'version',
]
class PackageNotFoundError(ModuleNotFoundError):
"""The package was not found."""
def __str__(self):
tmpl = "No package metadata was found for {self.name}"
return tmpl.format(**locals())
@property
def name(self):
(name,) = self.args
return name
class EntryPoint(
PyPy_repr, collections.namedtuple('EntryPointBase', 'name value group')
):
"""An entry point as defined by Python packaging conventions.
See `the packaging docs on entry points
<https://packaging.python.org/specifications/entry-points/>`_
for more information.
"""
pattern = re.compile(
r'(?P<module>[\w.]+)\s*'
r'(:\s*(?P<attr>[\w.]+))?\s*'
r'(?P<extras>\[.*\])?\s*$'
)
"""
A regular expression describing the syntax for an entry point,
which might look like:
- module
- package.module
- package.module:attribute
- package.module:object.attribute
- package.module:attr [extra1, extra2]
Other combinations are possible as well.
The expression is lenient about whitespace around the ':',
following the attr, and following any extras.
"""
def load(self):
"""Load the entry point from its definition. If only a module
is indicated by the value, return that module. Otherwise,
return the named object.
"""
match = self.pattern.match(self.value)
module = import_module(match.group('module'))
attrs = filter(None, (match.group('attr') or '').split('.'))
return functools.reduce(getattr, attrs, module)
@property
def module(self):
match = self.pattern.match(self.value)
return match.group('module')
@property
def attr(self):
match = self.pattern.match(self.value)
return match.group('attr')
@property
def extras(self):
match = self.pattern.match(self.value)
return list(re.finditer(r'\w+', match.group('extras') or ''))
@classmethod
def _from_config(cls, config):
return [
cls(name, value, group)
for group in config.sections()
for name, value in config.items(group)
]
@classmethod
def _from_text(cls, text):
config = ConfigParser(delimiters='=')
# case sensitive: https://stackoverflow.com/q/1611799/812183
config.optionxform = str
try:
config.read_string(text)
except AttributeError: # pragma: nocover
# Python 2 has no read_string
config.readfp(io.StringIO(text))
return EntryPoint._from_config(config)
def __iter__(self):
"""
Supply iter so one may construct dicts of EntryPoints easily.
"""
return iter((self.name, self))
def __reduce__(self):
return (
self.__class__,
(self.name, self.value, self.group),
)
class PackagePath(pathlib.PurePosixPath):
"""A reference to a path in a package"""
def read_text(self, encoding='utf-8'):
with self.locate().open(encoding=encoding) as stream:
return stream.read()
def read_binary(self):
with self.locate().open('rb') as stream:
return stream.read()
def locate(self):
"""Return a path-like object for this path"""
return self.dist.locate_file(self)
class FileHash:
def __init__(self, spec):
self.mode, _, self.value = spec.partition('=')
def __repr__(self):
return '<FileHash mode: {} value: {}>'.format(self.mode, self.value)
class Distribution:
"""A Python distribution package."""
@abc.abstractmethod
def read_text(self, filename):
"""Attempt to load metadata file given by the name.
:param filename: The name of the file in the distribution info.
:return: The text if found, otherwise None.
"""
@abc.abstractmethod
def locate_file(self, path):
"""
Given a path to a file in this distribution, return a path
to it.
"""
@classmethod
def from_name(cls, name):
"""Return the Distribution for the given package name.
:param name: The name of the distribution package to search for.
:return: The Distribution instance (or subclass thereof) for the named
package, if found.
:raises PackageNotFoundError: When the named package's distribution
metadata cannot be found.
"""
for resolver in cls._discover_resolvers():
dists = resolver(DistributionFinder.Context(name=name))
dist = next(iter(dists), None)
if dist is not None:
return dist
else:
raise PackageNotFoundError(name)
@classmethod
def discover(cls, **kwargs):
"""Return an iterable of Distribution objects for all packages.
Pass a ``context`` or pass keyword arguments for constructing
a context.
:context: A ``DistributionFinder.Context`` object.
:return: Iterable of Distribution objects for all packages.
"""
context = kwargs.pop('context', None)
if context and kwargs:
raise ValueError("cannot accept context and kwargs")
context = context or DistributionFinder.Context(**kwargs)
return itertools.chain.from_iterable(
resolver(context) for resolver in cls._discover_resolvers()
)
@staticmethod
def at(path):
"""Return a Distribution for the indicated metadata path
:param path: a string or path-like object
:return: a concrete Distribution instance for the path
"""
return PathDistribution(pathlib.Path(path))
@staticmethod
def _discover_resolvers():
"""Search the meta_path for resolvers."""
declared = (
getattr(finder, 'find_distributions', None) for finder in sys.meta_path
)
return filter(None, declared)
@classmethod
def _local(cls, root='.'):
from pep517 import build, meta
system = build.compat_system(root)
builder = functools.partial(
meta.build,
source_dir=root,
system=system,
)
return PathDistribution(zipp.Path(meta.build_as_zip(builder)))
@property
def metadata(self):
"""Return the parsed metadata for this Distribution.
The returned object will have keys that name the various bits of
metadata. See PEP 566 for details.
"""
text = (
self.read_text('METADATA')
or self.read_text('PKG-INFO')
# This last clause is here to support old egg-info files. Its
# effect is to just end up using the PathDistribution's self._path
# (which points to the egg-info file) attribute unchanged.
or self.read_text('')
)
return email.message_from_string(text)
@property
def version(self):
"""Return the 'Version' metadata for the distribution package."""
return self.metadata['Version']
@property
def entry_points(self):
return EntryPoint._from_text(self.read_text('entry_points.txt'))
@property
def files(self):
"""Files in this distribution.
:return: List of PackagePath for this distribution or None
Result is `None` if the metadata file that enumerates files
(i.e. RECORD for dist-info or SOURCES.txt for egg-info) is
missing.
Result may be empty if the metadata exists but is empty.
"""
file_lines = self._read_files_distinfo() or self._read_files_egginfo()
def make_file(name, hash=None, size_str=None):
result = PackagePath(name)
result.hash = FileHash(hash) if hash else None
result.size = int(size_str) if size_str else None
result.dist = self
return result
return file_lines and list(starmap(make_file, csv.reader(file_lines)))
def _read_files_distinfo(self):
"""
Read the lines of RECORD
"""
text = self.read_text('RECORD')
return text and text.splitlines()
def _read_files_egginfo(self):
"""
SOURCES.txt might contain literal commas, so wrap each line
in quotes.
"""
text = self.read_text('SOURCES.txt')
return text and map('"{}"'.format, text.splitlines())
@property
def requires(self):
"""Generated requirements specified for this Distribution"""
reqs = self._read_dist_info_reqs() or self._read_egg_info_reqs()
return reqs and list(reqs)
def _read_dist_info_reqs(self):
return self.metadata.get_all('Requires-Dist')
def _read_egg_info_reqs(self):
source = self.read_text('requires.txt')
return source and self._deps_from_requires_text(source)
@classmethod
def _deps_from_requires_text(cls, source):
section_pairs = cls._read_sections(source.splitlines())
sections = {
section: list(map(operator.itemgetter('line'), results))
for section, results in itertools.groupby(
section_pairs, operator.itemgetter('section')
)
}
return cls._convert_egg_info_reqs_to_simple_reqs(sections)
@staticmethod
def _read_sections(lines):
section = None
for line in filter(None, lines):
section_match = re.match(r'\[(.*)\]$', line)
if section_match:
section = section_match.group(1)
continue
yield locals()
@staticmethod
def _convert_egg_info_reqs_to_simple_reqs(sections):
"""
Historically, setuptools would solicit and store 'extra'
requirements, including those with environment markers,
in separate sections. More modern tools expect each
dependency to be defined separately, with any relevant
extras and environment markers attached directly to that
requirement. This method converts the former to the
latter. See _test_deps_from_requires_text for an example.
"""
def make_condition(name):
return name and 'extra == "{name}"'.format(name=name)
def parse_condition(section):
section = section or ''
extra, sep, markers = section.partition(':')
if extra and markers:
markers = '({markers})'.format(markers=markers)
conditions = list(filter(None, [markers, make_condition(extra)]))
return '; ' + ' and '.join(conditions) if conditions else ''
for section, deps in sections.items():
for dep in deps:
yield dep + parse_condition(section)
class DistributionFinder(MetaPathFinder):
"""
A MetaPathFinder capable of discovering installed distributions.
"""
class Context:
"""
Keyword arguments presented by the caller to
``distributions()`` or ``Distribution.discover()``
to narrow the scope of a search for distributions
in all DistributionFinders.
Each DistributionFinder may expect any parameters
and should attempt to honor the canonical
parameters defined below when appropriate.
"""
name = None
"""
Specific name for which a distribution finder should match.
A name of ``None`` matches all distributions.
"""
def __init__(self, **kwargs):
vars(self).update(kwargs)
@property
def path(self):
"""
The path that a distribution finder should search.
Typically refers to Python package paths and defaults
to ``sys.path``.
"""
return vars(self).get('path', sys.path)
@abc.abstractmethod
def find_distributions(self, context=Context()):
"""
Find distributions.
Return an iterable of all Distribution instances capable of
loading the metadata for packages matching the ``context``,
a DistributionFinder.Context instance.
"""
class FastPath:
"""
Micro-optimized class for searching a path for
children.
"""
def __init__(self, root):
self.root = str(root)
self.base = os.path.basename(self.root).lower()
def joinpath(self, child):
return pathlib.Path(self.root, child)
def children(self):
with suppress(Exception):
return os.listdir(self.root or '')
with suppress(Exception):
return self.zip_children()
return []
def zip_children(self):
zip_path = zipp.Path(self.root)
names = zip_path.root.namelist()
self.joinpath = zip_path.joinpath
return dict.fromkeys(child.split(posixpath.sep, 1)[0] for child in names)
def search(self, name):
return (
self.joinpath(child)
for child in self.children()
if name.matches(child, self.base)
)
class Prepared:
"""
A prepared search for metadata on a possibly-named package.
"""
normalized = None
suffixes = '.dist-info', '.egg-info'
exact_matches = [''][:0]
def __init__(self, name):
self.name = name
if name is None:
return
self.normalized = self.normalize(name)
self.exact_matches = [self.normalized + suffix for suffix in self.suffixes]
@staticmethod
def normalize(name):
"""
PEP 503 normalization plus dashes as underscores.
"""
return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_')
@staticmethod
def legacy_normalize(name):
"""
Normalize the package name as found in the convention in
older packaging tools versions and specs.
"""
return name.lower().replace('-', '_')
def matches(self, cand, base):
low = cand.lower()
pre, ext = os.path.splitext(low)
name, sep, rest = pre.partition('-')
return (
low in self.exact_matches
or ext in self.suffixes
and (not self.normalized or name.replace('.', '_') == self.normalized)
# legacy case:
or self.is_egg(base)
and low == 'egg-info'
)
def is_egg(self, base):
normalized = self.legacy_normalize(self.name or '')
prefix = normalized + '-' if normalized else ''
versionless_egg_name = normalized + '.egg' if self.name else ''
return (
base == versionless_egg_name
or base.startswith(prefix)
and base.endswith('.egg')
)
@install
class MetadataPathFinder(NullFinder, DistributionFinder):
"""A degenerate finder for distribution packages on the file system.
This finder supplies only a find_distributions() method for versions
of Python that do not have a PathFinder find_distributions().
"""
def find_distributions(self, context=DistributionFinder.Context()):
"""
Find distributions.
Return an iterable of all Distribution instances capable of
loading the metadata for packages matching ``context.name``
(or all names if ``None`` indicated) along the paths in the list
of directories ``context.path``.
"""
found = self._search_paths(context.name, context.path)
return map(PathDistribution, found)
@classmethod
def _search_paths(cls, name, paths):
"""Find metadata directories in paths heuristically."""
return itertools.chain.from_iterable(
path.search(Prepared(name)) for path in map(FastPath, paths)
)
class PathDistribution(Distribution):
def __init__(self, path):
"""Construct a distribution from a path to the metadata directory.
:param path: A pathlib.Path or similar object supporting
.joinpath(), __div__, .parent, and .read_text().
"""
self._path = path
def read_text(self, filename):
with suppress(
FileNotFoundError,
IsADirectoryError,
KeyError,
NotADirectoryError,
PermissionError,
):
return self._path.joinpath(filename).read_text(encoding='utf-8')
read_text.__doc__ = Distribution.read_text.__doc__
def locate_file(self, path):
return self._path.parent / path
def distribution(distribution_name):
"""Get the ``Distribution`` instance for the named package.
:param distribution_name: The name of the distribution package as a string.
:return: A ``Distribution`` instance (or subclass thereof).
"""
return Distribution.from_name(distribution_name)
def distributions(**kwargs):
"""Get all ``Distribution`` instances in the current environment.
:return: An iterable of ``Distribution`` instances.
"""
return Distribution.discover(**kwargs)
def metadata(distribution_name):
"""Get the metadata for the named package.
:param distribution_name: The name of the distribution package to query.
:return: An email.Message containing the parsed metadata.
"""
return Distribution.from_name(distribution_name).metadata
def version(distribution_name):
"""Get the version string for the named package.
:param distribution_name: The name of the distribution package to query.
:return: The version string for the package as defined in the package's
"Version" metadata key.
"""
return distribution(distribution_name).version
def entry_points():
"""Return EntryPoint objects for all installed packages.
:return: EntryPoint objects for all installed packages.
"""
eps = itertools.chain.from_iterable(dist.entry_points for dist in distributions())
by_group = operator.attrgetter('group')
ordered = sorted(eps, key=by_group)
grouped = itertools.groupby(ordered, by_group)
return {group: tuple(eps) for group, eps in grouped}
def files(distribution_name):
"""Return a list of files for the named package.
:param distribution_name: The name of the distribution package to query.
:return: List of files composing the distribution.
"""
return distribution(distribution_name).files
def requires(distribution_name):
"""
Return a list of requirements for the named package.
:return: An iterator of requirements, suitable for
packaging.requirement.Requirement.
"""
return distribution(distribution_name).requires

View file

@ -0,0 +1,75 @@
import sys
__all__ = ['install', 'NullFinder', 'PyPy_repr']
def install(cls):
"""
Class decorator for installation on sys.meta_path.
Adds the backport DistributionFinder to sys.meta_path and
attempts to disable the finder functionality of the stdlib
DistributionFinder.
"""
sys.meta_path.append(cls())
disable_stdlib_finder()
return cls
def disable_stdlib_finder():
"""
Give the backport primacy for discovering path-based distributions
by monkey-patching the stdlib O_O.
See #91 for more background for rationale on this sketchy
behavior.
"""
def matches(finder):
return getattr(
finder, '__module__', None
) == '_frozen_importlib_external' and hasattr(finder, 'find_distributions')
for finder in filter(matches, sys.meta_path): # pragma: nocover
del finder.find_distributions
class NullFinder:
"""
A "Finder" (aka "MetaClassFinder") that never finds any modules,
but may find distributions.
"""
@staticmethod
def find_spec(*args, **kwargs):
return None
# In Python 2, the import system requires finders
# to have a find_module() method, but this usage
# is deprecated in Python 3 in favor of find_spec().
# For the purposes of this finder (i.e. being present
# on sys.meta_path but having no other import
# system functionality), the two methods are identical.
find_module = find_spec
class PyPy_repr:
"""
Override repr for EntryPoint objects on PyPy to avoid __iter__ access.
Ref #97, #102.
"""
affected = hasattr(sys, 'pypy_version_info')
def __compat_repr__(self): # pragma: nocover
def make_param(name):
value = getattr(self, name)
return '{name}={value!r}'.format(**locals())
params = ', '.join(map(make_param, self._fields))
return 'EntryPoint({params})'.format(**locals())
if affected: # pragma: nocover
__repr__ = __compat_repr__
del affected

61
libs/common/pbr/build.py Normal file
View file

@ -0,0 +1,61 @@
# Copyright 2021 Monty Taylor
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""pep-517 support
Add::
[build-system]
requires = ["pbr>=5.7.0", "setuptools>=36.6.0", "wheel"]
build-backend = "pbr.build"
to pyproject.toml to use this
"""
from setuptools import build_meta
__all__ = [
'get_requires_for_build_sdist',
'get_requires_for_build_wheel',
'prepare_metadata_for_build_wheel',
'build_wheel',
'build_sdist',
]
def get_requires_for_build_wheel(config_settings=None):
return build_meta.get_requires_for_build_wheel(config_settings)
def get_requires_for_build_sdist(config_settings=None):
return build_meta.get_requires_for_build_sdist(config_settings)
def prepare_metadata_for_build_wheel(metadata_directory, config_settings=None):
return build_meta.prepare_metadata_for_build_wheel(
metadata_directory, config_settings)
def build_wheel(
wheel_directory,
config_settings=None,
metadata_directory=None,
):
return build_meta.build_wheel(
wheel_directory, config_settings, metadata_directory,
)
def build_sdist(sdist_directory, config_settings=None):
return build_meta.build_sdist(sdist_directory, config_settings)

View file

@ -132,11 +132,11 @@ class LocalBuildDoc(setup_command.BuildDoc):
autoindex.write(" %s.rst\n" % module) autoindex.write(" %s.rst\n" % module)
def _sphinx_tree(self): def _sphinx_tree(self):
source_dir = self._get_source_dir() source_dir = self._get_source_dir()
cmd = ['-H', 'Modules', '-o', source_dir, '.'] cmd = ['-H', 'Modules', '-o', source_dir, '.']
if apidoc_use_padding: if apidoc_use_padding:
cmd.insert(0, 'apidoc') cmd.insert(0, 'apidoc')
apidoc.main(cmd + self.autodoc_tree_excludes) apidoc.main(cmd + self.autodoc_tree_excludes)
def _sphinx_run(self): def _sphinx_run(self):
if not self.verbose: if not self.verbose:

View file

@ -40,8 +40,11 @@ def get_sha(args):
def get_info(args): def get_info(args):
print("{name}\t{version}\t{released}\t{sha}".format( if args.short:
**_get_info(args.name))) print("{version}".format(**_get_info(args.name)))
else:
print("{name}\t{version}\t{released}\t{sha}".format(
**_get_info(args.name)))
def _get_info(name): def _get_info(name):
@ -86,7 +89,9 @@ def main():
version=str(pbr.version.VersionInfo('pbr'))) version=str(pbr.version.VersionInfo('pbr')))
subparsers = parser.add_subparsers( subparsers = parser.add_subparsers(
title='commands', description='valid commands', help='additional help') title='commands', description='valid commands', help='additional help',
dest='cmd')
subparsers.required = True
cmd_sha = subparsers.add_parser('sha', help='print sha of package') cmd_sha = subparsers.add_parser('sha', help='print sha of package')
cmd_sha.set_defaults(func=get_sha) cmd_sha.set_defaults(func=get_sha)
@ -96,6 +101,8 @@ def main():
'info', help='print version info for package') 'info', help='print version info for package')
cmd_info.set_defaults(func=get_info) cmd_info.set_defaults(func=get_info)
cmd_info.add_argument('name', help='package to print info of') cmd_info.add_argument('name', help='package to print info of')
cmd_info.add_argument('-s', '--short', action="store_true",
help='only display package version')
cmd_freeze = subparsers.add_parser( cmd_freeze = subparsers.add_parser(
'freeze', help='print version info for all installed packages') 'freeze', help='print version info for all installed packages')

View file

@ -61,6 +61,11 @@ else:
integer_types = (int, long) # noqa integer_types = (int, long) # noqa
# We use this canary to detect whether the module has already been called,
# in order to avoid recursion
in_use = False
def pbr(dist, attr, value): def pbr(dist, attr, value):
"""Implements the actual pbr setup() keyword. """Implements the actual pbr setup() keyword.
@ -81,6 +86,16 @@ def pbr(dist, attr, value):
not work well with distributions that do use a `Distribution` subclass. not work well with distributions that do use a `Distribution` subclass.
""" """
# Distribution.finalize_options() is what calls this method. That means
# there is potential for recursion here. Recursion seems to be an issue
# particularly when using PEP517 build-system configs without
# setup_requires in setup.py. We can avoid the recursion by setting
# this canary so we don't repeat ourselves.
global in_use
if in_use:
return
in_use = True
if not value: if not value:
return return
if isinstance(value, string_type): if isinstance(value, string_type):

View file

@ -156,9 +156,9 @@ def _clean_changelog_message(msg):
* Escapes '`' which is interpreted as a literal * Escapes '`' which is interpreted as a literal
""" """
msg = msg.replace('*', '\*') msg = msg.replace('*', r'\*')
msg = msg.replace('_', '\_') msg = msg.replace('_', r'\_')
msg = msg.replace('`', '\`') msg = msg.replace('`', r'\`')
return msg return msg
@ -223,6 +223,11 @@ def _iter_log_inner(git_dir):
presentation logic to the output - making it suitable for different presentation logic to the output - making it suitable for different
uses. uses.
.. caution:: this function risk to return a tag that doesn't exist really
inside the git objects list due to replacement made
to tag name to also list pre-release suffix.
Compliant with the SemVer specification (e.g 1.2.3-rc1)
:return: An iterator of (hash, tags_set, 1st_line) tuples. :return: An iterator of (hash, tags_set, 1st_line) tuples.
""" """
log.info('[pbr] Generating ChangeLog') log.info('[pbr] Generating ChangeLog')
@ -248,7 +253,7 @@ def _iter_log_inner(git_dir):
for tag_string in refname.split("refs/tags/")[1:]: for tag_string in refname.split("refs/tags/")[1:]:
# git tag does not allow : or " " in tag names, so we split # git tag does not allow : or " " in tag names, so we split
# on ", " which is the separator between elements # on ", " which is the separator between elements
candidate = tag_string.split(", ")[0] candidate = tag_string.split(", ")[0].replace("-", ".")
if _is_valid_version(candidate): if _is_valid_version(candidate):
tags.add(candidate) tags.add(candidate)
@ -271,13 +276,14 @@ def write_git_changelog(git_dir=None, dest_dir=os.path.curdir,
changelog = _iter_changelog(changelog) changelog = _iter_changelog(changelog)
if not changelog: if not changelog:
return return
new_changelog = os.path.join(dest_dir, 'ChangeLog') new_changelog = os.path.join(dest_dir, 'ChangeLog')
# If there's already a ChangeLog and it's not writable, just use it if os.path.exists(new_changelog) and not os.access(new_changelog, os.W_OK):
if (os.path.exists(new_changelog) # If there's already a ChangeLog and it's not writable, just use it
and not os.access(new_changelog, os.W_OK)):
log.info('[pbr] ChangeLog not written (file already' log.info('[pbr] ChangeLog not written (file already'
' exists and it is not writeable)') ' exists and it is not writeable)')
return return
log.info('[pbr] Writing ChangeLog') log.info('[pbr] Writing ChangeLog')
with io.open(new_changelog, "w", encoding="utf-8") as changelog_file: with io.open(new_changelog, "w", encoding="utf-8") as changelog_file:
for release, content in changelog: for release, content in changelog:
@ -292,13 +298,14 @@ def generate_authors(git_dir=None, dest_dir='.', option_dict=dict()):
'SKIP_GENERATE_AUTHORS') 'SKIP_GENERATE_AUTHORS')
if should_skip: if should_skip:
return return
start = time.time() start = time.time()
old_authors = os.path.join(dest_dir, 'AUTHORS.in') old_authors = os.path.join(dest_dir, 'AUTHORS.in')
new_authors = os.path.join(dest_dir, 'AUTHORS') new_authors = os.path.join(dest_dir, 'AUTHORS')
# If there's already an AUTHORS file and it's not writable, just use it if os.path.exists(new_authors) and not os.access(new_authors, os.W_OK):
if (os.path.exists(new_authors) # If there's already an AUTHORS file and it's not writable, just use it
and not os.access(new_authors, os.W_OK)):
return return
log.info('[pbr] Generating AUTHORS') log.info('[pbr] Generating AUTHORS')
ignore_emails = '((jenkins|zuul)@review|infra@lists|jenkins@openstack)' ignore_emails = '((jenkins|zuul)@review|infra@lists|jenkins@openstack)'
if git_dir is None: if git_dir is None:

View file

@ -14,6 +14,7 @@
# under the License. # under the License.
import os import os
import shlex
import sys import sys
from pbr import find_package from pbr import find_package
@ -35,6 +36,21 @@ def get_man_section(section):
return os.path.join(get_manpath(), 'man%s' % section) return os.path.join(get_manpath(), 'man%s' % section)
def unquote_path(path):
# unquote the full path, e.g: "'a/full/path'" becomes "a/full/path", also
# strip the quotes off individual path components because os.walk cannot
# handle paths like: "'i like spaces'/'another dir'", so we will pass it
# "i like spaces/another dir" instead.
if os.name == 'nt':
# shlex cannot handle paths that contain backslashes, treating those
# as escape characters.
path = path.replace("\\", "/")
return "".join(shlex.split(path)).replace("/", "\\")
return "".join(shlex.split(path))
class FilesConfig(base.BaseConfig): class FilesConfig(base.BaseConfig):
section = 'files' section = 'files'
@ -57,21 +73,28 @@ class FilesConfig(base.BaseConfig):
target = target.strip() target = target.strip()
if not target.endswith(os.path.sep): if not target.endswith(os.path.sep):
target += os.path.sep target += os.path.sep
for (dirpath, dirnames, fnames) in os.walk(source_prefix): unquoted_prefix = unquote_path(source_prefix)
finished.append( unquoted_target = unquote_path(target)
"%s = " % dirpath.replace(source_prefix, target)) for (dirpath, dirnames, fnames) in os.walk(unquoted_prefix):
# As source_prefix is always matched, using replace with a
# a limit of one is always going to replace the path prefix
# and not accidentally replace some text in the middle of
# the path
new_prefix = dirpath.replace(unquoted_prefix,
unquoted_target, 1)
finished.append("'%s' = " % new_prefix)
finished.extend( finished.extend(
[" %s" % os.path.join(dirpath, f) for f in fnames]) [" '%s'" % os.path.join(dirpath, f) for f in fnames])
else: else:
finished.append(line) finished.append(line)
self.data_files = "\n".join(finished) self.data_files = "\n".join(finished)
def add_man_path(self, man_path): def add_man_path(self, man_path):
self.data_files = "%s\n%s =" % (self.data_files, man_path) self.data_files = "%s\n'%s' =" % (self.data_files, man_path)
def add_man_page(self, man_page): def add_man_page(self, man_page):
self.data_files = "%s\n %s" % (self.data_files, man_page) self.data_files = "%s\n '%s'" % (self.data_files, man_page)
def get_man_sections(self): def get_man_sections(self):
man_sections = dict() man_sections = dict()

View file

@ -48,6 +48,6 @@ TRUE_VALUES = ('true', '1', 'yes')
def get_boolean_option(option_dict, option_name, env_name): def get_boolean_option(option_dict, option_name, env_name):
return ((option_name in option_dict return ((option_name in option_dict and
and option_dict[option_name][1].lower() in TRUE_VALUES) or option_dict[option_name][1].lower() in TRUE_VALUES) or
str(os.getenv(env_name)).lower() in TRUE_VALUES) str(os.getenv(env_name)).lower() in TRUE_VALUES)

View file

@ -22,6 +22,16 @@ from __future__ import unicode_literals
from distutils.command import install as du_install from distutils.command import install as du_install
from distutils import log from distutils import log
# (hberaud) do not use six here to import urlparse
# to keep this module free from external dependencies
# to avoid cross dependencies errors on minimal system
# free from dependencies.
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
import email import email
import email.errors import email.errors
import os import os
@ -98,19 +108,31 @@ def get_reqs_from_files(requirements_files):
return [] return []
def egg_fragment(match):
return re.sub(r'(?P<PackageName>[\w.-]+)-'
r'(?P<GlobalVersion>'
r'(?P<VersionTripple>'
r'(?P<Major>0|[1-9][0-9]*)\.'
r'(?P<Minor>0|[1-9][0-9]*)\.'
r'(?P<Patch>0|[1-9][0-9]*)){1}'
r'(?P<Tags>(?:\-'
r'(?P<Prerelease>(?:(?=[0]{1}[0-9A-Za-z-]{0})(?:[0]{1})|'
r'(?=[1-9]{1}[0-9]*[A-Za-z]{0})(?:[0-9]+)|'
r'(?=[0-9]*[A-Za-z-]+[0-9A-Za-z-]*)(?:[0-9A-Za-z-]+)){1}'
r'(?:\.(?=[0]{1}[0-9A-Za-z-]{0})(?:[0]{1})|'
r'\.(?=[1-9]{1}[0-9]*[A-Za-z]{0})(?:[0-9]+)|'
r'\.(?=[0-9]*[A-Za-z-]+[0-9A-Za-z-]*)'
r'(?:[0-9A-Za-z-]+))*){1}){0,1}(?:\+'
r'(?P<Meta>(?:[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))){0,1}))',
r'\g<PackageName>>=\g<GlobalVersion>',
match.groups()[-1])
def parse_requirements(requirements_files=None, strip_markers=False): def parse_requirements(requirements_files=None, strip_markers=False):
if requirements_files is None: if requirements_files is None:
requirements_files = get_requirements_files() requirements_files = get_requirements_files()
def egg_fragment(match):
# take a versioned egg fragment and return a
# versioned package requirement e.g.
# nova-1.2.3 becomes nova>=1.2.3
return re.sub(r'([\w.]+)-([\w.-]+)',
r'\1>=\2',
match.groups()[-1])
requirements = [] requirements = []
for line in get_reqs_from_files(requirements_files): for line in get_reqs_from_files(requirements_files):
# Ignore comments # Ignore comments
@ -118,7 +140,8 @@ def parse_requirements(requirements_files=None, strip_markers=False):
continue continue
# Ignore index URL lines # Ignore index URL lines
if re.match(r'^\s*(-i|--index-url|--extra-index-url).*', line): if re.match(r'^\s*(-i|--index-url|--extra-index-url|--find-links).*',
line):
continue continue
# Handle nested requirements files such as: # Handle nested requirements files such as:
@ -140,16 +163,19 @@ def parse_requirements(requirements_files=None, strip_markers=False):
# -e git://github.com/openstack/nova/master#egg=nova # -e git://github.com/openstack/nova/master#egg=nova
# -e git://github.com/openstack/nova/master#egg=nova-1.2.3 # -e git://github.com/openstack/nova/master#egg=nova-1.2.3
# -e git+https://foo.com/zipball#egg=bar&subdirectory=baz # -e git+https://foo.com/zipball#egg=bar&subdirectory=baz
if re.match(r'\s*-e\s+', line):
line = re.sub(r'\s*-e\s+.*#egg=([^&]+).*$', egg_fragment, line)
# such as:
# http://github.com/openstack/nova/zipball/master#egg=nova # http://github.com/openstack/nova/zipball/master#egg=nova
# http://github.com/openstack/nova/zipball/master#egg=nova-1.2.3 # http://github.com/openstack/nova/zipball/master#egg=nova-1.2.3
# git+https://foo.com/zipball#egg=bar&subdirectory=baz # git+https://foo.com/zipball#egg=bar&subdirectory=baz
elif re.match(r'\s*(https?|git(\+(https|ssh))?):', line): # git+[ssh]://github.com/openstack/nova/zipball/master#egg=nova-1.2.3
line = re.sub(r'\s*(https?|git(\+(https|ssh))?):.*#egg=([^&]+).*$', # hg+[ssh]://github.com/openstack/nova/zipball/master#egg=nova-1.2.3
egg_fragment, line) # svn+[proto]://github.com/openstack/nova/zipball/master#egg=nova-1.2.3
# -f lines are for index locations, and don't get used here # -f lines are for index locations, and don't get used here
if re.match(r'\s*-e\s+', line):
extract = re.match(r'\s*-e\s+(.*)$', line)
line = extract.group(1)
egg = urlparse(line)
if egg.scheme:
line = re.sub(r'egg=([^&]+).*$', egg_fragment, egg.fragment)
elif re.match(r'\s*-f\s+', line): elif re.match(r'\s*-f\s+', line):
line = None line = None
reason = 'Index Location' reason = 'Index Location'
@ -183,7 +209,7 @@ def parse_dependency_links(requirements_files=None):
if re.match(r'\s*-[ef]\s+', line): if re.match(r'\s*-[ef]\s+', line):
dependency_links.append(re.sub(r'\s*-[ef]\s+', '', line)) dependency_links.append(re.sub(r'\s*-[ef]\s+', '', line))
# lines that are only urls can go in unmolested # lines that are only urls can go in unmolested
elif re.match(r'\s*(https?|git(\+(https|ssh))?):', line): elif re.match(r'^\s*(https?|git(\+(https|ssh))?|svn|hg)\S*:', line):
dependency_links.append(line) dependency_links.append(line)
return dependency_links return dependency_links
@ -302,6 +328,7 @@ except ImportError:
def have_nose(): def have_nose():
return _have_nose return _have_nose
_wsgi_text = """#PBR Generated from %(group)r _wsgi_text = """#PBR Generated from %(group)r
import threading import threading
@ -404,9 +431,13 @@ def generate_script(group, entry_point, header, template):
def override_get_script_args( def override_get_script_args(
dist, executable=os.path.normpath(sys.executable), is_wininst=False): dist, executable=os.path.normpath(sys.executable)):
"""Override entrypoints console_script.""" """Override entrypoints console_script."""
header = easy_install.get_script_header("", executable, is_wininst) # get_script_header() is deprecated since Setuptools 12.0
try:
header = easy_install.ScriptWriter.get_header("", executable)
except AttributeError:
header = easy_install.get_script_header("", executable)
for group, template in ENTRY_POINTS_MAP.items(): for group, template in ENTRY_POINTS_MAP.items():
for name, ep in dist.get_entry_map(group).items(): for name, ep in dist.get_entry_map(group).items():
yield (name, generate_script(group, ep, header, template)) yield (name, generate_script(group, ep, header, template))
@ -428,8 +459,12 @@ class LocalInstallScripts(install_scripts.install_scripts):
"""Intercepts console scripts entry_points.""" """Intercepts console scripts entry_points."""
command_name = 'install_scripts' command_name = 'install_scripts'
def _make_wsgi_scripts_only(self, dist, executable, is_wininst): def _make_wsgi_scripts_only(self, dist, executable):
header = easy_install.get_script_header("", executable, is_wininst) # get_script_header() is deprecated since Setuptools 12.0
try:
header = easy_install.ScriptWriter.get_header("", executable)
except AttributeError:
header = easy_install.get_script_header("", executable)
wsgi_script_template = ENTRY_POINTS_MAP['wsgi_scripts'] wsgi_script_template = ENTRY_POINTS_MAP['wsgi_scripts']
for name, ep in dist.get_entry_map('wsgi_scripts').items(): for name, ep in dist.get_entry_map('wsgi_scripts').items():
content = generate_script( content = generate_script(
@ -455,16 +490,12 @@ class LocalInstallScripts(install_scripts.install_scripts):
bs_cmd = self.get_finalized_command('build_scripts') bs_cmd = self.get_finalized_command('build_scripts')
executable = getattr( executable = getattr(
bs_cmd, 'executable', easy_install.sys_executable) bs_cmd, 'executable', easy_install.sys_executable)
is_wininst = getattr(
self.get_finalized_command("bdist_wininst"), '_is_running', False
)
if 'bdist_wheel' in self.distribution.have_run: if 'bdist_wheel' in self.distribution.have_run:
# We're building a wheel which has no way of generating mod_wsgi # We're building a wheel which has no way of generating mod_wsgi
# scripts for us. Let's build them. # scripts for us. Let's build them.
# NOTE(sigmavirus24): This needs to happen here because, as the # NOTE(sigmavirus24): This needs to happen here because, as the
# comment below indicates, no_ep is True when building a wheel. # comment below indicates, no_ep is True when building a wheel.
self._make_wsgi_scripts_only(dist, executable, is_wininst) self._make_wsgi_scripts_only(dist, executable)
if self.no_ep: if self.no_ep:
# no_ep is True if we're installing into an .egg file or building # no_ep is True if we're installing into an .egg file or building
@ -478,7 +509,7 @@ class LocalInstallScripts(install_scripts.install_scripts):
get_script_args = easy_install.get_script_args get_script_args = easy_install.get_script_args
executable = '"%s"' % executable executable = '"%s"' % executable
for args in get_script_args(dist, executable, is_wininst): for args in get_script_args(dist, executable):
self.write_script(*args) self.write_script(*args)
@ -550,8 +581,9 @@ class LocalEggInfo(egg_info.egg_info):
else: else:
log.info("[pbr] Reusing existing SOURCES.txt") log.info("[pbr] Reusing existing SOURCES.txt")
self.filelist = egg_info.FileList() self.filelist = egg_info.FileList()
for entry in open(manifest_filename, 'r').read().split('\n'): with open(manifest_filename, 'r') as fil:
self.filelist.append(entry) for entry in fil.read().split('\n'):
self.filelist.append(entry)
def _from_git(distribution): def _from_git(distribution):
@ -626,6 +658,7 @@ class LocalSDist(sdist.sdist):
self.filelist.sort() self.filelist.sort()
sdist.sdist.make_distribution(self) sdist.sdist.make_distribution(self)
try: try:
from pbr import builddoc from pbr import builddoc
_have_sphinx = True _have_sphinx = True
@ -659,12 +692,14 @@ def _get_increment_kwargs(git_dir, tag):
# git log output affecting out ability to have working sem ver headers. # git log output affecting out ability to have working sem ver headers.
changelog = git._run_git_command(['log', '--pretty=%B', version_spec], changelog = git._run_git_command(['log', '--pretty=%B', version_spec],
git_dir) git_dir)
header_len = len('sem-ver:')
commands = [line[header_len:].strip() for line in changelog.split('\n')
if line.lower().startswith('sem-ver:')]
symbols = set() symbols = set()
for command in commands: header = 'sem-ver:'
symbols.update([symbol.strip() for symbol in command.split(',')]) for line in changelog.split("\n"):
line = line.lower().strip()
if not line.lower().strip().startswith(header):
continue
new_symbols = line[len(header):].strip().split(",")
symbols.update([symbol.strip() for symbol in new_symbols])
def _handle_symbol(symbol, symbols, impact): def _handle_symbol(symbol, symbols, impact):
if symbol in symbols: if symbol in symbols:
@ -791,12 +826,9 @@ def _get_version_from_pkg_metadata(package_name):
pkg_metadata = {} pkg_metadata = {}
for filename in pkg_metadata_filenames: for filename in pkg_metadata_filenames:
try: try:
pkg_metadata_file = open(filename, 'r') with open(filename, 'r') as pkg_metadata_file:
except (IOError, OSError): pkg_metadata = email.message_from_file(pkg_metadata_file)
continue except (IOError, OSError, email.errors.MessageError):
try:
pkg_metadata = email.message_from_file(pkg_metadata_file)
except email.errors.MessageError:
continue continue
# Check to make sure we're in our own dir # Check to make sure we're in our own dir

View file

@ -187,7 +187,9 @@ class CapturedSubprocess(fixtures.Fixture):
self.addDetail(self.label + '-stderr', content.text_content(self.err)) self.addDetail(self.label + '-stderr', content.text_content(self.err))
self.returncode = proc.returncode self.returncode = proc.returncode
if proc.returncode: if proc.returncode:
raise AssertionError('Failed process %s' % proc.returncode) raise AssertionError(
'Failed process args=%r, kwargs=%r, returncode=%s' % (
self.args, self.kwargs, proc.returncode))
self.addCleanup(delattr, self, 'out') self.addCleanup(delattr, self, 'out')
self.addCleanup(delattr, self, 'err') self.addCleanup(delattr, self, 'err')
self.addCleanup(delattr, self, 'returncode') self.addCleanup(delattr, self, 'returncode')
@ -200,12 +202,15 @@ def _run_cmd(args, cwd):
:param cwd: The directory to run the comamnd in. :param cwd: The directory to run the comamnd in.
:return: ((stdout, stderr), returncode) :return: ((stdout, stderr), returncode)
""" """
print('Running %s' % ' '.join(args))
p = subprocess.Popen( p = subprocess.Popen(
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, args, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, cwd=cwd) stderr=subprocess.PIPE, cwd=cwd)
streams = tuple(s.decode('latin1').strip() for s in p.communicate()) streams = tuple(s.decode('latin1').strip() for s in p.communicate())
for stream_content in streams: print('STDOUT:')
print(stream_content) print(streams[0])
print('STDERR:')
print(streams[1])
return (streams) + (p.returncode,) return (streams) + (p.returncode,)

View file

@ -78,7 +78,7 @@ class TestCommands(base.BaseTestCase):
stdout, stderr, return_code = self.run_pbr('freeze') stdout, stderr, return_code = self.run_pbr('freeze')
self.assertEqual(0, return_code) self.assertEqual(0, return_code)
pkgs = [] pkgs = []
for l in stdout.split('\n'): for line in stdout.split('\n'):
pkgs.append(l.split('==')[0].lower()) pkgs.append(line.split('==')[0].lower())
pkgs_sort = sorted(pkgs[:]) pkgs_sort = sorted(pkgs[:])
self.assertEqual(pkgs_sort, pkgs) self.assertEqual(pkgs_sort, pkgs)

View file

@ -40,6 +40,7 @@
import glob import glob
import os import os
import sys
import tarfile import tarfile
import fixtures import fixtures
@ -74,7 +75,7 @@ class TestCore(base.BaseTestCase):
self.run_setup('egg_info') self.run_setup('egg_info')
stdout, _, _ = self.run_setup('--keywords') stdout, _, _ = self.run_setup('--keywords')
assert stdout == 'packaging,distutils,setuptools' assert stdout == 'packaging, distutils, setuptools'
def test_setup_py_build_sphinx(self): def test_setup_py_build_sphinx(self):
stdout, _, return_code = self.run_setup('build_sphinx') stdout, _, return_code = self.run_setup('build_sphinx')
@ -113,6 +114,12 @@ class TestCore(base.BaseTestCase):
def test_console_script_develop(self): def test_console_script_develop(self):
"""Test that we develop a non-pkg-resources console script.""" """Test that we develop a non-pkg-resources console script."""
if sys.version_info < (3, 0):
self.skipTest(
'Fails with recent virtualenv due to '
'https://github.com/pypa/virtualenv/issues/1638'
)
if os.name == 'nt': if os.name == 'nt':
self.skipTest('Windows support is passthrough') self.skipTest('Windows support is passthrough')

View file

@ -35,17 +35,31 @@ class FilesConfigTest(base.BaseTestCase):
]) ])
self.useFixture(pkg_fixture) self.useFixture(pkg_fixture)
pkg_etc = os.path.join(pkg_fixture.base, 'etc') pkg_etc = os.path.join(pkg_fixture.base, 'etc')
pkg_ansible = os.path.join(pkg_fixture.base, 'ansible',
'kolla-ansible', 'test')
dir_spcs = os.path.join(pkg_fixture.base, 'dir with space')
dir_subdir_spc = os.path.join(pkg_fixture.base, 'multi space',
'more spaces')
pkg_sub = os.path.join(pkg_etc, 'sub') pkg_sub = os.path.join(pkg_etc, 'sub')
subpackage = os.path.join( subpackage = os.path.join(
pkg_fixture.base, 'fake_package', 'subpackage') pkg_fixture.base, 'fake_package', 'subpackage')
os.makedirs(pkg_sub) os.makedirs(pkg_sub)
os.makedirs(subpackage) os.makedirs(subpackage)
os.makedirs(pkg_ansible)
os.makedirs(dir_spcs)
os.makedirs(dir_subdir_spc)
with open(os.path.join(pkg_etc, "foo"), 'w') as foo_file: with open(os.path.join(pkg_etc, "foo"), 'w') as foo_file:
foo_file.write("Foo Data") foo_file.write("Foo Data")
with open(os.path.join(pkg_sub, "bar"), 'w') as foo_file: with open(os.path.join(pkg_sub, "bar"), 'w') as foo_file:
foo_file.write("Bar Data") foo_file.write("Bar Data")
with open(os.path.join(pkg_ansible, "baz"), 'w') as baz_file:
baz_file.write("Baz Data")
with open(os.path.join(subpackage, "__init__.py"), 'w') as foo_file: with open(os.path.join(subpackage, "__init__.py"), 'w') as foo_file:
foo_file.write("# empty") foo_file.write("# empty")
with open(os.path.join(dir_spcs, "file with spc"), 'w') as spc_file:
spc_file.write("# empty")
with open(os.path.join(dir_subdir_spc, "file with spc"), 'w') as file_:
file_.write("# empty")
self.useFixture(base.DiveDir(pkg_fixture.base)) self.useFixture(base.DiveDir(pkg_fixture.base))
@ -74,5 +88,61 @@ class FilesConfigTest(base.BaseTestCase):
) )
files.FilesConfig(config, 'fake_package').run() files.FilesConfig(config, 'fake_package').run()
self.assertIn( self.assertIn(
'\netc/pbr/ = \n etc/foo\netc/pbr/sub = \n etc/sub/bar', "\n'etc/pbr/' = \n 'etc/foo'\n'etc/pbr/sub' = \n 'etc/sub/bar'",
config['files']['data_files'])
def test_data_files_with_spaces(self):
config = dict(
files=dict(
data_files="\n 'i like spaces' = 'dir with space'/*"
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(
"\n'i like spaces/' = \n 'dir with space/file with spc'",
config['files']['data_files'])
def test_data_files_with_spaces_subdirectories(self):
# test that we can handle whitespace in subdirectories
data_files = "\n 'one space/two space' = 'multi space/more spaces'/*"
expected = (
"\n'one space/two space/' = "
"\n 'multi space/more spaces/file with spc'")
config = dict(
files=dict(
data_files=data_files
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(expected, config['files']['data_files'])
def test_data_files_with_spaces_quoted_components(self):
# test that we can quote individual path components
data_files = (
"\n'one space'/'two space' = 'multi space'/'more spaces'/*"
)
expected = ("\n'one space/two space/' = "
"\n 'multi space/more spaces/file with spc'")
config = dict(
files=dict(
data_files=data_files
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(expected, config['files']['data_files'])
def test_data_files_globbing_source_prefix_in_directory_name(self):
# We want to test that the string, "docs", is not replaced in a
# subdirectory name, "sub-docs"
config = dict(
files=dict(
data_files="\n share/ansible = ansible/*"
)
)
files.FilesConfig(config, 'fake_package').run()
self.assertIn(
"\n'share/ansible/' = "
"\n'share/ansible/kolla-ansible' = "
"\n'share/ansible/kolla-ansible/test' = "
"\n 'ansible/kolla-ansible/test/baz'",
config['files']['data_files']) config['files']['data_files'])

View file

@ -11,7 +11,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
try:
import configparser
except ImportError:
import ConfigParser as configparser
import os.path import os.path
import pkg_resources
import shlex import shlex
import sys import sys
@ -77,19 +82,35 @@ class TestIntegration(base.BaseTestCase):
# We don't break these into separate tests because we'd need separate # We don't break these into separate tests because we'd need separate
# source dirs to isolate from side effects of running pip, and the # source dirs to isolate from side effects of running pip, and the
# overheads of setup would start to beat the benefits of parallelism. # overheads of setup would start to beat the benefits of parallelism.
self.useFixture(base.CapturedSubprocess( path = os.path.join(REPODIR, self.short_name)
'sync-req', setup_cfg = os.path.join(path, 'setup.cfg')
['python', 'update.py', os.path.join(REPODIR, self.short_name)], project_name = pkg_resources.safe_name(self.short_name).lower()
cwd=os.path.join(REPODIR, 'requirements'))) # These projects should all have setup.cfg files but we'll be careful
self.useFixture(base.CapturedSubprocess( if os.path.exists(setup_cfg):
'commit-requirements', config = configparser.ConfigParser()
'git diff --quiet || git commit -amrequirements', config.read(setup_cfg)
cwd=os.path.join(REPODIR, self.short_name), shell=True)) if config.has_section('metadata'):
path = os.path.join( raw_name = config.get('metadata', 'name',
self.useFixture(fixtures.TempDir()).path, 'project') fallback='notapackagename')
self.useFixture(base.CapturedSubprocess( # Technically we should really only need to use the raw
'clone', # name because all our projects should be good and use
['git', 'clone', os.path.join(REPODIR, self.short_name), path])) # normalized names but they don't...
project_name = pkg_resources.safe_name(raw_name).lower()
constraints = os.path.join(REPODIR, 'requirements',
'upper-constraints.txt')
tmp_constraints = os.path.join(
self.useFixture(fixtures.TempDir()).path,
'upper-constraints.txt')
# We need to filter out the package we are installing to avoid
# conflicts with the constraints.
with open(constraints, 'r') as src:
with open(tmp_constraints, 'w') as dest:
for line in src:
constraint = line.split('===')[0]
if project_name != constraint:
dest.write(line)
pip_cmd = PIP_CMD + ['-c', tmp_constraints]
venv = self.useFixture( venv = self.useFixture(
test_packaging.Venv('sdist', test_packaging.Venv('sdist',
modules=['pip', 'wheel', PBRVERSION], modules=['pip', 'wheel', PBRVERSION],
@ -105,7 +126,7 @@ class TestIntegration(base.BaseTestCase):
filename = os.path.join( filename = os.path.join(
path, 'dist', os.listdir(os.path.join(path, 'dist'))[0]) path, 'dist', os.listdir(os.path.join(path, 'dist'))[0])
self.useFixture(base.CapturedSubprocess( self.useFixture(base.CapturedSubprocess(
'tarball', [python] + PIP_CMD + [filename])) 'tarball', [python] + pip_cmd + [filename]))
venv = self.useFixture( venv = self.useFixture(
test_packaging.Venv('install-git', test_packaging.Venv('install-git',
modules=['pip', 'wheel', PBRVERSION], modules=['pip', 'wheel', PBRVERSION],
@ -113,7 +134,7 @@ class TestIntegration(base.BaseTestCase):
root = venv.path root = venv.path
python = venv.python python = venv.python
self.useFixture(base.CapturedSubprocess( self.useFixture(base.CapturedSubprocess(
'install-git', [python] + PIP_CMD + ['git+file://' + path])) 'install-git', [python] + pip_cmd + ['git+file://' + path]))
if self.short_name == 'nova': if self.short_name == 'nova':
found = False found = False
for _, _, filenames in os.walk(root): for _, _, filenames in os.walk(root):
@ -127,7 +148,7 @@ class TestIntegration(base.BaseTestCase):
root = venv.path root = venv.path
python = venv.python python = venv.python
self.useFixture(base.CapturedSubprocess( self.useFixture(base.CapturedSubprocess(
'install-e', [python] + PIP_CMD + ['-e', path])) 'install-e', [python] + pip_cmd + ['-e', path]))
class TestInstallWithoutPbr(base.BaseTestCase): class TestInstallWithoutPbr(base.BaseTestCase):
@ -188,12 +209,16 @@ class TestInstallWithoutPbr(base.BaseTestCase):
class TestMarkersPip(base.BaseTestCase): class TestMarkersPip(base.BaseTestCase):
scenarios = [ scenarios = [
('pip-1.5', {'modules': ['pip>=1.5,<1.6']}),
('pip-6.0', {'modules': ['pip>=6.0,<6.1']}),
('pip-latest', {'modules': ['pip']}), ('pip-latest', {'modules': ['pip']}),
('setuptools-EL7', {'modules': ['pip==1.4.1', 'setuptools==0.9.8']}), ('setuptools-Bionic', {
('setuptools-Trusty', {'modules': ['pip==1.5', 'setuptools==2.2']}), 'modules': ['pip==9.0.1', 'setuptools==39.0.1']}),
('setuptools-minimum', {'modules': ['pip==1.5', 'setuptools==0.7.2']}), ('setuptools-Stretch', {
'modules': ['pip==9.0.1', 'setuptools==33.1.1']}),
('setuptools-EL8', {'modules': ['pip==9.0.3', 'setuptools==39.2.0']}),
('setuptools-Buster', {
'modules': ['pip==18.1', 'setuptools==40.8.0']}),
('setuptools-Focal', {
'modules': ['pip==20.0.2', 'setuptools==45.2.0']}),
] ]
@testtools.skipUnless( @testtools.skipUnless(
@ -240,25 +265,17 @@ class TestLTSSupport(base.BaseTestCase):
# These versions come from the versions installed from the 'virtualenv' # These versions come from the versions installed from the 'virtualenv'
# command from the 'python-virtualenv' package. # command from the 'python-virtualenv' package.
scenarios = [ scenarios = [
('EL7', {'modules': ['pip==1.4.1', 'setuptools==0.9.8'], ('Bionic', {'modules': ['pip==9.0.1', 'setuptools==39.0.1']}),
'py3support': True}), # And EPEL6 ('Stretch', {'modules': ['pip==9.0.1', 'setuptools==33.1.1']}),
('Trusty', {'modules': ['pip==1.5', 'setuptools==2.2'], ('EL8', {'modules': ['pip==9.0.3', 'setuptools==39.2.0']}),
'py3support': True}), ('Buster', {'modules': ['pip==18.1', 'setuptools==40.8.0']}),
('Jessie', {'modules': ['pip==1.5.6', 'setuptools==5.5.1'], ('Focal', {'modules': ['pip==20.0.2', 'setuptools==45.2.0']}),
'py3support': True}),
# Wheezy has pip1.1, which cannot be called with '-m pip'
# So we'll use a different version of pip here.
('WheezyPrecise', {'modules': ['pip==1.4.1', 'setuptools==0.6c11'],
'py3support': False})
] ]
@testtools.skipUnless( @testtools.skipUnless(
os.environ.get('PBR_INTEGRATION', None) == '1', os.environ.get('PBR_INTEGRATION', None) == '1',
'integration tests not enabled') 'integration tests not enabled')
def test_lts_venv_default_versions(self): def test_lts_venv_default_versions(self):
if (sys.version_info[0] == 3 and not self.py3support):
self.skipTest('This combination will not install with py3, '
'skipping test')
venv = self.useFixture( venv = self.useFixture(
test_packaging.Venv('setuptools', modules=self.modules)) test_packaging.Venv('setuptools', modules=self.modules))
bin_python = venv.python bin_python = venv.python

View file

@ -48,7 +48,10 @@ import tempfile
import textwrap import textwrap
import fixtures import fixtures
import mock try:
from unittest import mock
except ImportError:
import mock
import pkg_resources import pkg_resources
import six import six
import testscenarios import testscenarios
@ -108,7 +111,7 @@ class GPGKeyFixture(fixtures.Fixture):
def setUp(self): def setUp(self):
super(GPGKeyFixture, self).setUp() super(GPGKeyFixture, self).setUp()
tempdir = self.useFixture(fixtures.TempDir()) tempdir = self.useFixture(fixtures.TempDir())
gnupg_version_re = re.compile('^gpg\s.*\s([\d+])\.([\d+])\.([\d+])') gnupg_version_re = re.compile(r'^gpg\s.*\s([\d+])\.([\d+])\.([\d+])')
gnupg_version = base._run_cmd(['gpg', '--version'], tempdir.path) gnupg_version = base._run_cmd(['gpg', '--version'], tempdir.path)
for line in gnupg_version[0].split('\n'): for line in gnupg_version[0].split('\n'):
gnupg_version = gnupg_version_re.match(line) gnupg_version = gnupg_version_re.match(line)
@ -120,9 +123,9 @@ class GPGKeyFixture(fixtures.Fixture):
else: else:
if gnupg_version is None: if gnupg_version is None:
gnupg_version = (0, 0, 0) gnupg_version = (0, 0, 0)
config_file = tempdir.path + '/key-config'
f = open(config_file, 'wt') config_file = os.path.join(tempdir.path, 'key-config')
try: with open(config_file, 'wt') as f:
if gnupg_version[0] == 2 and gnupg_version[1] >= 1: if gnupg_version[0] == 2 and gnupg_version[1] >= 1:
f.write(""" f.write("""
%no-protection %no-protection
@ -135,11 +138,9 @@ class GPGKeyFixture(fixtures.Fixture):
Name-Comment: N/A Name-Comment: N/A
Name-Email: example@example.com Name-Email: example@example.com
Expire-Date: 2d Expire-Date: 2d
Preferences: (setpref)
%commit %commit
""") """)
finally:
f.close()
# Note that --quick-random (--debug-quick-random in GnuPG 2.x) # Note that --quick-random (--debug-quick-random in GnuPG 2.x)
# does not have a corresponding preferences file setting and # does not have a corresponding preferences file setting and
# must be passed explicitly on the command line instead # must be passed explicitly on the command line instead
@ -149,6 +150,7 @@ class GPGKeyFixture(fixtures.Fixture):
gnupg_random = '--debug-quick-random' gnupg_random = '--debug-quick-random'
else: else:
gnupg_random = '' gnupg_random = ''
base._run_cmd( base._run_cmd(
['gpg', '--gen-key', '--batch', gnupg_random, config_file], ['gpg', '--gen-key', '--batch', gnupg_random, config_file],
tempdir.path) tempdir.path)
@ -173,17 +175,17 @@ class Venv(fixtures.Fixture):
""" """
self._reason = reason self._reason = reason
if modules == (): if modules == ():
pbr = 'file://%s#egg=pbr' % PBR_ROOT modules = ['pip', 'wheel', 'build', PBR_ROOT]
modules = ['pip', 'wheel', pbr]
self.modules = modules self.modules = modules
if pip_cmd is None: if pip_cmd is None:
self.pip_cmd = ['-m', 'pip', 'install'] self.pip_cmd = ['-m', 'pip', '-v', 'install']
else: else:
self.pip_cmd = pip_cmd self.pip_cmd = pip_cmd
def _setUp(self): def _setUp(self):
path = self.useFixture(fixtures.TempDir()).path path = self.useFixture(fixtures.TempDir()).path
virtualenv.create_environment(path, clear=True) virtualenv.cli_run([path])
python = os.path.join(path, 'bin', 'python') python = os.path.join(path, 'bin', 'python')
command = [python] + self.pip_cmd + ['-U'] command = [python] + self.pip_cmd + ['-U']
if self.modules and len(self.modules) > 0: if self.modules and len(self.modules) > 0:
@ -293,23 +295,23 @@ class TestPackagingInGitRepoWithCommit(base.BaseTestCase):
self.run_setup('sdist', allow_fail=False) self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f: with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read() body = f.read()
self.assertIn('\*', body) self.assertIn(r'\*', body)
def test_changelog_handles_dead_links_in_commit(self): def test_changelog_handles_dead_links_in_commit(self):
self.repo.commit(message_content="See os_ for to_do about qemu_.") self.repo.commit(message_content="See os_ for to_do about qemu_.")
self.run_setup('sdist', allow_fail=False) self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f: with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read() body = f.read()
self.assertIn('os\_', body) self.assertIn(r'os\_', body)
self.assertIn('to\_do', body) self.assertIn(r'to\_do', body)
self.assertIn('qemu\_', body) self.assertIn(r'qemu\_', body)
def test_changelog_handles_backticks(self): def test_changelog_handles_backticks(self):
self.repo.commit(message_content="Allow `openstack.org` to `work") self.repo.commit(message_content="Allow `openstack.org` to `work")
self.run_setup('sdist', allow_fail=False) self.run_setup('sdist', allow_fail=False)
with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f: with open(os.path.join(self.package_dir, 'ChangeLog'), 'r') as f:
body = f.read() body = f.read()
self.assertIn('\`', body) self.assertIn(r'\`', body)
def test_manifest_exclude_honoured(self): def test_manifest_exclude_honoured(self):
self.run_setup('sdist', allow_fail=False) self.run_setup('sdist', allow_fail=False)
@ -379,6 +381,12 @@ class TestPackagingWheels(base.BaseTestCase):
wheel_file.extractall(self.extracted_wheel_dir) wheel_file.extractall(self.extracted_wheel_dir)
wheel_file.close() wheel_file.close()
def test_metadata_directory_has_pbr_json(self):
# Build the path to the scripts directory
pbr_json = os.path.join(
self.extracted_wheel_dir, 'pbr_testpackage-0.0.dist-info/pbr.json')
self.assertTrue(os.path.exists(pbr_json))
def test_data_directory_has_wsgi_scripts(self): def test_data_directory_has_wsgi_scripts(self):
# Build the path to the scripts directory # Build the path to the scripts directory
scripts_dir = os.path.join( scripts_dir = os.path.join(
@ -531,11 +539,13 @@ class ParseRequirementsTest(base.BaseTestCase):
tempdir = tempfile.mkdtemp() tempdir = tempfile.mkdtemp()
requirements = os.path.join(tempdir, 'requirements.txt') requirements = os.path.join(tempdir, 'requirements.txt')
with open(requirements, 'w') as f: with open(requirements, 'w') as f:
f.write('-i https://myindex.local') f.write('-i https://myindex.local\n')
f.write(' --index-url https://myindex.local') f.write(' --index-url https://myindex.local\n')
f.write(' --extra-index-url https://myindex.local') f.write(' --extra-index-url https://myindex.local\n')
f.write('--find-links https://myindex.local\n')
f.write('arequirement>=1.0\n')
result = packaging.parse_requirements([requirements]) result = packaging.parse_requirements([requirements])
self.assertEqual([], result) self.assertEqual(['arequirement>=1.0'], result)
def test_nested_requirements(self): def test_nested_requirements(self):
tempdir = tempfile.mkdtemp() tempdir = tempfile.mkdtemp()
@ -662,12 +672,65 @@ class TestVersions(base.BaseTestCase):
version = packaging._get_version_from_git() version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1')) self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_multi_inline_symbols_no_space(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('Sem-ver: feature,api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_multi_inline_symbols_spaced(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('Sem-ver: feature, api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_multi_inline_symbols_reversed(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit('Sem-ver: api-break,feature')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_leading_space(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit(' sem-ver: api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_leading_space_multiline(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit(
(
' Some cool text\n'
' sem-ver: api-break'
)
)
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('2.0.0.dev1'))
def test_leading_characters_symbol_not_found(self):
self.repo.commit()
self.repo.tag('1.2.3')
self.repo.commit(' ssem-ver: api-break')
version = packaging._get_version_from_git()
self.assertThat(version, matchers.StartsWith('1.2.4.dev1'))
def test_tagged_version_has_tag_version(self): def test_tagged_version_has_tag_version(self):
self.repo.commit() self.repo.commit()
self.repo.tag('1.2.3') self.repo.tag('1.2.3')
version = packaging._get_version_from_git('1.2.3') version = packaging._get_version_from_git('1.2.3')
self.assertEqual('1.2.3', version) self.assertEqual('1.2.3', version)
def test_tagged_version_with_semver_compliant_prerelease(self):
self.repo.commit()
self.repo.tag('1.2.3-rc2')
version = packaging._get_version_from_git()
self.assertEqual('1.2.3.0rc2', version)
def test_non_canonical_tagged_version_bump(self): def test_non_canonical_tagged_version_bump(self):
self.repo.commit() self.repo.commit()
self.repo.tag('1.4') self.repo.tag('1.4')
@ -724,6 +787,13 @@ class TestVersions(base.BaseTestCase):
version = packaging._get_version_from_git('1.2.3') version = packaging._get_version_from_git('1.2.3')
self.assertThat(version, matchers.StartsWith('1.2.3.0a2.dev1')) self.assertThat(version, matchers.StartsWith('1.2.3.0a2.dev1'))
def test_untagged_version_after_semver_compliant_prerelease_tag(self):
self.repo.commit()
self.repo.tag('1.2.3-rc2')
self.repo.commit()
version = packaging._get_version_from_git()
self.assertEqual('1.2.3.0rc3.dev1', version)
def test_preversion_too_low_simple(self): def test_preversion_too_low_simple(self):
# That is, the target version is either already released or not high # That is, the target version is either already released or not high
# enough for the semver requirements given api breaks etc. # enough for the semver requirements given api breaks etc.
@ -750,8 +820,10 @@ class TestVersions(base.BaseTestCase):
def test_get_kwargs_corner_cases(self): def test_get_kwargs_corner_cases(self):
# No tags: # No tags:
git_dir = self.repo._basedir + '/.git'
get_kwargs = lambda tag: packaging._get_increment_kwargs(git_dir, tag) def get_kwargs(tag):
git_dir = self.repo._basedir + '/.git'
return packaging._get_increment_kwargs(git_dir, tag)
def _check_combinations(tag): def _check_combinations(tag):
self.repo.commit() self.repo.commit()
@ -903,6 +975,235 @@ class TestRequirementParsing(base.BaseTestCase):
self.assertEqual(exp_parsed, gen_parsed) self.assertEqual(exp_parsed, gen_parsed)
class TestPEP517Support(base.BaseTestCase):
def test_pep_517_support(self):
# Note that the current PBR PEP517 entrypoints rely on a valid
# PBR setup.py existing.
pkgs = {
'test_pep517':
{
'requirements.txt': textwrap.dedent("""\
sphinx
iso8601
"""),
# Override default setup.py to remove setup_requires.
'setup.py': textwrap.dedent("""\
#!/usr/bin/env python
import setuptools
setuptools.setup(pbr=True)
"""),
'setup.cfg': textwrap.dedent("""\
[metadata]
name = test_pep517
summary = A tiny test project
author = PBR Team
author-email = foo@example.com
home-page = https://example.com/
classifier =
Intended Audience :: Information Technology
Intended Audience :: System Administrators
License :: OSI Approved :: Apache Software License
Operating System :: POSIX :: Linux
Programming Language :: Python
Programming Language :: Python :: 2
Programming Language :: Python :: 2.7
Programming Language :: Python :: 3
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
"""),
'pyproject.toml': textwrap.dedent("""\
[build-system]
requires = ["pbr", "setuptools>=36.6.0", "wheel"]
build-backend = "pbr.build"
""")},
}
pkg_dirs = self.useFixture(CreatePackages(pkgs)).package_dirs
pkg_dir = pkg_dirs['test_pep517']
venv = self.useFixture(Venv('PEP517'))
# Test building sdists and wheels works. Note we do not use pip here
# because pip will forcefully install the latest version of PBR on
# pypi to satisfy the build-system requires. This means we can't self
# test changes using pip. Build with --no-isolation appears to avoid
# this problem.
self._run_cmd(venv.python, ('-m', 'build', '--no-isolation', '.'),
allow_fail=False, cwd=pkg_dir)
class TestRepositoryURLDependencies(base.BaseTestCase):
def setUp(self):
super(TestRepositoryURLDependencies, self).setUp()
self.requirements = os.path.join(tempfile.mkdtemp(),
'requirements.txt')
with open(self.requirements, 'w') as f:
f.write('\n'.join([
'-e git+git://git.pro-ject.org/oslo.messaging#egg=oslo.messaging-1.0.0-rc', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize-beta', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-4.0.1', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-1.0.0-alpha.beta.1', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'-e git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-2.0.0-rc.1+build.123', # noqa
'-e git+git://git.project.org/Proj#egg=Proj1',
'git+https://git.project.org/Proj#egg=Proj2-0.0.1',
'-e git+ssh://git.project.org/Proj#egg=Proj3',
'svn+svn://svn.project.org/svn/Proj#egg=Proj4-0.0.2',
'-e svn+http://svn.project.org/svn/Proj/trunk@2019#egg=Proj5',
'hg+http://hg.project.org/Proj@da39a3ee5e6b#egg=Proj-0.0.3',
'-e hg+http://hg.project.org/Proj@2019#egg=Proj',
'hg+http://hg.project.org/Proj@v1.0#egg=Proj-0.0.4',
'-e hg+http://hg.project.org/Proj@special_feature#egg=Proj',
'git://foo.com/zipball#egg=foo-bar-1.2.4',
'pypi-proj1', 'pypi-proj2']))
def test_egg_fragment(self):
expected = [
'django-thumborize',
'django-thumborize-beta',
'django-thumborize2-beta',
'django-thumborize2-beta>=4.0.1',
'django-thumborize2-beta>=1.0.0-alpha.beta.1',
'django-thumborize2-beta>=1.0.0-alpha-a.b-c-long+build.1-aef.1-its-okay', # noqa
'django-thumborize2-beta>=2.0.0-rc.1+build.123',
'django-thumborize-beta>=0.0.4',
'django-thumborize-beta>=1.2.3',
'django-thumborize-beta>=10.20.30',
'django-thumborize-beta>=1.1.2-prerelease+meta',
'django-thumborize-beta>=1.1.2+meta',
'django-thumborize-beta>=1.1.2+meta-valid',
'django-thumborize-beta>=1.0.0-alpha',
'django-thumborize-beta>=1.0.0-beta',
'django-thumborize-beta>=1.0.0-alpha.beta',
'django-thumborize-beta>=1.0.0-alpha.beta.1',
'django-thumborize-beta>=1.0.0-alpha.1',
'django-thumborize-beta>=1.0.0-alpha0.valid',
'django-thumborize-beta>=1.0.0-alpha.0valid',
'django-thumborize-beta>=1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'django-thumborize-beta>=1.0.0-rc.1+build.1',
'django-thumborize-beta>=2.0.0-rc.1+build.123',
'django-thumborize-beta>=1.2.3-beta',
'django-thumborize-beta>=10.2.3-DEV-SNAPSHOT',
'django-thumborize-beta>=1.2.3-SNAPSHOT-123',
'django-thumborize-beta>=1.0.0',
'django-thumborize-beta>=2.0.0',
'django-thumborize-beta>=1.1.7',
'django-thumborize-beta>=2.0.0+build.1848',
'django-thumborize-beta>=2.0.1-alpha.1227',
'django-thumborize-beta>=1.0.0-alpha+beta',
'django-thumborize-beta>=1.2.3----RC-SNAPSHOT.12.9.1--.12+788',
'django-thumborize-beta>=1.2.3----R-S.12.9.1--.12+meta',
'django-thumborize-beta>=1.2.3----RC-SNAPSHOT.12.9.1--.12',
'django-thumborize-beta>=1.0.0+0.build.1-rc.10000aaa-kk-0.1',
'django-thumborize-beta>=999999999999999999.99999999999999.9999999999999', # noqa
'Proj1',
'Proj2>=0.0.1',
'Proj3',
'Proj4>=0.0.2',
'Proj5',
'Proj>=0.0.3',
'Proj',
'Proj>=0.0.4',
'Proj',
'foo-bar>=1.2.4',
]
tests = [
'egg=django-thumborize',
'egg=django-thumborize-beta',
'egg=django-thumborize2-beta',
'egg=django-thumborize2-beta-4.0.1',
'egg=django-thumborize2-beta-1.0.0-alpha.beta.1',
'egg=django-thumborize2-beta-1.0.0-alpha-a.b-c-long+build.1-aef.1-its-okay', # noqa
'egg=django-thumborize2-beta-2.0.0-rc.1+build.123',
'egg=django-thumborize-beta-0.0.4',
'egg=django-thumborize-beta-1.2.3',
'egg=django-thumborize-beta-10.20.30',
'egg=django-thumborize-beta-1.1.2-prerelease+meta',
'egg=django-thumborize-beta-1.1.2+meta',
'egg=django-thumborize-beta-1.1.2+meta-valid',
'egg=django-thumborize-beta-1.0.0-alpha',
'egg=django-thumborize-beta-1.0.0-beta',
'egg=django-thumborize-beta-1.0.0-alpha.beta',
'egg=django-thumborize-beta-1.0.0-alpha.beta.1',
'egg=django-thumborize-beta-1.0.0-alpha.1',
'egg=django-thumborize-beta-1.0.0-alpha0.valid',
'egg=django-thumborize-beta-1.0.0-alpha.0valid',
'egg=django-thumborize-beta-1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'egg=django-thumborize-beta-1.0.0-rc.1+build.1',
'egg=django-thumborize-beta-2.0.0-rc.1+build.123',
'egg=django-thumborize-beta-1.2.3-beta',
'egg=django-thumborize-beta-10.2.3-DEV-SNAPSHOT',
'egg=django-thumborize-beta-1.2.3-SNAPSHOT-123',
'egg=django-thumborize-beta-1.0.0',
'egg=django-thumborize-beta-2.0.0',
'egg=django-thumborize-beta-1.1.7',
'egg=django-thumborize-beta-2.0.0+build.1848',
'egg=django-thumborize-beta-2.0.1-alpha.1227',
'egg=django-thumborize-beta-1.0.0-alpha+beta',
'egg=django-thumborize-beta-1.2.3----RC-SNAPSHOT.12.9.1--.12+788', # noqa
'egg=django-thumborize-beta-1.2.3----R-S.12.9.1--.12+meta',
'egg=django-thumborize-beta-1.2.3----RC-SNAPSHOT.12.9.1--.12',
'egg=django-thumborize-beta-1.0.0+0.build.1-rc.10000aaa-kk-0.1', # noqa
'egg=django-thumborize-beta-999999999999999999.99999999999999.9999999999999', # noqa
'egg=Proj1',
'egg=Proj2-0.0.1',
'egg=Proj3',
'egg=Proj4-0.0.2',
'egg=Proj5',
'egg=Proj-0.0.3',
'egg=Proj',
'egg=Proj-0.0.4',
'egg=Proj',
'egg=foo-bar-1.2.4',
]
for index, test in enumerate(tests):
self.assertEqual(expected[index],
re.sub(r'egg=([^&]+).*$',
packaging.egg_fragment,
test))
def test_parse_repo_url_requirements(self):
result = packaging.parse_requirements([self.requirements])
self.assertEqual(['oslo.messaging>=1.0.0-rc',
'django-thumborize',
'django-thumborize-beta',
'django-thumborize2-beta',
'django-thumborize2-beta>=4.0.1',
'django-thumborize2-beta>=1.0.0-alpha.beta.1',
'django-thumborize2-beta>=1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'django-thumborize2-beta>=2.0.0-rc.1+build.123',
'Proj1', 'Proj2>=0.0.1', 'Proj3',
'Proj4>=0.0.2', 'Proj5', 'Proj>=0.0.3',
'Proj', 'Proj>=0.0.4', 'Proj',
'foo-bar>=1.2.4', 'pypi-proj1',
'pypi-proj2'], result)
def test_parse_repo_url_dependency_links(self):
result = packaging.parse_dependency_links([self.requirements])
self.assertEqual(
[
'git+git://git.pro-ject.org/oslo.messaging#egg=oslo.messaging-1.0.0-rc', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize-beta', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-4.0.1', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-1.0.0-alpha.beta.1', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-1.0.0-alpha-a.b-c-somethinglong+build.1-aef.1-its-okay', # noqa
'git+git://git.pro-ject.org/django-thumborize#egg=django-thumborize2-beta-2.0.0-rc.1+build.123', # noqa
'git+git://git.project.org/Proj#egg=Proj1',
'git+https://git.project.org/Proj#egg=Proj2-0.0.1',
'git+ssh://git.project.org/Proj#egg=Proj3',
'svn+svn://svn.project.org/svn/Proj#egg=Proj4-0.0.2',
'svn+http://svn.project.org/svn/Proj/trunk@2019#egg=Proj5',
'hg+http://hg.project.org/Proj@da39a3ee5e6b#egg=Proj-0.0.3',
'hg+http://hg.project.org/Proj@2019#egg=Proj',
'hg+http://hg.project.org/Proj@v1.0#egg=Proj-0.0.4',
'hg+http://hg.project.org/Proj@special_feature#egg=Proj',
'git://foo.com/zipball#egg=foo-bar-1.2.4'], result)
def get_soabi(): def get_soabi():
soabi = None soabi = None
try: try:

View file

@ -10,7 +10,10 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import mock try:
from unittest import mock
except ImportError:
import mock
from pbr import pbr_json from pbr import pbr_json
from pbr.tests import base from pbr.tests import base

View file

@ -93,8 +93,9 @@ class SkipFileWrites(base.BaseTestCase):
option_dict=self.option_dict) option_dict=self.option_dict)
self.assertEqual( self.assertEqual(
not os.path.exists(self.filename), not os.path.exists(self.filename),
(self.option_value.lower() in options.TRUE_VALUES (self.option_value.lower() in options.TRUE_VALUES or
or self.env_value is not None)) self.env_value is not None))
_changelog_content = """7780758\x00Break parser\x00 (tag: refs/tags/1_foo.1) _changelog_content = """7780758\x00Break parser\x00 (tag: refs/tags/1_foo.1)
04316fe\x00Make python\x00 (refs/heads/review/monty_taylor/27519) 04316fe\x00Make python\x00 (refs/heads/review/monty_taylor/27519)
@ -125,6 +126,7 @@ def _make_old_git_changelog_format(line):
refname = refname.replace('tag: ', '') refname = refname.replace('tag: ', '')
return '\x00'.join((sha, msg, refname)) return '\x00'.join((sha, msg, refname))
_old_git_changelog_content = '\n'.join( _old_git_changelog_content = '\n'.join(
_make_old_git_changelog_format(line) _make_old_git_changelog_format(line)
for line in _changelog_content.split('\n')) for line in _changelog_content.split('\n'))
@ -162,7 +164,7 @@ class GitLogsTest(base.BaseTestCase):
self.assertIn("------", changelog_contents) self.assertIn("------", changelog_contents)
self.assertIn("Refactor hooks file", changelog_contents) self.assertIn("Refactor hooks file", changelog_contents)
self.assertIn( self.assertIn(
"Bug fix: create\_stack() fails when waiting", r"Bug fix: create\_stack() fails when waiting",
changelog_contents) changelog_contents)
self.assertNotIn("Refactor hooks file.", changelog_contents) self.assertNotIn("Refactor hooks file.", changelog_contents)
self.assertNotIn("182feb3", changelog_contents) self.assertNotIn("182feb3", changelog_contents)
@ -176,7 +178,7 @@ class GitLogsTest(base.BaseTestCase):
self.assertNotIn("ev)il", changelog_contents) self.assertNotIn("ev)il", changelog_contents)
self.assertNotIn("e(vi)l", changelog_contents) self.assertNotIn("e(vi)l", changelog_contents)
self.assertNotIn('Merge "', changelog_contents) self.assertNotIn('Merge "', changelog_contents)
self.assertNotIn('1\_foo.1', changelog_contents) self.assertNotIn(r'1\_foo.1', changelog_contents)
def test_generate_authors(self): def test_generate_authors(self):
author_old = u"Foo Foo <email@foo.com>" author_old = u"Foo Foo <email@foo.com>"
@ -216,9 +218,9 @@ class GitLogsTest(base.BaseTestCase):
with open(os.path.join(self.temp_path, "AUTHORS"), "r") as auth_fh: with open(os.path.join(self.temp_path, "AUTHORS"), "r") as auth_fh:
authors = auth_fh.read() authors = auth_fh.read()
self.assertTrue(author_old in authors) self.assertIn(author_old, authors)
self.assertTrue(author_new in authors) self.assertIn(author_new, authors)
self.assertTrue(co_author in authors) self.assertIn(co_author, authors)
class _SphinxConfig(object): class _SphinxConfig(object):

View file

@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2015 Hewlett-Packard Development Company, L.P. (HP) # Copyright (c) 2015 Hewlett-Packard Development Company, L.P. (HP)
# #
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
@ -13,6 +14,7 @@
# under the License. # under the License.
import io import io
import tempfile
import textwrap import textwrap
import six import six
@ -23,6 +25,122 @@ from pbr.tests import base
from pbr import util from pbr import util
def config_from_ini(ini):
config = {}
ini = textwrap.dedent(six.u(ini))
if sys.version_info >= (3, 2):
parser = configparser.ConfigParser()
parser.read_file(io.StringIO(ini))
else:
parser = configparser.SafeConfigParser()
parser.readfp(io.StringIO(ini))
for section in parser.sections():
config[section] = dict(parser.items(section))
return config
class TestBasics(base.BaseTestCase):
def test_basics(self):
self.maxDiff = None
config_text = """
[metadata]
name = foo
version = 1.0
author = John Doe
author_email = jd@example.com
maintainer = Jim Burke
maintainer_email = jb@example.com
home_page = http://example.com
summary = A foobar project.
description = Hello, world. This is a long description.
download_url = http://opendev.org/x/pbr
classifier =
Development Status :: 5 - Production/Stable
Programming Language :: Python
platform =
any
license = Apache 2.0
requires_dist =
Sphinx
requests
setup_requires_dist =
docutils
python_requires = >=3.6
provides_dist =
bax
provides_extras =
bar
obsoletes_dist =
baz
[files]
packages_root = src
packages =
foo
package_data =
"" = *.txt, *.rst
foo = *.msg
namespace_packages =
hello
data_files =
bitmaps =
bm/b1.gif
bm/b2.gif
config =
cfg/data.cfg
scripts =
scripts/hello-world.py
modules =
mod1
"""
expected = {
'name': u'foo',
'version': u'1.0',
'author': u'John Doe',
'author_email': u'jd@example.com',
'maintainer': u'Jim Burke',
'maintainer_email': u'jb@example.com',
'url': u'http://example.com',
'description': u'A foobar project.',
'long_description': u'Hello, world. This is a long description.',
'download_url': u'http://opendev.org/x/pbr',
'classifiers': [
u'Development Status :: 5 - Production/Stable',
u'Programming Language :: Python',
],
'platforms': [u'any'],
'license': u'Apache 2.0',
'install_requires': [
u'Sphinx',
u'requests',
],
'setup_requires': [u'docutils'],
'python_requires': u'>=3.6',
'provides': [u'bax'],
'provides_extras': [u'bar'],
'obsoletes': [u'baz'],
'extras_require': {},
'package_dir': {'': u'src'},
'packages': [u'foo'],
'package_data': {
'': ['*.txt,', '*.rst'],
'foo': ['*.msg'],
},
'namespace_packages': [u'hello'],
'data_files': [
('bitmaps', ['bm/b1.gif', 'bm/b2.gif']),
('config', ['cfg/data.cfg']),
],
'scripts': [u'scripts/hello-world.py'],
'py_modules': [u'mod1'],
}
config = config_from_ini(config_text)
actual = util.setup_cfg_to_setup_kwargs(config)
self.assertDictEqual(expected, actual)
class TestExtrasRequireParsingScenarios(base.BaseTestCase): class TestExtrasRequireParsingScenarios(base.BaseTestCase):
scenarios = [ scenarios = [
@ -64,20 +182,8 @@ class TestExtrasRequireParsingScenarios(base.BaseTestCase):
{} {}
})] })]
def config_from_ini(self, ini):
config = {}
if sys.version_info >= (3, 2):
parser = configparser.ConfigParser()
else:
parser = configparser.SafeConfigParser()
ini = textwrap.dedent(six.u(ini))
parser.readfp(io.StringIO(ini))
for section in parser.sections():
config[section] = dict(parser.items(section))
return config
def test_extras_parsing(self): def test_extras_parsing(self):
config = self.config_from_ini(self.config_text) config = config_from_ini(self.config_text)
kwargs = util.setup_cfg_to_setup_kwargs(config) kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(self.expected_extra_requires, self.assertEqual(self.expected_extra_requires,
@ -89,3 +195,127 @@ class TestInvalidMarkers(base.BaseTestCase):
def test_invalid_marker_raises_error(self): def test_invalid_marker_raises_error(self):
config = {'extras': {'test': "foo :bad_marker>'1.0'"}} config = {'extras': {'test': "foo :bad_marker>'1.0'"}}
self.assertRaises(SyntaxError, util.setup_cfg_to_setup_kwargs, config) self.assertRaises(SyntaxError, util.setup_cfg_to_setup_kwargs, config)
class TestMapFieldsParsingScenarios(base.BaseTestCase):
scenarios = [
('simple_project_urls', {
'config_text': """
[metadata]
project_urls =
Bug Tracker = https://bugs.launchpad.net/pbr/
Documentation = https://docs.openstack.org/pbr/
Source Code = https://opendev.org/openstack/pbr
""", # noqa: E501
'expected_project_urls': {
'Bug Tracker': 'https://bugs.launchpad.net/pbr/',
'Documentation': 'https://docs.openstack.org/pbr/',
'Source Code': 'https://opendev.org/openstack/pbr',
},
}),
('query_parameters', {
'config_text': """
[metadata]
project_urls =
Bug Tracker = https://bugs.launchpad.net/pbr/?query=true
Documentation = https://docs.openstack.org/pbr/?foo=bar
Source Code = https://git.openstack.org/cgit/openstack-dev/pbr/commit/?id=hash
""", # noqa: E501
'expected_project_urls': {
'Bug Tracker': 'https://bugs.launchpad.net/pbr/?query=true',
'Documentation': 'https://docs.openstack.org/pbr/?foo=bar',
'Source Code': 'https://git.openstack.org/cgit/openstack-dev/pbr/commit/?id=hash', # noqa: E501
},
}),
]
def test_project_url_parsing(self):
config = config_from_ini(self.config_text)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(self.expected_project_urls, kwargs['project_urls'])
class TestKeywordsParsingScenarios(base.BaseTestCase):
scenarios = [
('keywords_list', {
'config_text': """
[metadata]
keywords =
one
two
three
""", # noqa: E501
'expected_keywords': ['one', 'two', 'three'],
},
),
('inline_keywords', {
'config_text': """
[metadata]
keywords = one, two, three
""", # noqa: E501
'expected_keywords': ['one, two, three'],
}),
]
def test_keywords_parsing(self):
config = config_from_ini(self.config_text)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(self.expected_keywords, kwargs['keywords'])
class TestProvidesExtras(base.BaseTestCase):
def test_provides_extras(self):
ini = """
[metadata]
provides_extras = foo
bar
"""
config = config_from_ini(ini)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(['foo', 'bar'], kwargs['provides_extras'])
class TestDataFilesParsing(base.BaseTestCase):
scenarios = [
('data_files', {
'config_text': """
[files]
data_files =
'i like spaces/' =
'dir with space/file with spc 2'
'dir with space/file with spc 1'
""",
'data_files': [
('i like spaces/', ['dir with space/file with spc 2',
'dir with space/file with spc 1'])
]
})]
def test_handling_of_whitespace_in_data_files(self):
config = config_from_ini(self.config_text)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(self.data_files, kwargs['data_files'])
class TestUTF8DescriptionFile(base.BaseTestCase):
def test_utf8_description_file(self):
_, path = tempfile.mkstemp()
ini_template = """
[metadata]
description_file = %s
"""
# Two \n's because pbr strips the file content and adds \n\n
# This way we can use it directly as the assert comparison
unicode_description = u'UTF8 description: é"…-ʃŋ\'\n\n'
ini = ini_template % path
with io.open(path, 'w', encoding='utf8') as f:
f.write(unicode_description)
config = config_from_ini(ini)
kwargs = util.setup_cfg_to_setup_kwargs(config)
self.assertEqual(unicode_description, kwargs['long_description'])

View file

@ -77,8 +77,8 @@ class TestWsgiScripts(base.BaseTestCase):
def _test_wsgi(self, cmd_name, output, extra_args=None): def _test_wsgi(self, cmd_name, output, extra_args=None):
cmd = os.path.join(self.temp_dir, 'bin', cmd_name) cmd = os.path.join(self.temp_dir, 'bin', cmd_name)
print("Running %s -p 0" % cmd) print("Running %s -p 0 -b 127.0.0.1" % cmd)
popen_cmd = [cmd, '-p', '0'] popen_cmd = [cmd, '-p', '0', '-b', '127.0.0.1']
if extra_args: if extra_args:
popen_cmd.extend(extra_args) popen_cmd.extend(extra_args)
@ -98,7 +98,7 @@ class TestWsgiScripts(base.BaseTestCase):
stdoutdata = p.stdout.readline() # Available at ... stdoutdata = p.stdout.readline() # Available at ...
print(stdoutdata) print(stdoutdata)
m = re.search(b'(http://[^:]+:\d+)/', stdoutdata) m = re.search(br'(http://[^:]+:\d+)/', stdoutdata)
self.assertIsNotNone(m, "Regex failed to match on %s" % stdoutdata) self.assertIsNotNone(m, "Regex failed to match on %s" % stdoutdata)
stdoutdata = p.stdout.readline() # DANGER! ... stdoutdata = p.stdout.readline() # DANGER! ...

View file

@ -12,17 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import sys
sys.path.insert(0, os.path.abspath('../..'))
# -- General configuration ---------------------------------------------------- # -- General configuration ----------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', 'sphinx.ext.autodoc',
#'sphinx.ext.intersphinx',
] ]
# autodoc generation is a bit aggressive and a nuisance when doing heavy # autodoc generation is a bit aggressive and a nuisance when doing heavy
@ -49,17 +45,9 @@ add_module_names = True
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = 'sphinx'
# -- Options for HTML output -------------------------------------------------- # -- Options for HTML output --------------------------------------------------
# The theme to use for HTML and HTML Help pages. Major themes that come with
# Sphinx are currently 'default' and 'sphinxdoc'.
# html_theme_path = ["."]
# html_theme = '_theme'
# html_static_path = ['static']
# Output file base name for HTML help builder.
htmlhelp_basename = '%sdoc' % project
# Grouping the document tree into LaTeX files. List of tuples # Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass # (source start file, target name, title, author, documentclass
# [howto/manual]). # [howto/manual]).
@ -69,6 +57,3 @@ latex_documents = [
u'%s Documentation' % project, u'%s Documentation' % project,
u'OpenStack Foundation', 'manual'), u'OpenStack Foundation', 'manual'),
] ]
# Example configuration for intersphinx: refer to the Python standard library.
#intersphinx_mapping = {'http://docs.python.org/': None}

View file

@ -53,9 +53,9 @@ except ImportError:
@contextlib.contextmanager @contextlib.contextmanager
def open_config(filename): def open_config(filename):
if sys.version_info >= (3, 2): if sys.version_info >= (3, 2):
cfg = configparser.ConfigParser() cfg = configparser.ConfigParser()
else: else:
cfg = configparser.SafeConfigParser() cfg = configparser.SafeConfigParser()
cfg.read(filename) cfg.read(filename)
yield cfg yield cfg
with open(filename, 'w') as fp: with open(filename, 'w') as fp:

View file

@ -62,8 +62,10 @@ except ImportError:
import logging # noqa import logging # noqa
from collections import defaultdict from collections import defaultdict
import io
import os import os
import re import re
import shlex
import sys import sys
import traceback import traceback
@ -86,50 +88,52 @@ import pbr.hooks
# predicates in () # predicates in ()
_VERSION_SPEC_RE = re.compile(r'\s*(.*?)\s*\((.*)\)\s*$') _VERSION_SPEC_RE = re.compile(r'\s*(.*?)\s*\((.*)\)\s*$')
# Mappings from setup.cfg options, in (section, option) form, to setup()
# Mappings from setup() keyword arguments to setup.cfg options; # keyword arguments
# The values are (section, option) tuples, or simply (section,) tuples if CFG_TO_PY_SETUP_ARGS = (
# the option has the same name as the setup() argument (('metadata', 'name'), 'name'),
D1_D2_SETUP_ARGS = { (('metadata', 'version'), 'version'),
"name": ("metadata",), (('metadata', 'author'), 'author'),
"version": ("metadata",), (('metadata', 'author_email'), 'author_email'),
"author": ("metadata",), (('metadata', 'maintainer'), 'maintainer'),
"author_email": ("metadata",), (('metadata', 'maintainer_email'), 'maintainer_email'),
"maintainer": ("metadata",), (('metadata', 'home_page'), 'url'),
"maintainer_email": ("metadata",), (('metadata', 'project_urls'), 'project_urls'),
"url": ("metadata", "home_page"), (('metadata', 'summary'), 'description'),
"project_urls": ("metadata",), (('metadata', 'keywords'), 'keywords'),
"description": ("metadata", "summary"), (('metadata', 'description'), 'long_description'),
"keywords": ("metadata",), (
"long_description": ("metadata", "description"), ('metadata', 'description_content_type'),
"long_description_content_type": ("metadata", "description_content_type"), 'long_description_content_type',
"download_url": ("metadata",), ),
"classifiers": ("metadata", "classifier"), (('metadata', 'download_url'), 'download_url'),
"platforms": ("metadata", "platform"), # ** (('metadata', 'classifier'), 'classifiers'),
"license": ("metadata",), (('metadata', 'platform'), 'platforms'), # **
(('metadata', 'license'), 'license'),
# Use setuptools install_requires, not # Use setuptools install_requires, not
# broken distutils requires # broken distutils requires
"install_requires": ("metadata", "requires_dist"), (('metadata', 'requires_dist'), 'install_requires'),
"setup_requires": ("metadata", "setup_requires_dist"), (('metadata', 'setup_requires_dist'), 'setup_requires'),
"python_requires": ("metadata",), (('metadata', 'python_requires'), 'python_requires'),
"provides": ("metadata", "provides_dist"), # ** (('metadata', 'requires_python'), 'python_requires'),
"obsoletes": ("metadata", "obsoletes_dist"), # ** (('metadata', 'provides_dist'), 'provides'), # **
"package_dir": ("files", 'packages_root'), (('metadata', 'provides_extras'), 'provides_extras'),
"packages": ("files",), (('metadata', 'obsoletes_dist'), 'obsoletes'), # **
"package_data": ("files",), (('files', 'packages_root'), 'package_dir'),
"namespace_packages": ("files",), (('files', 'packages'), 'packages'),
"data_files": ("files",), (('files', 'package_data'), 'package_data'),
"scripts": ("files",), (('files', 'namespace_packages'), 'namespace_packages'),
"py_modules": ("files", "modules"), # ** (('files', 'data_files'), 'data_files'),
"cmdclass": ("global", "commands"), (('files', 'scripts'), 'scripts'),
(('files', 'modules'), 'py_modules'), # **
(('global', 'commands'), 'cmdclass'),
# Not supported in distutils2, but provided for # Not supported in distutils2, but provided for
# backwards compatibility with setuptools # backwards compatibility with setuptools
"use_2to3": ("backwards_compat", "use_2to3"), (('backwards_compat', 'zip_safe'), 'zip_safe'),
"zip_safe": ("backwards_compat", "zip_safe"), (('backwards_compat', 'tests_require'), 'tests_require'),
"tests_require": ("backwards_compat", "tests_require"), (('backwards_compat', 'dependency_links'), 'dependency_links'),
"dependency_links": ("backwards_compat",), (('backwards_compat', 'include_package_data'), 'include_package_data'),
"include_package_data": ("backwards_compat",), )
}
# setup() arguments that can have multiple values in setup.cfg # setup() arguments that can have multiple values in setup.cfg
MULTI_FIELDS = ("classifiers", MULTI_FIELDS = ("classifiers",
@ -146,16 +150,27 @@ MULTI_FIELDS = ("classifiers",
"dependency_links", "dependency_links",
"setup_requires", "setup_requires",
"tests_require", "tests_require",
"cmdclass") "keywords",
"cmdclass",
"provides_extras")
# setup() arguments that can have mapping values in setup.cfg # setup() arguments that can have mapping values in setup.cfg
MAP_FIELDS = ("project_urls",) MAP_FIELDS = ("project_urls",)
# setup() arguments that contain boolean values # setup() arguments that contain boolean values
BOOL_FIELDS = ("use_2to3", "zip_safe", "include_package_data") BOOL_FIELDS = ("zip_safe", "include_package_data")
CSV_FIELDS = ()
CSV_FIELDS = ("keywords",) def shlex_split(path):
if os.name == 'nt':
# shlex cannot handle paths that contain backslashes, treating those
# as escape characters.
path = path.replace("\\", "/")
return [x.replace("/", "\\") for x in shlex.split(path)]
return shlex.split(path)
def resolve_name(name): def resolve_name(name):
@ -205,10 +220,11 @@ def cfg_to_args(path='setup.cfg', script_args=()):
""" """
# The method source code really starts here. # The method source code really starts here.
if sys.version_info >= (3, 2): if sys.version_info >= (3, 0):
parser = configparser.ConfigParser() parser = configparser.ConfigParser()
else: else:
parser = configparser.SafeConfigParser() parser = configparser.SafeConfigParser()
if not os.path.exists(path): if not os.path.exists(path):
raise errors.DistutilsFileError("file '%s' does not exist" % raise errors.DistutilsFileError("file '%s' does not exist" %
os.path.abspath(path)) os.path.abspath(path))
@ -297,34 +313,25 @@ def setup_cfg_to_setup_kwargs(config, script_args=()):
# parse env_markers. # parse env_markers.
all_requirements = {} all_requirements = {}
for arg in D1_D2_SETUP_ARGS: for alias, arg in CFG_TO_PY_SETUP_ARGS:
if len(D1_D2_SETUP_ARGS[arg]) == 2: section, option = alias
# The distutils field name is different than distutils2's.
section, option = D1_D2_SETUP_ARGS[arg]
elif len(D1_D2_SETUP_ARGS[arg]) == 1:
# The distutils field name is the same thant distutils2's.
section = D1_D2_SETUP_ARGS[arg][0]
option = arg
in_cfg_value = has_get_option(config, section, option) in_cfg_value = has_get_option(config, section, option)
if not in_cfg_value and arg == "long_description":
in_cfg_value = has_get_option(config, section, "description_file")
if in_cfg_value:
in_cfg_value = split_multiline(in_cfg_value)
value = ''
for filename in in_cfg_value:
description_file = io.open(filename, encoding='utf-8')
try:
value += description_file.read().strip() + '\n\n'
finally:
description_file.close()
in_cfg_value = value
if not in_cfg_value: if not in_cfg_value:
# There is no such option in the setup.cfg continue
if arg == "long_description":
in_cfg_value = has_get_option(config, section,
"description_file")
if in_cfg_value:
in_cfg_value = split_multiline(in_cfg_value)
value = ''
for filename in in_cfg_value:
description_file = open(filename)
try:
value += description_file.read().strip() + '\n\n'
finally:
description_file.close()
in_cfg_value = value
else:
continue
if arg in CSV_FIELDS: if arg in CSV_FIELDS:
in_cfg_value = split_csv(in_cfg_value) in_cfg_value = split_csv(in_cfg_value)
@ -333,7 +340,7 @@ def setup_cfg_to_setup_kwargs(config, script_args=()):
elif arg in MAP_FIELDS: elif arg in MAP_FIELDS:
in_cfg_map = {} in_cfg_map = {}
for i in split_multiline(in_cfg_value): for i in split_multiline(in_cfg_value):
k, v = i.split('=') k, v = i.split('=', 1)
in_cfg_map[k.strip()] = v.strip() in_cfg_map[k.strip()] = v.strip()
in_cfg_value = in_cfg_map in_cfg_value = in_cfg_map
elif arg in BOOL_FIELDS: elif arg in BOOL_FIELDS:
@ -370,26 +377,27 @@ def setup_cfg_to_setup_kwargs(config, script_args=()):
for line in in_cfg_value: for line in in_cfg_value:
if '=' in line: if '=' in line:
key, value = line.split('=', 1) key, value = line.split('=', 1)
key, value = (key.strip(), value.strip()) key_unquoted = shlex_split(key.strip())[0]
key, value = (key_unquoted, value.strip())
if key in data_files: if key in data_files:
# Multiple duplicates of the same package name; # Multiple duplicates of the same package name;
# this is for backwards compatibility of the old # this is for backwards compatibility of the old
# format prior to d2to1 0.2.6. # format prior to d2to1 0.2.6.
prev = data_files[key] prev = data_files[key]
prev.extend(value.split()) prev.extend(shlex_split(value))
else: else:
prev = data_files[key.strip()] = value.split() prev = data_files[key.strip()] = shlex_split(value)
elif firstline: elif firstline:
raise errors.DistutilsOptionError( raise errors.DistutilsOptionError(
'malformed package_data first line %r (misses ' 'malformed package_data first line %r (misses '
'"=")' % line) '"=")' % line)
else: else:
prev.extend(line.strip().split()) prev.extend(shlex_split(line.strip()))
firstline = False firstline = False
if arg == 'data_files': if arg == 'data_files':
# the data_files value is a pointlessly different structure # the data_files value is a pointlessly different structure
# from the package_data value # from the package_data value
data_files = data_files.items() data_files = sorted(data_files.items())
in_cfg_value = data_files in_cfg_value = data_files
elif arg == 'cmdclass': elif arg == 'cmdclass':
cmdclass = {} cmdclass = {}
@ -532,7 +540,7 @@ def get_extension_modules(config):
else: else:
# Backwards compatibility for old syntax; don't use this though # Backwards compatibility for old syntax; don't use this though
labels = section.split('=', 1) labels = section.split('=', 1)
labels = [l.strip() for l in labels] labels = [label.strip() for label in labels]
if (len(labels) == 2) and (labels[0] == 'extension'): if (len(labels) == 2) and (labels[0] == 'extension'):
ext_args = {} ext_args = {}
for field in EXTENSION_FIELDS: for field in EXTENSION_FIELDS:

View file

@ -15,13 +15,24 @@
# under the License. # under the License.
""" """
Utilities for consuming the version from pkg_resources. Utilities for consuming the version from importlib-metadata.
""" """
import itertools import itertools
import operator import operator
import sys import sys
# TODO(stephenfin): Remove this once we drop support for Python < 3.8
if sys.version_info >= (3, 8):
from importlib import metadata as importlib_metadata
use_importlib = True
else:
try:
import importlib_metadata
use_importlib = True
except ImportError:
use_importlib = False
def _is_int(string): def _is_int(string):
try: try:
@ -323,8 +334,8 @@ class SemanticVersion(object):
version number of the component to preserve sorting. (Used for version number of the component to preserve sorting. (Used for
rpm support) rpm support)
""" """
if ((self._prerelease_type or self._dev_count) if ((self._prerelease_type or self._dev_count) and
and pre_separator is None): pre_separator is None):
segments = [self.decrement().brief_string()] segments = [self.decrement().brief_string()]
pre_separator = "." pre_separator = "."
else: else:
@ -431,12 +442,15 @@ class VersionInfo(object):
"""Obtain a version from pkg_resources or setup-time logic if missing. """Obtain a version from pkg_resources or setup-time logic if missing.
This will try to get the version of the package from the pkg_resources This will try to get the version of the package from the pkg_resources
This will try to get the version of the package from the
record associated with the package, and if there is no such record record associated with the package, and if there is no such record
importlib_metadata record associated with the package, and if there
falls back to the logic sdist would use. falls back to the logic sdist would use.
is no such record falls back to the logic sdist would use.
""" """
# Lazy import because pkg_resources is costly to import so defer until
# we absolutely need it.
import pkg_resources import pkg_resources
try: try:
requirement = pkg_resources.Requirement.parse(self.package) requirement = pkg_resources.Requirement.parse(self.package)
provider = pkg_resources.get_provider(requirement) provider = pkg_resources.get_provider(requirement)
@ -447,6 +461,25 @@ class VersionInfo(object):
# installed into anything. Revert to setup-time logic. # installed into anything. Revert to setup-time logic.
from pbr import packaging from pbr import packaging
result_string = packaging.get_version(self.package) result_string = packaging.get_version(self.package)
return SemanticVersion.from_pip_string(result_string)
def _get_version_from_importlib_metadata(self):
"""Obtain a version from importlib or setup-time logic if missing.
This will try to get the version of the package from the
importlib_metadata record associated with the package, and if there
is no such record falls back to the logic sdist would use.
"""
try:
distribution = importlib_metadata.distribution(self.package)
result_string = distribution.version
except importlib_metadata.PackageNotFoundError:
# The most likely cause for this is running tests in a tree
# produced from a tarball where the package itself has not been
# installed into anything. Revert to setup-time logic.
from pbr import packaging
result_string = packaging.get_version(self.package)
return SemanticVersion.from_pip_string(result_string) return SemanticVersion.from_pip_string(result_string)
def release_string(self): def release_string(self):
@ -459,7 +492,12 @@ class VersionInfo(object):
def semantic_version(self): def semantic_version(self):
"""Return the SemanticVersion object for this version.""" """Return the SemanticVersion object for this version."""
if self._semantic is None: if self._semantic is None:
self._semantic = self._get_version_from_pkg_resources() # TODO(damami): simplify this once Python 3.8 is the oldest
# we support
if use_importlib:
self._semantic = self._get_version_from_importlib_metadata()
else:
self._semantic = self._get_version_from_pkg_resources()
return self._semantic return self._semantic
def version_string(self): def version_string(self):

View file

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# pylint: disable-all # pylint: disable-all
from __future__ import print_function
import os import os
import re import re
@ -132,9 +133,14 @@ class SubRipShifter(object):
def run(self, args): def run(self, args):
self.arguments = self.build_parser().parse_args(args) self.arguments = self.build_parser().parse_args(args)
if self.arguments.in_place:
self.create_backup() if os.path.isfile(self.arguments.file):
self.arguments.action() if self.arguments.in_place:
self.create_backup()
self.arguments.action()
else:
print('No such file', self.arguments.file)
def parse_time(self, time_string): def parse_time(self, time_string):
negative = time_string.startswith('-') negative = time_string.startswith('-')

View file

@ -290,7 +290,7 @@ class SubRipFile(UserList, object):
@classmethod @classmethod
def _open_unicode_file(cls, path, claimed_encoding=None): def _open_unicode_file(cls, path, claimed_encoding=None):
encoding = claimed_encoding or cls._detect_encoding(path) encoding = claimed_encoding or cls._detect_encoding(path)
source_file = codecs.open(path, 'rU', encoding=encoding) source_file = codecs.open(path, 'r', encoding=encoding)
# get rid of BOM if any # get rid of BOM if any
possible_bom = CODECS_BOMS.get(encoding, None) possible_bom = CODECS_BOMS.get(encoding, None)

View file

@ -16,14 +16,14 @@ from pytz.exceptions import AmbiguousTimeError
from pytz.exceptions import InvalidTimeError from pytz.exceptions import InvalidTimeError
from pytz.exceptions import NonExistentTimeError from pytz.exceptions import NonExistentTimeError
from pytz.exceptions import UnknownTimeZoneError from pytz.exceptions import UnknownTimeZoneError
from pytz.lazy import LazyDict, LazyList, LazySet from pytz.lazy import LazyDict, LazyList, LazySet # noqa
from pytz.tzinfo import unpickler, BaseTzInfo from pytz.tzinfo import unpickler, BaseTzInfo
from pytz.tzfile import build_tzinfo from pytz.tzfile import build_tzinfo
# The IANA (nee Olson) database is updated several times a year. # The IANA (nee Olson) database is updated several times a year.
OLSON_VERSION = '2018g' OLSON_VERSION = '2022f'
VERSION = '2018.7' # pip compatible version number. VERSION = '2022.6' # pip compatible version number.
__version__ = VERSION __version__ = VERSION
OLSEN_VERSION = OLSON_VERSION # Old releases had this misspelling OLSEN_VERSION = OLSON_VERSION # Old releases had this misspelling
@ -34,7 +34,7 @@ __all__ = [
'NonExistentTimeError', 'UnknownTimeZoneError', 'NonExistentTimeError', 'UnknownTimeZoneError',
'all_timezones', 'all_timezones_set', 'all_timezones', 'all_timezones_set',
'common_timezones', 'common_timezones_set', 'common_timezones', 'common_timezones_set',
'BaseTzInfo', 'BaseTzInfo', 'FixedOffset',
] ]
@ -86,7 +86,7 @@ def open_resource(name):
""" """
name_parts = name.lstrip('/').split('/') name_parts = name.lstrip('/').split('/')
for part in name_parts: for part in name_parts:
if part == os.path.pardir or os.path.sep in part: if part == os.path.pardir or os.sep in part:
raise ValueError('Bad path segment: %r' % part) raise ValueError('Bad path segment: %r' % part)
zoneinfo_dir = os.environ.get('PYTZ_TZDATADIR', None) zoneinfo_dir = os.environ.get('PYTZ_TZDATADIR', None)
if zoneinfo_dir is not None: if zoneinfo_dir is not None:
@ -111,6 +111,13 @@ def open_resource(name):
def resource_exists(name): def resource_exists(name):
"""Return true if the given resource exists""" """Return true if the given resource exists"""
try: try:
if os.environ.get('PYTZ_SKIPEXISTSCHECK', ''):
# In "standard" distributions, we can assume that
# all the listed timezones are present. As an
# import-speed optimization, you can set the
# PYTZ_SKIPEXISTSCHECK flag to skip checking
# for the presence of the resource file on disk.
return True
open_resource(name).close() open_resource(name).close()
return True return True
except IOError: except IOError:
@ -157,6 +164,9 @@ def timezone(zone):
Unknown Unknown
''' '''
if zone is None:
raise UnknownTimeZoneError(None)
if zone.upper() == 'UTC': if zone.upper() == 'UTC':
return utc return utc
@ -166,9 +176,9 @@ def timezone(zone):
# All valid timezones are ASCII # All valid timezones are ASCII
raise UnknownTimeZoneError(zone) raise UnknownTimeZoneError(zone)
zone = _unmunge_zone(zone) zone = _case_insensitive_zone_lookup(_unmunge_zone(zone))
if zone not in _tzinfo_cache: if zone not in _tzinfo_cache:
if zone in all_timezones_set: if zone in all_timezones_set: # noqa
fp = open_resource(zone) fp = open_resource(zone)
try: try:
_tzinfo_cache[zone] = build_tzinfo(zone, fp) _tzinfo_cache[zone] = build_tzinfo(zone, fp)
@ -185,6 +195,17 @@ def _unmunge_zone(zone):
return zone.replace('_plus_', '+').replace('_minus_', '-') return zone.replace('_plus_', '+').replace('_minus_', '-')
_all_timezones_lower_to_standard = None
def _case_insensitive_zone_lookup(zone):
"""case-insensitively matching timezone, else return zone unchanged"""
global _all_timezones_lower_to_standard
if _all_timezones_lower_to_standard is None:
_all_timezones_lower_to_standard = dict((tz.lower(), tz) for tz in all_timezones) # noqa
return _all_timezones_lower_to_standard.get(zone.lower()) or zone # noqa
ZERO = datetime.timedelta(0) ZERO = datetime.timedelta(0)
HOUR = datetime.timedelta(hours=1) HOUR = datetime.timedelta(hours=1)
@ -249,8 +270,8 @@ def _UTC():
module global. module global.
These examples belong in the UTC class above, but it is obscured; or in These examples belong in the UTC class above, but it is obscured; or in
the README.txt, but we are not depending on Python 2.4 so integrating the README.rst, but we are not depending on Python 2.4 so integrating
the README.txt examples with the unit tests is not trivial. the README.rst examples with the unit tests is not trivial.
>>> import datetime, pickle >>> import datetime, pickle
>>> dt = datetime.datetime(2005, 3, 1, 14, 13, 21, tzinfo=utc) >>> dt = datetime.datetime(2005, 3, 1, 14, 13, 21, tzinfo=utc)
@ -272,6 +293,8 @@ def _UTC():
False False
""" """
return utc return utc
_UTC.__safe_for_unpickling__ = True _UTC.__safe_for_unpickling__ = True
@ -282,6 +305,8 @@ def _p(*args):
by shortening the path. by shortening the path.
""" """
return unpickler(*args) return unpickler(*args)
_p.__safe_for_unpickling__ = True _p.__safe_for_unpickling__ = True
@ -330,7 +355,7 @@ class _CountryTimezoneDict(LazyDict):
if line.startswith('#'): if line.startswith('#'):
continue continue
code, coordinates, zone = line.split(None, 4)[:3] code, coordinates, zone = line.split(None, 4)[:3]
if zone not in all_timezones_set: if zone not in all_timezones_set: # noqa
continue continue
try: try:
data[code].append(zone) data[code].append(zone)
@ -340,6 +365,7 @@ class _CountryTimezoneDict(LazyDict):
finally: finally:
zone_tab.close() zone_tab.close()
country_timezones = _CountryTimezoneDict() country_timezones = _CountryTimezoneDict()
@ -363,6 +389,7 @@ class _CountryNameDict(LazyDict):
finally: finally:
zone_tab.close() zone_tab.close()
country_names = _CountryNameDict() country_names = _CountryNameDict()
@ -474,6 +501,7 @@ def FixedOffset(offset, _tzinfos={}):
return info return info
FixedOffset.__safe_for_unpickling__ = True FixedOffset.__safe_for_unpickling__ = True
@ -483,6 +511,7 @@ def _test():
import pytz import pytz
return doctest.testmod(pytz) return doctest.testmod(pytz)
if __name__ == '__main__': if __name__ == '__main__':
_test() _test()
all_timezones = \ all_timezones = \
@ -661,6 +690,7 @@ all_timezones = \
'America/North_Dakota/Beulah', 'America/North_Dakota/Beulah',
'America/North_Dakota/Center', 'America/North_Dakota/Center',
'America/North_Dakota/New_Salem', 'America/North_Dakota/New_Salem',
'America/Nuuk',
'America/Ojinaga', 'America/Ojinaga',
'America/Panama', 'America/Panama',
'America/Pangnirtung', 'America/Pangnirtung',
@ -787,6 +817,7 @@ all_timezones = \
'Asia/Pontianak', 'Asia/Pontianak',
'Asia/Pyongyang', 'Asia/Pyongyang',
'Asia/Qatar', 'Asia/Qatar',
'Asia/Qostanay',
'Asia/Qyzylorda', 'Asia/Qyzylorda',
'Asia/Rangoon', 'Asia/Rangoon',
'Asia/Riyadh', 'Asia/Riyadh',
@ -933,6 +964,7 @@ all_timezones = \
'Europe/Kaliningrad', 'Europe/Kaliningrad',
'Europe/Kiev', 'Europe/Kiev',
'Europe/Kirov', 'Europe/Kirov',
'Europe/Kyiv',
'Europe/Lisbon', 'Europe/Lisbon',
'Europe/Ljubljana', 'Europe/Ljubljana',
'Europe/London', 'Europe/London',
@ -1027,6 +1059,7 @@ all_timezones = \
'Pacific/Guam', 'Pacific/Guam',
'Pacific/Honolulu', 'Pacific/Honolulu',
'Pacific/Johnston', 'Pacific/Johnston',
'Pacific/Kanton',
'Pacific/Kiritimati', 'Pacific/Kiritimati',
'Pacific/Kosrae', 'Pacific/Kosrae',
'Pacific/Kwajalein', 'Pacific/Kwajalein',
@ -1187,7 +1220,6 @@ common_timezones = \
'America/Fort_Nelson', 'America/Fort_Nelson',
'America/Fortaleza', 'America/Fortaleza',
'America/Glace_Bay', 'America/Glace_Bay',
'America/Godthab',
'America/Goose_Bay', 'America/Goose_Bay',
'America/Grand_Turk', 'America/Grand_Turk',
'America/Grenada', 'America/Grenada',
@ -1235,12 +1267,12 @@ common_timezones = \
'America/Montserrat', 'America/Montserrat',
'America/Nassau', 'America/Nassau',
'America/New_York', 'America/New_York',
'America/Nipigon',
'America/Nome', 'America/Nome',
'America/Noronha', 'America/Noronha',
'America/North_Dakota/Beulah', 'America/North_Dakota/Beulah',
'America/North_Dakota/Center', 'America/North_Dakota/Center',
'America/North_Dakota/New_Salem', 'America/North_Dakota/New_Salem',
'America/Nuuk',
'America/Ojinaga', 'America/Ojinaga',
'America/Panama', 'America/Panama',
'America/Pangnirtung', 'America/Pangnirtung',
@ -1251,7 +1283,6 @@ common_timezones = \
'America/Porto_Velho', 'America/Porto_Velho',
'America/Puerto_Rico', 'America/Puerto_Rico',
'America/Punta_Arenas', 'America/Punta_Arenas',
'America/Rainy_River',
'America/Rankin_Inlet', 'America/Rankin_Inlet',
'America/Recife', 'America/Recife',
'America/Regina', 'America/Regina',
@ -1272,7 +1303,6 @@ common_timezones = \
'America/Swift_Current', 'America/Swift_Current',
'America/Tegucigalpa', 'America/Tegucigalpa',
'America/Thule', 'America/Thule',
'America/Thunder_Bay',
'America/Tijuana', 'America/Tijuana',
'America/Toronto', 'America/Toronto',
'America/Tortola', 'America/Tortola',
@ -1351,6 +1381,7 @@ common_timezones = \
'Asia/Pontianak', 'Asia/Pontianak',
'Asia/Pyongyang', 'Asia/Pyongyang',
'Asia/Qatar', 'Asia/Qatar',
'Asia/Qostanay',
'Asia/Qyzylorda', 'Asia/Qyzylorda',
'Asia/Riyadh', 'Asia/Riyadh',
'Asia/Sakhalin', 'Asia/Sakhalin',
@ -1388,7 +1419,6 @@ common_timezones = \
'Australia/Adelaide', 'Australia/Adelaide',
'Australia/Brisbane', 'Australia/Brisbane',
'Australia/Broken_Hill', 'Australia/Broken_Hill',
'Australia/Currie',
'Australia/Darwin', 'Australia/Darwin',
'Australia/Eucla', 'Australia/Eucla',
'Australia/Hobart', 'Australia/Hobart',
@ -1424,8 +1454,8 @@ common_timezones = \
'Europe/Istanbul', 'Europe/Istanbul',
'Europe/Jersey', 'Europe/Jersey',
'Europe/Kaliningrad', 'Europe/Kaliningrad',
'Europe/Kiev',
'Europe/Kirov', 'Europe/Kirov',
'Europe/Kyiv',
'Europe/Lisbon', 'Europe/Lisbon',
'Europe/Ljubljana', 'Europe/Ljubljana',
'Europe/London', 'Europe/London',
@ -1453,7 +1483,6 @@ common_timezones = \
'Europe/Tallinn', 'Europe/Tallinn',
'Europe/Tirane', 'Europe/Tirane',
'Europe/Ulyanovsk', 'Europe/Ulyanovsk',
'Europe/Uzhgorod',
'Europe/Vaduz', 'Europe/Vaduz',
'Europe/Vatican', 'Europe/Vatican',
'Europe/Vienna', 'Europe/Vienna',
@ -1461,7 +1490,6 @@ common_timezones = \
'Europe/Volgograd', 'Europe/Volgograd',
'Europe/Warsaw', 'Europe/Warsaw',
'Europe/Zagreb', 'Europe/Zagreb',
'Europe/Zaporozhye',
'Europe/Zurich', 'Europe/Zurich',
'GMT', 'GMT',
'Indian/Antananarivo', 'Indian/Antananarivo',
@ -1482,7 +1510,6 @@ common_timezones = \
'Pacific/Chuuk', 'Pacific/Chuuk',
'Pacific/Easter', 'Pacific/Easter',
'Pacific/Efate', 'Pacific/Efate',
'Pacific/Enderbury',
'Pacific/Fakaofo', 'Pacific/Fakaofo',
'Pacific/Fiji', 'Pacific/Fiji',
'Pacific/Funafuti', 'Pacific/Funafuti',
@ -1491,6 +1518,7 @@ common_timezones = \
'Pacific/Guadalcanal', 'Pacific/Guadalcanal',
'Pacific/Guam', 'Pacific/Guam',
'Pacific/Honolulu', 'Pacific/Honolulu',
'Pacific/Kanton',
'Pacific/Kiritimati', 'Pacific/Kiritimati',
'Pacific/Kosrae', 'Pacific/Kosrae',
'Pacific/Kwajalein', 'Pacific/Kwajalein',

View file

@ -8,7 +8,11 @@ __all__ = [
] ]
class UnknownTimeZoneError(KeyError): class Error(Exception):
'''Base class for all exceptions raised by the pytz library'''
class UnknownTimeZoneError(KeyError, Error):
'''Exception raised when pytz is passed an unknown timezone. '''Exception raised when pytz is passed an unknown timezone.
>>> isinstance(UnknownTimeZoneError(), LookupError) >>> isinstance(UnknownTimeZoneError(), LookupError)
@ -20,11 +24,18 @@ class UnknownTimeZoneError(KeyError):
>>> isinstance(UnknownTimeZoneError(), KeyError) >>> isinstance(UnknownTimeZoneError(), KeyError)
True True
And also a subclass of pytz.exceptions.Error, as are other pytz
exceptions.
>>> isinstance(UnknownTimeZoneError(), Error)
True
''' '''
pass pass
class InvalidTimeError(Exception): class InvalidTimeError(Error):
'''Base class for invalid time exceptions.''' '''Base class for invalid time exceptions.'''

View file

@ -1,4 +1,3 @@
#!/usr/bin/env python
''' '''
$Id: tzfile.py,v 1.8 2004/06/03 00:15:24 zenzen Exp $ $Id: tzfile.py,v 1.8 2004/06/03 00:15:24 zenzen Exp $
''' '''

Some files were not shown because too many files have changed in this diff Show more