async support via redis and rq
This commit is contained in:
@@ -1,8 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import typer
|
||||
|
||||
from basango.core.config import CrawlerConfig
|
||||
from basango.core.config_manager import ConfigManager
|
||||
from basango.domain import PageRange, DateRange, UpdateDirection
|
||||
from basango.domain import DateRange, PageRange, UpdateDirection
|
||||
from basango.services import CsvPersistor
|
||||
from basango.services.crawler.async_api import (
|
||||
QueueSettings,
|
||||
schedule_async_crawl,
|
||||
start_worker,
|
||||
)
|
||||
from basango.services.crawler.html_crawler import HtmlCrawler
|
||||
from basango.services.crawler.wordpress_crawler import WordpressCrawler
|
||||
|
||||
@@ -21,16 +31,34 @@ def crawl_cmd(
|
||||
category: str = typer.Option(None, "--category", "-g", help="Optional category"),
|
||||
notify: bool = typer.Option(False, "--notify", "-n", help="Enable notifications"),
|
||||
env: str = typer.Option("development", "--env", "-c", help="Environment"),
|
||||
async_mode: bool = typer.Option(
|
||||
False,
|
||||
"--async/--no-async",
|
||||
help="Schedule crawl through Redis queues instead of running synchronously.",
|
||||
),
|
||||
) -> None:
|
||||
"""Crawl a single source based on CLI-provided settings."""
|
||||
"""Crawl a single source, either synchronously or via the async queue."""
|
||||
manager = ConfigManager()
|
||||
|
||||
pipeline = manager.get(env)
|
||||
manager.ensure_directories(pipeline)
|
||||
manager.setup_logging(pipeline)
|
||||
|
||||
source = pipeline.sources.find(source_id)
|
||||
assert source is not None, f"Source '{source_id}' not found in config"
|
||||
if source is None:
|
||||
raise typer.BadParameter(f"Source '{source_id}' not found in config")
|
||||
|
||||
if async_mode:
|
||||
job_id = schedule_async_crawl(
|
||||
source_id=source_id,
|
||||
env=env,
|
||||
page_range=page,
|
||||
date_range=date,
|
||||
category=category,
|
||||
)
|
||||
typer.echo(
|
||||
f"Scheduled async crawl job {job_id} for source '{source_id}' on queue"
|
||||
)
|
||||
return
|
||||
|
||||
crawler_config = CrawlerConfig(
|
||||
source=source,
|
||||
@@ -46,8 +74,56 @@ def crawl_cmd(
|
||||
WordpressCrawler,
|
||||
]
|
||||
|
||||
source_identifier = getattr(source, "source_id", source_id) or source_id
|
||||
persistors = [
|
||||
CsvPersistor(
|
||||
data_dir=pipeline.paths.data,
|
||||
source_id=str(source_identifier),
|
||||
)
|
||||
]
|
||||
|
||||
for crawler in crawlers:
|
||||
if crawler.supports() == source.source_kind:
|
||||
crawler = crawler(crawler_config, pipeline.fetch.client)
|
||||
crawler = crawler(
|
||||
crawler_config,
|
||||
pipeline.fetch.client,
|
||||
persistors=persistors,
|
||||
)
|
||||
crawler.fetch()
|
||||
break
|
||||
|
||||
|
||||
@app.command("worker")
|
||||
def worker_cmd(
|
||||
queue: Optional[List[str]] = typer.Option(
|
||||
None,
|
||||
"--queue",
|
||||
"-q",
|
||||
help="Queue name(s) (without prefix). Provide multiple times to listen to more than one queue.",
|
||||
),
|
||||
burst: bool = typer.Option(
|
||||
False,
|
||||
"--burst",
|
||||
help="Process available jobs and exit instead of running continuously.",
|
||||
),
|
||||
redis_url: str = typer.Option(
|
||||
None,
|
||||
"--redis-url",
|
||||
help="Redis connection URL. Defaults to BASANGO_REDIS_URL.",
|
||||
),
|
||||
env: str = typer.Option(
|
||||
"development",
|
||||
"--env",
|
||||
"-c",
|
||||
help="Environment used to configure logging before starting the worker.",
|
||||
),
|
||||
) -> None:
|
||||
"""Run an RQ worker that consumes crawler queues."""
|
||||
manager = ConfigManager()
|
||||
pipeline = manager.get(env)
|
||||
manager.ensure_directories(pipeline)
|
||||
manager.setup_logging(pipeline)
|
||||
|
||||
settings = QueueSettings(redis_url=redis_url) if redis_url else QueueSettings()
|
||||
queue_names = list(queue) if queue else None
|
||||
start_worker(queue_names=queue_names, settings=settings, burst=burst)
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
from .date_parser import DateParser
|
||||
from .http_client import HttpClient
|
||||
from .http_client import BaseHttpClient, SyncHttpClient, AsyncHttpClient
|
||||
from .open_graph import OpenGraphProvider
|
||||
from .persistence import BasePersistor, CsvPersistor, JsonPersistor, ApiPersistor
|
||||
from .user_agents import UserAgentProvider
|
||||
|
||||
HttpClient = SyncHttpClient
|
||||
|
||||
__all__ = [
|
||||
"DateParser",
|
||||
"BaseHttpClient",
|
||||
"SyncHttpClient",
|
||||
"AsyncHttpClient",
|
||||
"HttpClient",
|
||||
"OpenGraphProvider",
|
||||
"UserAgentProvider",
|
||||
"BasePersistor",
|
||||
"CsvPersistor",
|
||||
"JsonPersistor",
|
||||
"ApiPersistor",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
from .queue import QueueManager, QueueSettings
|
||||
from .schemas import ListingTaskPayload, ArticleTaskPayload, ProcessedTaskPayload
|
||||
from .tasks import (
|
||||
schedule_async_crawl,
|
||||
collect_listing,
|
||||
collect_article,
|
||||
forward_for_processing,
|
||||
)
|
||||
from .worker import start_worker
|
||||
|
||||
__all__ = [
|
||||
"QueueManager",
|
||||
"QueueSettings",
|
||||
"ListingTaskPayload",
|
||||
"ArticleTaskPayload",
|
||||
"ProcessedTaskPayload",
|
||||
"schedule_async_crawl",
|
||||
"collect_listing",
|
||||
"collect_article",
|
||||
"forward_for_processing",
|
||||
"start_worker",
|
||||
]
|
||||
@@ -0,0 +1,82 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
|
||||
from redis import Redis
|
||||
from rq import Queue
|
||||
|
||||
from .schemas import (
|
||||
ArticleTaskPayload,
|
||||
ListingTaskPayload,
|
||||
ProcessedTaskPayload,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class QueueSettings:
|
||||
redis_url: str = field(
|
||||
default_factory=lambda: os.getenv( # type: ignore[arg-type]
|
||||
"BASANGO_REDIS_URL", "redis://localhost:6379/0"
|
||||
)
|
||||
)
|
||||
prefix: str = field(
|
||||
default_factory=lambda: os.getenv("BASANGO_QUEUE_PREFIX", "crawler")
|
||||
)
|
||||
default_timeout: int = field(
|
||||
default_factory=lambda: int(os.getenv("BASANGO_QUEUE_TIMEOUT", "600"))
|
||||
)
|
||||
result_ttl: int = field(
|
||||
default_factory=lambda: int(os.getenv("BASANGO_QUEUE_RESULT_TTL", "3600"))
|
||||
)
|
||||
failure_ttl: int = field(
|
||||
default_factory=lambda: int(os.getenv("BASANGO_QUEUE_FAILURE_TTL", "3600"))
|
||||
)
|
||||
listing_queue: str = "listing"
|
||||
article_queue: str = "articles"
|
||||
processed_queue: str = "processed"
|
||||
|
||||
|
||||
class QueueManager:
|
||||
def __init__(self, settings: QueueSettings | None = None) -> None:
|
||||
self.settings = settings or QueueSettings()
|
||||
self.connection = Redis.from_url(self.settings.redis_url)
|
||||
self.listing_queue = self._build_queue(self.settings.listing_queue)
|
||||
self.article_queue = self._build_queue(self.settings.article_queue)
|
||||
self.processed_queue = self._build_queue(self.settings.processed_queue)
|
||||
|
||||
def _build_queue(self, suffix: str) -> Queue:
|
||||
return Queue(
|
||||
self.queue_name(suffix),
|
||||
connection=self.connection,
|
||||
default_timeout=self.settings.default_timeout,
|
||||
result_ttl=self.settings.result_ttl,
|
||||
failure_ttl=self.settings.failure_ttl,
|
||||
)
|
||||
|
||||
def queue_name(self, suffix: str) -> str:
|
||||
return f"{self.settings.prefix}:{suffix}"
|
||||
|
||||
def enqueue_listing(self, payload: ListingTaskPayload):
|
||||
return self.listing_queue.enqueue(
|
||||
"basango.services.crawler.async.tasks.collect_listing",
|
||||
payload.to_dict(),
|
||||
)
|
||||
|
||||
def enqueue_article(self, payload: ArticleTaskPayload):
|
||||
return self.article_queue.enqueue(
|
||||
"basango.services.crawler.async.tasks.collect_article",
|
||||
payload.to_dict(),
|
||||
)
|
||||
|
||||
def enqueue_processed(self, payload: ProcessedTaskPayload):
|
||||
return self.processed_queue.enqueue(
|
||||
"basango.services.crawler.async.tasks.forward_for_processing",
|
||||
payload.to_dict(),
|
||||
)
|
||||
|
||||
def iter_queue_names(self) -> Iterable[str]:
|
||||
yield self.queue_name(self.settings.listing_queue)
|
||||
yield self.queue_name(self.settings.article_queue)
|
||||
yield self.queue_name(self.settings.processed_queue)
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, fields
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
def _coerce_kwargs(cls, data: Mapping[str, Any]) -> dict[str, Any]:
|
||||
return {field.name: data.get(field.name) for field in fields(cls)}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ListingTaskPayload:
|
||||
source_id: str
|
||||
env: str = "development"
|
||||
page_range: str | None = None
|
||||
date_range: str | None = None
|
||||
category: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "ListingTaskPayload":
|
||||
return cls(**_coerce_kwargs(cls, data))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ArticleTaskPayload:
|
||||
source_id: str
|
||||
env: str = "development"
|
||||
url: str | None = None
|
||||
data: Any | None = None
|
||||
date_range: str | None = None
|
||||
category: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "ArticleTaskPayload":
|
||||
return cls(**_coerce_kwargs(cls, data))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ProcessedTaskPayload:
|
||||
source_id: str
|
||||
env: str = "development"
|
||||
article: Mapping[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any]) -> "ProcessedTaskPayload":
|
||||
return cls(**_coerce_kwargs(cls, data))
|
||||
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from basango.core.config import CrawlerConfig
|
||||
from basango.core.config_manager import ConfigManager
|
||||
from basango.domain import DateRange, PageRange, SourceKind, UpdateDirection
|
||||
from basango.services import CsvPersistor
|
||||
from basango.services.crawler.html_crawler import HtmlCrawler
|
||||
from basango.services.crawler.wordpress_crawler import WordpressCrawler
|
||||
|
||||
from .queue import QueueManager, QueueSettings
|
||||
from .schemas import (
|
||||
ArticleTaskPayload,
|
||||
ListingTaskPayload,
|
||||
ProcessedTaskPayload,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def schedule_async_crawl(
|
||||
*,
|
||||
source_id: str,
|
||||
env: str = "development",
|
||||
page_range: str | None = None,
|
||||
date_range: str | None = None,
|
||||
category: str | None = None,
|
||||
settings: QueueSettings | None = None,
|
||||
):
|
||||
payload = ListingTaskPayload(
|
||||
source_id=source_id,
|
||||
env=env,
|
||||
page_range=page_range,
|
||||
date_range=date_range,
|
||||
category=category,
|
||||
)
|
||||
manager = QueueManager(settings=settings)
|
||||
job = manager.enqueue_listing(payload)
|
||||
logger.info("Scheduled listing collection job %s for source %s", job.id, source_id)
|
||||
return job.id
|
||||
|
||||
|
||||
def collect_listing(payload: dict[str, Any]) -> int:
|
||||
data = ListingTaskPayload.from_dict(payload)
|
||||
manager = ConfigManager()
|
||||
pipeline = manager.get(data.env)
|
||||
source = pipeline.sources.find(data.source_id)
|
||||
if source is None:
|
||||
logger.error("Unknown source id %s", data.source_id)
|
||||
return 0
|
||||
|
||||
crawler_config = CrawlerConfig(
|
||||
source=source,
|
||||
page_range=PageRange.create(data.page_range) if data.page_range else None,
|
||||
date_range=DateRange.create(data.date_range) if data.date_range else None,
|
||||
category=data.category,
|
||||
notify=False,
|
||||
direction=UpdateDirection.FORWARD,
|
||||
)
|
||||
client_config = pipeline.fetch.client
|
||||
queue_manager = QueueManager()
|
||||
|
||||
if source.source_kind == SourceKind.HTML:
|
||||
crawler = HtmlCrawler(crawler_config, client_config)
|
||||
queued = _collect_html_listing(crawler, data, queue_manager)
|
||||
elif source.source_kind == SourceKind.WORDPRESS:
|
||||
crawler = WordpressCrawler(crawler_config, client_config)
|
||||
queued = _collect_wordpress_listing(crawler, data, queue_manager)
|
||||
else:
|
||||
logger.warning(
|
||||
"Async crawling not supported for source kind %s", source.source_kind
|
||||
)
|
||||
queued = 0
|
||||
|
||||
logger.info("Queued %s article detail jobs for source %s", queued, data.source_id)
|
||||
return queued
|
||||
|
||||
|
||||
def collect_article(payload: dict[str, Any]) -> dict[str, Any] | None:
|
||||
data = ArticleTaskPayload.from_dict(payload)
|
||||
manager = ConfigManager()
|
||||
pipeline = manager.get(data.env)
|
||||
source = pipeline.sources.find(data.source_id)
|
||||
if source is None:
|
||||
logger.error("Unknown source id %s", data.source_id)
|
||||
return None
|
||||
|
||||
crawler_config = CrawlerConfig(
|
||||
source=source,
|
||||
date_range=DateRange.create(data.date_range) if data.date_range else None,
|
||||
category=data.category,
|
||||
notify=False,
|
||||
direction=UpdateDirection.FORWARD,
|
||||
)
|
||||
|
||||
source_identifier = getattr(source, "source_id", data.source_id) or data.source_id
|
||||
persistors = [
|
||||
CsvPersistor(
|
||||
data_dir=pipeline.paths.data,
|
||||
source_id=str(source_identifier),
|
||||
)
|
||||
]
|
||||
|
||||
queue_manager = QueueManager()
|
||||
|
||||
if source.source_kind == SourceKind.HTML:
|
||||
article = _collect_html_article(
|
||||
HtmlCrawler(crawler_config, pipeline.fetch.client, persistors=persistors),
|
||||
data,
|
||||
)
|
||||
elif source.source_kind == SourceKind.WORDPRESS:
|
||||
article = _collect_wordpress_article(
|
||||
WordpressCrawler(
|
||||
crawler_config, pipeline.fetch.client, persistors=persistors
|
||||
),
|
||||
data,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Async crawling not supported for source kind %s", source.source_kind
|
||||
)
|
||||
article = None
|
||||
|
||||
if article:
|
||||
queue_manager.enqueue_processed(
|
||||
ProcessedTaskPayload(
|
||||
source_id=data.source_id,
|
||||
env=data.env,
|
||||
article=article,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Persisted article %s and forwarded to processed queue",
|
||||
article.get("link"),
|
||||
)
|
||||
|
||||
return article
|
||||
|
||||
|
||||
def forward_for_processing(payload: dict[str, Any]) -> dict[str, Any] | None:
|
||||
data = ProcessedTaskPayload.from_dict(payload)
|
||||
article = dict(data.article) if data.article is not None else None
|
||||
if article is None:
|
||||
logger.info(
|
||||
"Ready for downstream processing: source=%s (no article)", data.source_id
|
||||
)
|
||||
return None
|
||||
logger.info(
|
||||
"Ready for downstream processing: source=%s link=%s",
|
||||
data.source_id,
|
||||
article.get("link"),
|
||||
)
|
||||
return article
|
||||
|
||||
|
||||
def _collect_html_listing(
|
||||
crawler: HtmlCrawler,
|
||||
payload: ListingTaskPayload,
|
||||
queue_manager: QueueManager,
|
||||
) -> int:
|
||||
source = crawler.source
|
||||
selector = source.source_selectors.articles
|
||||
if not selector:
|
||||
logger.warning(
|
||||
"No article selector configured for HTML source %s",
|
||||
source.source_id,
|
||||
)
|
||||
return 0
|
||||
|
||||
page_range = crawler.config.page_range or crawler.get_pagination()
|
||||
queued = 0
|
||||
|
||||
for page in range(page_range.start, page_range.end + 1):
|
||||
page_url = crawler._build_page_url(page)
|
||||
try:
|
||||
soup = crawler.crawl(page_url, page)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Failed to crawl page %s: %s", page_url, exc)
|
||||
continue
|
||||
|
||||
for node in soup.select(selector):
|
||||
link = crawler._extract_link(node)
|
||||
if not link:
|
||||
continue
|
||||
queue_manager.enqueue_article(
|
||||
ArticleTaskPayload(
|
||||
source_id=payload.source_id,
|
||||
env=payload.env,
|
||||
url=link,
|
||||
date_range=payload.date_range,
|
||||
category=payload.category,
|
||||
)
|
||||
)
|
||||
queued += 1
|
||||
|
||||
return queued
|
||||
|
||||
|
||||
def _collect_wordpress_listing(
|
||||
crawler: WordpressCrawler,
|
||||
payload: ListingTaskPayload,
|
||||
queue_manager: QueueManager,
|
||||
) -> int:
|
||||
page_range = crawler.config.page_range or crawler.get_pagination()
|
||||
queued = 0
|
||||
|
||||
for page in range(page_range.start, page_range.end + 1):
|
||||
endpoint = crawler._posts_endpoint(page)
|
||||
try:
|
||||
response = crawler.client.get(endpoint)
|
||||
articles = response.json()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Failed to fetch WordPress page %s: %s", endpoint, exc)
|
||||
continue
|
||||
|
||||
if not isinstance(articles, list):
|
||||
logger.warning("Unexpected WordPress payload type: %s", type(articles))
|
||||
continue
|
||||
|
||||
for entry in articles:
|
||||
queue_manager.enqueue_article(
|
||||
ArticleTaskPayload(
|
||||
source_id=payload.source_id,
|
||||
env=payload.env,
|
||||
url=entry.get("link"),
|
||||
data=entry,
|
||||
date_range=payload.date_range,
|
||||
category=payload.category,
|
||||
)
|
||||
)
|
||||
queued += 1
|
||||
|
||||
return queued
|
||||
|
||||
|
||||
def _collect_html_article(
|
||||
crawler: HtmlCrawler,
|
||||
payload: ArticleTaskPayload,
|
||||
) -> dict[str, Any] | None:
|
||||
if not payload.url:
|
||||
logger.warning("Missing article url for HTML source %s", payload.source_id)
|
||||
return None
|
||||
|
||||
crawler._current_article_url = payload.url # type: ignore[attr-defined]
|
||||
try:
|
||||
soup = crawler.crawl(payload.url)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("Failed to crawl article %s: %s", payload.url, exc)
|
||||
return None
|
||||
|
||||
crawler.fetch_one(str(soup), crawler.config.date_range)
|
||||
crawler.completed(False)
|
||||
return None
|
||||
|
||||
|
||||
def _collect_wordpress_article(
|
||||
crawler: WordpressCrawler,
|
||||
payload: ArticleTaskPayload,
|
||||
) -> dict[str, Any] | None:
|
||||
if payload.data is None:
|
||||
logger.warning("Missing WordPress payload for source %s", payload.source_id)
|
||||
return None
|
||||
|
||||
crawler.fetch_one(payload.data, crawler.config.date_range)
|
||||
crawler.completed(False)
|
||||
return None
|
||||
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Sequence
|
||||
|
||||
from rq import Queue, Worker
|
||||
|
||||
from .queue import QueueManager, QueueSettings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def start_worker(
|
||||
queue_names: Sequence[str] | None = None,
|
||||
*,
|
||||
settings: QueueSettings | None = None,
|
||||
burst: bool = False,
|
||||
) -> None:
|
||||
manager = QueueManager(settings=settings)
|
||||
if queue_names is None or not list(queue_names):
|
||||
queue_names = [manager.settings.article_queue]
|
||||
|
||||
resolved = [manager.queue_name(name) for name in queue_names]
|
||||
queues = [Queue(name, connection=manager.connection) for name in resolved]
|
||||
|
||||
logger.info("Starting RQ worker for queues %s", ", ".join(resolved))
|
||||
worker = Worker(queues, connection=manager.connection)
|
||||
worker.work(burst=burst)
|
||||
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from importlib import import_module
|
||||
from typing import Any, Sequence
|
||||
|
||||
_async_queue = import_module("basango.services.crawler.async.queue")
|
||||
_async_tasks = import_module("basango.services.crawler.async.tasks")
|
||||
_async_worker = import_module("basango.services.crawler.async.worker")
|
||||
_async_schemas = import_module("basango.services.crawler.async.schemas")
|
||||
|
||||
QueueManager = getattr(_async_queue, "QueueManager")
|
||||
QueueSettings = getattr(_async_queue, "QueueSettings")
|
||||
ListingTaskPayload = getattr(_async_schemas, "ListingTaskPayload")
|
||||
ArticleTaskPayload = getattr(_async_schemas, "ArticleTaskPayload")
|
||||
ProcessedTaskPayload = getattr(_async_schemas, "ProcessedTaskPayload")
|
||||
schedule_async_crawl = getattr(_async_tasks, "schedule_async_crawl")
|
||||
collect_listing = getattr(_async_tasks, "collect_listing")
|
||||
collect_article = getattr(_async_tasks, "collect_article")
|
||||
forward_for_processing = getattr(_async_tasks, "forward_for_processing")
|
||||
start_worker = getattr(_async_worker, "start_worker")
|
||||
|
||||
__all__ = [
|
||||
"QueueManager",
|
||||
"QueueSettings",
|
||||
"ListingTaskPayload",
|
||||
"ArticleTaskPayload",
|
||||
"ProcessedTaskPayload",
|
||||
"schedule_async_crawl",
|
||||
"collect_listing",
|
||||
"collect_article",
|
||||
"forward_for_processing",
|
||||
"start_worker",
|
||||
]
|
||||
@@ -1,24 +1,27 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from typing import Optional, Any, Dict, List
|
||||
from typing import Optional, Any, Dict, List, Sequence
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from basango.core.config import CrawlerConfig, ClientConfig
|
||||
from basango.domain import DateRange, SourceKind, PageRange
|
||||
from basango.domain.exception import ArticleOutOfRange
|
||||
from basango.services import HttpClient, DateParser, OpenGraphProvider
|
||||
from basango.services import HttpClient, DateParser, OpenGraphProvider, BasePersistor
|
||||
|
||||
|
||||
class BaseCrawler(ABC):
|
||||
def __init__(
|
||||
self, crawler_config: CrawlerConfig, client_config: ClientConfig
|
||||
self,
|
||||
crawler_config: CrawlerConfig,
|
||||
client_config: ClientConfig,
|
||||
persistors: Sequence[BasePersistor] | None = None,
|
||||
) -> None:
|
||||
self.config = crawler_config
|
||||
self.source = crawler_config.source
|
||||
self.client = HttpClient(client_config=client_config)
|
||||
self.results: List[Dict[str, Any]] = []
|
||||
self.persistors: list[BasePersistor] = list(persistors) if persistors else []
|
||||
self.date_parser = DateParser()
|
||||
self.open_graph = OpenGraphProvider()
|
||||
|
||||
@@ -49,6 +52,7 @@ class BaseCrawler(ABC):
|
||||
metadata_value = asdict(metadata)
|
||||
else:
|
||||
metadata_value = metadata
|
||||
|
||||
article = {
|
||||
"title": title,
|
||||
"link": link,
|
||||
@@ -58,7 +62,7 @@ class BaseCrawler(ABC):
|
||||
"timestamp": timestamp,
|
||||
"metadata": metadata_value,
|
||||
}
|
||||
self.results.append(article)
|
||||
self._persist(article)
|
||||
logging.info(f"> {title} [saved]")
|
||||
|
||||
@abstractmethod
|
||||
@@ -85,7 +89,8 @@ class BaseCrawler(ABC):
|
||||
logging.info("Crawling completed")
|
||||
if notify:
|
||||
logging.info("Sending notification about completion")
|
||||
# Implement notification logic here
|
||||
# TODO: Implement notification logic here
|
||||
self._shutdown_persistors()
|
||||
|
||||
@classmethod
|
||||
def skip(cls, date_range: DateRange, timestamp: str, title: str, date: str) -> None:
|
||||
@@ -93,3 +98,25 @@ class BaseCrawler(ABC):
|
||||
raise ArticleOutOfRange.create(timestamp, date_range)
|
||||
|
||||
logging.warning(f"> {title} [Skipped {date}]")
|
||||
|
||||
def _persist(self, article: Dict[str, Any]) -> None:
|
||||
for persistor in self.persistors:
|
||||
try:
|
||||
persistor.persist(article)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.exception(
|
||||
"Failed to persist article via %s: %s",
|
||||
persistor.__class__.__name__,
|
||||
exc,
|
||||
)
|
||||
|
||||
def _shutdown_persistors(self) -> None:
|
||||
for persistor in self.persistors:
|
||||
try:
|
||||
persistor.close()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.exception(
|
||||
"Failed to close persistor %s: %s",
|
||||
persistor.__class__.__name__,
|
||||
exc,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, cast, override
|
||||
from typing import Optional, cast, override, Sequence
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
from bs4 import BeautifulSoup, Tag
|
||||
@@ -11,13 +11,17 @@ from basango.core.config.source_config import HtmlSourceConfig
|
||||
from basango.domain import DateRange, PageRange, SourceKind
|
||||
from basango.domain.exception import ArticleOutOfRange
|
||||
from basango.services.crawler.base_crawler import BaseCrawler
|
||||
from basango.services import BasePersistor
|
||||
|
||||
|
||||
class HtmlCrawler(BaseCrawler):
|
||||
def __init__(
|
||||
self, crawler_config: CrawlerConfig, client_config: ClientConfig
|
||||
self,
|
||||
crawler_config: CrawlerConfig,
|
||||
client_config: ClientConfig,
|
||||
persistors: Sequence[BasePersistor] | None = None,
|
||||
) -> None:
|
||||
super().__init__(crawler_config, client_config)
|
||||
super().__init__(crawler_config, client_config, persistors=persistors)
|
||||
if not self.source or self.source.source_kind != SourceKind.HTML:
|
||||
raise ValueError("HtmlCrawler requires a source of kind HTML")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, override, cast, Final, Any
|
||||
from typing import Optional, override, cast, Final, Any, Sequence
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
@@ -9,13 +9,17 @@ from basango.core.config import WordPressSourceConfig, CrawlerConfig, ClientConf
|
||||
from basango.domain import SourceKind, PageRange, DateRange
|
||||
from basango.domain.exception import ArticleOutOfRange
|
||||
from basango.services.crawler.base_crawler import BaseCrawler
|
||||
from basango.services import BasePersistor
|
||||
|
||||
|
||||
class WordpressCrawler(BaseCrawler):
|
||||
def __init__(
|
||||
self, crawler_config: CrawlerConfig, client_config: ClientConfig
|
||||
self,
|
||||
crawler_config: CrawlerConfig,
|
||||
client_config: ClientConfig,
|
||||
persistors: Sequence[BasePersistor] | None = None,
|
||||
) -> None:
|
||||
super().__init__(crawler_config, client_config)
|
||||
super().__init__(crawler_config, client_config, persistors=persistors)
|
||||
if not self.source or self.source.source_kind != SourceKind.WORDPRESS:
|
||||
raise ValueError("WordpressCrawler requires a source of kind WORDPRESS")
|
||||
|
||||
@@ -162,9 +166,7 @@ class WordpressCrawler(BaseCrawler):
|
||||
def get_last_page(self) -> int:
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
@staticmethod
|
||||
@override
|
||||
def supports() -> SourceKind:
|
||||
return SourceKind.WORDPRESS
|
||||
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any, Optional, TypeAlias
|
||||
|
||||
import httpx
|
||||
|
||||
from basango.core.config import ClientConfig
|
||||
from basango.services.user_agents import UserAgentProvider
|
||||
|
||||
HttpHeaders: TypeAlias = dict[str, str] | None
|
||||
HttpParams: TypeAlias = dict[str, Any] | None
|
||||
HttpData: TypeAlias = Any | None
|
||||
|
||||
TRANSIENT_STATUSES = (429, 500, 502, 503, 504)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HttpClient:
|
||||
client_config: ClientConfig
|
||||
user_agent_provider: UserAgentProvider | None = None
|
||||
default_headers: HttpHeaders = None
|
||||
|
||||
def _compute_backoff(self, attempt: int) -> float:
|
||||
base = min(
|
||||
self.client_config.backoff_initial
|
||||
* (self.client_config.backoff_multiplier**attempt),
|
||||
self.client_config.backoff_max,
|
||||
)
|
||||
jitter = random.uniform(0, base * 0.25)
|
||||
return base + jitter
|
||||
|
||||
def _retry_delay(
|
||||
self, attempt: int, response: Optional[httpx.Response] = None
|
||||
) -> float:
|
||||
delay = 0.0
|
||||
|
||||
if response is not None and self.client_config.respect_retry_after:
|
||||
retry_after = (
|
||||
response.headers.get("Retry-After") if response.headers else None
|
||||
)
|
||||
if retry_after:
|
||||
try:
|
||||
delay = max(0.0, float(int(retry_after)))
|
||||
except ValueError:
|
||||
try:
|
||||
dt = parsedate_to_datetime(retry_after)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
delay = max(0.0, (dt - now).total_seconds())
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
if delay == 0.0:
|
||||
delay = self._compute_backoff(attempt)
|
||||
|
||||
return delay
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.user_agent_provider is not None:
|
||||
user_agent = self.user_agent_provider.get()
|
||||
self._user_agent = (
|
||||
user_agent if user_agent else self.client_config.user_agent
|
||||
)
|
||||
else:
|
||||
provider = UserAgentProvider(
|
||||
rotate=self.client_config.rotate,
|
||||
fallback=self.client_config.user_agent,
|
||||
)
|
||||
user_agent = provider.get()
|
||||
self._user_agent = (
|
||||
user_agent if user_agent else self.client_config.user_agent
|
||||
)
|
||||
|
||||
headers = {"User-Agent": self._user_agent}
|
||||
|
||||
if self.default_headers:
|
||||
headers.update(self.default_headers)
|
||||
|
||||
self._client = httpx.Client(
|
||||
follow_redirects=self.client_config.follow_redirects,
|
||||
max_redirects=5,
|
||||
verify=self.client_config.verify_ssl,
|
||||
timeout=self.client_config.timeout,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Context manager support -------------------------------------------------
|
||||
def __enter__(self) -> "HttpClient": # noqa: D401
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None: # noqa: D401
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
# Core request with retries ----------------------------------------------
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
headers: HttpHeaders = None,
|
||||
params: HttpParams = None,
|
||||
data: Any | None = None,
|
||||
json: Any | None = None,
|
||||
) -> httpx.Response:
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
response = self._client.request(
|
||||
method, url, headers=headers, params=params, data=data, json=json
|
||||
)
|
||||
if (
|
||||
response.status_code in TRANSIENT_STATUSES
|
||||
) and attempt < self.client_config.max_retries:
|
||||
time.sleep(self._retry_delay(attempt, response))
|
||||
attempt += 1
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except httpx.HTTPStatusError as e:
|
||||
status = e.response.status_code if e.response else 0
|
||||
if (
|
||||
status in TRANSIENT_STATUSES
|
||||
) and attempt < self.client_config.max_retries:
|
||||
time.sleep(self._retry_delay(attempt, e.response))
|
||||
attempt += 1
|
||||
continue
|
||||
raise
|
||||
except httpx.RequestError:
|
||||
if attempt < self.client_config.max_retries:
|
||||
time.sleep(self._compute_backoff(attempt))
|
||||
attempt += 1
|
||||
continue
|
||||
raise
|
||||
|
||||
# Public helpers ----------------------------------------------------------
|
||||
def get(self, url: str) -> httpx.Response:
|
||||
return self._request("GET", url)
|
||||
|
||||
def post(
|
||||
self, url: str, data: HttpData = None, json: HttpData = None
|
||||
) -> httpx.Response:
|
||||
return self._request("POST", url, data=data, json=json)
|
||||
@@ -0,0 +1,9 @@
|
||||
from .base_http_client import BaseHttpClient
|
||||
from .sync_http_client import SyncHttpClient
|
||||
from .async_http_client import AsyncHttpClient
|
||||
|
||||
__all__ = [
|
||||
"BaseHttpClient",
|
||||
"SyncHttpClient",
|
||||
"AsyncHttpClient",
|
||||
]
|
||||
@@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
from .base_http_client import (
|
||||
BaseHttpClient,
|
||||
HttpData,
|
||||
HttpHeaders,
|
||||
HttpParams,
|
||||
TRANSIENT_STATUSES,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncHttpClient(BaseHttpClient):
|
||||
_client: httpx.AsyncClient = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
self._client = httpx.AsyncClient(
|
||||
follow_redirects=self.client_config.follow_redirects,
|
||||
max_redirects=5,
|
||||
verify=self.client_config.verify_ssl,
|
||||
timeout=self.client_config.timeout,
|
||||
headers=dict(self._headers),
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> "AsyncHttpClient":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.aclose()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._client.is_closed:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError: # no running loop
|
||||
asyncio.run(self.aclose())
|
||||
else:
|
||||
loop.create_task(self.aclose())
|
||||
|
||||
async def aclose(self) -> None:
|
||||
try:
|
||||
await self._client.aclose()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
headers: HttpHeaders = None,
|
||||
params: HttpParams = None,
|
||||
data: HttpData = None,
|
||||
json: HttpData = None,
|
||||
) -> httpx.Response:
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
response = await self._client.request(
|
||||
method,
|
||||
url,
|
||||
headers=self._build_headers(headers),
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
)
|
||||
if (
|
||||
response.status_code in TRANSIENT_STATUSES
|
||||
) and attempt < self.client_config.max_retries:
|
||||
await asyncio.sleep(self._retry_delay(attempt, response))
|
||||
attempt += 1
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except httpx.HTTPStatusError as exc:
|
||||
status = exc.response.status_code if exc.response else 0
|
||||
if (
|
||||
status in TRANSIENT_STATUSES
|
||||
) and attempt < self.client_config.max_retries:
|
||||
await asyncio.sleep(self._retry_delay(attempt, exc.response))
|
||||
attempt += 1
|
||||
continue
|
||||
raise
|
||||
except httpx.RequestError:
|
||||
if attempt < self.client_config.max_retries:
|
||||
await asyncio.sleep(self._compute_backoff(attempt))
|
||||
attempt += 1
|
||||
continue
|
||||
raise
|
||||
|
||||
async def get(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
headers: HttpHeaders = None,
|
||||
params: HttpParams = None,
|
||||
) -> httpx.Response:
|
||||
return await self._request("GET", url, headers=headers, params=params)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
headers: HttpHeaders = None,
|
||||
params: HttpParams = None,
|
||||
data: HttpData = None,
|
||||
json: HttpData = None,
|
||||
) -> httpx.Response:
|
||||
return await self._request(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
)
|
||||
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any, Optional, TypeAlias
|
||||
|
||||
import httpx
|
||||
|
||||
from basango.core.config import ClientConfig
|
||||
from basango.services.user_agents import UserAgentProvider
|
||||
|
||||
HttpHeaders: TypeAlias = dict[str, str] | None
|
||||
HttpParams: TypeAlias = dict[str, Any] | None
|
||||
HttpData: TypeAlias = Any | None
|
||||
|
||||
TRANSIENT_STATUSES = (429, 500, 502, 503, 504)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseHttpClient(ABC):
|
||||
client_config: ClientConfig
|
||||
user_agent_provider: UserAgentProvider | None = None
|
||||
default_headers: HttpHeaders = None
|
||||
_user_agent: str = field(init=False, repr=False)
|
||||
_headers: dict[str, str] = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
provider = self.user_agent_provider or UserAgentProvider(
|
||||
rotate=self.client_config.rotate,
|
||||
fallback=self.client_config.user_agent,
|
||||
)
|
||||
user_agent = provider.get()
|
||||
self._user_agent = user_agent if user_agent else self.client_config.user_agent
|
||||
|
||||
headers = {"User-Agent": self._user_agent}
|
||||
if self.default_headers:
|
||||
headers.update(self.default_headers)
|
||||
self._headers = headers
|
||||
|
||||
def _compute_backoff(self, attempt: int) -> float:
|
||||
base = min(
|
||||
self.client_config.backoff_initial
|
||||
* (self.client_config.backoff_multiplier**attempt),
|
||||
self.client_config.backoff_max,
|
||||
)
|
||||
jitter = random.uniform(0, base * 0.25)
|
||||
return base + jitter
|
||||
|
||||
def _retry_delay(
|
||||
self, attempt: int, response: Optional[httpx.Response] = None
|
||||
) -> float:
|
||||
delay = 0.0
|
||||
if response is not None and self.client_config.respect_retry_after:
|
||||
retry_after = (
|
||||
response.headers.get("Retry-After") if response.headers else None
|
||||
)
|
||||
if retry_after:
|
||||
delay = self._parse_retry_after(retry_after)
|
||||
|
||||
if delay == 0.0:
|
||||
delay = self._compute_backoff(attempt)
|
||||
return delay
|
||||
|
||||
@staticmethod
|
||||
def _parse_retry_after(header_value: str) -> float:
|
||||
try:
|
||||
return max(0.0, float(int(header_value)))
|
||||
except (TypeError, ValueError):
|
||||
try:
|
||||
dt = parsedate_to_datetime(header_value)
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
now = datetime.now(timezone.utc)
|
||||
return max(0.0, (dt - now).total_seconds())
|
||||
except Exception: # noqa: BLE001
|
||||
return 0.0
|
||||
|
||||
def _build_headers(self, headers: HttpHeaders = None) -> dict[str, str]:
|
||||
merged = dict(self._headers)
|
||||
if headers:
|
||||
merged.update(headers)
|
||||
return merged
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None: # pragma: no cover - enforced by subclasses
|
||||
"""Close the underlying HTTPX client."""
|
||||
@@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
|
||||
from .base_http_client import (
|
||||
BaseHttpClient,
|
||||
HttpData,
|
||||
HttpHeaders,
|
||||
HttpParams,
|
||||
TRANSIENT_STATUSES,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncHttpClient(BaseHttpClient):
|
||||
_client: httpx.Client = field(init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
self._client = httpx.Client(
|
||||
follow_redirects=self.client_config.follow_redirects,
|
||||
max_redirects=5,
|
||||
verify=self.client_config.verify_ssl,
|
||||
timeout=self.client_config.timeout,
|
||||
headers=dict(self._headers),
|
||||
)
|
||||
|
||||
def __enter__(self) -> "SyncHttpClient":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
try:
|
||||
self._client.close()
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
*,
|
||||
headers: HttpHeaders = None,
|
||||
params: HttpParams = None,
|
||||
data: HttpData = None,
|
||||
json: HttpData = None,
|
||||
) -> httpx.Response:
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
response = self._client.request(
|
||||
method,
|
||||
url,
|
||||
headers=self._build_headers(headers),
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
)
|
||||
if (
|
||||
response.status_code in TRANSIENT_STATUSES
|
||||
) and attempt < self.client_config.max_retries:
|
||||
time.sleep(self._retry_delay(attempt, response))
|
||||
attempt += 1
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except httpx.HTTPStatusError as exc:
|
||||
status = exc.response.status_code if exc.response else 0
|
||||
if (
|
||||
status in TRANSIENT_STATUSES
|
||||
) and attempt < self.client_config.max_retries:
|
||||
time.sleep(self._retry_delay(attempt, exc.response))
|
||||
attempt += 1
|
||||
continue
|
||||
raise
|
||||
except httpx.RequestError:
|
||||
if attempt < self.client_config.max_retries:
|
||||
time.sleep(self._compute_backoff(attempt))
|
||||
attempt += 1
|
||||
continue
|
||||
raise
|
||||
|
||||
def get(
|
||||
self, url: str, *, headers: HttpHeaders = None, params: HttpParams = None
|
||||
) -> httpx.Response:
|
||||
return self._request("GET", url, headers=headers, params=params)
|
||||
|
||||
def post(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
headers: HttpHeaders = None,
|
||||
params: HttpParams = None,
|
||||
data: HttpData = None,
|
||||
json: HttpData = None,
|
||||
) -> httpx.Response:
|
||||
return self._request(
|
||||
"POST",
|
||||
url,
|
||||
headers=headers,
|
||||
params=params,
|
||||
data=data,
|
||||
json=json,
|
||||
)
|
||||
@@ -5,7 +5,7 @@ from typing import Optional
|
||||
import trafilatura
|
||||
|
||||
from basango.core.config import ClientConfig
|
||||
from basango.services.http_client import HttpClient
|
||||
from basango.services.http_client import SyncHttpClient
|
||||
from basango.services.user_agents import UserAgentProvider
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class OpenGraphProvider:
|
||||
self, user_agent_provider: UserAgentProvider = UserAgentProvider(rotate=False)
|
||||
) -> None:
|
||||
self._user_agent = user_agent_provider.og()
|
||||
self._http_client = HttpClient(
|
||||
self._http_client = SyncHttpClient(
|
||||
client_config=ClientConfig(),
|
||||
default_headers={"User-Agent": self._user_agent},
|
||||
)
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from .base_persistor import BasePersistor
|
||||
from .csv_persistor import CsvPersistor
|
||||
from .json_persistor import JsonPersistor
|
||||
from .api_persistor import ApiPersistor
|
||||
|
||||
__all__ = [
|
||||
"BasePersistor",
|
||||
"CsvPersistor",
|
||||
"JsonPersistor",
|
||||
"ApiPersistor",
|
||||
]
|
||||
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Mapping
|
||||
|
||||
from basango.services.http_client import SyncHttpClient
|
||||
|
||||
from .base_persistor import BasePersistor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiPersistor(BasePersistor):
|
||||
endpoint: str
|
||||
http_client: SyncHttpClient
|
||||
headers: dict[str, str] | None = None
|
||||
raise_for_status: bool = True
|
||||
|
||||
def persist(self, article: Mapping[str, Any]) -> None:
|
||||
try:
|
||||
response = self.http_client.post(
|
||||
self.endpoint,
|
||||
json=article,
|
||||
headers=self.headers,
|
||||
)
|
||||
if self.raise_for_status:
|
||||
response.raise_for_status()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.exception(
|
||||
"[ApiPersistor] Failed to persist article at %s: %s",
|
||||
self.endpoint,
|
||||
exc,
|
||||
)
|
||||
if self.raise_for_status:
|
||||
raise
|
||||
@@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Mapping, Any
|
||||
|
||||
|
||||
class BasePersistor(ABC):
|
||||
"""Abstract interface for article persistence backends."""
|
||||
|
||||
@abstractmethod
|
||||
def persist(self, article: Mapping[str, Any]) -> None:
|
||||
"""Persist a single article payload."""
|
||||
|
||||
def close(self) -> None: # pragma: no cover - optional override
|
||||
"""Hook for subclasses that need explicit shutdown."""
|
||||
return None
|
||||
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
from .base_persistor import BasePersistor
|
||||
|
||||
|
||||
DEFAULT_FIELDS = (
|
||||
"title",
|
||||
"link",
|
||||
"body",
|
||||
"categories",
|
||||
"source",
|
||||
"timestamp",
|
||||
"metadata",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CsvPersistor(BasePersistor):
|
||||
data_dir: Path
|
||||
source_id: str
|
||||
fieldnames: Sequence[str] = DEFAULT_FIELDS
|
||||
encoding: str = "utf-8"
|
||||
_file_path: Path = field(init=False, repr=False)
|
||||
_lock: Lock = field(default_factory=Lock, init=False, repr=False)
|
||||
_header_written: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._file_path = self.data_dir / f"{self.source_id}.csv"
|
||||
if self._file_path.exists() and self._file_path.stat().st_size > 0:
|
||||
self._header_written = True
|
||||
|
||||
def persist(self, article: Mapping[str, Any]) -> None:
|
||||
record = self._serialise(article)
|
||||
with self._lock:
|
||||
needs_header = not self._header_written or not self._file_path.exists()
|
||||
with self._file_path.open(
|
||||
"a", newline="", encoding=self.encoding
|
||||
) as handle:
|
||||
writer = csv.DictWriter(handle, fieldnames=self.fieldnames)
|
||||
if needs_header:
|
||||
writer.writeheader()
|
||||
self._header_written = True
|
||||
writer.writerow(record)
|
||||
|
||||
def _serialise(self, article: Mapping[str, Any]) -> dict[str, Any]:
|
||||
categories = article.get("categories")
|
||||
if isinstance(categories, (list, tuple)):
|
||||
serialised_categories = ";".join(str(item) for item in categories)
|
||||
else:
|
||||
serialised_categories = categories
|
||||
|
||||
metadata = article.get("metadata")
|
||||
if metadata is None or isinstance(metadata, str):
|
||||
serialised_metadata = metadata
|
||||
else:
|
||||
serialised_metadata = json.dumps(metadata, ensure_ascii=False)
|
||||
|
||||
record = {field: article.get(field) for field in self.fieldnames}
|
||||
record["categories"] = serialised_categories
|
||||
record["metadata"] = serialised_metadata
|
||||
return record
|
||||
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Mapping
|
||||
|
||||
from .base_persistor import BasePersistor
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonPersistor(BasePersistor):
|
||||
data_dir: Path
|
||||
source_id: str
|
||||
suffix: str = ".jsonl"
|
||||
encoding: str = "utf-8"
|
||||
_file_path: Path = field(init=False, repr=False)
|
||||
_lock: Lock = field(default_factory=Lock, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._file_path = self.data_dir / f"{self.source_id}{self.suffix}"
|
||||
|
||||
def persist(self, article: Mapping[str, Any]) -> None:
|
||||
payload = json.dumps(article, ensure_ascii=False)
|
||||
with self._lock:
|
||||
with self._file_path.open("a", encoding=self.encoding) as handle:
|
||||
handle.write(payload)
|
||||
handle.write("\n")
|
||||
Reference in New Issue
Block a user