Update soupsieve==2.3.1

This commit is contained in:
JonnyWong16 2021-11-28 14:13:48 -08:00
parent dcfd8abddd
commit 36b55398a8
No known key found for this signature in database
GPG key ID: B1F1F9807184697A
8 changed files with 791 additions and 375 deletions

View file

@ -30,6 +30,8 @@ from . import css_parser as cp
from . import css_match as cm from . import css_match as cm
from . import css_types as ct from . import css_types as ct
from .util import DEBUG, SelectorSyntaxError # noqa: F401 from .util import DEBUG, SelectorSyntaxError # noqa: F401
import bs4 # type: ignore[import]
from typing import Dict, Optional, Any, List, Iterator, Iterable
__all__ = ( __all__ = (
'DEBUG', 'SelectorSyntaxError', 'SoupSieve', 'DEBUG', 'SelectorSyntaxError', 'SoupSieve',
@ -40,15 +42,18 @@ __all__ = (
SoupSieve = cm.SoupSieve SoupSieve = cm.SoupSieve
def compile(pattern, namespaces=None, flags=0, **kwargs): # noqa: A001 def compile( # noqa: A001
pattern: str,
namespaces: Optional[Dict[str, str]] = None,
flags: int = 0,
*,
custom: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> cm.SoupSieve:
"""Compile CSS pattern.""" """Compile CSS pattern."""
if namespaces is not None: ns = ct.Namespaces(namespaces) if namespaces is not None else namespaces # type: Optional[ct.Namespaces]
namespaces = ct.Namespaces(namespaces) cs = ct.CustomSelectors(custom) if custom is not None else custom # type: Optional[ct.CustomSelectors]
custom = kwargs.get('custom')
if custom is not None:
custom = ct.CustomSelectors(custom)
if isinstance(pattern, SoupSieve): if isinstance(pattern, SoupSieve):
if flags: if flags:
@ -59,53 +64,103 @@ def compile(pattern, namespaces=None, flags=0, **kwargs): # noqa: A001
raise ValueError("Cannot process 'custom' argument on a compiled selector list") raise ValueError("Cannot process 'custom' argument on a compiled selector list")
return pattern return pattern
return cp._cached_css_compile(pattern, namespaces, custom, flags) return cp._cached_css_compile(pattern, ns, cs, flags)
def purge(): def purge() -> None:
"""Purge cached patterns.""" """Purge cached patterns."""
cp._purge_cache() cp._purge_cache()
def closest(select, tag, namespaces=None, flags=0, **kwargs): def closest(
select: str,
tag: 'bs4.Tag',
namespaces: Optional[Dict[str, str]] = None,
flags: int = 0,
*,
custom: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> 'bs4.Tag':
"""Match closest ancestor.""" """Match closest ancestor."""
return compile(select, namespaces, flags, **kwargs).closest(tag) return compile(select, namespaces, flags, **kwargs).closest(tag)
def match(select, tag, namespaces=None, flags=0, **kwargs): def match(
select: str,
tag: 'bs4.Tag',
namespaces: Optional[Dict[str, str]] = None,
flags: int = 0,
*,
custom: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> bool:
"""Match node.""" """Match node."""
return compile(select, namespaces, flags, **kwargs).match(tag) return compile(select, namespaces, flags, **kwargs).match(tag)
def filter(select, iterable, namespaces=None, flags=0, **kwargs): # noqa: A001 def filter( # noqa: A001
select: str,
iterable: Iterable['bs4.Tag'],
namespaces: Optional[Dict[str, str]] = None,
flags: int = 0,
*,
custom: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> List['bs4.Tag']:
"""Filter list of nodes.""" """Filter list of nodes."""
return compile(select, namespaces, flags, **kwargs).filter(iterable) return compile(select, namespaces, flags, **kwargs).filter(iterable)
def select_one(select, tag, namespaces=None, flags=0, **kwargs): def select_one(
select: str,
tag: 'bs4.Tag',
namespaces: Optional[Dict[str, str]] = None,
flags: int = 0,
*,
custom: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> 'bs4.Tag':
"""Select a single tag.""" """Select a single tag."""
return compile(select, namespaces, flags, **kwargs).select_one(tag) return compile(select, namespaces, flags, **kwargs).select_one(tag)
def select(select, tag, namespaces=None, limit=0, flags=0, **kwargs): def select(
select: str,
tag: 'bs4.Tag',
namespaces: Optional[Dict[str, str]] = None,
limit: int = 0,
flags: int = 0,
*,
custom: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> List['bs4.Tag']:
"""Select the specified tags.""" """Select the specified tags."""
return compile(select, namespaces, flags, **kwargs).select(tag, limit) return compile(select, namespaces, flags, **kwargs).select(tag, limit)
def iselect(select, tag, namespaces=None, limit=0, flags=0, **kwargs): def iselect(
select: str,
tag: 'bs4.Tag',
namespaces: Optional[Dict[str, str]] = None,
limit: int = 0,
flags: int = 0,
*,
custom: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> Iterator['bs4.Tag']:
"""Iterate the specified tags.""" """Iterate the specified tags."""
for el in compile(select, namespaces, flags, **kwargs).iselect(tag, limit): for el in compile(select, namespaces, flags, **kwargs).iselect(tag, limit):
yield el yield el
def escape(ident): def escape(ident: str) -> str:
"""Escape identifier.""" """Escape identifier."""
return cp.escape(ident) return cp.escape(ident)

View file

@ -79,7 +79,11 @@ class Version(namedtuple("Version", ["major", "minor", "micro", "release", "pre"
""" """
def __new__(cls, major, minor, micro, release="final", pre=0, post=0, dev=0): def __new__(
cls,
major: int, minor: int, micro: int, release: str = "final",
pre: int = 0, post: int = 0, dev: int = 0
) -> "Version":
"""Validate version info.""" """Validate version info."""
# Ensure all parts are positive integers. # Ensure all parts are positive integers.
@ -115,27 +119,27 @@ class Version(namedtuple("Version", ["major", "minor", "micro", "release", "pre"
return super(Version, cls).__new__(cls, major, minor, micro, release, pre, post, dev) return super(Version, cls).__new__(cls, major, minor, micro, release, pre, post, dev)
def _is_pre(self): def _is_pre(self) -> bool:
"""Is prerelease.""" """Is prerelease."""
return self.pre > 0 return bool(self.pre > 0)
def _is_dev(self): def _is_dev(self) -> bool:
"""Is development.""" """Is development."""
return bool(self.release < "alpha") return bool(self.release < "alpha")
def _is_post(self): def _is_post(self) -> bool:
"""Is post.""" """Is post."""
return self.post > 0 return bool(self.post > 0)
def _get_dev_status(self): # pragma: no cover def _get_dev_status(self) -> str: # pragma: no cover
"""Get development status string.""" """Get development status string."""
return DEV_STATUS[self.release] return DEV_STATUS[self.release]
def _get_canonical(self): def _get_canonical(self) -> str:
"""Get the canonical output string.""" """Get the canonical output string."""
# Assemble major, minor, micro version and append `pre`, `post`, or `dev` if needed.. # Assemble major, minor, micro version and append `pre`, `post`, or `dev` if needed..
@ -153,7 +157,7 @@ class Version(namedtuple("Version", ["major", "minor", "micro", "release", "pre"
return ver return ver
def parse_version(ver): def parse_version(ver: str) -> Version:
"""Parse version into a comparable Version tuple.""" """Parse version into a comparable Version tuple."""
m = RE_VER.match(ver) m = RE_VER.match(ver)
@ -188,5 +192,5 @@ def parse_version(ver):
return Version(major, minor, micro, release, pre, post, dev) return Version(major, minor, micro, release, pre, post, dev)
__version_info__ = Version(2, 2, 1, "final") __version_info__ = Version(2, 3, 1, "final")
__version__ = __version_info__._get_canonical() __version__ = __version_info__._get_canonical()

View file

@ -2,11 +2,10 @@
from datetime import datetime from datetime import datetime
from . import util from . import util
import re import re
from .import css_types as ct from . import css_types as ct
import unicodedata import unicodedata
from collections.abc import Sequence import bs4 # type: ignore[import]
from typing import Iterator, Iterable, List, Any, Optional, Tuple, Union, Dict, Callable, Sequence, cast
import bs4
# Empty tag pattern (whitespace okay) # Empty tag pattern (whitespace okay)
RE_NOT_EMPTY = re.compile('[^ \t\r\n\f]') RE_NOT_EMPTY = re.compile('[^ \t\r\n\f]')
@ -56,7 +55,7 @@ FEB_LEAP_MONTH = 29
DAYS_IN_WEEK = 7 DAYS_IN_WEEK = 7
class _FakeParent(object): class _FakeParent:
""" """
Fake parent class. Fake parent class.
@ -65,22 +64,22 @@ class _FakeParent(object):
fake parent so we can traverse the root element as a child. fake parent so we can traverse the root element as a child.
""" """
def __init__(self, element): def __init__(self, element: 'bs4.Tag') -> None:
"""Initialize.""" """Initialize."""
self.contents = [element] self.contents = [element]
def __len__(self): def __len__(self) -> 'bs4.PageElement':
"""Length.""" """Length."""
return len(self.contents) return len(self.contents)
class _DocumentNav(object): class _DocumentNav:
"""Navigate a Beautiful Soup document.""" """Navigate a Beautiful Soup document."""
@classmethod @classmethod
def assert_valid_input(cls, tag): def assert_valid_input(cls, tag: Any) -> None:
"""Check if valid input tag or document.""" """Check if valid input tag or document."""
# Fail on unexpected types. # Fail on unexpected types.
@ -88,64 +87,67 @@ class _DocumentNav(object):
raise TypeError("Expected a BeautifulSoup 'Tag', but instead recieved type {}".format(type(tag))) raise TypeError("Expected a BeautifulSoup 'Tag', but instead recieved type {}".format(type(tag)))
@staticmethod @staticmethod
def is_doc(obj): def is_doc(obj: 'bs4.Tag') -> bool:
"""Is `BeautifulSoup` object.""" """Is `BeautifulSoup` object."""
return isinstance(obj, bs4.BeautifulSoup) return isinstance(obj, bs4.BeautifulSoup)
@staticmethod @staticmethod
def is_tag(obj): def is_tag(obj: 'bs4.PageElement') -> bool:
"""Is tag.""" """Is tag."""
return isinstance(obj, bs4.Tag) return isinstance(obj, bs4.Tag)
@staticmethod @staticmethod
def is_declaration(obj): # pragma: no cover def is_declaration(obj: 'bs4.PageElement') -> bool: # pragma: no cover
"""Is declaration.""" """Is declaration."""
return isinstance(obj, bs4.Declaration) return isinstance(obj, bs4.Declaration)
@staticmethod @staticmethod
def is_cdata(obj): def is_cdata(obj: 'bs4.PageElement') -> bool:
"""Is CDATA.""" """Is CDATA."""
return isinstance(obj, bs4.CData) return isinstance(obj, bs4.CData)
@staticmethod @staticmethod
def is_processing_instruction(obj): # pragma: no cover def is_processing_instruction(obj: 'bs4.PageElement') -> bool: # pragma: no cover
"""Is processing instruction.""" """Is processing instruction."""
return isinstance(obj, bs4.ProcessingInstruction) return isinstance(obj, bs4.ProcessingInstruction)
@staticmethod @staticmethod
def is_navigable_string(obj): def is_navigable_string(obj: 'bs4.PageElement') -> bool:
"""Is navigable string.""" """Is navigable string."""
return isinstance(obj, bs4.NavigableString) return isinstance(obj, bs4.NavigableString)
@staticmethod @staticmethod
def is_special_string(obj): def is_special_string(obj: 'bs4.PageElement') -> bool:
"""Is special string.""" """Is special string."""
return isinstance(obj, (bs4.Comment, bs4.Declaration, bs4.CData, bs4.ProcessingInstruction, bs4.Doctype)) return isinstance(obj, (bs4.Comment, bs4.Declaration, bs4.CData, bs4.ProcessingInstruction, bs4.Doctype))
@classmethod @classmethod
def is_content_string(cls, obj): def is_content_string(cls, obj: 'bs4.PageElement') -> bool:
"""Check if node is content string.""" """Check if node is content string."""
return cls.is_navigable_string(obj) and not cls.is_special_string(obj) return cls.is_navigable_string(obj) and not cls.is_special_string(obj)
@staticmethod @staticmethod
def create_fake_parent(el): def create_fake_parent(el: 'bs4.Tag') -> _FakeParent:
"""Create fake parent for a given element.""" """Create fake parent for a given element."""
return _FakeParent(el) return _FakeParent(el)
@staticmethod @staticmethod
def is_xml_tree(el): def is_xml_tree(el: 'bs4.Tag') -> bool:
"""Check if element (or document) is from a XML tree.""" """Check if element (or document) is from a XML tree."""
return el._is_xml return bool(el._is_xml)
def is_iframe(self, el): def is_iframe(self, el: 'bs4.Tag') -> bool:
"""Check if element is an `iframe`.""" """Check if element is an `iframe`."""
return ((el.name if self.is_xml_tree(el) else util.lower(el.name)) == 'iframe') and self.is_html_tag(el) return bool(
((el.name if self.is_xml_tree(el) else util.lower(el.name)) == 'iframe') and
self.is_html_tag(el) # type: ignore[attr-defined]
)
def is_root(self, el): def is_root(self, el: 'bs4.Tag') -> bool:
""" """
Return whether element is a root element. Return whether element is a root element.
@ -153,19 +155,26 @@ class _DocumentNav(object):
and we check if it is the root element under an `iframe`. and we check if it is the root element under an `iframe`.
""" """
root = self.root and self.root is el root = self.root and self.root is el # type: ignore[attr-defined]
if not root: if not root:
parent = self.get_parent(el) parent = self.get_parent(el)
root = parent is not None and self.is_html and self.is_iframe(parent) root = parent is not None and self.is_html and self.is_iframe(parent) # type: ignore[attr-defined]
return root return root
def get_contents(self, el, no_iframe=False): def get_contents(self, el: 'bs4.Tag', no_iframe: bool = False) -> Iterator['bs4.PageElement']:
"""Get contents or contents in reverse.""" """Get contents or contents in reverse."""
if not no_iframe or not self.is_iframe(el): if not no_iframe or not self.is_iframe(el):
for content in el.contents: for content in el.contents:
yield content yield content
def get_children(self, el, start=None, reverse=False, tags=True, no_iframe=False): def get_children(
self,
el: 'bs4.Tag',
start: Optional[int] = None,
reverse: bool = False,
tags: bool = True,
no_iframe: bool = False
) -> Iterator['bs4.PageElement']:
"""Get children.""" """Get children."""
if not no_iframe or not self.is_iframe(el): if not no_iframe or not self.is_iframe(el):
@ -184,7 +193,12 @@ class _DocumentNav(object):
if not tags or self.is_tag(node): if not tags or self.is_tag(node):
yield node yield node
def get_descendants(self, el, tags=True, no_iframe=False): def get_descendants(
self,
el: 'bs4.Tag',
tags: bool = True,
no_iframe: bool = False
) -> Iterator['bs4.PageElement']:
"""Get descendants.""" """Get descendants."""
if not no_iframe or not self.is_iframe(el): if not no_iframe or not self.is_iframe(el):
@ -215,7 +229,7 @@ class _DocumentNav(object):
if not tags or is_tag: if not tags or is_tag:
yield child yield child
def get_parent(self, el, no_iframe=False): def get_parent(self, el: 'bs4.Tag', no_iframe: bool = False) -> 'bs4.Tag':
"""Get parent.""" """Get parent."""
parent = el.parent parent = el.parent
@ -224,25 +238,25 @@ class _DocumentNav(object):
return parent return parent
@staticmethod @staticmethod
def get_tag_name(el): def get_tag_name(el: 'bs4.Tag') -> Optional[str]:
"""Get tag.""" """Get tag."""
return el.name return cast(Optional[str], el.name)
@staticmethod @staticmethod
def get_prefix_name(el): def get_prefix_name(el: 'bs4.Tag') -> Optional[str]:
"""Get prefix.""" """Get prefix."""
return el.prefix return cast(Optional[str], el.prefix)
@staticmethod @staticmethod
def get_uri(el): def get_uri(el: 'bs4.Tag') -> Optional[str]:
"""Get namespace `URI`.""" """Get namespace `URI`."""
return el.namespace return cast(Optional[str], el.namespace)
@classmethod @classmethod
def get_next(cls, el, tags=True): def get_next(cls, el: 'bs4.Tag', tags: bool = True) -> 'bs4.PageElement':
"""Get next sibling tag.""" """Get next sibling tag."""
sibling = el.next_sibling sibling = el.next_sibling
@ -251,7 +265,7 @@ class _DocumentNav(object):
return sibling return sibling
@classmethod @classmethod
def get_previous(cls, el, tags=True): def get_previous(cls, el: 'bs4.Tag', tags: bool = True) -> 'bs4.PageElement':
"""Get previous sibling tag.""" """Get previous sibling tag."""
sibling = el.previous_sibling sibling = el.previous_sibling
@ -260,7 +274,7 @@ class _DocumentNav(object):
return sibling return sibling
@staticmethod @staticmethod
def has_html_ns(el): def has_html_ns(el: 'bs4.Tag') -> bool:
""" """
Check if element has an HTML namespace. Check if element has an HTML namespace.
@ -269,16 +283,16 @@ class _DocumentNav(object):
""" """
ns = getattr(el, 'namespace') if el else None ns = getattr(el, 'namespace') if el else None
return ns and ns == NS_XHTML return bool(ns and ns == NS_XHTML)
@staticmethod @staticmethod
def split_namespace(el, attr_name): def split_namespace(el: 'bs4.Tag', attr_name: str) -> Tuple[Optional[str], Optional[str]]:
"""Return namespace and attribute name without the prefix.""" """Return namespace and attribute name without the prefix."""
return getattr(attr_name, 'namespace', None), getattr(attr_name, 'name', None) return getattr(attr_name, 'namespace', None), getattr(attr_name, 'name', None)
@classmethod @classmethod
def normalize_value(cls, value): def normalize_value(cls, value: Any) -> Union[str, Sequence[str]]:
"""Normalize the value to be a string or list of strings.""" """Normalize the value to be a string or list of strings."""
# Treat `None` as empty string. # Treat `None` as empty string.
@ -297,20 +311,26 @@ class _DocumentNav(object):
if isinstance(value, Sequence): if isinstance(value, Sequence):
new_value = [] new_value = []
for v in value: for v in value:
if isinstance(v, Sequence): if not isinstance(v, (str, bytes)) and isinstance(v, Sequence):
# This is most certainly a user error and will crash and burn later, # This is most certainly a user error and will crash and burn later.
# but to avoid excessive recursion, kick out now. # To keep things working, we'll do what we do with all objects,
new_value.append(v) # And convert them to strings.
new_value.append(str(v))
else: else:
# Convert the child to a string # Convert the child to a string
new_value.append(cls.normalize_value(v)) new_value.append(cast(str, cls.normalize_value(v)))
return new_value return new_value
# Try and make anything else a string # Try and make anything else a string
return str(value) return str(value)
@classmethod @classmethod
def get_attribute_by_name(cls, el, name, default=None): def get_attribute_by_name(
cls,
el: 'bs4.Tag',
name: str,
default: Optional[Union[str, Sequence[str]]] = None
) -> Optional[Union[str, Sequence[str]]]:
"""Get attribute by name.""" """Get attribute by name."""
value = default value = default
@ -327,39 +347,39 @@ class _DocumentNav(object):
return value return value
@classmethod @classmethod
def iter_attributes(cls, el): def iter_attributes(cls, el: 'bs4.Tag') -> Iterator[Tuple[str, Optional[Union[str, Sequence[str]]]]]:
"""Iterate attributes.""" """Iterate attributes."""
for k, v in el.attrs.items(): for k, v in el.attrs.items():
yield k, cls.normalize_value(v) yield k, cls.normalize_value(v)
@classmethod @classmethod
def get_classes(cls, el): def get_classes(cls, el: 'bs4.Tag') -> Sequence[str]:
"""Get classes.""" """Get classes."""
classes = cls.get_attribute_by_name(el, 'class', []) classes = cls.get_attribute_by_name(el, 'class', [])
if isinstance(classes, str): if isinstance(classes, str):
classes = RE_NOT_WS.findall(classes) classes = RE_NOT_WS.findall(classes)
return classes return cast(Sequence[str], classes)
def get_text(self, el, no_iframe=False): def get_text(self, el: 'bs4.Tag', no_iframe: bool = False) -> str:
"""Get text.""" """Get text."""
return ''.join( return ''.join(
[node for node in self.get_descendants(el, tags=False, no_iframe=no_iframe) if self.is_content_string(node)] [node for node in self.get_descendants(el, tags=False, no_iframe=no_iframe) if self.is_content_string(node)]
) )
def get_own_text(self, el, no_iframe=False): def get_own_text(self, el: 'bs4.Tag', no_iframe: bool = False) -> List[str]:
"""Get Own Text.""" """Get Own Text."""
return [node for node in self.get_contents(el, no_iframe=no_iframe) if self.is_content_string(node)] return [node for node in self.get_contents(el, no_iframe=no_iframe) if self.is_content_string(node)]
class Inputs(object): class Inputs:
"""Class for parsing and validating input items.""" """Class for parsing and validating input items."""
@staticmethod @staticmethod
def validate_day(year, month, day): def validate_day(year: int, month: int, day: int) -> bool:
"""Validate day.""" """Validate day."""
max_days = LONG_MONTH max_days = LONG_MONTH
@ -370,7 +390,7 @@ class Inputs(object):
return 1 <= day <= max_days return 1 <= day <= max_days
@staticmethod @staticmethod
def validate_week(year, week): def validate_week(year: int, week: int) -> bool:
"""Validate week.""" """Validate week."""
max_week = datetime.strptime("{}-{}-{}".format(12, 31, year), "%m-%d-%Y").isocalendar()[1] max_week = datetime.strptime("{}-{}-{}".format(12, 31, year), "%m-%d-%Y").isocalendar()[1]
@ -379,34 +399,36 @@ class Inputs(object):
return 1 <= week <= max_week return 1 <= week <= max_week
@staticmethod @staticmethod
def validate_month(month): def validate_month(month: int) -> bool:
"""Validate month.""" """Validate month."""
return 1 <= month <= 12 return 1 <= month <= 12
@staticmethod @staticmethod
def validate_year(year): def validate_year(year: int) -> bool:
"""Validate year.""" """Validate year."""
return 1 <= year return 1 <= year
@staticmethod @staticmethod
def validate_hour(hour): def validate_hour(hour: int) -> bool:
"""Validate hour.""" """Validate hour."""
return 0 <= hour <= 23 return 0 <= hour <= 23
@staticmethod @staticmethod
def validate_minutes(minutes): def validate_minutes(minutes: int) -> bool:
"""Validate minutes.""" """Validate minutes."""
return 0 <= minutes <= 59 return 0 <= minutes <= 59
@classmethod @classmethod
def parse_value(cls, itype, value): def parse_value(cls, itype: str, value: Optional[str]) -> Optional[Tuple[float, ...]]:
"""Parse the input value.""" """Parse the input value."""
parsed = None parsed = None # type: Optional[Tuple[float, ...]]
if value is None:
return value
if itype == "date": if itype == "date":
m = RE_DATE.match(value) m = RE_DATE.match(value)
if m: if m:
@ -452,23 +474,29 @@ class Inputs(object):
elif itype in ("number", "range"): elif itype in ("number", "range"):
m = RE_NUM.match(value) m = RE_NUM.match(value)
if m: if m:
parsed = float(m.group('value')) parsed = (float(m.group('value')),)
return parsed return parsed
class _Match(object): class CSSMatch(_DocumentNav):
"""Perform CSS matching.""" """Perform CSS matching."""
def __init__(self, selectors, scope, namespaces, flags): def __init__(
self,
selectors: ct.SelectorList,
scope: 'bs4.Tag',
namespaces: Optional[ct.Namespaces],
flags: int
) -> None:
"""Initialize.""" """Initialize."""
self.assert_valid_input(scope) self.assert_valid_input(scope)
self.tag = scope self.tag = scope
self.cached_meta_lang = [] self.cached_meta_lang = [] # type: List[Tuple[str, str]]
self.cached_default_forms = [] self.cached_default_forms = [] # type: List[Tuple['bs4.Tag', 'bs4.Tag']]
self.cached_indeterminate_forms = [] self.cached_indeterminate_forms = [] # type: List[Tuple['bs4.Tag', str, bool]]
self.selectors = selectors self.selectors = selectors
self.namespaces = {} if namespaces is None else namespaces self.namespaces = {} if namespaces is None else namespaces # type: Union[ct.Namespaces, Dict[str, str]]
self.flags = flags self.flags = flags
self.iframe_restrict = False self.iframe_restrict = False
@ -494,12 +522,12 @@ class _Match(object):
self.is_xml = self.is_xml_tree(doc) self.is_xml = self.is_xml_tree(doc)
self.is_html = not self.is_xml or self.has_html_namespace self.is_html = not self.is_xml or self.has_html_namespace
def supports_namespaces(self): def supports_namespaces(self) -> bool:
"""Check if namespaces are supported in the HTML type.""" """Check if namespaces are supported in the HTML type."""
return self.is_xml or self.has_html_namespace return self.is_xml or self.has_html_namespace
def get_tag_ns(self, el): def get_tag_ns(self, el: 'bs4.Tag') -> str:
"""Get tag namespace.""" """Get tag namespace."""
if self.supports_namespaces(): if self.supports_namespaces():
@ -511,24 +539,24 @@ class _Match(object):
namespace = NS_XHTML namespace = NS_XHTML
return namespace return namespace
def is_html_tag(self, el): def is_html_tag(self, el: 'bs4.Tag') -> bool:
"""Check if tag is in HTML namespace.""" """Check if tag is in HTML namespace."""
return self.get_tag_ns(el) == NS_XHTML return self.get_tag_ns(el) == NS_XHTML
def get_tag(self, el): def get_tag(self, el: 'bs4.Tag') -> Optional[str]:
"""Get tag.""" """Get tag."""
name = self.get_tag_name(el) name = self.get_tag_name(el)
return util.lower(name) if name is not None and not self.is_xml else name return util.lower(name) if name is not None and not self.is_xml else name
def get_prefix(self, el): def get_prefix(self, el: 'bs4.Tag') -> Optional[str]:
"""Get prefix.""" """Get prefix."""
prefix = self.get_prefix_name(el) prefix = self.get_prefix_name(el)
return util.lower(prefix) if prefix is not None and not self.is_xml else prefix return util.lower(prefix) if prefix is not None and not self.is_xml else prefix
def find_bidi(self, el): def find_bidi(self, el: 'bs4.Tag') -> Optional[int]:
"""Get directionality from element text.""" """Get directionality from element text."""
for node in self.get_children(el, tags=False): for node in self.get_children(el, tags=False):
@ -564,7 +592,7 @@ class _Match(object):
return ct.SEL_DIR_LTR if bidi == 'L' else ct.SEL_DIR_RTL return ct.SEL_DIR_LTR if bidi == 'L' else ct.SEL_DIR_RTL
return None return None
def extended_language_filter(self, lang_range, lang_tag): def extended_language_filter(self, lang_range: str, lang_tag: str) -> bool:
"""Filter the language tags.""" """Filter the language tags."""
match = True match = True
@ -615,7 +643,12 @@ class _Match(object):
return match return match
def match_attribute_name(self, el, attr, prefix): def match_attribute_name(
self,
el: 'bs4.Tag',
attr: str,
prefix: Optional[str]
) -> Optional[Union[str, Sequence[str]]]:
"""Match attribute name and return value if it exists.""" """Match attribute name and return value if it exists."""
value = None value = None
@ -663,13 +696,13 @@ class _Match(object):
break break
return value return value
def match_namespace(self, el, tag): def match_namespace(self, el: 'bs4.Tag', tag: ct.SelectorTag) -> bool:
"""Match the namespace of the element.""" """Match the namespace of the element."""
match = True match = True
namespace = self.get_tag_ns(el) namespace = self.get_tag_ns(el)
default_namespace = self.namespaces.get('') default_namespace = self.namespaces.get('')
tag_ns = '' if tag.prefix is None else self.namespaces.get(tag.prefix, None) tag_ns = '' if tag.prefix is None else self.namespaces.get(tag.prefix)
# We must match the default namespace if one is not provided # We must match the default namespace if one is not provided
if tag.prefix is None and (default_namespace is not None and namespace != default_namespace): if tag.prefix is None and (default_namespace is not None and namespace != default_namespace):
match = False match = False
@ -684,27 +717,26 @@ class _Match(object):
match = False match = False
return match return match
def match_attributes(self, el, attributes): def match_attributes(self, el: 'bs4.Tag', attributes: Tuple[ct.SelectorAttribute, ...]) -> bool:
"""Match attributes.""" """Match attributes."""
match = True match = True
if attributes: if attributes:
for a in attributes: for a in attributes:
value = self.match_attribute_name(el, a.attribute, a.prefix) temp = self.match_attribute_name(el, a.attribute, a.prefix)
pattern = a.xml_type_pattern if self.is_xml and a.xml_type_pattern else a.pattern pattern = a.xml_type_pattern if self.is_xml and a.xml_type_pattern else a.pattern
if isinstance(value, list): if temp is None:
value = ' '.join(value)
if value is None:
match = False match = False
break break
elif pattern is None: value = temp if isinstance(temp, str) else ' '.join(temp)
if pattern is None:
continue continue
elif pattern.match(value) is None: elif pattern.match(value) is None:
match = False match = False
break break
return match return match
def match_tagname(self, el, tag): def match_tagname(self, el: 'bs4.Tag', tag: ct.SelectorTag) -> bool:
"""Match tag name.""" """Match tag name."""
name = (util.lower(tag.name) if not self.is_xml and tag.name is not None else tag.name) name = (util.lower(tag.name) if not self.is_xml and tag.name is not None else tag.name)
@ -713,7 +745,7 @@ class _Match(object):
name not in (self.get_tag(el), '*') name not in (self.get_tag(el), '*')
) )
def match_tag(self, el, tag): def match_tag(self, el: 'bs4.Tag', tag: Optional[ct.SelectorTag]) -> bool:
"""Match the tag.""" """Match the tag."""
match = True match = True
@ -725,10 +757,14 @@ class _Match(object):
match = False match = False
return match return match
def match_past_relations(self, el, relation): def match_past_relations(self, el: 'bs4.Tag', relation: ct.SelectorList) -> bool:
"""Match past relationship.""" """Match past relationship."""
found = False found = False
# I don't think this can ever happen, but it makes `mypy` happy
if isinstance(relation[0], ct.SelectorNull): # pragma: no cover
return found
if relation[0].rel_type == REL_PARENT: if relation[0].rel_type == REL_PARENT:
parent = self.get_parent(el, no_iframe=self.iframe_restrict) parent = self.get_parent(el, no_iframe=self.iframe_restrict)
while not found and parent: while not found and parent:
@ -749,21 +785,28 @@ class _Match(object):
found = self.match_selectors(sibling, relation) found = self.match_selectors(sibling, relation)
return found return found
def match_future_child(self, parent, relation, recursive=False): def match_future_child(self, parent: 'bs4.Tag', relation: ct.SelectorList, recursive: bool = False) -> bool:
"""Match future child.""" """Match future child."""
match = False match = False
children = self.get_descendants if recursive else self.get_children if recursive:
children = self.get_descendants # type: Callable[..., Iterator['bs4.Tag']]
else:
children = self.get_children
for child in children(parent, no_iframe=self.iframe_restrict): for child in children(parent, no_iframe=self.iframe_restrict):
match = self.match_selectors(child, relation) match = self.match_selectors(child, relation)
if match: if match:
break break
return match return match
def match_future_relations(self, el, relation): def match_future_relations(self, el: 'bs4.Tag', relation: ct.SelectorList) -> bool:
"""Match future relationship.""" """Match future relationship."""
found = False found = False
# I don't think this can ever happen, but it makes `mypy` happy
if isinstance(relation[0], ct.SelectorNull): # pragma: no cover
return found
if relation[0].rel_type == REL_HAS_PARENT: if relation[0].rel_type == REL_HAS_PARENT:
found = self.match_future_child(el, relation, True) found = self.match_future_child(el, relation, True)
elif relation[0].rel_type == REL_HAS_CLOSE_PARENT: elif relation[0].rel_type == REL_HAS_CLOSE_PARENT:
@ -779,11 +822,14 @@ class _Match(object):
found = self.match_selectors(sibling, relation) found = self.match_selectors(sibling, relation)
return found return found
def match_relations(self, el, relation): def match_relations(self, el: 'bs4.Tag', relation: ct.SelectorList) -> bool:
"""Match relationship to other elements.""" """Match relationship to other elements."""
found = False found = False
if isinstance(relation[0], ct.SelectorNull) or relation[0].rel_type is None:
return found
if relation[0].rel_type.startswith(':'): if relation[0].rel_type.startswith(':'):
found = self.match_future_relations(el, relation) found = self.match_future_relations(el, relation)
else: else:
@ -791,7 +837,7 @@ class _Match(object):
return found return found
def match_id(self, el, ids): def match_id(self, el: 'bs4.Tag', ids: Tuple[str, ...]) -> bool:
"""Match element's ID.""" """Match element's ID."""
found = True found = True
@ -801,7 +847,7 @@ class _Match(object):
break break
return found return found
def match_classes(self, el, classes): def match_classes(self, el: 'bs4.Tag', classes: Tuple[str, ...]) -> bool:
"""Match element's classes.""" """Match element's classes."""
current_classes = self.get_classes(el) current_classes = self.get_classes(el)
@ -812,7 +858,7 @@ class _Match(object):
break break
return found return found
def match_root(self, el): def match_root(self, el: 'bs4.Tag') -> bool:
"""Match element as root.""" """Match element as root."""
is_root = self.is_root(el) is_root = self.is_root(el)
@ -838,12 +884,12 @@ class _Match(object):
sibling = self.get_next(sibling, tags=False) sibling = self.get_next(sibling, tags=False)
return is_root return is_root
def match_scope(self, el): def match_scope(self, el: 'bs4.Tag') -> bool:
"""Match element as scope.""" """Match element as scope."""
return self.scope is el return self.scope is el
def match_nth_tag_type(self, el, child): def match_nth_tag_type(self, el: 'bs4.Tag', child: 'bs4.Tag') -> bool:
"""Match tag type for `nth` matches.""" """Match tag type for `nth` matches."""
return( return(
@ -851,7 +897,7 @@ class _Match(object):
(self.get_tag_ns(child) == self.get_tag_ns(el)) (self.get_tag_ns(child) == self.get_tag_ns(el))
) )
def match_nth(self, el, nth): def match_nth(self, el: 'bs4.Tag', nth: 'bs4.Tag') -> bool:
"""Match `nth` elements.""" """Match `nth` elements."""
matched = True matched = True
@ -952,7 +998,7 @@ class _Match(object):
break break
return matched return matched
def match_empty(self, el): def match_empty(self, el: 'bs4.Tag') -> bool:
"""Check if element is empty (if requested).""" """Check if element is empty (if requested)."""
is_empty = True is_empty = True
@ -965,7 +1011,7 @@ class _Match(object):
break break
return is_empty return is_empty
def match_subselectors(self, el, selectors): def match_subselectors(self, el: 'bs4.Tag', selectors: Tuple[ct.SelectorList, ...]) -> bool:
"""Match selectors.""" """Match selectors."""
match = True match = True
@ -974,11 +1020,11 @@ class _Match(object):
match = False match = False
return match return match
def match_contains(self, el, contains): def match_contains(self, el: 'bs4.Tag', contains: Tuple[ct.SelectorContains, ...]) -> bool:
"""Match element if it contains text.""" """Match element if it contains text."""
match = True match = True
content = None content = None # type: Optional[Union[str, Sequence[str]]]
for contain_list in contains: for contain_list in contains:
if content is None: if content is None:
if contain_list.own: if contain_list.own:
@ -1002,7 +1048,7 @@ class _Match(object):
match = False match = False
return match return match
def match_default(self, el): def match_default(self, el: 'bs4.Tag') -> bool:
"""Match default.""" """Match default."""
match = False match = False
@ -1035,19 +1081,19 @@ class _Match(object):
if name in ('input', 'button'): if name in ('input', 'button'):
v = self.get_attribute_by_name(child, 'type', '') v = self.get_attribute_by_name(child, 'type', '')
if v and util.lower(v) == 'submit': if v and util.lower(v) == 'submit':
self.cached_default_forms.append([form, child]) self.cached_default_forms.append((form, child))
if el is child: if el is child:
match = True match = True
break break
return match return match
def match_indeterminate(self, el): def match_indeterminate(self, el: 'bs4.Tag') -> bool:
"""Match default.""" """Match default."""
match = False match = False
name = self.get_attribute_by_name(el, 'name') name = cast(str, self.get_attribute_by_name(el, 'name'))
def get_parent_form(el): def get_parent_form(el: 'bs4.Tag') -> Optional['bs4.Tag']:
"""Find this input's form.""" """Find this input's form."""
form = None form = None
parent = self.get_parent(el, no_iframe=True) parent = self.get_parent(el, no_iframe=True)
@ -1098,11 +1144,11 @@ class _Match(object):
break break
if not checked: if not checked:
match = True match = True
self.cached_indeterminate_forms.append([form, name, match]) self.cached_indeterminate_forms.append((form, name, match))
return match return match
def match_lang(self, el, langs): def match_lang(self, el: 'bs4.Tag', langs: Tuple[ct.SelectorLang, ...]) -> bool:
"""Match languages.""" """Match languages."""
match = False match = False
@ -1169,26 +1215,26 @@ class _Match(object):
content = v content = v
if c_lang and content: if c_lang and content:
found_lang = content found_lang = content
self.cached_meta_lang.append((root, found_lang)) self.cached_meta_lang.append((cast(str, root), cast(str, found_lang)))
break break
if found_lang: if found_lang:
break break
if not found_lang: if not found_lang:
self.cached_meta_lang.append((root, False)) self.cached_meta_lang.append((cast(str, root), ''))
# If we determined a language, compare. # If we determined a language, compare.
if found_lang: if found_lang:
for patterns in langs: for patterns in langs:
match = False match = False
for pattern in patterns: for pattern in patterns:
if self.extended_language_filter(pattern, found_lang): if self.extended_language_filter(pattern, cast(str, found_lang)):
match = True match = True
if not match: if not match:
break break
return match return match
def match_dir(self, el, directionality): def match_dir(self, el: 'bs4.Tag', directionality: int) -> bool:
"""Check directionality.""" """Check directionality."""
# If we have to match both left and right, we can't match either. # If we have to match both left and right, we can't match either.
@ -1220,13 +1266,13 @@ class _Match(object):
# Auto handling for text inputs # Auto handling for text inputs
if ((is_input and itype in ('text', 'search', 'tel', 'url', 'email')) or is_textarea) and direction == 0: if ((is_input and itype in ('text', 'search', 'tel', 'url', 'email')) or is_textarea) and direction == 0:
if is_textarea: if is_textarea:
value = [] temp = []
for node in self.get_contents(el, no_iframe=True): for node in self.get_contents(el, no_iframe=True):
if self.is_content_string(node): if self.is_content_string(node):
value.append(node) temp.append(node)
value = ''.join(value) value = ''.join(temp)
else: else:
value = self.get_attribute_by_name(el, 'value', '') value = cast(str, self.get_attribute_by_name(el, 'value', ''))
if value: if value:
for c in value: for c in value:
bidi = unicodedata.bidirectional(c) bidi = unicodedata.bidirectional(c)
@ -1251,7 +1297,7 @@ class _Match(object):
# Match parents direction # Match parents direction
return self.match_dir(self.get_parent(el, no_iframe=True), directionality) return self.match_dir(self.get_parent(el, no_iframe=True), directionality)
def match_range(self, el, condition): def match_range(self, el: 'bs4.Tag', condition: int) -> bool:
""" """
Match range. Match range.
@ -1264,20 +1310,14 @@ class _Match(object):
out_of_range = False out_of_range = False
itype = util.lower(self.get_attribute_by_name(el, 'type')) itype = util.lower(self.get_attribute_by_name(el, 'type'))
mn = self.get_attribute_by_name(el, 'min', None) mn = Inputs.parse_value(itype, cast(str, self.get_attribute_by_name(el, 'min', None)))
if mn is not None: mx = Inputs.parse_value(itype, cast(str, self.get_attribute_by_name(el, 'max', None)))
mn = Inputs.parse_value(itype, mn)
mx = self.get_attribute_by_name(el, 'max', None)
if mx is not None:
mx = Inputs.parse_value(itype, mx)
# There is no valid min or max, so we cannot evaluate a range # There is no valid min or max, so we cannot evaluate a range
if mn is None and mx is None: if mn is None and mx is None:
return False return False
value = self.get_attribute_by_name(el, 'value', None) value = Inputs.parse_value(itype, cast(str, self.get_attribute_by_name(el, 'value', None)))
if value is not None:
value = Inputs.parse_value(itype, value)
if value is not None: if value is not None:
if itype in ("date", "datetime-local", "month", "week", "number", "range"): if itype in ("date", "datetime-local", "month", "week", "number", "range"):
if mn is not None and value < mn: if mn is not None and value < mn:
@ -1297,7 +1337,7 @@ class _Match(object):
return not out_of_range if condition & ct.SEL_IN_RANGE else out_of_range return not out_of_range if condition & ct.SEL_IN_RANGE else out_of_range
def match_defined(self, el): def match_defined(self, el: 'bs4.Tag') -> bool:
""" """
Match defined. Match defined.
@ -1313,12 +1353,14 @@ class _Match(object):
name = self.get_tag(el) name = self.get_tag(el)
return ( return (
name.find('-') == -1 or name is not None and (
name.find(':') != -1 or name.find('-') == -1 or
self.get_prefix(el) is not None name.find(':') != -1 or
self.get_prefix(el) is not None
)
) )
def match_placeholder_shown(self, el): def match_placeholder_shown(self, el: 'bs4.Tag') -> bool:
""" """
Match placeholder shown according to HTML spec. Match placeholder shown according to HTML spec.
@ -1333,7 +1375,7 @@ class _Match(object):
return match return match
def match_selectors(self, el, selectors): def match_selectors(self, el: 'bs4.Tag', selectors: ct.SelectorList) -> bool:
"""Check if element matches one of the selectors.""" """Check if element matches one of the selectors."""
match = False match = False
@ -1405,7 +1447,7 @@ class _Match(object):
if selector.flags & DIR_FLAGS and not self.match_dir(el, selector.flags & DIR_FLAGS): if selector.flags & DIR_FLAGS and not self.match_dir(el, selector.flags & DIR_FLAGS):
continue continue
# Validate that the tag contains the specified text. # Validate that the tag contains the specified text.
if not self.match_contains(el, selector.contains): if selector.contains and not self.match_contains(el, selector.contains):
continue continue
match = not is_not match = not is_not
break break
@ -1417,21 +1459,20 @@ class _Match(object):
return match return match
def select(self, limit=0): def select(self, limit: int = 0) -> Iterator['bs4.Tag']:
"""Match all tags under the targeted tag.""" """Match all tags under the targeted tag."""
if limit < 1: lim = None if limit < 1 else limit
limit = None
for child in self.get_descendants(self.tag): for child in self.get_descendants(self.tag):
if self.match(child): if self.match(child):
yield child yield child
if limit is not None: if lim is not None:
limit -= 1 lim -= 1
if limit < 1: if lim < 1:
break break
def closest(self): def closest(self) -> Optional['bs4.Tag']:
"""Match closest ancestor.""" """Match closest ancestor."""
current = self.tag current = self.tag
@ -1443,30 +1484,39 @@ class _Match(object):
current = self.get_parent(current) current = self.get_parent(current)
return closest return closest
def filter(self): # noqa A001 def filter(self) -> List['bs4.Tag']: # noqa A001
"""Filter tag's children.""" """Filter tag's children."""
return [tag for tag in self.get_contents(self.tag) if not self.is_navigable_string(tag) and self.match(tag)] return [tag for tag in self.get_contents(self.tag) if not self.is_navigable_string(tag) and self.match(tag)]
def match(self, el): def match(self, el: 'bs4.Tag') -> bool:
"""Match.""" """Match."""
return not self.is_doc(el) and self.is_tag(el) and self.match_selectors(el, self.selectors) return not self.is_doc(el) and self.is_tag(el) and self.match_selectors(el, self.selectors)
class CSSMatch(_DocumentNav, _Match):
"""The Beautiful Soup CSS match class."""
class SoupSieve(ct.Immutable): class SoupSieve(ct.Immutable):
"""Compiled Soup Sieve selector matching object.""" """Compiled Soup Sieve selector matching object."""
pattern: str
selectors: ct.SelectorList
namespaces: Optional[ct.Namespaces]
custom: Dict[str, str]
flags: int
__slots__ = ("pattern", "selectors", "namespaces", "custom", "flags", "_hash") __slots__ = ("pattern", "selectors", "namespaces", "custom", "flags", "_hash")
def __init__(self, pattern, selectors, namespaces, custom, flags): def __init__(
self,
pattern: str,
selectors: ct.SelectorList,
namespaces: Optional[ct.Namespaces],
custom: Optional[ct.CustomSelectors],
flags: int
):
"""Initialize.""" """Initialize."""
super(SoupSieve, self).__init__( super().__init__(
pattern=pattern, pattern=pattern,
selectors=selectors, selectors=selectors,
namespaces=namespaces, namespaces=namespaces,
@ -1474,17 +1524,17 @@ class SoupSieve(ct.Immutable):
flags=flags flags=flags
) )
def match(self, tag): def match(self, tag: 'bs4.Tag') -> bool:
"""Match.""" """Match."""
return CSSMatch(self.selectors, tag, self.namespaces, self.flags).match(tag) return CSSMatch(self.selectors, tag, self.namespaces, self.flags).match(tag)
def closest(self, tag): def closest(self, tag: 'bs4.Tag') -> 'bs4.Tag':
"""Match closest ancestor.""" """Match closest ancestor."""
return CSSMatch(self.selectors, tag, self.namespaces, self.flags).closest() return CSSMatch(self.selectors, tag, self.namespaces, self.flags).closest()
def filter(self, iterable): # noqa A001 def filter(self, iterable: Iterable['bs4.Tag']) -> List['bs4.Tag']: # noqa A001
""" """
Filter. Filter.
@ -1501,24 +1551,24 @@ class SoupSieve(ct.Immutable):
else: else:
return [node for node in iterable if not CSSMatch.is_navigable_string(node) and self.match(node)] return [node for node in iterable if not CSSMatch.is_navigable_string(node) and self.match(node)]
def select_one(self, tag): def select_one(self, tag: 'bs4.Tag') -> 'bs4.Tag':
"""Select a single tag.""" """Select a single tag."""
tags = self.select(tag, limit=1) tags = self.select(tag, limit=1)
return tags[0] if tags else None return tags[0] if tags else None
def select(self, tag, limit=0): def select(self, tag: 'bs4.Tag', limit: int = 0) -> List['bs4.Tag']:
"""Select the specified tags.""" """Select the specified tags."""
return list(self.iselect(tag, limit)) return list(self.iselect(tag, limit))
def iselect(self, tag, limit=0): def iselect(self, tag: 'bs4.Tag', limit: int = 0) -> Iterator['bs4.Tag']:
"""Iterate the specified tags.""" """Iterate the specified tags."""
for el in CSSMatch(self.selectors, tag, self.namespaces, self.flags).select(limit): for el in CSSMatch(self.selectors, tag, self.namespaces, self.flags).select(limit):
yield el yield el
def __repr__(self): # pragma: no cover def __repr__(self) -> str: # pragma: no cover
"""Representation.""" """Representation."""
return "SoupSieve(pattern={!r}, namespaces={!r}, custom={!r}, flags={!r})".format( return "SoupSieve(pattern={!r}, namespaces={!r}, custom={!r}, flags={!r})".format(

View file

@ -6,6 +6,7 @@ from . import css_match as cm
from . import css_types as ct from . import css_types as ct
from .util import SelectorSyntaxError from .util import SelectorSyntaxError
import warnings import warnings
from typing import Optional, Dict, Match, Tuple, Type, Any, List, Union, Iterator, cast
UNICODE_REPLACEMENT_CHAR = 0xFFFD UNICODE_REPLACEMENT_CHAR = 0xFFFD
@ -196,32 +197,42 @@ FLG_OPEN = 0x40
FLG_IN_RANGE = 0x80 FLG_IN_RANGE = 0x80
FLG_OUT_OF_RANGE = 0x100 FLG_OUT_OF_RANGE = 0x100
FLG_PLACEHOLDER_SHOWN = 0x200 FLG_PLACEHOLDER_SHOWN = 0x200
FLG_FORGIVE = 0x400
# Maximum cached patterns to store # Maximum cached patterns to store
_MAXCACHE = 500 _MAXCACHE = 500
@lru_cache(maxsize=_MAXCACHE) @lru_cache(maxsize=_MAXCACHE)
def _cached_css_compile(pattern, namespaces, custom, flags): def _cached_css_compile(
pattern: str,
namespaces: Optional[ct.Namespaces],
custom: Optional[ct.CustomSelectors],
flags: int
) -> cm.SoupSieve:
"""Cached CSS compile.""" """Cached CSS compile."""
custom_selectors = process_custom(custom) custom_selectors = process_custom(custom)
return cm.SoupSieve( return cm.SoupSieve(
pattern, pattern,
CSSParser(pattern, custom=custom_selectors, flags=flags).process_selectors(), CSSParser(
pattern,
custom=custom_selectors,
flags=flags
).process_selectors(),
namespaces, namespaces,
custom, custom,
flags flags
) )
def _purge_cache(): def _purge_cache() -> None:
"""Purge the cache.""" """Purge the cache."""
_cached_css_compile.cache_clear() _cached_css_compile.cache_clear()
def process_custom(custom): def process_custom(custom: Optional[ct.CustomSelectors]) -> Dict[str, Union[str, ct.SelectorList]]:
"""Process custom.""" """Process custom."""
custom_selectors = {} custom_selectors = {}
@ -236,14 +247,14 @@ def process_custom(custom):
return custom_selectors return custom_selectors
def css_unescape(content, string=False): def css_unescape(content: str, string: bool = False) -> str:
""" """
Unescape CSS value. Unescape CSS value.
Strings allow for spanning the value on multiple strings by escaping a new line. Strings allow for spanning the value on multiple strings by escaping a new line.
""" """
def replace(m): def replace(m: Match[str]) -> str:
"""Replace with the appropriate substitute.""" """Replace with the appropriate substitute."""
if m.group(1): if m.group(1):
@ -263,7 +274,7 @@ def css_unescape(content, string=False):
return (RE_CSS_ESC if not string else RE_CSS_STR_ESC).sub(replace, content) return (RE_CSS_ESC if not string else RE_CSS_STR_ESC).sub(replace, content)
def escape(ident): def escape(ident: str) -> str:
"""Escape identifier.""" """Escape identifier."""
string = [] string = []
@ -291,21 +302,21 @@ def escape(ident):
return ''.join(string) return ''.join(string)
class SelectorPattern(object): class SelectorPattern:
"""Selector pattern.""" """Selector pattern."""
def __init__(self, name, pattern): def __init__(self, name: str, pattern: str) -> None:
"""Initialize.""" """Initialize."""
self.name = name self.name = name
self.re_pattern = re.compile(pattern, re.I | re.X | re.U) self.re_pattern = re.compile(pattern, re.I | re.X | re.U)
def get_name(self): def get_name(self) -> str:
"""Get name.""" """Get name."""
return self.name return self.name
def match(self, selector, index, flags): def match(self, selector: str, index: int, flags: int) -> Optional[Match[str]]:
"""Match the selector.""" """Match the selector."""
return self.re_pattern.match(selector, index) return self.re_pattern.match(selector, index)
@ -314,7 +325,7 @@ class SelectorPattern(object):
class SpecialPseudoPattern(SelectorPattern): class SpecialPseudoPattern(SelectorPattern):
"""Selector pattern.""" """Selector pattern."""
def __init__(self, patterns): def __init__(self, patterns: Tuple[Tuple[str, Tuple[str, ...], str, Type[SelectorPattern]], ...]) -> None:
"""Initialize.""" """Initialize."""
self.patterns = {} self.patterns = {}
@ -324,15 +335,15 @@ class SpecialPseudoPattern(SelectorPattern):
for pseudo in p[1]: for pseudo in p[1]:
self.patterns[pseudo] = pattern self.patterns[pseudo] = pattern
self.matched_name = None self.matched_name = None # type: Optional[SelectorPattern]
self.re_pseudo_name = re.compile(PAT_PSEUDO_CLASS_SPECIAL, re.I | re.X | re.U) self.re_pseudo_name = re.compile(PAT_PSEUDO_CLASS_SPECIAL, re.I | re.X | re.U)
def get_name(self): def get_name(self) -> str:
"""Get name.""" """Get name."""
return self.matched_name.get_name() return '' if self.matched_name is None else self.matched_name.get_name()
def match(self, selector, index, flags): def match(self, selector: str, index: int, flags: int) -> Optional[Match[str]]:
"""Match the selector.""" """Match the selector."""
pseudo = None pseudo = None
@ -348,7 +359,7 @@ class SpecialPseudoPattern(SelectorPattern):
return pseudo return pseudo
class _Selector(object): class _Selector:
""" """
Intermediate selector class. Intermediate selector class.
@ -357,23 +368,23 @@ class _Selector(object):
the data in an object that can be pickled and hashed. the data in an object that can be pickled and hashed.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
"""Initialize.""" """Initialize."""
self.tag = kwargs.get('tag', None) self.tag = kwargs.get('tag', None) # type: Optional[ct.SelectorTag]
self.ids = kwargs.get('ids', []) self.ids = kwargs.get('ids', []) # type: List[str]
self.classes = kwargs.get('classes', []) self.classes = kwargs.get('classes', []) # type: List[str]
self.attributes = kwargs.get('attributes', []) self.attributes = kwargs.get('attributes', []) # type: List[ct.SelectorAttribute]
self.nth = kwargs.get('nth', []) self.nth = kwargs.get('nth', []) # type: List[ct.SelectorNth]
self.selectors = kwargs.get('selectors', []) self.selectors = kwargs.get('selectors', []) # type: List[ct.SelectorList]
self.relations = kwargs.get('relations', []) self.relations = kwargs.get('relations', []) # type: List[_Selector]
self.rel_type = kwargs.get('rel_type', None) self.rel_type = kwargs.get('rel_type', None) # type: Optional[str]
self.contains = kwargs.get('contains', []) self.contains = kwargs.get('contains', []) # type: List[ct.SelectorContains]
self.lang = kwargs.get('lang', []) self.lang = kwargs.get('lang', []) # type: List[ct.SelectorLang]
self.flags = kwargs.get('flags', 0) self.flags = kwargs.get('flags', 0) # type: int
self.no_match = kwargs.get('no_match', False) self.no_match = kwargs.get('no_match', False) # type: bool
def _freeze_relations(self, relations): def _freeze_relations(self, relations: List['_Selector']) -> ct.SelectorList:
"""Freeze relation.""" """Freeze relation."""
if relations: if relations:
@ -383,7 +394,7 @@ class _Selector(object):
else: else:
return ct.SelectorList() return ct.SelectorList()
def freeze(self): def freeze(self) -> Union[ct.Selector, ct.SelectorNull]:
"""Freeze self.""" """Freeze self."""
if self.no_match: if self.no_match:
@ -403,7 +414,7 @@ class _Selector(object):
self.flags self.flags
) )
def __str__(self): # pragma: no cover def __str__(self) -> str: # pragma: no cover
"""String representation.""" """String representation."""
return ( return (
@ -417,7 +428,7 @@ class _Selector(object):
__repr__ = __str__ __repr__ = __str__
class CSSParser(object): class CSSParser:
"""Parse CSS selectors.""" """Parse CSS selectors."""
css_tokens = ( css_tokens = (
@ -447,7 +458,12 @@ class CSSParser(object):
SelectorPattern("combine", PAT_COMBINE) SelectorPattern("combine", PAT_COMBINE)
) )
def __init__(self, selector, custom=None, flags=0): def __init__(
self,
selector: str,
custom: Optional[Dict[str, Union[str, ct.SelectorList]]] = None,
flags: int = 0
) -> None:
"""Initialize.""" """Initialize."""
self.pattern = selector.replace('\x00', '\ufffd') self.pattern = selector.replace('\x00', '\ufffd')
@ -455,7 +471,7 @@ class CSSParser(object):
self.debug = self.flags & util.DEBUG self.debug = self.flags & util.DEBUG
self.custom = {} if custom is None else custom self.custom = {} if custom is None else custom
def parse_attribute_selector(self, sel, m, has_selector): def parse_attribute_selector(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool:
"""Create attribute selector from the returned regex match.""" """Create attribute selector from the returned regex match."""
inverse = False inverse = False
@ -465,22 +481,22 @@ class CSSParser(object):
attr = css_unescape(m.group('attr_name')) attr = css_unescape(m.group('attr_name'))
is_type = False is_type = False
pattern2 = None pattern2 = None
value = ''
if case: if case:
flags = re.I if case == 'i' else 0 flags = (re.I if case == 'i' else 0) | re.DOTALL
elif util.lower(attr) == 'type': elif util.lower(attr) == 'type':
flags = re.I flags = re.I | re.DOTALL
is_type = True is_type = True
else: else:
flags = 0 flags = re.DOTALL
if op: if op:
if m.group('value').startswith(('"', "'")): if m.group('value').startswith(('"', "'")):
value = css_unescape(m.group('value')[1:-1], True) value = css_unescape(m.group('value')[1:-1], True)
else: else:
value = css_unescape(m.group('value')) value = css_unescape(m.group('value'))
else:
value = None
if not op: if not op:
# Attribute name # Attribute name
pattern = None pattern = None
@ -525,7 +541,7 @@ class CSSParser(object):
has_selector = True has_selector = True
return has_selector return has_selector
def parse_tag_pattern(self, sel, m, has_selector): def parse_tag_pattern(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool:
"""Parse tag pattern from regex match.""" """Parse tag pattern from regex match."""
prefix = css_unescape(m.group('tag_ns')[:-1]) if m.group('tag_ns') else None prefix = css_unescape(m.group('tag_ns')[:-1]) if m.group('tag_ns') else None
@ -534,7 +550,7 @@ class CSSParser(object):
has_selector = True has_selector = True
return has_selector return has_selector
def parse_pseudo_class_custom(self, sel, m, has_selector): def parse_pseudo_class_custom(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool:
""" """
Parse custom pseudo class alias. Parse custom pseudo class alias.
@ -552,7 +568,7 @@ class CSSParser(object):
) )
if not isinstance(selector, ct.SelectorList): if not isinstance(selector, ct.SelectorList):
self.custom[pseudo] = None del self.custom[pseudo]
selector = CSSParser( selector = CSSParser(
selector, custom=self.custom, flags=self.flags selector, custom=self.custom, flags=self.flags
).process_selectors(flags=FLG_PSEUDO) ).process_selectors(flags=FLG_PSEUDO)
@ -562,7 +578,14 @@ class CSSParser(object):
has_selector = True has_selector = True
return has_selector return has_selector
def parse_pseudo_class(self, sel, m, has_selector, iselector, is_html): def parse_pseudo_class(
self,
sel: _Selector,
m: Match[str],
has_selector: bool,
iselector: Iterator[Tuple[str, Match[str]]],
is_html: bool
) -> Tuple[bool, bool]:
"""Parse pseudo class.""" """Parse pseudo class."""
complex_pseudo = False complex_pseudo = False
@ -650,7 +673,13 @@ class CSSParser(object):
return has_selector, is_html return has_selector, is_html
def parse_pseudo_nth(self, sel, m, has_selector, iselector): def parse_pseudo_nth(
self,
sel: _Selector,
m: Match[str],
has_selector: bool,
iselector: Iterator[Tuple[str, Match[str]]]
) -> bool:
"""Parse `nth` pseudo.""" """Parse `nth` pseudo."""
mdict = m.groupdict() mdict = m.groupdict()
@ -671,23 +700,23 @@ class CSSParser(object):
s2 = 1 s2 = 1
var = True var = True
else: else:
nth_parts = RE_NTH.match(content) nth_parts = cast(Match[str], RE_NTH.match(content))
s1 = '-' if nth_parts.group('s1') and nth_parts.group('s1') == '-' else '' _s1 = '-' if nth_parts.group('s1') and nth_parts.group('s1') == '-' else ''
a = nth_parts.group('a') a = nth_parts.group('a')
var = a.endswith('n') var = a.endswith('n')
if a.startswith('n'): if a.startswith('n'):
s1 += '1' _s1 += '1'
elif var: elif var:
s1 += a[:-1] _s1 += a[:-1]
else: else:
s1 += a _s1 += a
s2 = '-' if nth_parts.group('s2') and nth_parts.group('s2') == '-' else '' _s2 = '-' if nth_parts.group('s2') and nth_parts.group('s2') == '-' else ''
if nth_parts.group('b'): if nth_parts.group('b'):
s2 += nth_parts.group('b') _s2 += nth_parts.group('b')
else: else:
s2 = '0' _s2 = '0'
s1 = int(s1, 10) s1 = int(_s1, 10)
s2 = int(s2, 10) s2 = int(_s2, 10)
pseudo_sel = mdict['name'] pseudo_sel = mdict['name']
if postfix == '_child': if postfix == '_child':
@ -709,20 +738,38 @@ class CSSParser(object):
has_selector = True has_selector = True
return has_selector return has_selector
def parse_pseudo_open(self, sel, name, has_selector, iselector, index): def parse_pseudo_open(
self,
sel: _Selector,
name: str,
has_selector: bool,
iselector: Iterator[Tuple[str, Match[str]]],
index: int
) -> bool:
"""Parse pseudo with opening bracket.""" """Parse pseudo with opening bracket."""
flags = FLG_PSEUDO | FLG_OPEN flags = FLG_PSEUDO | FLG_OPEN
if name == ':not': if name == ':not':
flags |= FLG_NOT flags |= FLG_NOT
if name == ':has': elif name == ':has':
flags |= FLG_RELATIVE flags |= FLG_RELATIVE | FLG_FORGIVE
elif name in (':where', ':is'):
flags |= FLG_FORGIVE
sel.selectors.append(self.parse_selectors(iselector, index, flags)) sel.selectors.append(self.parse_selectors(iselector, index, flags))
has_selector = True has_selector = True
return has_selector return has_selector
def parse_has_combinator(self, sel, m, has_selector, selectors, rel_type, index): def parse_has_combinator(
self,
sel: _Selector,
m: Match[str],
has_selector: bool,
selectors: List[_Selector],
rel_type: str,
index: int
) -> Tuple[bool, _Selector, str]:
"""Parse combinator tokens.""" """Parse combinator tokens."""
combinator = m.group('relation').strip() combinator = m.group('relation').strip()
@ -731,12 +778,9 @@ class CSSParser(object):
if combinator == COMMA_COMBINATOR: if combinator == COMMA_COMBINATOR:
if not has_selector: if not has_selector:
# If we've not captured any selector parts, the comma is either at the beginning of the pattern # If we've not captured any selector parts, the comma is either at the beginning of the pattern
# or following another comma, both of which are unexpected. Commas must split selectors. # or following another comma, both of which are unexpected. But shouldn't fail the pseudo-class.
raise SelectorSyntaxError( sel.no_match = True
"The combinator '{}' at postion {}, must have a selector before it".format(combinator, index),
self.pattern,
index
)
sel.rel_type = rel_type sel.rel_type = rel_type
selectors[-1].relations.append(sel) selectors[-1].relations.append(sel)
rel_type = ":" + WS_COMBINATOR rel_type = ":" + WS_COMBINATOR
@ -757,44 +801,63 @@ class CSSParser(object):
self.pattern, self.pattern,
index index
) )
# Set the leading combinator for the next selector. # Set the leading combinator for the next selector.
rel_type = ':' + combinator rel_type = ':' + combinator
sel = _Selector()
sel = _Selector()
has_selector = False has_selector = False
return has_selector, sel, rel_type return has_selector, sel, rel_type
def parse_combinator(self, sel, m, has_selector, selectors, relations, is_pseudo, index): def parse_combinator(
self,
sel: _Selector,
m: Match[str],
has_selector: bool,
selectors: List[_Selector],
relations: List[_Selector],
is_pseudo: bool,
is_forgive: bool,
index: int
) -> Tuple[bool, _Selector]:
"""Parse combinator tokens.""" """Parse combinator tokens."""
combinator = m.group('relation').strip() combinator = m.group('relation').strip()
if not combinator: if not combinator:
combinator = WS_COMBINATOR combinator = WS_COMBINATOR
if not has_selector: if not has_selector:
raise SelectorSyntaxError( if not is_forgive or combinator != COMMA_COMBINATOR:
"The combinator '{}' at postion {}, must have a selector before it".format(combinator, index), raise SelectorSyntaxError(
self.pattern, "The combinator '{}' at postion {}, must have a selector before it".format(combinator, index),
index self.pattern,
) index
)
if combinator == COMMA_COMBINATOR: # If we are in a forgiving pseudo class, just make the selector a "no match"
if not sel.tag and not is_pseudo: if combinator == COMMA_COMBINATOR:
# Implied `*` sel.no_match = True
sel.tag = ct.SelectorTag('*', None) del relations[:]
sel.relations.extend(relations) selectors.append(sel)
selectors.append(sel)
del relations[:]
else: else:
sel.relations.extend(relations) if combinator == COMMA_COMBINATOR:
sel.rel_type = combinator if not sel.tag and not is_pseudo:
del relations[:] # Implied `*`
relations.append(sel) sel.tag = ct.SelectorTag('*', None)
sel = _Selector() sel.relations.extend(relations)
selectors.append(sel)
del relations[:]
else:
sel.relations.extend(relations)
sel.rel_type = combinator
del relations[:]
relations.append(sel)
sel = _Selector()
has_selector = False has_selector = False
return has_selector, sel return has_selector, sel
def parse_class_id(self, sel, m, has_selector): def parse_class_id(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool:
"""Parse HTML classes and ids.""" """Parse HTML classes and ids."""
selector = m.group(0) selector = m.group(0)
@ -805,7 +868,7 @@ class CSSParser(object):
has_selector = True has_selector = True
return has_selector return has_selector
def parse_pseudo_contains(self, sel, m, has_selector): def parse_pseudo_contains(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool:
"""Parse contains.""" """Parse contains."""
pseudo = util.lower(css_unescape(m.group('name'))) pseudo = util.lower(css_unescape(m.group('name')))
@ -826,11 +889,11 @@ class CSSParser(object):
else: else:
value = css_unescape(value) value = css_unescape(value)
patterns.append(value) patterns.append(value)
sel.contains.append(ct.SelectorContains(tuple(patterns), contains_own)) sel.contains.append(ct.SelectorContains(patterns, contains_own))
has_selector = True has_selector = True
return has_selector return has_selector
def parse_pseudo_lang(self, sel, m, has_selector): def parse_pseudo_lang(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool:
"""Parse pseudo language.""" """Parse pseudo language."""
values = m.group('values') values = m.group('values')
@ -851,7 +914,7 @@ class CSSParser(object):
return has_selector return has_selector
def parse_pseudo_dir(self, sel, m, has_selector): def parse_pseudo_dir(self, sel: _Selector, m: Match[str], has_selector: bool) -> bool:
"""Parse pseudo direction.""" """Parse pseudo direction."""
value = ct.SEL_DIR_LTR if util.lower(m.group('dir')) == 'ltr' else ct.SEL_DIR_RTL value = ct.SEL_DIR_LTR if util.lower(m.group('dir')) == 'ltr' else ct.SEL_DIR_RTL
@ -859,15 +922,23 @@ class CSSParser(object):
has_selector = True has_selector = True
return has_selector return has_selector
def parse_selectors(self, iselector, index=0, flags=0): def parse_selectors(
self,
iselector: Iterator[Tuple[str, Match[str]]],
index: int = 0,
flags: int = 0
) -> ct.SelectorList:
"""Parse selectors.""" """Parse selectors."""
# Initialize important variables
sel = _Selector() sel = _Selector()
selectors = [] selectors = []
has_selector = False has_selector = False
closed = False closed = False
relations = [] relations = [] # type: List[_Selector]
rel_type = ":" + WS_COMBINATOR rel_type = ":" + WS_COMBINATOR
# Setup various flags
is_open = bool(flags & FLG_OPEN) is_open = bool(flags & FLG_OPEN)
is_pseudo = bool(flags & FLG_PSEUDO) is_pseudo = bool(flags & FLG_PSEUDO)
is_relative = bool(flags & FLG_RELATIVE) is_relative = bool(flags & FLG_RELATIVE)
@ -878,7 +949,9 @@ class CSSParser(object):
is_in_range = bool(flags & FLG_IN_RANGE) is_in_range = bool(flags & FLG_IN_RANGE)
is_out_of_range = bool(flags & FLG_OUT_OF_RANGE) is_out_of_range = bool(flags & FLG_OUT_OF_RANGE)
is_placeholder_shown = bool(flags & FLG_PLACEHOLDER_SHOWN) is_placeholder_shown = bool(flags & FLG_PLACEHOLDER_SHOWN)
is_forgive = bool(flags & FLG_FORGIVE)
# Print out useful debug stuff
if self.debug: # pragma: no cover if self.debug: # pragma: no cover
if is_pseudo: if is_pseudo:
print(' is_pseudo: True') print(' is_pseudo: True')
@ -900,7 +973,10 @@ class CSSParser(object):
print(' is_out_of_range: True') print(' is_out_of_range: True')
if is_placeholder_shown: if is_placeholder_shown:
print(' is_placeholder_shown: True') print(' is_placeholder_shown: True')
if is_forgive:
print(' is_forgive: True')
# The algorithm for relative selectors require an initial selector in the selector list
if is_relative: if is_relative:
selectors.append(_Selector()) selectors.append(_Selector())
@ -929,11 +1005,13 @@ class CSSParser(object):
is_html = True is_html = True
elif key == 'pseudo_close': elif key == 'pseudo_close':
if not has_selector: if not has_selector:
raise SelectorSyntaxError( if not is_forgive:
"Expected a selector at postion {}".format(m.start(0)), raise SelectorSyntaxError(
self.pattern, "Expected a selector at postion {}".format(m.start(0)),
m.start(0) self.pattern,
) m.start(0)
)
sel.no_match = True
if is_open: if is_open:
closed = True closed = True
break break
@ -950,7 +1028,7 @@ class CSSParser(object):
) )
else: else:
has_selector, sel = self.parse_combinator( has_selector, sel = self.parse_combinator(
sel, m, has_selector, selectors, relations, is_pseudo, index sel, m, has_selector, selectors, relations, is_pseudo, is_forgive, index
) )
elif key == 'attribute': elif key == 'attribute':
has_selector = self.parse_attribute_selector(sel, m, has_selector) has_selector = self.parse_attribute_selector(sel, m, has_selector)
@ -969,6 +1047,7 @@ class CSSParser(object):
except StopIteration: except StopIteration:
pass pass
# Handle selectors that are not closed
if is_open and not closed: if is_open and not closed:
raise SelectorSyntaxError( raise SelectorSyntaxError(
"Unclosed pseudo-class at position {}".format(index), "Unclosed pseudo-class at position {}".format(index),
@ -976,6 +1055,7 @@ class CSSParser(object):
index index
) )
# Cleanup completed selector piece
if has_selector: if has_selector:
if not sel.tag and not is_pseudo: if not sel.tag and not is_pseudo:
# Implied `*` # Implied `*`
@ -987,8 +1067,28 @@ class CSSParser(object):
sel.relations.extend(relations) sel.relations.extend(relations)
del relations[:] del relations[:]
selectors.append(sel) selectors.append(sel)
else:
# Forgive empty slots in pseudo-classes that have lists (and are forgiving)
elif is_forgive:
if is_relative:
# Handle relative selectors pseudo-classes with empty slots like `:has()`
if selectors and selectors[-1].rel_type is None and rel_type == ': ':
sel.rel_type = rel_type
sel.no_match = True
selectors[-1].relations.append(sel)
has_selector = True
else:
# Handle normal pseudo-classes with empty slots
if not selectors or not relations:
# Others like `:is()` etc.
sel.no_match = True
del relations[:]
selectors.append(sel)
has_selector = True
if not has_selector:
# We will always need to finish a selector when `:has()` is used as it leads with combining. # We will always need to finish a selector when `:has()` is used as it leads with combining.
# May apply to others as well.
raise SelectorSyntaxError( raise SelectorSyntaxError(
'Expected a selector at position {}'.format(index), 'Expected a selector at position {}'.format(index),
self.pattern, self.pattern,
@ -1009,9 +1109,10 @@ class CSSParser(object):
if is_placeholder_shown: if is_placeholder_shown:
selectors[-1].flags = ct.SEL_PLACEHOLDER_SHOWN selectors[-1].flags = ct.SEL_PLACEHOLDER_SHOWN
# Return selector list
return ct.SelectorList([s.freeze() for s in selectors], is_not, is_html) return ct.SelectorList([s.freeze() for s in selectors], is_not, is_html)
def selector_iter(self, pattern): def selector_iter(self, pattern: str) -> Iterator[Tuple[str, Match[str]]]:
"""Iterate selector tokens.""" """Iterate selector tokens."""
# Ignore whitespace and comments at start and end of pattern # Ignore whitespace and comments at start and end of pattern
@ -1052,7 +1153,7 @@ class CSSParser(object):
if self.debug: # pragma: no cover if self.debug: # pragma: no cover
print('## END PARSING') print('## END PARSING')
def process_selectors(self, index=0, flags=0): def process_selectors(self, index: int = 0, flags: int = 0) -> ct.SelectorList:
"""Process selectors.""" """Process selectors."""
return self.parse_selectors(self.selector_iter(self.pattern), index, flags) return self.parse_selectors(self.selector_iter(self.pattern), index, flags)

View file

@ -1,6 +1,7 @@
"""CSS selector structure items.""" """CSS selector structure items."""
import copyreg import copyreg
from collections.abc import Hashable, Mapping from .pretty import pretty
from typing import Any, Type, Tuple, Union, Dict, Iterator, Hashable, Optional, Pattern, Iterable, Mapping
__all__ = ( __all__ = (
'Selector', 'Selector',
@ -29,12 +30,14 @@ SEL_DEFINED = 0x200
SEL_PLACEHOLDER_SHOWN = 0x400 SEL_PLACEHOLDER_SHOWN = 0x400
class Immutable(object): class Immutable:
"""Immutable.""" """Immutable."""
__slots__ = ('_hash',) __slots__: Tuple[str, ...] = ('_hash',)
def __init__(self, **kwargs): _hash: int
def __init__(self, **kwargs: Any) -> None:
"""Initialize.""" """Initialize."""
temp = [] temp = []
@ -45,12 +48,12 @@ class Immutable(object):
super(Immutable, self).__setattr__('_hash', hash(tuple(temp))) super(Immutable, self).__setattr__('_hash', hash(tuple(temp)))
@classmethod @classmethod
def __base__(cls): def __base__(cls) -> "Type[Immutable]":
"""Get base class.""" """Get base class."""
return cls return cls
def __eq__(self, other): def __eq__(self, other: Any) -> bool:
"""Equal.""" """Equal."""
return ( return (
@ -58,7 +61,7 @@ class Immutable(object):
all([getattr(other, key) == getattr(self, key) for key in self.__slots__ if key != '_hash']) all([getattr(other, key) == getattr(self, key) for key in self.__slots__ if key != '_hash'])
) )
def __ne__(self, other): def __ne__(self, other: Any) -> bool:
"""Equal.""" """Equal."""
return ( return (
@ -66,63 +69,74 @@ class Immutable(object):
any([getattr(other, key) != getattr(self, key) for key in self.__slots__ if key != '_hash']) any([getattr(other, key) != getattr(self, key) for key in self.__slots__ if key != '_hash'])
) )
def __hash__(self): def __hash__(self) -> int:
"""Hash.""" """Hash."""
return self._hash return self._hash
def __setattr__(self, name, value): def __setattr__(self, name: str, value: Any) -> None:
"""Prevent mutability.""" """Prevent mutability."""
raise AttributeError("'{}' is immutable".format(self.__class__.__name__)) raise AttributeError("'{}' is immutable".format(self.__class__.__name__))
def __repr__(self): # pragma: no cover def __repr__(self) -> str: # pragma: no cover
"""Representation.""" """Representation."""
return "{}({})".format( return "{}({})".format(
self.__base__(), ', '.join(["{}={!r}".format(k, getattr(self, k)) for k in self.__slots__[:-1]]) self.__class__.__name__, ', '.join(["{}={!r}".format(k, getattr(self, k)) for k in self.__slots__[:-1]])
) )
__str__ = __repr__ __str__ = __repr__
def pretty(self) -> None: # pragma: no cover
"""Pretty print."""
class ImmutableDict(Mapping): print(pretty(self))
class ImmutableDict(Mapping[Any, Any]):
"""Hashable, immutable dictionary.""" """Hashable, immutable dictionary."""
def __init__(self, arg): def __init__(
self,
arg: Union[Dict[Any, Any], Iterable[Tuple[Any, Any]]]
) -> None:
"""Initialize.""" """Initialize."""
arg self._validate(arg)
is_dict = isinstance(arg, dict)
if (
is_dict and not all([isinstance(v, Hashable) for v in arg.values()]) or
not is_dict and not all([isinstance(k, Hashable) and isinstance(v, Hashable) for k, v in arg])
):
raise TypeError('All values must be hashable')
self._d = dict(arg) self._d = dict(arg)
self._hash = hash(tuple([(type(x), x, type(y), y) for x, y in sorted(self._d.items())])) self._hash = hash(tuple([(type(x), x, type(y), y) for x, y in sorted(self._d.items())]))
def __iter__(self): def _validate(self, arg: Union[Dict[Any, Any], Iterable[Tuple[Any, Any]]]) -> None:
"""Validate arguments."""
if isinstance(arg, dict):
if not all([isinstance(v, Hashable) for v in arg.values()]):
raise TypeError('{} values must be hashable'.format(self.__class__.__name__))
elif not all([isinstance(k, Hashable) and isinstance(v, Hashable) for k, v in arg]):
raise TypeError('{} values must be hashable'.format(self.__class__.__name__))
def __iter__(self) -> Iterator[Any]:
"""Iterator.""" """Iterator."""
return iter(self._d) return iter(self._d)
def __len__(self): def __len__(self) -> int:
"""Length.""" """Length."""
return len(self._d) return len(self._d)
def __getitem__(self, key): def __getitem__(self, key: Any) -> Any:
"""Get item: `namespace['key']`.""" """Get item: `namespace['key']`."""
return self._d[key] return self._d[key]
def __hash__(self): def __hash__(self) -> int:
"""Hash.""" """Hash."""
return self._hash return self._hash
def __repr__(self): # pragma: no cover def __repr__(self) -> str: # pragma: no cover
"""Representation.""" """Representation."""
return "{!r}".format(self._d) return "{!r}".format(self._d)
@ -133,37 +147,37 @@ class ImmutableDict(Mapping):
class Namespaces(ImmutableDict): class Namespaces(ImmutableDict):
"""Namespaces.""" """Namespaces."""
def __init__(self, arg): def __init__(self, arg: Union[Dict[str, str], Iterable[Tuple[str, str]]]) -> None:
"""Initialize.""" """Initialize."""
# If there are arguments, check the first index. super().__init__(arg)
# `super` should fail if the user gave multiple arguments,
# so don't bother checking that.
is_dict = isinstance(arg, dict)
if is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg.items()]):
raise TypeError('Namespace keys and values must be Unicode strings')
elif not is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]):
raise TypeError('Namespace keys and values must be Unicode strings')
super(Namespaces, self).__init__(arg) def _validate(self, arg: Union[Dict[str, str], Iterable[Tuple[str, str]]]) -> None:
"""Validate arguments."""
if isinstance(arg, dict):
if not all([isinstance(v, str) for v in arg.values()]):
raise TypeError('{} values must be hashable'.format(self.__class__.__name__))
elif not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]):
raise TypeError('{} keys and values must be Unicode strings'.format(self.__class__.__name__))
class CustomSelectors(ImmutableDict): class CustomSelectors(ImmutableDict):
"""Custom selectors.""" """Custom selectors."""
def __init__(self, arg): def __init__(self, arg: Union[Dict[str, str], Iterable[Tuple[str, str]]]) -> None:
"""Initialize.""" """Initialize."""
# If there are arguments, check the first index. super().__init__(arg)
# `super` should fail if the user gave multiple arguments,
# so don't bother checking that.
is_dict = isinstance(arg, dict)
if is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg.items()]):
raise TypeError('CustomSelectors keys and values must be Unicode strings')
elif not is_dict and not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]):
raise TypeError('CustomSelectors keys and values must be Unicode strings')
super(CustomSelectors, self).__init__(arg) def _validate(self, arg: Union[Dict[str, str], Iterable[Tuple[str, str]]]) -> None:
"""Validate arguments."""
if isinstance(arg, dict):
if not all([isinstance(v, str) for v in arg.values()]):
raise TypeError('{} values must be hashable'.format(self.__class__.__name__))
elif not all([isinstance(k, str) and isinstance(v, str) for k, v in arg]):
raise TypeError('{} keys and values must be Unicode strings'.format(self.__class__.__name__))
class Selector(Immutable): class Selector(Immutable):
@ -174,13 +188,35 @@ class Selector(Immutable):
'relation', 'rel_type', 'contains', 'lang', 'flags', '_hash' 'relation', 'rel_type', 'contains', 'lang', 'flags', '_hash'
) )
tag: Optional['SelectorTag']
ids: Tuple[str, ...]
classes: Tuple[str, ...]
attributes: Tuple['SelectorAttribute', ...]
nth: Tuple['SelectorNth', ...]
selectors: Tuple['SelectorList', ...]
relation: 'SelectorList'
rel_type: Optional[str]
contains: Tuple['SelectorContains', ...]
lang: Tuple['SelectorLang', ...]
flags: int
def __init__( def __init__(
self, tag, ids, classes, attributes, nth, selectors, self,
relation, rel_type, contains, lang, flags tag: Optional['SelectorTag'],
ids: Tuple[str, ...],
classes: Tuple[str, ...],
attributes: Tuple['SelectorAttribute', ...],
nth: Tuple['SelectorNth', ...],
selectors: Tuple['SelectorList', ...],
relation: 'SelectorList',
rel_type: Optional[str],
contains: Tuple['SelectorContains', ...],
lang: Tuple['SelectorLang', ...],
flags: int
): ):
"""Initialize.""" """Initialize."""
super(Selector, self).__init__( super().__init__(
tag=tag, tag=tag,
ids=ids, ids=ids,
classes=classes, classes=classes,
@ -198,10 +234,10 @@ class Selector(Immutable):
class SelectorNull(Immutable): class SelectorNull(Immutable):
"""Null Selector.""" """Null Selector."""
def __init__(self): def __init__(self) -> None:
"""Initialize.""" """Initialize."""
super(SelectorNull, self).__init__() super().__init__()
class SelectorTag(Immutable): class SelectorTag(Immutable):
@ -209,13 +245,13 @@ class SelectorTag(Immutable):
__slots__ = ("name", "prefix", "_hash") __slots__ = ("name", "prefix", "_hash")
def __init__(self, name, prefix): name: str
prefix: Optional[str]
def __init__(self, name: str, prefix: Optional[str]) -> None:
"""Initialize.""" """Initialize."""
super(SelectorTag, self).__init__( super().__init__(name=name, prefix=prefix)
name=name,
prefix=prefix
)
class SelectorAttribute(Immutable): class SelectorAttribute(Immutable):
@ -223,10 +259,21 @@ class SelectorAttribute(Immutable):
__slots__ = ("attribute", "prefix", "pattern", "xml_type_pattern", "_hash") __slots__ = ("attribute", "prefix", "pattern", "xml_type_pattern", "_hash")
def __init__(self, attribute, prefix, pattern, xml_type_pattern): attribute: str
prefix: str
pattern: Optional[Pattern[str]]
xml_type_pattern: Optional[Pattern[str]]
def __init__(
self,
attribute: str,
prefix: str,
pattern: Optional[Pattern[str]],
xml_type_pattern: Optional[Pattern[str]]
) -> None:
"""Initialize.""" """Initialize."""
super(SelectorAttribute, self).__init__( super().__init__(
attribute=attribute, attribute=attribute,
prefix=prefix, prefix=prefix,
pattern=pattern, pattern=pattern,
@ -239,13 +286,13 @@ class SelectorContains(Immutable):
__slots__ = ("text", "own", "_hash") __slots__ = ("text", "own", "_hash")
def __init__(self, text, own): text: Tuple[str, ...]
own: bool
def __init__(self, text: Iterable[str], own: bool) -> None:
"""Initialize.""" """Initialize."""
super(SelectorContains, self).__init__( super().__init__(text=tuple(text), own=own)
text=text,
own=own
)
class SelectorNth(Immutable): class SelectorNth(Immutable):
@ -253,10 +300,17 @@ class SelectorNth(Immutable):
__slots__ = ("a", "n", "b", "of_type", "last", "selectors", "_hash") __slots__ = ("a", "n", "b", "of_type", "last", "selectors", "_hash")
def __init__(self, a, n, b, of_type, last, selectors): a: int
n: bool
b: int
of_type: bool
last: bool
selectors: 'SelectorList'
def __init__(self, a: int, n: bool, b: int, of_type: bool, last: bool, selectors: 'SelectorList') -> None:
"""Initialize.""" """Initialize."""
super(SelectorNth, self).__init__( super().__init__(
a=a, a=a,
n=n, n=n,
b=b, b=b,
@ -271,24 +325,24 @@ class SelectorLang(Immutable):
__slots__ = ("languages", "_hash",) __slots__ = ("languages", "_hash",)
def __init__(self, languages): languages: Tuple[str, ...]
def __init__(self, languages: Iterable[str]):
"""Initialize.""" """Initialize."""
super(SelectorLang, self).__init__( super().__init__(languages=tuple(languages))
languages=tuple(languages)
)
def __iter__(self): def __iter__(self) -> Iterator[str]:
"""Iterator.""" """Iterator."""
return iter(self.languages) return iter(self.languages)
def __len__(self): # pragma: no cover def __len__(self) -> int: # pragma: no cover
"""Length.""" """Length."""
return len(self.languages) return len(self.languages)
def __getitem__(self, index): # pragma: no cover def __getitem__(self, index: int) -> str: # pragma: no cover
"""Get item.""" """Get item."""
return self.languages[index] return self.languages[index]
@ -299,36 +353,45 @@ class SelectorList(Immutable):
__slots__ = ("selectors", "is_not", "is_html", "_hash") __slots__ = ("selectors", "is_not", "is_html", "_hash")
def __init__(self, selectors=tuple(), is_not=False, is_html=False): selectors: Tuple[Union['Selector', 'SelectorNull'], ...]
is_not: bool
is_html: bool
def __init__(
self,
selectors: Optional[Iterable[Union['Selector', 'SelectorNull']]] = None,
is_not: bool = False,
is_html: bool = False
) -> None:
"""Initialize.""" """Initialize."""
super(SelectorList, self).__init__( super().__init__(
selectors=tuple(selectors), selectors=tuple(selectors) if selectors is not None else tuple(),
is_not=is_not, is_not=is_not,
is_html=is_html is_html=is_html
) )
def __iter__(self): def __iter__(self) -> Iterator[Union['Selector', 'SelectorNull']]:
"""Iterator.""" """Iterator."""
return iter(self.selectors) return iter(self.selectors)
def __len__(self): def __len__(self) -> int:
"""Length.""" """Length."""
return len(self.selectors) return len(self.selectors)
def __getitem__(self, index): def __getitem__(self, index: int) -> Union['Selector', 'SelectorNull']:
"""Get item.""" """Get item."""
return self.selectors[index] return self.selectors[index]
def _pickle(p): def _pickle(p: Any) -> Any:
return p.__base__(), tuple([getattr(p, s) for s in p.__slots__[:-1]]) return p.__base__(), tuple([getattr(p, s) for s in p.__slots__[:-1]])
def pickle_register(obj): def pickle_register(obj: Any) -> None:
"""Allow object to be pickled.""" """Allow object to be pickled."""
copyreg.pickle(obj, _pickle) copyreg.pickle(obj, _pickle)

137
lib/soupsieve/pretty.py Normal file
View file

@ -0,0 +1,137 @@
"""
Format a pretty string of a `SoupSieve` object for easy debugging.
This won't necessarily support all types and such, and definitely
not support custom outputs.
It is mainly geared towards our types as the `SelectorList`
object is a beast to look at without some indentation and newlines.
The format and various output types is fairly known (though it
hasn't been tested extensively to make sure we aren't missing corners).
Example:
```
>>> import soupsieve as sv
>>> sv.compile('this > that.class[name=value]').selectors.pretty()
SelectorList(
selectors=(
Selector(
tag=SelectorTag(
name='that',
prefix=None),
ids=(),
classes=(
'class',
),
attributes=(
SelectorAttribute(
attribute='name',
prefix='',
pattern=re.compile(
'^value$'),
xml_type_pattern=None),
),
nth=(),
selectors=(),
relation=SelectorList(
selectors=(
Selector(
tag=SelectorTag(
name='this',
prefix=None),
ids=(),
classes=(),
attributes=(),
nth=(),
selectors=(),
relation=SelectorList(
selectors=(),
is_not=False,
is_html=False),
rel_type='>',
contains=(),
lang=(),
flags=0),
),
is_not=False,
is_html=False),
rel_type=None,
contains=(),
lang=(),
flags=0),
),
is_not=False,
is_html=False)
```
"""
import re
from typing import Any
RE_CLASS = re.compile(r'(?i)[a-z_][_a-z\d\.]+\(')
RE_PARAM = re.compile(r'(?i)[_a-z][_a-z\d]+=')
RE_EMPTY = re.compile(r'\(\)|\[\]|\{\}')
RE_LSTRT = re.compile(r'\[')
RE_DSTRT = re.compile(r'\{')
RE_TSTRT = re.compile(r'\(')
RE_LEND = re.compile(r'\]')
RE_DEND = re.compile(r'\}')
RE_TEND = re.compile(r'\)')
RE_INT = re.compile(r'\d+')
RE_KWORD = re.compile(r'(?i)[_a-z][_a-z\d]+')
RE_DQSTR = re.compile(r'"(?:\\.|[^"\\])*"')
RE_SQSTR = re.compile(r"'(?:\\.|[^'\\])*'")
RE_SEP = re.compile(r'\s*(,)\s*')
RE_DSEP = re.compile(r'\s*(:)\s*')
TOKENS = {
'class': RE_CLASS,
'param': RE_PARAM,
'empty': RE_EMPTY,
'lstrt': RE_LSTRT,
'dstrt': RE_DSTRT,
'tstrt': RE_TSTRT,
'lend': RE_LEND,
'dend': RE_DEND,
'tend': RE_TEND,
'sqstr': RE_SQSTR,
'sep': RE_SEP,
'dsep': RE_DSEP,
'int': RE_INT,
'kword': RE_KWORD,
'dqstr': RE_DQSTR
}
def pretty(obj: Any) -> str: # pragma: no cover
"""Make the object output string pretty."""
sel = str(obj)
index = 0
end = len(sel) - 1
indent = 0
output = []
while index <= end:
m = None
for k, v in TOKENS.items():
m = v.match(sel, index)
if m:
name = k
index = m.end(0)
if name in ('class', 'lstrt', 'dstrt', 'tstrt'):
indent += 4
output.append('{}\n{}'.format(m.group(0), " " * indent))
elif name in ('param', 'int', 'kword', 'sqstr', 'dqstr', 'empty'):
output.append(m.group(0))
elif name in ('lend', 'dend', 'tend'):
indent -= 4
output.append(m.group(0))
elif name in ('sep',):
output.append('{}\n{}'.format(m.group(1), " " * indent))
elif name in ('dsep',):
output.append('{} '.format(m.group(1)))
break
return ''.join(output)

0
lib/soupsieve/py.typed Normal file
View file

View file

@ -2,6 +2,7 @@
from functools import wraps, lru_cache from functools import wraps, lru_cache
import warnings import warnings
import re import re
from typing import Callable, Any, Optional, Tuple, List
DEBUG = 0x00001 DEBUG = 0x00001
@ -12,7 +13,7 @@ UC_Z = ord('Z')
@lru_cache(maxsize=512) @lru_cache(maxsize=512)
def lower(string): def lower(string: str) -> str:
"""Lower.""" """Lower."""
new_string = [] new_string = []
@ -25,7 +26,7 @@ def lower(string):
class SelectorSyntaxError(Exception): class SelectorSyntaxError(Exception):
"""Syntax error in a CSS selector.""" """Syntax error in a CSS selector."""
def __init__(self, msg, pattern=None, index=None): def __init__(self, msg: str, pattern: Optional[str] = None, index: Optional[int] = None) -> None:
"""Initialize.""" """Initialize."""
self.line = None self.line = None
@ -37,30 +38,34 @@ class SelectorSyntaxError(Exception):
self.context, self.line, self.col = get_pattern_context(pattern, index) self.context, self.line, self.col = get_pattern_context(pattern, index)
msg = '{}\n line {}:\n{}'.format(msg, self.line, self.context) msg = '{}\n line {}:\n{}'.format(msg, self.line, self.context)
super(SelectorSyntaxError, self).__init__(msg) super().__init__(msg)
def deprecated(message, stacklevel=2): # pragma: no cover def deprecated(message: str, stacklevel: int = 2) -> Callable[..., Any]: # pragma: no cover
""" """
Raise a `DeprecationWarning` when wrapped function/method is called. Raise a `DeprecationWarning` when wrapped function/method is called.
Borrowed from https://stackoverflow.com/a/48632082/866026 Usage:
@deprecated("This method will be removed in version X; use Y instead.")
def some_method()"
pass
""" """
def _decorator(func): def _wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func) @wraps(func)
def _func(*args, **kwargs): def _deprecated_func(*args: Any, **kwargs: Any) -> Any:
warnings.warn( warnings.warn(
"'{}' is deprecated. {}".format(func.__name__, message), f"'{func.__name__}' is deprecated. {message}",
category=DeprecationWarning, category=DeprecationWarning,
stacklevel=stacklevel stacklevel=stacklevel
) )
return func(*args, **kwargs) return func(*args, **kwargs)
return _func return _deprecated_func
return _decorator return _wrapper
def warn_deprecated(message, stacklevel=2): # pragma: no cover def warn_deprecated(message: str, stacklevel: int = 2) -> None: # pragma: no cover
"""Warn deprecated.""" """Warn deprecated."""
warnings.warn( warnings.warn(
@ -70,14 +75,15 @@ def warn_deprecated(message, stacklevel=2): # pragma: no cover
) )
def get_pattern_context(pattern, index): def get_pattern_context(pattern: str, index: int) -> Tuple[str, int, int]:
"""Get the pattern context.""" """Get the pattern context."""
last = 0 last = 0
current_line = 1 current_line = 1
col = 1 col = 1
text = [] text = [] # type: List[str]
line = 1 line = 1
offset = None # type: Optional[int]
# Split pattern by newline and handle the text before the newline # Split pattern by newline and handle the text before the newline
for m in RE_PATTERN_LINE_SPLIT.finditer(pattern): for m in RE_PATTERN_LINE_SPLIT.finditer(pattern):