common: clean code

This commit is contained in:
Corentin Sechet 2022-04-12 11:35:10 +02:00
parent 5125e02d3c
commit 832c903e81
5 changed files with 58 additions and 100 deletions

View File

@ -88,11 +88,6 @@ async def main(
exclude_tags,
)
def _on_close() -> None:
ctx.obj.echo_error_summary()
ctx.call_on_close(_on_close)
@main.command(name="prune-caches")
@argument("cache_names", nargs=-1)
@ -110,7 +105,7 @@ async def css_diff_cli(config: Config, right_source: str) -> None:
for _, url in config.urls:
await css_diff(
url,
config.default_source,
config.source,
config.get_source(right_source),
)

View File

@ -10,9 +10,8 @@ from xdg import xdg_config_dirs, xdg_config_home
from yaml import Loader
from yaml import load as load_yaml
from frontools.cache import Cache, FileCache, NullCache
from frontools.cache import FileCache, NullCache
from frontools.sources import CachedSource, OverrideSource, Source
from frontools.utils import ErrorSummary
REMOTE_SOURCE_NAME = "remote"
@ -40,54 +39,36 @@ class Config:
def __init__(
self,
source_name: Optional[str],
use_cache: bool,
default_source_name: Optional[str],
include_urls: list[str],
exclude_urls: list[str],
include_tags: list[str],
exclude_tags: list[str],
):
self._use_cache = use_cache
self._sources: dict[str, Source] = {}
self._themes: dict[str, ThemeConfig] = {}
self._block_urls: list[Pattern[str]] = []
if default_source_name is None:
self._default_source_name = REMOTE_SOURCE_NAME
else:
self._default_source_name = default_source_name
self._add_source(
REMOTE_SOURCE_NAME,
CachedSource,
FileCache(REMOTE_SOURCE_NAME) if use_cache else NullCache(),
)
self._source_name = source_name if source_name else REMOTE_SOURCE_NAME
self._error_summary = ErrorSummary()
if use_cache:
remote_cache: Cache = FileCache(REMOTE_SOURCE_NAME)
else:
remote_cache = NullCache()
self._add_source(REMOTE_SOURCE_NAME, CachedSource, remote_cache)
self._include_urls = [re_compile(it) for it in include_urls]
self._exclude_urls = [re_compile(it) for it in exclude_urls]
self._include_tags = set(include_tags)
self._exclude_tags = set(exclude_tags)
@staticmethod
async def load(
config_path: Optional[Path],
default_source_name: Optional[str],
use_cache: bool,
include_urls: list[str],
exclude_urls: list[str],
include_tags: list[str],
exclude_tags: list[str],
) -> "Config":
"""Load config from the given path"""
config = Config(
use_cache,
default_source_name,
self._filter = _Filter(
include_urls,
exclude_urls,
include_tags,
exclude_tags,
)
@staticmethod
async def load(config_path: Optional[Path], *args: Any, **kwargs: Any) -> "Config":
"""Load config from the given path"""
config = Config(*args, **kwargs)
if config_path is None:
config_path = _find_config()
@ -112,14 +93,9 @@ class Config:
return config
@property
def remote_source(self) -> Source:
def source(self) -> Source:
"""get the default source for this context"""
return self.get_source(REMOTE_SOURCE_NAME)
@property
def default_source(self) -> Source:
"""get the default source for this context"""
return self.get_source(self._default_source_name)
return self.get_source(self._source_name)
@property
def urls(self) -> Iterable[tuple[str, str]]:
@ -129,15 +105,15 @@ class Config:
if self._filter(url, config.tags):
yield theme_name, url
def add_theme_url(
self, name: str, url: str, tags: Optional[Iterable[str]] = None
def add_url(
self, theme_name: str, url: str, tags: Optional[Iterable[str]] = None
) -> None:
"""Add an url for a theme"""
theme = self._themes.get(name, None)
theme = self._themes.get(theme_name, None)
if theme is None:
theme = ThemeConfig()
self._themes[name] = theme
self._themes[theme_name] = theme
if tags is None:
new_tags = set()
@ -157,47 +133,54 @@ class Config:
yaml_document = load_yaml(yaml_file, Loader)
for theme_name, urls in yaml_document.items():
for url, tags in urls.items():
self.add_theme_url(theme_name, url, tags)
self.add_url(theme_name, url, tags)
def block_url_patterns(self, *patterns: str) -> None:
def block_urls(self, *patterns: str) -> None:
"""Will return 500 error for urls matching this pattern."""
for pattern in patterns:
self._block_urls.append(re_compile(pattern))
def add_override_source(
def override(
self,
name: str,
source_name: str,
mappings: list[tuple[str, str]],
next_source_name: Optional[str] = None,
) -> None:
"""Add a source overriding given patterns"""
assert name not in self._sources
if next_source_name is None:
next_source = self.default_source
else:
next_source = self.get_source(next_source_name)
self._add_source(name, OverrideSource, mappings, next_source)
assert source_name not in self._sources
next_source = self.get_source(
next_source_name if next_source_name else REMOTE_SOURCE_NAME
)
self._add_source(source_name, OverrideSource, mappings, next_source)
def get_source(self, name: str) -> Source:
"""Get an alternate source in the configured ones"""
if name not in self._sources:
raise Exception(f"No source configured matching {name}")
raise ConfigError(f"No source configured matching {name}")
return self._sources[name]
def echo_error_summary(self) -> None:
"""Echo the error summary to the user."""
self._error_summary.echo()
def _add_source(
self, name: str, source_class: Type[Source], *args: Any, **kwargs: Any
) -> None:
if name in self._sources:
raise Exception(f"Source {name} already configured")
self._sources[name] = source_class(
self._error_summary, self._block_urls, *args, **kwargs
)
raise ConfigError(f"Source {name} already configured")
self._sources[name] = source_class(self._block_urls, *args, **kwargs)
def _filter(self, url: str, tags: set[str]) -> bool:
class _Filter:
def __init__(
self,
include_urls: list[str],
exclude_urls: list[str],
include_tags: list[str],
exclude_tags: list[str],
):
self._include_urls = [re_compile(it) for it in include_urls]
self._exclude_urls = [re_compile(it) for it in exclude_urls]
self._include_tags = set(include_tags)
self._exclude_tags = set(exclude_tags)
def __call__(self, url: str, tags: set[str]) -> bool:
if self._include_urls:
if all(not it.match(url) for it in self._include_urls):
return False

View File

@ -29,7 +29,7 @@ async def screenshot_diff(
output_path = Path(output_directory)
output_path.mkdir(parents=True)
left_source = config.default_source
left_source = config.source
right_source = config.get_source(right_source_name)
async with left_source.get_browser(width=screen_width) as left_browser:

View File

@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from asyncio import TimeoutError as AIOTimeoutError
from contextlib import asynccontextmanager
from logging import getLogger
from pathlib import Path
from re import Pattern
from re import compile as re_compile
@ -14,7 +15,8 @@ from playwright.async_api import TimeoutError as PlaywrightTimeoutError
from playwright.async_api import ViewportSize, async_playwright
from frontools.cache import Cache
from frontools.utils import ErrorSummary
_LOGGER = getLogger("frontools")
class Browser:
@ -46,13 +48,11 @@ class Browser:
break
except PlaywrightTimeoutError:
if retry == 2:
self._source._error_summary.add_error(
_LOGGER.error(
f"Timeout while loading {url} : retried 3 times, skipping"
)
except Error as ex:
self._source._error_summary.add_error(
f"Error while loading {url} : {ex}"
)
_LOGGER.error(f"Error while loading {url} : {ex}")
yield page
await page.close()
@ -60,10 +60,7 @@ class Browser:
class Source(ABC):
"""Base class for sources"""
def __init__(
self, error_summary: ErrorSummary, block_urls: list[Pattern[str]]
) -> None:
self._error_summary = error_summary
def __init__(self, block_urls: list[Pattern[str]]) -> None:
self._block_urls = block_urls
@abstractmethod
@ -112,11 +109,10 @@ class CachedSource(Source):
def __init__(
self,
error_summary: ErrorSummary,
block_urls: list[Pattern[str]],
cache: Cache,
) -> None:
super().__init__(error_summary, block_urls)
super().__init__(block_urls)
self._cache = cache
async def get_url(self, url: str) -> Optional[bytes]:
@ -129,7 +125,7 @@ class CachedSource(Source):
async with session.get(url) as response:
return await response.content.read()
except (ClientConnectionError, ClientPayloadError, AIOTimeoutError) as ex:
self._error_summary.add_error(f"error while loading {url} : {ex}")
_LOGGER.error(f"error while loading {url} : {ex}")
return None
@ -139,12 +135,11 @@ class OverrideSource(Source):
def __init__(
self,
error_summary: ErrorSummary,
block_urls: list[Pattern[str]],
mappings: list[tuple[str, str]],
next_source: Source,
):
super().__init__(error_summary, block_urls)
super().__init__(block_urls)
self._mappings: list[tuple[Pattern[str], str]] = []
self._next_source = next_source

View File

@ -9,21 +9,6 @@ from click import echo, progressbar
from xdg import xdg_config_home
class ErrorSummary:
def __init__(self) -> None:
self._errors: list[str] = []
def add_error(self, message: str) -> None:
self._errors.append(message)
def echo(self) -> None:
if not len(self._errors):
return
echo("***** Error summary :")
for error in self._errors:
echo(error, err=True)
TaskListType = list[tuple[str, Awaitable[None]]]