selectors: support styling with dynamic pseudo-classes

This commit is contained in:
Corentin Sechet 2023-01-17 18:13:37 +01:00
parent 4834e4aced
commit 407edaf51e
5 changed files with 236 additions and 50 deletions

View File

@ -22,7 +22,7 @@ def tests(session: Session) -> None:
def black(session: Session) -> None:
"""Check black formatting."""
session.install("black")
session.run("black", "--check", *LINT_PATHS)
session.run("black", "--line-length", "110", "--check", *LINT_PATHS)
@nox.session(reuse_venv=True)
@ -67,7 +67,7 @@ def lint(session: Session) -> None:
def checks(session: Session) -> None:
"""Run all checks"""
session.notify("lint")
session.notify("tests")
# session.notify("tests")
@nox.session(python=False)

View File

@ -11,20 +11,25 @@ DiffItem = tuple[ElementWrapper, Declaration | None, Declaration | None]
def diff(page_set: PageSet, left: Stylesheet, right: Stylesheet) -> Iterable[DiffItem]:
for element in page_set.elements:
left_style = left.style(element)
right_style = right.style(element)
left_styled_node = left.style(element)
right_styled_node = right.style(element)
left_declarations = set(left_style.keys())
right_declarations = set(right_style.keys())
states = set(left_styled_node.matching_dom_states) | set(right_styled_node.matching_dom_states)
for declaration in left_declarations - right_declarations:
yield (element, left_style[declaration], None)
for state in states:
left_style = left_styled_node.get_style(state)
right_style = right_styled_node.get_style(state)
left_declarations = set(left_style.keys())
right_declarations = set(right_style.keys())
for declaration in left_declarations & right_declarations:
left_declaration = left_style[declaration]
right_declaration = right_style[declaration]
if str(left_declaration) != str(right_declaration):
yield (element, left_style[declaration], right_style[declaration])
for declaration in left_declarations - right_declarations:
yield (element, left_style[declaration], None)
for declaration in right_declarations - left_declarations:
yield (element, None, right_style[declaration])
for declaration in left_declarations & right_declarations:
left_declaration = left_style[declaration]
right_declaration = right_style[declaration]
if str(left_declaration) != str(right_declaration):
yield (element, left_style[declaration], right_style[declaration])
for declaration in right_declarations - left_declarations:
yield (element, None, right_style[declaration])

View File

@ -1,34 +1,133 @@
from abc import ABC
from functools import cached_property
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Iterable, TypeVar, cast
from cssselect2 import ElementWrapper
from cssselect2.compiler import CompiledSelector
from cssselect2.parser import CombinedSelector as WrappedCombinedSelector
from cssselect2.parser import CompoundSelector as WrappedCompoundSelector
from cssselect2.parser import Selector as WrappedSelector
from cssselect2.parser import CombinedSelector as CSelCombinedSelector
from cssselect2.parser import CompoundSelector as CSelCompoundSelector
from cssselect2.parser import PseudoClassSelector as CSelPseudoClassSelector
from cssselect2.parser import Selector as CSelSelector
TWrapped = TypeVar("TWrapped")
TCSelSelector = TypeVar("TCSelSelector")
class Selector(Generic[TWrapped]):
def __init__(self, wrapped_selector: TWrapped) -> None:
DomState = frozenset[tuple[ElementWrapper, str]]
class Selector(ABC, Generic[TCSelSelector]):
def __init__(self, wrapped_selector: TCSelSelector) -> None:
self._wrapped_selector = wrapped_selector
@staticmethod
def wrap(selector: Any) -> "Selector[TWrapped]":
if isinstance(selector, WrappedCombinedSelector):
def wrap(selector: Any) -> "Selector[TCSelSelector]":
if isinstance(selector, CSelCombinedSelector):
return CombinedSelector(selector)
if isinstance(selector, WrappedCompoundSelector):
if isinstance(selector, CSelCompoundSelector):
return CompoundSelector(selector)
assert False
@cached_property
def compiled(self) -> CompiledSelector:
return CompiledSelector(WrappedSelector(self._wrapped_selector))
# As :hover, :focus ... are never matched by cssselect2, we strip them from the compiled
# selectors, and filter them out after matching if needed.
stripped_selector = _strip_dynamic_pseudo_classes(self._wrapped_selector)
return CompiledSelector(CSelSelector(stripped_selector))
@property
def pseudo_class(self) -> str | None:
return None
def matching_dom_states(self, node: ElementWrapper) -> Iterable[DomState]:
if self.pseudo_class:
yield frozenset([(node, self.pseudo_class)])
else:
yield frozenset()
def match_state(self, node: ElementWrapper, state: DomState) -> bool:
pseudo_class = self.pseudo_class
if pseudo_class is None:
return True
return any(node_it == node and state_it == pseudo_class for node_it, state_it in state)
def __str__(self) -> str:
return str(self._wrapped_selector)
class CompoundSelector(Selector[WrappedCompoundSelector]):
pass
class CompoundSelector(Selector[CSelCompoundSelector]):
@property
def pseudo_class(self) -> str | None:
simple_selectors = self._wrapped_selector.simple_selectors
if len(simple_selectors) < 2:
return None
last_selector = simple_selectors[-1]
if not isinstance(last_selector, CSelPseudoClassSelector):
return None
return cast(str, last_selector.name)
class CombinedSelector(Selector[WrappedCombinedSelector]):
pass
class CombinedSelector(Selector[CSelCombinedSelector]):
@cached_property
def left(self) -> Selector[Any]:
return Selector.wrap(self._wrapped_selector.left)
@property
def combinator(self) -> str:
return cast(str, self._wrapped_selector.combinator)
@cached_property
def right(self) -> Selector[Any]:
return Selector.wrap(self._wrapped_selector.right)
def matching_dom_states(self, node: ElementWrapper) -> Iterable[DomState]:
combinator = self.combinator
left = self._wrapped_selector.left
parent = node
if combinator == ">":
parent = parent.parent
elif combinator == "":
while not parent.match(left):
parent = parent.parent
elif combinator == "+":
parent = parent.previous
elif combinator == "~":
for parent in parent.iter_previous_siblings():
if parent.match(left):
break
else:
assert False
parent_contextes = list(self.left.matching_dom_states(parent))
for base_context in super().matching_dom_states(node):
for parent_context in parent_contextes:
yield parent_context | base_context
def _strip_dynamic_pseudo_classes(selector: TCSelSelector) -> TCSelSelector:
if isinstance(selector, CSelCompoundSelector):
return cast(
TCSelSelector,
CSelCompoundSelector(
[
selector
for selector in selector.simple_selectors
if not isinstance(selector, CSelPseudoClassSelector)
]
),
)
if isinstance(selector, CSelCombinedSelector):
return cast(
TCSelSelector,
CSelCombinedSelector(
_strip_dynamic_pseudo_classes(selector.left),
selector.combinator,
_strip_dynamic_pseudo_classes(selector.right),
),
)
raise Exception()

