202 lines
7.0 KiB
Python
202 lines
7.0 KiB
Python
"""Source for remote files"""
|
|
from abc import ABC, abstractmethod
|
|
from asyncio import TimeoutError as AIOTimeoutError
|
|
from contextlib import asynccontextmanager
|
|
from logging import getLogger
|
|
from pathlib import Path
|
|
from re import Pattern
|
|
from re import compile as re_compile
|
|
from shutil import rmtree
|
|
from typing import AsyncGenerator, Optional, cast
|
|
from frontools.utils import get_url_slug
|
|
|
|
from aiohttp import ClientConnectionError, ClientPayloadError, ClientSession
|
|
from playwright.async_api import BrowserContext, Error, Page, Route
|
|
from playwright.async_api import TimeoutError as PlaywrightTimeoutError
|
|
from playwright.async_api import ViewportSize, async_playwright
|
|
from xdg import xdg_cache_home
|
|
|
|
_LOGGER = getLogger("frontools")
|
|
|
|
|
|
class Browser:
|
|
"""Wrapper around Playwright BrowserContext.
|
|
|
|
We need that to set routing on page, and not on browser context, due to a Playwright bug spamming output
|
|
with error when setting route directly on the context.
|
|
"""
|
|
|
|
def __init__(self, source: "Source", browser_context: BrowserContext) -> None:
|
|
"""Wraps a browser instance, with helpers methods to load pages."""
|
|
self._source = source
|
|
self._browser_context = browser_context
|
|
|
|
@asynccontextmanager
|
|
async def load_page(self, url: str) -> AsyncGenerator[Page, None]:
|
|
"""Retrieve a page and wait for it to be fully loaded.
|
|
|
|
@param url The url to load
|
|
|
|
@return A Playwright page, fully loaded.
|
|
"""
|
|
page = await self._browser_context.new_page()
|
|
await page.route("*", self._source.route)
|
|
for retry in range(0, 3):
|
|
try:
|
|
await page.goto(url)
|
|
await page.wait_for_load_state("networkidle")
|
|
break
|
|
except PlaywrightTimeoutError:
|
|
if retry == 2:
|
|
_LOGGER.error(
|
|
f"Timeout while loading {url} : retried 3 times, skipping"
|
|
)
|
|
except Error as ex:
|
|
_LOGGER.error(f"Error while loading {url} : {ex}")
|
|
yield page
|
|
await page.close()
|
|
|
|
|
|
class Source(ABC):
|
|
"""Base class for sources"""
|
|
|
|
def __init__(self, block_urls: list[Pattern[str]]) -> None:
|
|
self._block_urls = block_urls
|
|
|
|
@abstractmethod
|
|
async def get_url(self, url: str) -> Optional[bytes]:
|
|
"""Retrieve the given url content"""
|
|
|
|
@asynccontextmanager
|
|
async def get_browser(
|
|
self, width: Optional[int] = None, height: Optional[int] = None
|
|
) -> AsyncGenerator[Browser, None]:
|
|
"""Return a Playwright browser that will eventually get files from local cache"""
|
|
|
|
viewport: ViewportSize = cast(
|
|
ViewportSize, None
|
|
) # Playwright typings are broken
|
|
|
|
if width is not None:
|
|
viewport = dict(
|
|
# height is not used, as screenshot are taken full page
|
|
width=width,
|
|
height=600,
|
|
)
|
|
|
|
async with async_playwright() as pwright:
|
|
browser = await pwright.firefox.launch(headless=True)
|
|
context = await browser.new_context(
|
|
viewport=viewport, ignore_https_errors=True
|
|
)
|
|
yield Browser(self, context)
|
|
await browser.close()
|
|
|
|
async def route(self, route: Route) -> None:
|
|
url = route.request.url
|
|
if any([pattern.match(url) for pattern in self._block_urls]):
|
|
await route.fulfill(status=500)
|
|
else:
|
|
content = await self.get_url(url)
|
|
if content is None:
|
|
await route.abort("connectionfailed")
|
|
else:
|
|
await route.fulfill(body=content, status=200)
|
|
|
|
|
|
class CachedSource(Source):
|
|
"""Source loading urls from the internet."""
|
|
|
|
cache_base = xdg_cache_home() / "frontools"
|
|
|
|
def __init__(self, block_urls: list[Pattern[str]], name: str, disabled: bool = False) -> None:
|
|
super().__init__(block_urls)
|
|
self._name = name
|
|
self._disabled = disabled
|
|
|
|
async def get_url(self, url: str) -> Optional[bytes]:
|
|
"""Get a page content from the local or remote cache."""
|
|
if self._disabled:
|
|
return await self._load_url(url)
|
|
|
|
cache_file_path = self._get_cache_file_path(url)
|
|
if not cache_file_path.is_file():
|
|
content = await self._load_url(url)
|
|
if content is not None:
|
|
with open(cache_file_path, "wb") as cache_file:
|
|
cache_file.write(content)
|
|
else:
|
|
with open(cache_file_path, "rb") as cache_file:
|
|
content = cache_file.read()
|
|
|
|
return content
|
|
|
|
async def _load_url(self, url: str) -> Optional[bytes]:
|
|
try:
|
|
async with ClientSession() as session:
|
|
async with session.get(url) as response:
|
|
return await response.content.read()
|
|
except (ClientConnectionError, ClientPayloadError, AIOTimeoutError) as ex:
|
|
_LOGGER.error(f"error while loading {url} : {ex}")
|
|
|
|
return None
|
|
|
|
@staticmethod
|
|
def prune(cache_names: list[str]) -> None:
|
|
"""Remove caches from filesystem.
|
|
|
|
If empty list is provided, all caches will be cleaned
|
|
"""
|
|
if not cache_names:
|
|
cache_names = [
|
|
it.name for it in CachedSource.cache_base.iterdir() if it.is_dir()
|
|
]
|
|
for cache_name in cache_names:
|
|
cache_path: Path = CachedSource.cache_base / cache_name
|
|
if not cache_path.is_dir():
|
|
_LOGGER.error(f"{cache_path} isn't a chache directory")
|
|
continue
|
|
_LOGGER.info(f"Removing {cache_path}")
|
|
rmtree(cache_path)
|
|
|
|
def _get_cache_file_path(self, url: str) -> Path:
|
|
key_slug = get_url_slug(url)
|
|
cache_directory = self.cache_base / self._name
|
|
file_path = cache_directory.joinpath(*key_slug.split("&"))
|
|
file_path = file_path.parent / (file_path.name[:254] + "_")
|
|
file_directory = file_path.parent
|
|
|
|
if not file_directory.is_dir():
|
|
file_directory.mkdir(parents=True)
|
|
|
|
return file_path
|
|
|
|
|
|
class OverrideSource(Source):
|
|
"""Source overriding paths matching patterns with local files"""
|
|
|
|
def __init__(
|
|
self,
|
|
block_urls: list[Pattern[str]],
|
|
mappings: list[tuple[str, str]],
|
|
next_source: Source,
|
|
):
|
|
super().__init__(block_urls)
|
|
self._mappings: list[tuple[Pattern[str], str]] = []
|
|
self._next_source = next_source
|
|
|
|
for pattern, replace in mappings:
|
|
self._mappings.append((re_compile(pattern), replace))
|
|
|
|
async def get_url(self, url: str) -> Optional[bytes]:
|
|
"""Return local stylesheet"""
|
|
|
|
for pattern, replace in self._mappings:
|
|
if pattern.match(url):
|
|
mapped_path = Path(pattern.sub(replace, url))
|
|
if mapped_path.is_file():
|
|
with open(mapped_path, "rb") as mapped_file:
|
|
return mapped_file.read()
|
|
|
|
return await self._next_source.get_url(url)
|