diff --git a/lib/soupsieve/__init__.py b/lib/soupsieve/__init__.py index fefc6ca0..c89b7002 100644 --- a/lib/soupsieve/__init__.py +++ b/lib/soupsieve/__init__.py @@ -30,6 +30,8 @@ from . import css_parser as cp from . import css_match as cm from . import css_types as ct from .util import DEBUG, SelectorSyntaxError # noqa: F401 +import bs4 # type: ignore[import] +from typing import Dict, Optional, Any, List, Iterator, Iterable __all__ = ( 'DEBUG', 'SelectorSyntaxError', 'SoupSieve', @@ -40,15 +42,18 @@ __all__ = ( SoupSieve = cm.SoupSieve -def compile(pattern, namespaces=None, flags=0, **kwargs): # noqa: A001 +def compile( # noqa: A001 + pattern: str, + namespaces: Optional[Dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[Dict[str, str]] = None, + **kwargs: Any +) -> cm.SoupSieve: """Compile CSS pattern.""" - if namespaces is not None: - namespaces = ct.Namespaces(namespaces) - - custom = kwargs.get('custom') - if custom is not None: - custom = ct.CustomSelectors(custom) + ns = ct.Namespaces(namespaces) if namespaces is not None else namespaces # type: Optional[ct.Namespaces] + cs = ct.CustomSelectors(custom) if custom is not None else custom # type: Optional[ct.CustomSelectors] if isinstance(pattern, SoupSieve): if flags: @@ -59,53 +64,103 @@ def compile(pattern, namespaces=None, flags=0, **kwargs): # noqa: A001 raise ValueError("Cannot process 'custom' argument on a compiled selector list") return pattern - return cp._cached_css_compile(pattern, namespaces, custom, flags) + return cp._cached_css_compile(pattern, ns, cs, flags) -def purge(): +def purge() -> None: """Purge cached patterns.""" cp._purge_cache() -def closest(select, tag, namespaces=None, flags=0, **kwargs): +def closest( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[Dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[Dict[str, str]] = None, + **kwargs: Any +) -> 'bs4.Tag': """Match closest ancestor.""" return compile(select, namespaces, flags, **kwargs).closest(tag) -def match(select, tag, namespaces=None, flags=0, **kwargs): +def match( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[Dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[Dict[str, str]] = None, + **kwargs: Any +) -> bool: """Match node.""" return compile(select, namespaces, flags, **kwargs).match(tag) -def filter(select, iterable, namespaces=None, flags=0, **kwargs): # noqa: A001 +def filter( # noqa: A001 + select: str, + iterable: Iterable['bs4.Tag'], + namespaces: Optional[Dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[Dict[str, str]] = None, + **kwargs: Any +) -> List['bs4.Tag']: """Filter list of nodes.""" return compile(select, namespaces, flags, **kwargs).filter(iterable) -def select_one(select, tag, namespaces=None, flags=0, **kwargs): +def select_one( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[Dict[str, str]] = None, + flags: int = 0, + *, + custom: Optional[Dict[str, str]] = None, + **kwargs: Any +) -> 'bs4.Tag': """Select a single tag.""" return compile(select, namespaces, flags, **kwargs).select_one(tag) -def select(select, tag, namespaces=None, limit=0, flags=0, **kwargs): +def select( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[Dict[str, str]] = None, + limit: int = 0, + flags: int = 0, + *, + custom: Optional[Dict[str, str]] = None, + **kwargs: Any +) -> List['bs4.Tag']: """Select the specified tags.""" return compile(select, namespaces, flags, **kwargs).select(tag, limit) -def iselect(select, tag, namespaces=None, limit=0, flags=0, **kwargs): +def iselect( + select: str, + tag: 'bs4.Tag', + namespaces: Optional[Dict[str, str]] = None, + limit: int = 0, + flags: int = 0, + *, + custom: Optional[Dict[str, str]] = None, + **kwargs: Any +) -> Iterator['bs4.Tag']: """Iterate the specified tags.""" for el in compile(select, namespaces, flags, **kwargs).iselect(tag, limit): yield el -def escape(ident): +def escape(ident: str) -> str: """Escape identifier.""" return cp.escape(ident) diff --git a/lib/soupsieve/__meta__.py b/lib/soupsieve/__meta__.py index eb145789..2d769fbf 100644 --- a/lib/soupsieve/__meta__.py +++ b/lib/soupsieve/__meta__.py @@ -79,7 +79,11 @@ class Version(namedtuple("Version", ["major", "minor", "micro", "release", "pre" """ - def __new__(cls, major, minor, micro, release="final", pre=0, post=0, dev=0): + def __new__( + cls, + major: int, minor: int, micro: int, release: str = "final", + pre: int = 0, post: int = 0, dev: int = 0 + ) -> "Version": """Validate version info.""" # Ensure all parts are positive integers. @@ -115,27 +119,27 @@ class Version(namedtuple("Version", ["major", "minor", "micro", "release", "pre" return super(Version, cls).__new__(cls, major, minor, micro, release, pre, post, dev) - def _is_pre(self): + def _is_pre(self) -> bool: """Is prerelease.""" - return self.pre > 0 + return bool(self.pre > 0) - def _is_dev(self): + def _is_dev(self) -> bool: """Is development.""" return bool(self.release < "alpha") - def _is_post(self): + def _is_post(self) -> bool: """Is post.""" - return self.post > 0 + return bool(self.post > 0) - def _get_dev_status(self): # pragma: no cover + def _get_dev_status(self) -> str: # pragma: no cover """Get development status string.""" return DEV_STATUS[self.release] - def _get_canonical(self): + def _get_canonical(self) -> str: """Get the canonical output string.""" # Assemble major, minor, micro version and append `pre`, `post`, or `dev` if needed.. @@ -153,7 +157,7 @@ class Version(namedtuple("Version", ["major", "minor", "micro", "release", "pre" return ver -def parse_version(ver): +def parse_version(ver: str) -> Version: """Parse version into a comparable Version tuple.""" m = RE_VER.match(ver) @@ -188,5 +192,5 @@ def parse_version(ver): return Version(major, minor, micro, release, pre, post, dev) -__version_info__ = Version(2, 2, 1, "final") +__version_info__ = Version(2, 3, 1, "final") __version__ = __version_info__._get_canonical() diff --git a/lib/soupsieve/css_match.py b/lib/soupsieve/css_match.py index a9eeaad2..79bb8707 100644 --- a/lib/soupsieve/css_match.py +++ b/lib/soupsieve/css_match.py @@ -2,11 +2,10 @@ from datetime import datetime from . import util import re -from .import css_types as ct +from . import css_types as ct import unicodedata -from collections.abc import Sequence - -import bs4 +import bs4 # type: ignore[import] +from typing import Iterator, Iterable, List, Any, Optional, Tuple, Union, Dict, Callable, Sequence, cast # Empty tag pattern (whitespace okay) RE_NOT_EMPTY = re.compile('[^ \t\r\n\f]') @@ -56,7 +55,7 @@ FEB_LEAP_MONTH = 29 DAYS_IN_WEEK = 7 -class _FakeParent(object): +class _FakeParent: """ Fake parent class. @@ -65,22 +64,22 @@ class _FakeParent(object): fake parent so we can traverse the root element as a child. """ - def __init__(self, element): + def __init__(self, element: 'bs4.Tag') -> None: """Initialize.""" self.contents = [element] - def __len__(self): + def __len__(self) -> 'bs4.PageElement': """Length.""" return len(self.contents) -class _DocumentNav(object): +class _DocumentNav: """Navigate a Beautiful Soup document.""" @classmethod - def assert_valid_input(cls, tag): + def assert_valid_input(cls, tag: Any) -> None: """Check if valid input tag or document.""" # Fail on unexpected types. @@ -88,64 +87,67 @@ class _DocumentNav(object): raise TypeError("Expected a BeautifulSoup 'Tag', but instead recieved type {}".format(type(tag))) @staticmethod - def is_doc(obj): + def is_doc(obj: 'bs4.Tag') -> bool: """Is `BeautifulSoup` object.""" return isinstance(obj, bs4.BeautifulSoup) @staticmethod - def is_tag(obj): + def is_tag(obj: 'bs4.PageElement') -> bool: """Is tag.""" return isinstance(obj, bs4.Tag) @staticmethod - def is_declaration(obj): # pragma: no cover + def is_declaration(obj: 'bs4.PageElement') -> bool: # pragma: no cover """Is declaration.""" return isinstance(obj, bs4.Declaration) @staticmethod - def is_cdata(obj): + def is_cdata(obj: 'bs4.PageElement') -> bool: """Is CDATA.""" return isinstance(obj, bs4.CData) @staticmethod - def is_processing_instruction(obj): # pragma: no cover + def is_processing_instruction(obj: 'bs4.PageElement') -> bool: # pragma: no cover """Is processing instruction.""" return isinstance(obj, bs4.ProcessingInstruction) @staticmethod - def is_navigable_string(obj): + def is_navigable_string(obj: 'bs4.PageElement') -> bool: """Is navigable string.""" return isinstance(obj, bs4.NavigableString) @staticmethod - def is_special_string(obj): + def is_special_string(obj: 'bs4.PageElement') -> bool: """Is special string.""" return isinstance(obj, (bs4.Comment, bs4.Declaration, bs4.CData, bs4.ProcessingInstruction, bs4.Doctype)) @classmethod - def is_content_string(cls, obj): + def is_content_string(cls, obj: 'bs4.PageElement') -> bool: """Check if node is content string.""" return cls.is_navigable_string(obj) and not cls.is_special_string(obj) @staticmethod - def create_fake_parent(el): + def create_fake_parent(el: 'bs4.Tag') -> _FakeParent: """Create fake parent for a given element.""" return _FakeParent(el) @staticmethod - def is_xml_tree(el): + def is_xml_tree(el: 'bs4.Tag') -> bool: """Check if element (or document) is from a XML tree.""" - return el._is_xml + return bool(el._is_xml) - def is_iframe(self, el): + def is_iframe(self, el: 'bs4.Tag') -> bool: """Check if element is an `iframe`.""" - return ((el.name if self.is_xml_tree(el) else util.lower(el.name)) == 'iframe') and self.is_html_tag(el) + return bool( + ((el.name if self.is_xml_tree(el) else util.lower(el.name)) == 'iframe') and + self.is_html_tag(el) # type: ignore[attr-defined] + ) - def is_root(self, el): + def is_root(self, el: 'bs4.Tag') -> bool: """ Return whether element is a root element. @@ -153,19 +155,26 @@ class _DocumentNav(object): and we check if it is the root element under an `iframe`. """ - root = self.root and self.root is el + root = self.root and self.root is el # type: ignore[attr-defined] if not root: parent = self.get_parent(el) - root = parent is not None and self.is_html and self.is_iframe(parent) + root = parent is not None and self.is_html and self.is_iframe(parent) # type: ignore[attr-defined] return root - def get_contents(self, el, no_iframe=False): + def get_contents(self, el: 'bs4.Tag', no_iframe: bool = False) -> Iterator['bs4.PageElement']: """Get contents or contents in reverse.""" if not no_iframe or not self.is_iframe(el): for content in el.contents: yield content - def get_children(self, el, start=None, reverse=False, tags=True, no_iframe=False): + def get_children( + self, + el: 'bs4.Tag', + start: Optional[int] = None, + reverse: bool = False, + tags: bool = True, + no_iframe: bool = False + ) -> Iterator['bs4.PageElement']: """Get children.""" if not no_iframe or not self.is_iframe(el): @@ -184,7 +193,12 @@ class _DocumentNav(object): if not tags or self.is_tag(node): yield node - def get_descendants(self, el, tags=True, no_iframe=False): + def get_descendants( + self, + el: 'bs4.Tag', + tags: bool = True, + no_iframe: bool = False + ) -> Iterator['bs4.PageElement']: """Get descendants.""" if not no_iframe or not self.is_iframe(el): @@ -215,7 +229,7 @@ class _DocumentNav(object): if not tags or is_tag: yield child - def get_parent(self, el, no_iframe=False): + def get_parent(self, el: 'bs4.Tag', no_iframe: bool = False) -> 'bs4.Tag': """Get parent.""" parent = el.parent @@ -224,25 +238,25 @@ class _DocumentNav(object): return parent @staticmethod - def get_tag_name(el): + def get_tag_name(el: 'bs4.Tag') -> Optional[str]: """Get tag.""" - return el.name + return cast(Optional[str], el.name) @staticmethod - def get_prefix_name(el): + def get_prefix_name(el: 'bs4.Tag') -> Optional[str]: """Get prefix.""" - return el.prefix + return cast(Optional[str], el.prefix) @staticmethod - def get_uri(el): + def get_uri(el: 'bs4.Tag') -> Optional[str]: """Get namespace `URI`.""" - return el.namespace + return cast(Optional[str], el.namespace) @classmethod - def get_next(cls, el, tags=True): + def get_next(cls, el: 'bs4.Tag', tags: bool = True) -> 'bs4.PageElement': """Get next sibling tag.""" sibling = el.next_sibling @@ -251,7 +265,7 @@ class _DocumentNav(object): return sibling @classmethod - def get_previous(cls, el, tags=True): + def get_previous(cls, el: 'bs4.Tag', tags: bool = True) -> 'bs4.PageElement': """Get previous sibling tag.""" sibling = el.previous_sibling @@ -260,7 +274,7 @@ class _DocumentNav(object): return sibling @staticmethod - def has_html_ns(el): + def has_html_ns(el: 'bs4.Tag') -> bool: """ Check if element has an HTML namespace. @@ -269,16 +283,16 @@ class _DocumentNav(object): """ ns = getattr(el, 'namespace') if el else None - return ns and ns == NS_XHTML + return bool(ns and ns == NS_XHTML) @staticmethod - def split_namespace(el, attr_name): + def split_namespace(el: 'bs4.Tag', attr_name: str) -> Tuple[Optional[str], Optional[str]]: """Return namespace and attribute name without the prefix.""" return getattr(attr_name, 'namespace', None), getattr(attr_name, 'name', None) @classmethod - def normalize_value(cls, value): + def normalize_value(cls, value: Any) -> Union[str, Sequence[str]]: """Normalize the value to be a string or list of strings.""" # Treat `None` as empty string. @@ -297,20 +311,26 @@ class _DocumentNav(object): if isinstance(value, Sequence): new_value = [] for v in value: - if isinstance(v, Sequence): - # This is most certainly a user error and will crash and burn later, - # but to avoid excessive recursion, kick out now. - new_value.append(v) + if not isinstance(v, (str, bytes)) and isinstance(v, Sequence): + # This is most certainly a user error and will crash and burn later. + # To keep things working, we'll do what we do with all objects, + # And convert them to strings. + new_value.append(str(v)) else: # Convert the child to a string - new_value.append(cls.normalize_value(v)) + new_value.append(cast(str, cls.normalize_value(v))) return new_value # Try and make anything else a string return str(value) @classmethod - def get_attribute_by_name(cls, el, name, default=None): + def get_attribute_by_name( + cls, + el: 'bs4.Tag', + name: str, + default: Optional[Union[str, Sequence[str]]] = None + ) -> Optional[Union[str, Sequence[str]]]: """Get attribute by name.""" value = default @@ -327,39 +347,39 @@ class _DocumentNav(object): return value @classmethod - def iter_attributes(cls, el): + def iter_attributes(cls, el: 'bs4.Tag') -> Iterator[Tuple[str, Optional[Union[str, Sequence[str]]]]]: """Iterate attributes.""" for k, v in el.attrs.items(): yield k, cls.normalize_value(v) @classmethod - def get_classes(cls, el): + def get_classes(cls, el: 'bs4.Tag') -> Sequence[str]: """Get classes.""" classes = cls.get_attribute_by_name(el, 'class', []) if isinstance(classes, str): classes = RE_NOT_WS.findall(classes) - return classes + return cast(Sequence[str], classes) - def get_text(self, el, no_iframe=False): + def get_text(self, el: 'bs4.Tag', no_iframe: bool = False) -> str: """Get text.""" return ''.join( [node for node in self.get_descendants(el, tags=False, no_iframe=no_iframe) if self.is_content_string(node)] ) - def get_own_text(self, el, no_iframe=False): + def get_own_text(self, el: 'bs4.Tag', no_iframe: bool = False) -> List[str]: """Get Own Text.""" return [node for node in self.get_contents(el, no_iframe=no_iframe) if self.is_content_string(node)] -class Inputs(object): +class Inputs: """Class for parsing and validating input items.""" @staticmethod - def validate_day(year, month, day): + def validate_day(year: int, month: int, day: int) -> bool: """Validate day.""" max_days = LONG_MONTH @@ -370,7 +390,7 @@ class Inputs(object): return 1 <= day <= max_days @staticmethod - def validate_week(year, week): + def validate_week(year: int, week: int) -> bool: """Validate week.""" max_week = datetime.strptime("{}-{}-{}".format(12, 31, year), "%m-%d-%Y").isocalendar()[1] @@ -379,34 +399,36 @@ class Inputs(object): return 1 <= week <= max_week @staticmethod - def validate_month(month): + def validate_month(month: int) -> bool: """Validate month.""" return 1 <= month <= 12 @staticmethod - def validate_year(year): + def validate_year(year: int) -> bool: """Validate year.""" return 1 <= year @staticmethod - def validate_hour(hour): + def validate_hour(hour: int) -> bool: """Validate hour.""" return 0 <= hour <= 23 @staticmethod - def validate_minutes(minutes): + def validate_minutes(minutes: int) -> bool: """Validate minutes.""" return 0 <= minutes <= 59 @classmethod - def parse_value(cls, itype, value): + def parse_value(cls, itype: str, value: Optional[str]) -> Optional[Tuple[float, ...]]: """Parse the input value.""" - parsed = None + parsed = None # type: Optional[Tuple[float, ...]] + if value is None: + return value if itype == "date": m = RE_DATE.match(value) if m: @@ -452,23 +474,29 @@ class Inputs(object): elif itype in ("number", "range"): m = RE_NUM.match(value) if m: - parsed = float(m.group('value')) + parsed = (float(m.group('value')),) return parsed -class _Match(object): +class CSSMatch(_DocumentNav): """Perform CSS matching.""" - def __init__(self, selectors, scope, namespaces, flags): + def __init__( + self, + selectors: ct.SelectorList, + scope: 'bs4.Tag', + namespaces: Optional[ct.Namespaces], + flags: int + ) -> None: """Initialize.""" self.assert_valid_input(scope) self.tag = scope - self.cached_meta_lang = [] - self.cached_default_forms = [] - self.cached_indeterminate_forms = [] + self.cached_meta_lang = [] # type: List[Tuple[str, str]] + self.cached_default_forms = [] # type: List[Tuple['bs4.Tag', 'bs4.Tag']] + self.cached_indeterminate_forms = [] # type: List[Tuple['bs4.Tag', str, bool]] self.selectors = selectors - self.namespaces = {} if namespaces is None else namespaces + self.namespaces = {} if namespaces is None else namespaces # type: Union[ct.Namespaces, Dict[str, str]] self.flags = flags self.iframe_restrict = False @@ -494,12 +522,12 @@ class _Match(object): self.is_xml = self.is_xml_tree(doc) self.is_html = not self.is_xml or self.has_html_namespace - def supports_namespaces(self): + def supports_namespaces(self) -> bool: """Check if namespaces are supported in the HTML type.""" return self.is_xml or self.has_html_namespace - def get_tag_ns(self, el): + def get_tag_ns(self, el: 'bs4.Tag') -> str: """Get tag namespace.""" if self.supports_namespaces(): @@ -511,24 +539,24 @@ class _Match(object): namespace = NS_XHTML return namespace - def is_html_tag(self, el): + def is_html_tag(self, el: 'bs4.Tag') -> bool: """Check if tag is in HTML namespace.""" return self.get_tag_ns(el) == NS_XHTML - def get_tag(self, el): + def get_tag(self, el: 'bs4.Tag') -> Optional[str]: """Get tag.""" name = self.get_tag_name(el) return util.lower(name) if name is not None and not self.is_xml else name - def get_prefix(self, el): + def get_prefix(self, el: 'bs4.Tag') -> Optional[str]: """Get prefix.""" prefix = self.get_prefix_name(el) return util.lower(prefix) if prefix is not None and not self.is_xml else prefix - def find_bidi(self, el): + def find_bidi(self, el: 'bs4.Tag') -> Optional[int]: """Get directionality from element text.""" for node in self.get_children(el, tags=False): @@ -564,7 +592,7 @@ class _Match(object): return ct.SEL_DIR_LTR if bidi == 'L' else ct.SEL_DIR_RTL return None - def extended_language_filter(self, lang_range, lang_tag): + def extended_language_filter(self, lang_range: str, lang_tag: str) -> bool: """Filter the language tags.""" match = True @@ -615,7 +643,12 @@ class _Match(object): return match - def match_attribute_name(self, el, attr, prefix): + def match_attribute_name( + self, + el: 'bs4.Tag', + attr: str, + prefix: Optional[str] + ) -> Optional[Union[str, Sequence[str]]]: """Match attribute name and return value if it exists.""" value = None @@ -663,13 +696,13 @@ class _Match(object): break return value - def match_namespace(self, el, tag): + def match_namespace(self, el: 'bs4.Tag', tag: ct.SelectorTag) -> bool: """Match the namespace of the element.""" match = True namespace = self.get_tag_ns(el) default_namespace = self.namespaces.get('') - tag_ns = '' if tag.prefix is None else self.namespaces.get(tag.prefix, None) + tag_ns = '' if tag.prefix is None else self.namespaces.get(tag.prefix) # We must match the default namespace if one is not provided if tag.prefix is None and (default_namespace is not None and namespace != default_namespace): match = False @@ -684,27 +717,26 @@ class _Match(object): match = False return match - def match_attributes(self, el, attributes): + def match_attributes(self, el: 'bs4.Tag', attributes: Tuple[ct.SelectorAttribute, ...]) -> bool: """Match attributes.""" match = True if attributes: for a in attributes: - value = self.match_attribute_name(el, a.attribute, a.prefix) + temp = self.match_attribute_name(el, a.attribute, a.prefix) pattern = a.xml_type_pattern if self.is_xml and a.xml_type_pattern else a.pattern - if isinstance(value, list): - value = ' '.join(value) - if value is None: + if temp is None: match = False break - elif pattern is None: + value = temp if isinstance(temp, str) else ' '.join(temp) + if pattern is None: continue elif pattern.match(value) is None: match = False break return match - def match_tagname(self, el, tag): + def match_tagname(self, el: 'bs4.Tag', tag: ct.SelectorTag) -> bool: """Match tag name.""" name = (util.lower(tag.name) if not self.is_xml and tag.name is not None else tag.name) @@ -713,7 +745,7 @@ class _Match(object): name not in (self.get_tag(el), '*') ) - def match_tag(self, el, tag): + def match_tag(self, el: 'bs4.Tag', tag: Optional[ct.SelectorTag]) -> bool: """Match the tag.""" match = True @@ -725,10 +757,14 @@ class _Match(object): match = False return match - def match_past_relations(self, el, relation): + def match_past_relations(self, el: 'bs4.Tag', relation: ct.SelectorList) -> bool: """Match past relationship.""" found = False + # I don't think this can ever happen, but it makes `mypy` happy + if isinstance(relation[0], ct.SelectorNull): # pragma: no cover + return found + if relation[0].rel_type == REL_PARENT: parent = self.get_parent(el, no_iframe=self.iframe_restrict) while not found and parent: @@ -749,21 +785,28 @@ class _Match(object): found = self.match_selectors(sibling, relation) return found - def match_future_child(self, parent, relation, recursive=False): + def match_future_child(self, parent: 'bs4.Tag', relation: ct.SelectorList, recursive: bool = False) -> bool: """Match future child.""" match = False - children = self.get_descendants if recursive else self.get_children + if recursive: + children = self.get_descendants # type: Callable[..., Iterator['bs4.Tag']] + else: + children = self.get_children for child in children(parent, no_iframe=self.iframe_restrict): match = self.match_selectors(child, relation) if match: break return match - def match_future_relations(self, el, relation): + def match_future_relations(self, el: 'bs4.Tag', relation: ct.SelectorList) -> bool: """Match future relationship.""" found = False + # I don't think this can ever happen, but it makes `mypy` happy + if isinstance(relation[0], ct.SelectorNull): # pragma: no cover + return found + if relation[0].rel_type == REL_HAS_PARENT: found = self.match_future_child(el, relation, True) elif relation[0].rel_type == REL_HAS_CLOSE_PARENT: @@ -779,11 +822,14 @@ class _Match(object): found = self.match_selectors(sibling, relation) return found - def match_relations(self, el, relation): + def match_relations(self, el: 'bs4.Tag', relation: ct.SelectorList) -> bool: """Match relationship to other elements.""" found = False + if isinstance(relation[0], ct.SelectorNull) or relation[0].rel_type is None: + return found + if relation[0].rel_type.startswith(':'): found = self.match_future_relations(el, relation) else: @@ -791,7 +837,7 @@ class _Match(object): return found - def match_id(self, el, ids): + def match_id(self, el: 'bs4.Tag', ids: Tuple[str, ...]) -> bool: """Match element's ID.""" found = True @@ -801,7 +847,7 @@ class _Match(object): break return found - def match_classes(self, el, classes): + def match_classes(self, el: 'bs4.Tag', classes: Tuple[str, ...]) -> bool: """Match element's classes.""" current_classes = self.get_classes(el) @@ -812,7 +858,7 @@ class _Match(object): break return found - def match_root(self, el): + def match_root(self, el: 'bs4.Tag') -> bool: """Match element as root.""" is_root = self.is_root(el) @@ -838,12 +884,12 @@ class _Match(object): sibling = self.get_next(sibling, tags=False) return is_root - def match_scope(self, el): + def match_scope(self, el: 'bs4.Tag') -> bool: """Match element as scope.""" return self.scope is el - def match_nth_tag_type(self, el, child): + def match_nth_tag_type(self, el: 'bs4.Tag', child: 'bs4.Tag') -> bool: """Match tag type for `nth` matches.""" return( @@ -851,7 +897,7 @@ class _Match(object): (self.get_tag_ns(child) == self.get_tag_ns(el)) ) - def match_nth(self, el, nth): + def match_nth(self, el: 'bs4.Tag', nth: 'bs4.Tag') -> bool: """Match `nth` elements.""" matched = True @@ -952,7 +998,7 @@ class _Match(object): break return matched - def match_empty(self, el): + def match_empty(self, el: 'bs4.Tag') -> bool: """Check if element is empty (if requested).""" is_empty = True @@ -965,7 +1011,7 @@ class _Match(object): break return is_empty - def match_subselectors(self, el, selectors): + def match_subselectors(self, el: 'bs4.Tag', selectors: Tuple[ct.SelectorList, ...]) -> bool: """Match selectors.""" match = True @@ -974,11 +1020,11 @@ class _Match(object): match = False return match - def match_contains(self, el, contains): + def match_contains(self, el: 'bs4.Tag', contains: Tuple[ct.SelectorContains, ...]) -> bool: """Match element if it contains text.""" match = True - content = None + content = None # type: Optional[Union[str, Sequence[str]]] for contain_list in contains: if content is None: if contain_list.own: @@ -1002,7 +1048,7 @@ class _Match(object): match = False return match - def match_default(self, el): + def match_default(self, el: 'bs4.Tag') -> bool: """Match default.""" match = False @@ -1035,19 +1081,19 @@ class _Match(object): if name in ('input', 'button'): v = self.get_attribute_by_name(child, 'type', '') if v and util.lower(v) == 'submit': - self.cached_default_forms.append([form, child]) + self.cached_default_forms.append((form, child)) if el is child: match = True break return match - def match_indeterminate(self, el): + def match_indeterminate(self, el: 'bs4.Tag') -> bool: """Match default.""" match = False - name = self.get_attribute_by_name(el, 'name') + name = cast(str, self.get_attribute_by_name(el, 'name')) - def get_parent_form(el): + def get_parent_form(el: 'bs4.Tag') -> Optional['bs4.Tag']: """Find this input's form.""" form = None parent = self.get_parent(el, no_iframe=True) @@ -1098,11 +1144,11 @@ class _Match(object): break if not checked: match = True - self.cached_indeterminate_forms.append([form, name, match]) + self.cached_indeterminate_forms.append((form, name, match)) return match - def match_lang(self, el, langs): + def match_lang(self, el: 'bs4.Tag', langs: Tuple[ct.SelectorLang, ...]) -> bool: """Match languages.""" match = False @@ -1169,26 +1215,26 @@ class _Match(object): content = v if c_lang and content: found_lang = content - self.cached_meta_lang.append((root, found_lang)) + self.cached_meta_lang.append((cast(str, root), cast(str, found_lang))) break if found_lang: break if not found_lang: - self.cached_meta_lang.append((root, False)) + self.cached_meta_lang.append((cast(str, root), '')) # If we determined a language, compare. if found_lang: for patterns in langs: match = False for pattern in patterns: - if self.extended_language_filter(pattern, found_lang): + if self.extended_language_filter(pattern, cast(str, found_lang)): match = True if not match: break return match - def match_dir(self, el, directionality): + def match_dir(self, el: 'bs4.Tag', directionality: int) -> bool: """Check directionality.""" # If we have to match both left and right, we can't match either. @@ -1220,13 +1266,13 @@ class _Match(object): # Auto handling for text inputs if ((is_input and itype in ('text', 'search', 'tel', 'url', 'email')) or is_textarea) and direction == 0: if is_textarea: - value = [] + temp = [] for node in self.get_contents(el, no_iframe=True): if self.is_content_string(node): - value.append(node) - value = ''.join(value) + temp.append(node) + value = ''.join(temp) else: - value = self.get_attribute_by_name(el, 'value', '') + value = cast(str, self.get_attribute_by_name(el, 'value', '')) if value: for c in value: bidi = unicodedata.bidirectional(c) @@ -1251,7 +1297,7 @@ class _Match(object): # Match parents direction return self.match_dir(self.get_parent(el, no_iframe=True), directionality) - def match_range(self, el, condition): + def match_range(self, el: 'bs4.Tag', condition: int) -> bool: """ Match range. @@ -1264,20 +1310,14 @@ class _Match(object): out_of_range = False itype = util.lower(self.get_attribute_by_name(el, 'type')) - mn = self.get_attribute_by_name(el, 'min', None) - if mn is not None: - mn = Inputs.parse_value(itype, mn) - mx = self.get_attribute_by_name(el, 'max', None) - if mx is not None: - mx = Inputs.parse_value(itype, mx) + mn = Inputs.parse_value(itype, cast(str, self.get_attribute_by_name(el, 'min', None))) + mx = Inputs.parse_value(itype, cast(str, self.get_attribute_by_name(el, 'max', None))) # There is no valid min or max, so we cannot evaluate a range if mn is None and mx is None: return False - value = self.get_attribute_by_name(el, 'value', None) - if value is not None: - value = Inputs.parse_value(itype, value) + value = Inputs.parse_value(itype, cast(str, self.get_attribute_by_name(el, 'value', None))) if value is not None: if itype in ("date", "datetime-local", "month", "week", "number", "range"): if mn is not None and value < mn: @@ -1297,7 +1337,7 @@ class _Match(object): return not out_of_range if condition & ct.SEL_IN_RANGE else out_of_range - def match_defined(self, el): + def match_defined(self, el: 'bs4.Tag') -> bool: """ Match defined. @@ -1313,12 +1353,14 @@ class _Match(object): name = self.get_tag(el) return ( - name.find('-') == -1 or - name.find(':') != -1 or - self.get_prefix(el) is not None + name is not None and ( + name.find('-') == -1 or + name.find(':') != -1 or + self.get_prefix(el) is not None + ) ) - def match_placeholder_shown(self, el): + def match_placeholder_shown(self, el: 'bs4.Tag') -> bool: """ Match placeholder shown according to HTML spec. @@ -1333,7 +1375,7 @@ class _Match(object): return match - def match_selectors(self, el, selectors): + def match_selectors(self, el: 'bs4.Tag', selectors: ct.SelectorList) -> bool: """Check if element matches one of the selectors.""" match = False @@ -1405,7 +1447,7 @@ class _Match(object): if selector.flags & DIR_FLAGS and not self.match_dir(el, selector.flags & DIR_FLAGS): continue # Validate that the tag contains the specified text. - if not self.match_contains(el, selector.contains): + if selector.contains and not self.match_contains(el, selector.contains): continue match = not is_not break @@ -1417,21 +1459,20 @@ class _Match(object): return match - def select(self, limit=0): + def select(self, limit: int = 0) -> Iterator['bs4.Tag']: """Match all tags under the targeted tag.""" - if limit < 1: - limit = None + lim = None if limit < 1 else limit for child in self.get_descendants(self.tag): if self.match(child): yield child - if limit is not None: - limit -= 1 - if limit < 1: + if lim is not None: + lim -= 1 + if lim < 1: break - def closest(self): + def closest(self) -> Optional['bs4.Tag']: """Match closest ancestor.""" current = self.tag @@ -1443,30 +1484,39 @@ class _Match(object): current = self.get_parent(current) return closest - def filter(self): # noqa A001 + def filter(self) -> List['bs4.Tag']: # noqa A001 """Filter tag's children.""" return [tag for tag in self.get_contents(self.tag) if not self.is_navigable_string(tag) and self.match(tag)] - def match(self, el): + def match(self, el: 'bs4.Tag') -> bool: """Match.""" return not self.is_doc(el) and self.is_tag(el) and self.match_selectors(el, self.selectors) -class CSSMatch(_DocumentNav, _Match): - """The Beautiful Soup CSS match class.""" - - class SoupSieve(ct.Immutable): """Compiled Soup Sieve selector matching object.""" + pattern: str + selectors: ct.SelectorList + namespaces: Optional[ct.Namespaces] + custom: Dict[str, str] + flags: int + __slots__ = ("pattern", "selectors", "namespaces", "custom", "flags", "_hash") - def __init__(self, pattern, selectors, namespaces, custom, flags): + def __init__( + self, + pattern: str, + selectors: ct.SelectorList, + namespaces: Optional[ct.Namespaces], + custom: Optional[ct.CustomSelectors], + flags: int + ): """Initialize.""" - super(SoupSieve, self).__init__( + super().__init__( pattern=pattern, selectors=selectors, namespaces=namespaces, @@ -1474,17 +1524,17 @@ class SoupSieve(ct.Immutable): flags=flags ) - def match(self, tag): + def match(self, tag: 'bs4.Tag') -> bool: """Match.""" return CSSMatch(self.selectors, tag, self.namespaces, self.flags).match(tag) - def closest(self, tag): + def closest(self, tag: 'bs4.Tag') -> 'bs4.Tag': """Match closest ancestor.""" return CSSMatch(self.selectors, tag, self.namespaces, self.flags).closest() - def filter(self, iterable): # noqa A001 + def filter(self, iterable: Iterable['bs4.Tag']) -> List['bs4.Tag']: # noqa A001 """ Filter. @@ -1501,24 +1551,24 @@ class SoupSieve(ct.Immutable): else: return [node for node in iterable if not CSSMatch.is_navigable_string(node) and self.match(node)] - def select_one(self, tag): + def select_one(self, tag: 'bs4.Tag') -> 'bs4.Tag': """Select a single tag.""" tags = self.select(tag, limit=1) return tags[0] if tags else None - def select(self, tag, limit=0): + def select(self, tag: 'bs4.Tag', limit: int = 0) -> List['bs4.Tag']: """Select the specified tags.""" return list(self.iselect(tag, limit)) - def iselect(self, tag, limit=0): + def iselect(self, tag: 'bs4.Tag', limit: int = 0) -> Iterator['bs4.Tag']: """Iterate the specified tags.""" for el in CSSMatch(self.selectors, tag, self.namespaces, self.flags).select(limit): yield el - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover """Representation.""" return "SoupSieve(pattern={!r}, namespaces={!r}, custom={!r}, flags={!r})".format( diff --git a/lib/soupsieve/css_parser.py b/lib/soupsieve/css_parser.py index 462aa947..0536b80f 100644 --- a/lib/soupsieve/css_parser.py +++ b/lib/soupsieve/css_parser.py @@ -6,6 +6,7 @@ from . import css_match as cm from . import css_types as ct from .util import SelectorSyntaxError import warnings +from typing import Optional, Dict, Match, Tuple, Type, Any, List, Union, Iterator, cast UNICODE_REPLACEMENT_CHAR = 0xFFFD @@ -196,32 +197,42 @@ FLG_OPEN = 0x40 FLG_IN_RANGE = 0x80 FLG_OUT_OF_RANGE = 0x100 FLG_PLACEHOLDER_SHOWN = 0x200 +FLG_FORGIVE = 0x400 # Maximum cached patterns to store _MAXCACHE = 500 @lru_cache(maxsize=_MAXCACHE) -def _cached_css_compile(pattern, namespaces, custom, flags): +def _cached_css_compile( + pattern: str, + namespaces: Optional[ct.Namespaces], + custom: Optional[ct.CustomSelectors], + flags: int +) -> cm.SoupSieve: """Cached CSS compile.""" custom_selectors = process_custom(custom) return cm.SoupSieve( pattern, - CSSParser(pattern, custom=custom_selectors, flags=flags).process_selectors(), + CSSParser( + pattern, + custom=custom_selectors, + flags=flags + ).process_selectors(), namespaces, custom, flags ) -def _purge_cache(): +def _purge_cache() -> None: """Purge the cache.""" _cached_css_compile.cache_clear() -def process_custom(custom): +def process_custom(custom: Optional[ct.CustomSelectors]) -> Dict[str, Union[str, ct.SelectorList]]: """Process custom.""" custom_selectors = {} @@ -236,14 +247,14 @@ def process_custom(custom): return custom_selectors -def css_unescape(content, string=False): +def css_unescape(content: str, string: bool = False) -> str: """ Unescape CSS value. Strings allow for spanning the value on multiple strings by escaping a new line. """ - def replace(m): + def replace(m: Match[str]) -> str: """Replace with the appropriate substitute.""" if m.group(1): @@ -263,7 +274,7 @@ def css_unescape(content, string=False): return (RE_CSS_ESC if not string else RE_CSS_STR_ESC).sub(replace, content) -def escape(ident): +def escape(ident: str) -> str: """Escape identifier.""" string = [] @@ -291,21 +302,21 @@ def escape(ident): return ''.join(string) -class SelectorPattern(object): +class SelectorPattern: """Selector pattern.""" - def __init__(self, name, pattern): + def __init__(self, name: str, pattern: str) -> None: """Initialize.""" self.name = name self.re_pattern = re.compile(pattern, re.I | re.X | re.U) - def get_name(self): + def get_name(self) -> str: """Get name.""" return self.name - def match(self, selector, index, flags): + def match(self, selector: str, index: int, flags: int) -> Optional[Match[str]]: """Match the selector.""" return self.re_pattern.match(selector, index) @@ -314,7 +325,7 @@ class SelectorPattern(object): class SpecialPseudoPattern(SelectorPattern): """Selector pattern.""" - def __init__(self, patterns): + def __init__(self, patterns: Tuple[Tuple[str, Tuple[str, ...], str, Type[SelectorPattern]], ...]) -> None: """Initialize.""" self.patterns = {} @@ -324,15 +335,15 @@ class SpecialPseudoPattern(SelectorPattern): for pseudo in p[1]: self.patterns[pseudo] = pattern - self.matched_name = None + self.matched_name = None # type: Optional[SelectorPattern] self.re_pseudo_name = re.compile(PAT_PSEUDO_CLASS_SPECIAL, re.I | re.X | re.U) - def get_name(self): + def get_name(self) -> str: """Get name.""" - return self.matched_name.get_name() + return '' if self.matched_name is None else self.matched_name.get_name() - def match(self, selector, index, flags): + def match(self, selector: str, index: int, flags: int) -> Optional[Match[str]]: """Match the selector.""" pseudo = None @@ -348,7 +359,7 @@ class SpecialPseudoPattern(SelectorPattern): return pseudo -class _Selector(object): +class _Selector: """ Intermediate selector class. @@ -357,23 +368,23 @@ class _Selector(object): the data in an object that can be pickled and hashed. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: """Initialize.""" - self.tag = kwargs.get('tag', None) - self.ids = kwargs.get('ids', []) - self.classes = kwargs.get('classes', []) - self.attributes = kwargs.get('attributes', []) - self.nth = kwargs.get('nth', []) - self.selectors = kwargs.get('selectors', []) - self.relations = kwargs.get('relations', []) - self.rel_type = kwargs.get('rel_type', None) - self.contains = kwargs.get('contains', []) - self.lang = kwargs.get('lang', []) - self.flags = kwargs.get('flags', 0) - self.no_match = kwargs.get('no_match', False) + self.tag = kwargs.get('tag', None) # type: Optional[ct.SelectorTag] + self.ids = kwargs.get('ids', []) # type: List[str] + self.classes = kwargs.get('classes', []) # type: List[str] + self.attributes = kwargs.get('attributes', []) # type: List[ct.SelectorAttribute] + self.nth = kwargs.get('nth', []) # type: List[ct.SelectorNth] + self.selectors = kwargs.get('selectors', []) # type: List[ct.SelectorList] + self.relations = kwargs.get('relations', []) # type: List[_Selector] + self.rel_type = kwargs.get('rel_type', None) # type: Optional[str] + self.contains = kwargs.get('contains', []) # type: List[ct.SelectorContains] + self.lang = kwargs.get('lang', []) # type: List[ct.SelectorLang] + self.flags = kwargs.get('flags', 0) # type: int + self.no_match = kwargs.get('no_match', False) # type: bool - def _freeze_relations(self, relations): + def _freeze_relations(self, relations: List['_Selector']) -> ct.SelectorList: """Freeze relation.""" if relations: @@ -383,7 +394,7 @@ class _Selector(object): else: return ct.SelectorList() - def freeze(self): + def freeze(self) -> Union[ct.Selector, ct.SelectorNull]: """Freeze self.""" if self.no_match: @@ -403,7 +414,7 @@ class _Selector(object): self.flags ) - def __str__(self): # pragma: no cover + def __str__(self) -> str: # pragma: no cover """String representation.""" return ( @@ -417,7 +428,7 @@ class _Selector(object): __repr__ = __str__ -class CSSParser(object): +class CSSParser: """Parse CSS selectors.""" css_tokens = ( @@ -447,7 +458,12 @@ class CSSParser(object): SelectorPattern("combine", PAT_COMBINE) ) - def __init__(self, selector, custom=None, flags=0): + def __init__( + self, + selector: str, + custom: Optional[Dict[str, Union[str, ct.SelectorList]]] = None, + flags: int = 0 + ) -> None: """Initialize.""" self.pattern = selector.replace('\x00', '\ufffd') @@ -455,7 +471,7 @@ class CSSParser(object): self.debug = self.flags & util.DEBUG self.custom = {} if custom is None else custom - def parse_attribute_selector(self, sel, m, has_selector): + def parse_attribute_selector(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Create attribute selector from the returned regex match.""" inverse = False @@ -465,22 +481,22 @@ class CSSParser(object): attr = css_unescape(m.group('attr_name')) is_type = False pattern2 = None + value = '' if case: - flags = re.I if case == 'i' else 0 + flags = (re.I if case == 'i' else 0) | re.DOTALL elif util.lower(attr) == 'type': - flags = re.I + flags = re.I | re.DOTALL is_type = True else: - flags = 0 + flags = re.DOTALL if op: if m.group('value').startswith(('"', "'")): value = css_unescape(m.group('value')[1:-1], True) else: value = css_unescape(m.group('value')) - else: - value = None + if not op: # Attribute name pattern = None @@ -525,7 +541,7 @@ class CSSParser(object): has_selector = True return has_selector - def parse_tag_pattern(self, sel, m, has_selector): + def parse_tag_pattern(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse tag pattern from regex match.""" prefix = css_unescape(m.group('tag_ns')[:-1]) if m.group('tag_ns') else None @@ -534,7 +550,7 @@ class CSSParser(object): has_selector = True return has_selector - def parse_pseudo_class_custom(self, sel, m, has_selector): + def parse_pseudo_class_custom(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """ Parse custom pseudo class alias. @@ -552,7 +568,7 @@ class CSSParser(object): ) if not isinstance(selector, ct.SelectorList): - self.custom[pseudo] = None + del self.custom[pseudo] selector = CSSParser( selector, custom=self.custom, flags=self.flags ).process_selectors(flags=FLG_PSEUDO) @@ -562,7 +578,14 @@ class CSSParser(object): has_selector = True return has_selector - def parse_pseudo_class(self, sel, m, has_selector, iselector, is_html): + def parse_pseudo_class( + self, + sel: _Selector, + m: Match[str], + has_selector: bool, + iselector: Iterator[Tuple[str, Match[str]]], + is_html: bool + ) -> Tuple[bool, bool]: """Parse pseudo class.""" complex_pseudo = False @@ -650,7 +673,13 @@ class CSSParser(object): return has_selector, is_html - def parse_pseudo_nth(self, sel, m, has_selector, iselector): + def parse_pseudo_nth( + self, + sel: _Selector, + m: Match[str], + has_selector: bool, + iselector: Iterator[Tuple[str, Match[str]]] + ) -> bool: """Parse `nth` pseudo.""" mdict = m.groupdict() @@ -671,23 +700,23 @@ class CSSParser(object): s2 = 1 var = True else: - nth_parts = RE_NTH.match(content) - s1 = '-' if nth_parts.group('s1') and nth_parts.group('s1') == '-' else '' + nth_parts = cast(Match[str], RE_NTH.match(content)) + _s1 = '-' if nth_parts.group('s1') and nth_parts.group('s1') == '-' else '' a = nth_parts.group('a') var = a.endswith('n') if a.startswith('n'): - s1 += '1' + _s1 += '1' elif var: - s1 += a[:-1] + _s1 += a[:-1] else: - s1 += a - s2 = '-' if nth_parts.group('s2') and nth_parts.group('s2') == '-' else '' + _s1 += a + _s2 = '-' if nth_parts.group('s2') and nth_parts.group('s2') == '-' else '' if nth_parts.group('b'): - s2 += nth_parts.group('b') + _s2 += nth_parts.group('b') else: - s2 = '0' - s1 = int(s1, 10) - s2 = int(s2, 10) + _s2 = '0' + s1 = int(_s1, 10) + s2 = int(_s2, 10) pseudo_sel = mdict['name'] if postfix == '_child': @@ -709,20 +738,38 @@ class CSSParser(object): has_selector = True return has_selector - def parse_pseudo_open(self, sel, name, has_selector, iselector, index): + def parse_pseudo_open( + self, + sel: _Selector, + name: str, + has_selector: bool, + iselector: Iterator[Tuple[str, Match[str]]], + index: int + ) -> bool: """Parse pseudo with opening bracket.""" flags = FLG_PSEUDO | FLG_OPEN if name == ':not': flags |= FLG_NOT - if name == ':has': - flags |= FLG_RELATIVE + elif name == ':has': + flags |= FLG_RELATIVE | FLG_FORGIVE + elif name in (':where', ':is'): + flags |= FLG_FORGIVE sel.selectors.append(self.parse_selectors(iselector, index, flags)) has_selector = True + return has_selector - def parse_has_combinator(self, sel, m, has_selector, selectors, rel_type, index): + def parse_has_combinator( + self, + sel: _Selector, + m: Match[str], + has_selector: bool, + selectors: List[_Selector], + rel_type: str, + index: int + ) -> Tuple[bool, _Selector, str]: """Parse combinator tokens.""" combinator = m.group('relation').strip() @@ -731,12 +778,9 @@ class CSSParser(object): if combinator == COMMA_COMBINATOR: if not has_selector: # If we've not captured any selector parts, the comma is either at the beginning of the pattern - # or following another comma, both of which are unexpected. Commas must split selectors. - raise SelectorSyntaxError( - "The combinator '{}' at postion {}, must have a selector before it".format(combinator, index), - self.pattern, - index - ) + # or following another comma, both of which are unexpected. But shouldn't fail the pseudo-class. + sel.no_match = True + sel.rel_type = rel_type selectors[-1].relations.append(sel) rel_type = ":" + WS_COMBINATOR @@ -757,44 +801,63 @@ class CSSParser(object): self.pattern, index ) + # Set the leading combinator for the next selector. rel_type = ':' + combinator - sel = _Selector() + sel = _Selector() has_selector = False return has_selector, sel, rel_type - def parse_combinator(self, sel, m, has_selector, selectors, relations, is_pseudo, index): + def parse_combinator( + self, + sel: _Selector, + m: Match[str], + has_selector: bool, + selectors: List[_Selector], + relations: List[_Selector], + is_pseudo: bool, + is_forgive: bool, + index: int + ) -> Tuple[bool, _Selector]: """Parse combinator tokens.""" combinator = m.group('relation').strip() if not combinator: combinator = WS_COMBINATOR if not has_selector: - raise SelectorSyntaxError( - "The combinator '{}' at postion {}, must have a selector before it".format(combinator, index), - self.pattern, - index - ) + if not is_forgive or combinator != COMMA_COMBINATOR: + raise SelectorSyntaxError( + "The combinator '{}' at postion {}, must have a selector before it".format(combinator, index), + self.pattern, + index + ) - if combinator == COMMA_COMBINATOR: - if not sel.tag and not is_pseudo: - # Implied `*` - sel.tag = ct.SelectorTag('*', None) - sel.relations.extend(relations) - selectors.append(sel) - del relations[:] + # If we are in a forgiving pseudo class, just make the selector a "no match" + if combinator == COMMA_COMBINATOR: + sel.no_match = True + del relations[:] + selectors.append(sel) else: - sel.relations.extend(relations) - sel.rel_type = combinator - del relations[:] - relations.append(sel) - sel = _Selector() + if combinator == COMMA_COMBINATOR: + if not sel.tag and not is_pseudo: + # Implied `*` + sel.tag = ct.SelectorTag('*', None) + sel.relations.extend(relations) + selectors.append(sel) + del relations[:] + else: + sel.relations.extend(relations) + sel.rel_type = combinator + del relations[:] + relations.append(sel) + sel = _Selector() has_selector = False + return has_selector, sel - def parse_class_id(self, sel, m, has_selector): + def parse_class_id(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse HTML classes and ids.""" selector = m.group(0) @@ -805,7 +868,7 @@ class CSSParser(object): has_selector = True return has_selector - def parse_pseudo_contains(self, sel, m, has_selector): + def parse_pseudo_contains(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse contains.""" pseudo = util.lower(css_unescape(m.group('name'))) @@ -826,11 +889,11 @@ class CSSParser(object): else: value = css_unescape(value) patterns.append(value) - sel.contains.append(ct.SelectorContains(tuple(patterns), contains_own)) + sel.contains.append(ct.SelectorContains(patterns, contains_own)) has_selector = True return has_selector - def parse_pseudo_lang(self, sel, m, has_selector): + def parse_pseudo_lang(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse pseudo language.""" values = m.group('values') @@ -851,7 +914,7 @@ class CSSParser(object): return has_selector - def parse_pseudo_dir(self, sel, m, has_selector): + def parse_pseudo_dir(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool: """Parse pseudo direction.""" value = ct.SEL_DIR_LTR if util.lower(m.group('dir')) == 'ltr' else ct.SEL_DIR_RTL @@ -859,15 +922,23 @@ class CSSParser(object): has_selector = True return has_selector - def parse_selectors(self, iselector, index=0, flags=0): + def parse_selectors( + self, + iselector: Iterator[Tuple[str, Match[str]]], + index: int = 0, + flags: int = 0 + ) -> ct.SelectorList: """Parse selectors.""" + # Initialize important variables sel = _Selector() selectors = [] has_selector = False closed = False - relations = [] + relations = [] # type: List[_Selector] rel_type = ":" + WS_COMBINATOR + + # Setup various flags is_open = bool(flags & FLG_OPEN) is_pseudo = bool(flags & FLG_PSEUDO) is_relative = bool(flags & FLG_RELATIVE) @@ -878,7 +949,9 @@ class CSSParser(object): is_in_range = bool(flags & FLG_IN_RANGE) is_out_of_range = bool(flags & FLG_OUT_OF_RANGE) is_placeholder_shown = bool(flags & FLG_PLACEHOLDER_SHOWN) + is_forgive = bool(flags & FLG_FORGIVE) + # Print out useful debug stuff if self.debug: # pragma: no cover if is_pseudo: print(' is_pseudo: True') @@ -900,7 +973,10 @@ class CSSParser(object): print(' is_out_of_range: True') if is_placeholder_shown: print(' is_placeholder_shown: True') + if is_forgive: + print(' is_forgive: True') + # The algorithm for relative selectors require an initial selector in the selector list if is_relative: selectors.append(_Selector()) @@ -929,11 +1005,13 @@ class CSSParser(object): is_html = True elif key == 'pseudo_close': if not has_selector: - raise SelectorSyntaxError( - "Expected a selector at postion {}".format(m.start(0)), - self.pattern, - m.start(0) - ) + if not is_forgive: + raise SelectorSyntaxError( + "Expected a selector at postion {}".format(m.start(0)), + self.pattern, + m.start(0) + ) + sel.no_match = True if is_open: closed = True break @@ -950,7 +1028,7 @@ class CSSParser(object): ) else: has_selector, sel = self.parse_combinator( - sel, m, has_selector, selectors, relations, is_pseudo, index + sel, m, has_selector, selectors, relations, is_pseudo, is_forgive, index ) elif key == 'attribute': has_selector = self.parse_attribute_selector(sel, m, has_selector) @@ -969,6 +1047,7 @@ class CSSParser(object): except StopIteration: pass + # Handle selectors that are not closed if is_open and not closed: raise SelectorSyntaxError( "Unclosed pseudo-class at position {}".format(index), @@ -976,6 +1055,7 @@ class CSSParser(object): index ) + # Cleanup completed selector piece if has_selector: if not sel.tag and not is_pseudo: # Implied `*` @@ -987,8 +1067,28 @@ class CSSParser(object): sel.relations.extend(relations) del relations[:] selectors.append(sel) - else: + + # Forgive empty slots in pseudo-classes that have lists (and are forgiving) + elif is_forgive: + if is_relative: + # Handle relative selectors pseudo-classes with empty slots like `:has()` + if selectors and selectors[-1].rel_type is None and rel_type == ': ': + sel.rel_type = rel_type + sel.no_match = True + selectors[-1].relations.append(sel) + has_selector = True + else: + # Handle normal pseudo-classes with empty slots + if not selectors or not relations: + # Others like `:is()` etc. + sel.no_match = True + del relations[:] + selectors.append(sel) + has_selector = True + + if not has_selector: # We will always need to finish a selector when `:has()` is used as it leads with combining. + # May apply to others as well. raise SelectorSyntaxError( 'Expected a selector at position {}'.format(index), self.pattern, @@ -1009,9 +1109,10 @@ class CSSParser(object): if is_placeholder_shown: selectors[-1].flags = ct.SEL_PLACEHOLDER_SHOWN + # Return selector list return ct.SelectorList([s.freeze() for s in selectors], is_not, is_html) - def selector_iter(self, pattern): + def selector_iter(self, pattern: str) -> Iterator[Tuple[str, Match[str]]]: """Iterate selector tokens.""" # Ignore whitespace and comments at start and end of pattern @@ -1052,7 +1153,7 @@ class CSSParser(object): if self.debug: # pragma: no cover print('## END PARSING') - def process_selectors(self, index=0, flags=0): + def process_selectors(self, index: int = 0, flags: int = 0) -> ct.SelectorList: """Process selectors.""" return self.parse_selectors(self.selector_iter(self.pattern), index, flags) diff --git a/lib/soupsieve/css_types.py b/lib/soupsieve/css_types.py index c2b9f30d..e5a6e49c 100644 --- a/lib/soupsieve/css_types.py +++ b/lib/soupsieve/css_types.py @@ -1,6 +1,7 @@ """CSS selector structure items.""" import copyreg -from collections.abc import Hashable, Mapping +from .pretty import pretty +from typing import Any, Type, Tuple, Union, Dict, Iterator, Hashable, Optional, Pattern, Iterable, Mapping __all__ = ( 'Selector', @@ -29,12 +30,14 @@ SEL_DEFINED = 0x200 SEL_PLACEHOLDER_SHOWN = 0x400 -class Immutable(object): +class Immutable: """Immutable.""" - __slots__ = ('_hash',) + __slots__: Tuple[str, ...] = ('_hash',) - def __init__(self, **kwargs): + _hash: int + + def __init__(self, **kwargs: Any) -> None: """Initialize.""" temp = [] @@ -45,12 +48,12 @@ class Immutable(object): super(Immutable, self).__setattr__('_hash', hash(tuple(temp))) @classmethod - def __base__(cls): + def __base__(cls) -> "Type[Immutable]": """Get base class.""" return cls - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: """Equal.""" return ( @@ -58,7 +61,7 @@ class Immutable(object): all([getattr(other, key) == getattr(self, key) for key in self.__slots__ if key != '_hash']) ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: """Equal.""" return ( @@ -66,63 +69,74 @@ class Immutable(object): any([getattr(other, key) != getattr(self, key) for key in self.__slots__ if key != '_hash']) ) - def __hash__(self): + def __hash__(self) -> int: """Hash.""" return self._hash - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: """Prevent mutability.""" raise AttributeError("'{}' is immutable".format(self.__class__.__name__)) - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover """Representation.""" return "{}({})".format( - self.__base__(), ', '.join(["{}={!r}".format(k, getattr(self, k)) for k in self.__slots__[:-1]]) + self.__class__.__name__, ', '.join(["{}={!r}".format(k, getattr(self, k)) for k in self.__slots__[:-1]]) ) __str__ = __repr__ + def pretty(self) -> None: # pragma: no cover + """Pretty print.""" -class ImmutableDict(Mapping): + print(pretty(self)) + + +class ImmutableDict(Mapping[Any, Any]): """Hashable, immutable dictionary.""" - def __init__(self, arg): + def __init__( + self, + arg: Union[Dict[Any, Any], Iterable[Tuple[Any, Any]]] + ) -> None: """Initialize.""" - arg - is_dict = isinstance(arg, dict) - if ( - is_dict and not all([isinstance(v, Hashable) for v in arg.values()]) or - not is_dict and not all([isinstance(k, Hashable) and isinstance(v, Hashable) for k, v in arg]) - ): - raise TypeError('All values must be hashable') - + self._validate(arg) self._d = dict(arg) self._hash = hash(tuple([(type(x), x, type(y), y) for x, y in sorted(self._d.items())])) - def __iter__(self): + def _validate(self, arg: Union[Dict[Any, Any], Iterable[Tuple[Any, Any]]]) -> None: + """Validate arguments.""" + + if isinstance(arg, dict): + if not all([isinstance(v, Hashable) for v in arg.values()]): + raise TypeError('{} values must be hashable'.format(self.__class__.__name__)) + elif not all([isinstance(k, Hashable) and isinstance(v, Hashable) for k, v in arg]): + raise TypeError('{} values must be hashable'.format(self.__class__.__name__)) + + def __iter__(self) -> Iterator[Any]: """Iterator.""" return iter(self._d) - def __len__(self): + def __len__(self) -> int: """Length.""" return len(self._d) - def __getitem__(self, key): + def __getitem__(self, key: Any) -> Any: """Get item: `namespace['key']`.""" + return self._d[key] - def __hash__(self): + def __hash__(self) -> int: """Hash.""" return self._hash - def __repr__(self): # pragma: no cover + def __repr__(self) -> str: # pragma: no cover """Representation.""" return "{!r}".format(self._d) @@ -133,37 +147,37 @@ class ImmutableDict(Mapping): class Namespaces(ImmutableDict): """Namespaces.""" - def __init__(self, arg): + def __init__(self, arg: Union[Dict[str, str], Iterable[Tuple[str, str]]]) -> None: """Initialize.""" - # If there are arguments, check the first index. - # `super` should fail if the user gave multiple arguments, - # so don't bother checking that. - is_dict = isinstance(arg, dict) - if is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg.items()]): - raise TypeError('Namespace keys and values must be Unicode strings') - elif not is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]): - raise TypeError('Namespace keys and values must be Unicode strings') + super().__init__(arg) - super(Namespaces, self).__init__(arg) + def _validate(self, arg: Union[Dict[str, str], Iterable[Tuple[str, str]]]) -> None: + """Validate arguments.""" + + if isinstance(arg, dict): + if not all([isinstance(v, str) for v in arg.values()]): + raise TypeError('{} values must be hashable'.format(self.__class__.__name__)) + elif not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]): + raise TypeError('{} keys and values must be Unicode strings'.format(self.__class__.__name__)) class CustomSelectors(ImmutableDict): """Custom selectors.""" - def __init__(self, arg): + def __init__(self, arg: Union[Dict[str, str], Iterable[Tuple[str, str]]]) -> None: """Initialize.""" - # If there are arguments, check the first index. - # `super` should fail if the user gave multiple arguments, - # so don't bother checking that. - is_dict = isinstance(arg, dict) - if is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg.items()]): - raise TypeError('CustomSelectors keys and values must be Unicode strings') - elif not is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]): - raise TypeError('CustomSelectors keys and values must be Unicode strings') + super().__init__(arg) - super(CustomSelectors, self).__init__(arg) + def _validate(self, arg: Union[Dict[str, str], Iterable[Tuple[str, str]]]) -> None: + """Validate arguments.""" + + if isinstance(arg, dict): + if not all([isinstance(v, str) for v in arg.values()]): + raise TypeError('{} values must be hashable'.format(self.__class__.__name__)) + elif not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]): + raise TypeError('{} keys and values must be Unicode strings'.format(self.__class__.__name__)) class Selector(Immutable): @@ -174,13 +188,35 @@ class Selector(Immutable): 'relation', 'rel_type', 'contains', 'lang', 'flags', '_hash' ) + tag: Optional['SelectorTag'] + ids: Tuple[str, ...] + classes: Tuple[str, ...] + attributes: Tuple['SelectorAttribute', ...] + nth: Tuple['SelectorNth', ...] + selectors: Tuple['SelectorList', ...] + relation: 'SelectorList' + rel_type: Optional[str] + contains: Tuple['SelectorContains', ...] + lang: Tuple['SelectorLang', ...] + flags: int + def __init__( - self, tag, ids, classes, attributes, nth, selectors, - relation, rel_type, contains, lang, flags + self, + tag: Optional['SelectorTag'], + ids: Tuple[str, ...], + classes: Tuple[str, ...], + attributes: Tuple['SelectorAttribute', ...], + nth: Tuple['SelectorNth', ...], + selectors: Tuple['SelectorList', ...], + relation: 'SelectorList', + rel_type: Optional[str], + contains: Tuple['SelectorContains', ...], + lang: Tuple['SelectorLang', ...], + flags: int ): """Initialize.""" - super(Selector, self).__init__( + super().__init__( tag=tag, ids=ids, classes=classes, @@ -198,10 +234,10 @@ class Selector(Immutable): class SelectorNull(Immutable): """Null Selector.""" - def __init__(self): + def __init__(self) -> None: """Initialize.""" - super(SelectorNull, self).__init__() + super().__init__() class SelectorTag(Immutable): @@ -209,13 +245,13 @@ class SelectorTag(Immutable): __slots__ = ("name", "prefix", "_hash") - def __init__(self, name, prefix): + name: str + prefix: Optional[str] + + def __init__(self, name: str, prefix: Optional[str]) -> None: """Initialize.""" - super(SelectorTag, self).__init__( - name=name, - prefix=prefix - ) + super().__init__(name=name, prefix=prefix) class SelectorAttribute(Immutable): @@ -223,10 +259,21 @@ class SelectorAttribute(Immutable): __slots__ = ("attribute", "prefix", "pattern", "xml_type_pattern", "_hash") - def __init__(self, attribute, prefix, pattern, xml_type_pattern): + attribute: str + prefix: str + pattern: Optional[Pattern[str]] + xml_type_pattern: Optional[Pattern[str]] + + def __init__( + self, + attribute: str, + prefix: str, + pattern: Optional[Pattern[str]], + xml_type_pattern: Optional[Pattern[str]] + ) -> None: """Initialize.""" - super(SelectorAttribute, self).__init__( + super().__init__( attribute=attribute, prefix=prefix, pattern=pattern, @@ -239,13 +286,13 @@ class SelectorContains(Immutable): __slots__ = ("text", "own", "_hash") - def __init__(self, text, own): + text: Tuple[str, ...] + own: bool + + def __init__(self, text: Iterable[str], own: bool) -> None: """Initialize.""" - super(SelectorContains, self).__init__( - text=text, - own=own - ) + super().__init__(text=tuple(text), own=own) class SelectorNth(Immutable): @@ -253,10 +300,17 @@ class SelectorNth(Immutable): __slots__ = ("a", "n", "b", "of_type", "last", "selectors", "_hash") - def __init__(self, a, n, b, of_type, last, selectors): + a: int + n: bool + b: int + of_type: bool + last: bool + selectors: 'SelectorList' + + def __init__(self, a: int, n: bool, b: int, of_type: bool, last: bool, selectors: 'SelectorList') -> None: """Initialize.""" - super(SelectorNth, self).__init__( + super().__init__( a=a, n=n, b=b, @@ -271,24 +325,24 @@ class SelectorLang(Immutable): __slots__ = ("languages", "_hash",) - def __init__(self, languages): + languages: Tuple[str, ...] + + def __init__(self, languages: Iterable[str]): """Initialize.""" - super(SelectorLang, self).__init__( - languages=tuple(languages) - ) + super().__init__(languages=tuple(languages)) - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Iterator.""" return iter(self.languages) - def __len__(self): # pragma: no cover + def __len__(self) -> int: # pragma: no cover """Length.""" return len(self.languages) - def __getitem__(self, index): # pragma: no cover + def __getitem__(self, index: int) -> str: # pragma: no cover """Get item.""" return self.languages[index] @@ -299,36 +353,45 @@ class SelectorList(Immutable): __slots__ = ("selectors", "is_not", "is_html", "_hash") - def __init__(self, selectors=tuple(), is_not=False, is_html=False): + selectors: Tuple[Union['Selector', 'SelectorNull'], ...] + is_not: bool + is_html: bool + + def __init__( + self, + selectors: Optional[Iterable[Union['Selector', 'SelectorNull']]] = None, + is_not: bool = False, + is_html: bool = False + ) -> None: """Initialize.""" - super(SelectorList, self).__init__( - selectors=tuple(selectors), + super().__init__( + selectors=tuple(selectors) if selectors is not None else tuple(), is_not=is_not, is_html=is_html ) - def __iter__(self): + def __iter__(self) -> Iterator[Union['Selector', 'SelectorNull']]: """Iterator.""" return iter(self.selectors) - def __len__(self): + def __len__(self) -> int: """Length.""" return len(self.selectors) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Union['Selector', 'SelectorNull']: """Get item.""" return self.selectors[index] -def _pickle(p): +def _pickle(p: Any) -> Any: return p.__base__(), tuple([getattr(p, s) for s in p.__slots__[:-1]]) -def pickle_register(obj): +def pickle_register(obj: Any) -> None: """Allow object to be pickled.""" copyreg.pickle(obj, _pickle) diff --git a/lib/soupsieve/pretty.py b/lib/soupsieve/pretty.py new file mode 100644 index 00000000..57d16c97 --- /dev/null +++ b/lib/soupsieve/pretty.py @@ -0,0 +1,137 @@ +""" +Format a pretty string of a `SoupSieve` object for easy debugging. + +This won't necessarily support all types and such, and definitely +not support custom outputs. + +It is mainly geared towards our types as the `SelectorList` +object is a beast to look at without some indentation and newlines. +The format and various output types is fairly known (though it +hasn't been tested extensively to make sure we aren't missing corners). + +Example: + +``` +>>> import soupsieve as sv +>>> sv.compile('this > that.class[name=value]').selectors.pretty() +SelectorList( + selectors=( + Selector( + tag=SelectorTag( + name='that', + prefix=None), + ids=(), + classes=( + 'class', + ), + attributes=( + SelectorAttribute( + attribute='name', + prefix='', + pattern=re.compile( + '^value$'), + xml_type_pattern=None), + ), + nth=(), + selectors=(), + relation=SelectorList( + selectors=( + Selector( + tag=SelectorTag( + name='this', + prefix=None), + ids=(), + classes=(), + attributes=(), + nth=(), + selectors=(), + relation=SelectorList( + selectors=(), + is_not=False, + is_html=False), + rel_type='>', + contains=(), + lang=(), + flags=0), + ), + is_not=False, + is_html=False), + rel_type=None, + contains=(), + lang=(), + flags=0), + ), + is_not=False, + is_html=False) +``` +""" +import re +from typing import Any + +RE_CLASS = re.compile(r'(?i)[a-z_][_a-z\d\.]+\(') +RE_PARAM = re.compile(r'(?i)[_a-z][_a-z\d]+=') +RE_EMPTY = re.compile(r'\(\)|\[\]|\{\}') +RE_LSTRT = re.compile(r'\[') +RE_DSTRT = re.compile(r'\{') +RE_TSTRT = re.compile(r'\(') +RE_LEND = re.compile(r'\]') +RE_DEND = re.compile(r'\}') +RE_TEND = re.compile(r'\)') +RE_INT = re.compile(r'\d+') +RE_KWORD = re.compile(r'(?i)[_a-z][_a-z\d]+') +RE_DQSTR = re.compile(r'"(?:\\.|[^"\\])*"') +RE_SQSTR = re.compile(r"'(?:\\.|[^'\\])*'") +RE_SEP = re.compile(r'\s*(,)\s*') +RE_DSEP = re.compile(r'\s*(:)\s*') + +TOKENS = { + 'class': RE_CLASS, + 'param': RE_PARAM, + 'empty': RE_EMPTY, + 'lstrt': RE_LSTRT, + 'dstrt': RE_DSTRT, + 'tstrt': RE_TSTRT, + 'lend': RE_LEND, + 'dend': RE_DEND, + 'tend': RE_TEND, + 'sqstr': RE_SQSTR, + 'sep': RE_SEP, + 'dsep': RE_DSEP, + 'int': RE_INT, + 'kword': RE_KWORD, + 'dqstr': RE_DQSTR +} + + +def pretty(obj: Any) -> str: # pragma: no cover + """Make the object output string pretty.""" + + sel = str(obj) + index = 0 + end = len(sel) - 1 + indent = 0 + output = [] + + while index <= end: + m = None + for k, v in TOKENS.items(): + m = v.match(sel, index) + + if m: + name = k + index = m.end(0) + if name in ('class', 'lstrt', 'dstrt', 'tstrt'): + indent += 4 + output.append('{}\n{}'.format(m.group(0), " " * indent)) + elif name in ('param', 'int', 'kword', 'sqstr', 'dqstr', 'empty'): + output.append(m.group(0)) + elif name in ('lend', 'dend', 'tend'): + indent -= 4 + output.append(m.group(0)) + elif name in ('sep',): + output.append('{}\n{}'.format(m.group(1), " " * indent)) + elif name in ('dsep',): + output.append('{} '.format(m.group(1))) + break + + return ''.join(output) diff --git a/lib/soupsieve/py.typed b/lib/soupsieve/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/lib/soupsieve/util.py b/lib/soupsieve/util.py index 7f5d9f89..2b1ed24b 100644 --- a/lib/soupsieve/util.py +++ b/lib/soupsieve/util.py @@ -2,6 +2,7 @@ from functools import wraps, lru_cache import warnings import re +from typing import Callable, Any, Optional, Tuple, List DEBUG = 0x00001 @@ -12,7 +13,7 @@ UC_Z = ord('Z') @lru_cache(maxsize=512) -def lower(string): +def lower(string: str) -> str: """Lower.""" new_string = [] @@ -25,7 +26,7 @@ def lower(string): class SelectorSyntaxError(Exception): """Syntax error in a CSS selector.""" - def __init__(self, msg, pattern=None, index=None): + def __init__(self, msg: str, pattern: Optional[str] = None, index: Optional[int] = None) -> None: """Initialize.""" self.line = None @@ -37,30 +38,34 @@ class SelectorSyntaxError(Exception): self.context, self.line, self.col = get_pattern_context(pattern, index) msg = '{}\n line {}:\n{}'.format(msg, self.line, self.context) - super(SelectorSyntaxError, self).__init__(msg) + super().__init__(msg) -def deprecated(message, stacklevel=2): # pragma: no cover +def deprecated(message: str, stacklevel: int = 2) -> Callable[..., Any]: # pragma: no cover """ Raise a `DeprecationWarning` when wrapped function/method is called. - Borrowed from https://stackoverflow.com/a/48632082/866026 + Usage: + + @deprecated("This method will be removed in version X; use Y instead.") + def some_method()" + pass """ - def _decorator(func): + def _wrapper(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) - def _func(*args, **kwargs): + def _deprecated_func(*args: Any, **kwargs: Any) -> Any: warnings.warn( - "'{}' is deprecated. {}".format(func.__name__, message), + f"'{func.__name__}' is deprecated. {message}", category=DeprecationWarning, stacklevel=stacklevel ) return func(*args, **kwargs) - return _func - return _decorator + return _deprecated_func + return _wrapper -def warn_deprecated(message, stacklevel=2): # pragma: no cover +def warn_deprecated(message: str, stacklevel: int = 2) -> None: # pragma: no cover """Warn deprecated.""" warnings.warn( @@ -70,14 +75,15 @@ def warn_deprecated(message, stacklevel=2): # pragma: no cover ) -def get_pattern_context(pattern, index): +def get_pattern_context(pattern: str, index: int) -> Tuple[str, int, int]: """Get the pattern context.""" last = 0 current_line = 1 col = 1 - text = [] + text = [] # type: List[str] line = 1 + offset = None # type: Optional[int] # Split pattern by newline and handle the text before the newline for m in RE_PATTERN_LINE_SPLIT.finditer(pattern):