View File

@ -1,4 +1,5 @@
from functools import cached_property
from functools import cached_property, lru_cache
from itertools import chain
from pathlib import Path
from typing import IO, Any, Iterable, cast
@ -6,9 +7,45 @@ from cssselect2 import ElementWrapper, Matcher
from tinycss2 import parse_stylesheet
from stylo.nodes import Declaration, Node, QualifiedRule
from stylo.selector import Selector
from stylo.selector import DomState, Selector
from stylo.source_map import SourceMap
Match = tuple[Selector[Any], QualifiedRule]
class StyledNode:
def __init__(self, node: ElementWrapper, matches: Iterable[Match]) -> None:
self._node = node
self._matches = list(matches)
@cached_property
def matching_dom_states(self) -> set[DomState]:
states = list(
chain.from_iterable(selector.matching_dom_states(self._node) for selector, _ in self._matches)
)
return set(states)
def get_style(self, state: DomState) -> dict[str, Declaration]:
declarations: dict[str, Declaration] = {}
for selector, rule in self._matches:
if not selector.match_state(self._node, state):
continue
for declaration in rule.declarations:
name = declaration.name
previous_declaration = declarations.get(name, None)
if (
previous_declaration is not None
and previous_declaration.important
and not declaration.important
):
continue
declarations[name] = declaration
return declarations
class Stylesheet:
def __init__(self, content: IO[str]) -> None:
@ -30,25 +67,11 @@ class Stylesheet:
return list(_list())
def match(self, node: ElementWrapper) -> Iterable[tuple[Selector[Any], QualifiedRule]]:
def match(self, node: ElementWrapper) -> Iterable[Match]:
for match in self._matcher.match(node):
selector, rule = match[3]
yield cast(Selector[Any], selector), cast(QualifiedRule, rule)
def style(self, node: ElementWrapper) -> dict[str, Declaration]:
declarations: dict[str, Declaration] = {}
for _, rule in self.match(node):
for declaration in rule.declarations:
name = declaration.name
previous_declaration = declarations.get(name, None)
if (
previous_declaration is not None
and previous_declaration.important
and not declaration.important
):
continue
declarations[name] = declaration
return declarations
@lru_cache
def style(self, node: ElementWrapper) -> StyledNode:
return StyledNode(node, self.match(node))

59
tests/test_stylesheet.py Normal file
View File

@ -0,0 +1,59 @@
from io import StringIO
from cssselect2.tree import ElementWrapper
from html5lib import parse
from stylo.nodes import Declaration
from stylo.stylesheet import Stylesheet
def _load(html: str, css: str) -> tuple[ElementWrapper, Stylesheet]:
element_wrapper = ElementWrapper.from_html_root(parse(html))
return element_wrapper, Stylesheet(StringIO(css))
def _assert_style_equals(actual: dict[str, Declaration], expected: dict[str, str]) -> None:
actual_dict = {it.name: it.value for it in actual.values()}
assert actual_dict == expected
def test_match_simple_selector() -> None:
root, stylesheet = _load(
"<html><body><p class='test-class'></p></body>",
".test-class { }",
)
paragraph = root.query("p")
matches = list(stylesheet.match(paragraph))
assert len(matches) == 1
selector, _ = matches[0]
assert str(selector) == ".test-class"
def test_match_selector_pseudo_class() -> None:
root, stylesheet = _load(
"<html><body><p class='test-class'></p></body>",
".test-class:hover { }",
)
paragraph = root.query("p")
matches = list(stylesheet.match(paragraph))
assert len(matches) == 1
selector, _ = matches[0]
assert str(selector) == ".test-class:hover"
def test_style_pseudo_class() -> None:
root, stylesheet = _load(
"<html><body><a class='link'></body></html>",
"""
.link { background: blue; color: red;}
.link:hover { background: red; }
""",
)
link = root.query("a")
styled_node = stylesheet.style(link)
normal_style = styled_node.get_style(frozenset())
_assert_style_equals(normal_style, {"background": "blue", "color": "red"})
hover_style = styled_node.get_style(frozenset([(link, "hover")]))
_assert_style_equals(hover_style, {"background": "red", "color": "red"})