diff --git a/astrbot/core/collection/__init__.py b/astrbot/core/collection/__init__.py new file mode 100644 index 000000000..b2167ef3a --- /dev/null +++ b/astrbot/core/collection/__init__.py @@ -0,0 +1,9 @@ +"""Plugin collection module.""" + +from .models import CollectionMetadata, CollectionPlugin, PluginCollection + +__all__ = [ + "CollectionMetadata", + "CollectionPlugin", + "PluginCollection", +] diff --git a/astrbot/core/collection/compatibility.py b/astrbot/core/collection/compatibility.py new file mode 100644 index 000000000..ff64a06ba --- /dev/null +++ b/astrbot/core/collection/compatibility.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from astrbot.core import logger, sp +from astrbot.core.star.star_handler import star_handlers_registry + + +@dataclass(slots=True) +class PriorityApplyResult: + persisted: bool + applied_in_memory: bool + + +class PriorityCompatibility: + """Compatibility layer for handler priority overrides (PR4716).""" + + @staticmethod + async def get_priority_overrides() -> dict[str, int]: + try: + overrides = await sp.global_get("handler_priority_overrides", {}) + if isinstance(overrides, dict) and overrides: + return {str(k): int(v) for k, v in overrides.items()} + except Exception as e: + logger.debug(f"Failed to read handler_priority_overrides from sp: {e!s}") + + overrides: dict[str, int] = {} + for name, handler in star_handlers_registry.star_handlers_map.items(): + try: + priority = int(handler.extras_configs.get("priority", 0) or 0) + except Exception: + priority = 0 + if priority != 0: + overrides[name] = priority + return overrides + + @staticmethod + async def apply_priority_overrides( + overrides: dict[str, int], + ) -> PriorityApplyResult: + normalized: dict[str, int] = { + str(k): int(v) for k, v in (overrides or {}).items() + } + + try: + await sp.global_put("handler_priority_overrides", normalized) + return PriorityApplyResult(persisted=True, applied_in_memory=False) + except Exception as e: + logger.debug(f"Failed to write handler_priority_overrides to sp: {e!s}") + + for name, priority in normalized.items(): + handler = star_handlers_registry.star_handlers_map.get(name) + if handler is None: + continue + handler.extras_configs["priority"] = int(priority) + + # Best-effort compatibility: star_handlers_registry internals may change upstream. + try: + handlers = getattr(star_handlers_registry, "_handlers", None) + if isinstance(handlers, list): + handlers.sort( + key=lambda h: -int(h.extras_configs.get("priority", 0) or 0), + ) + except Exception as e: + logger.debug( + f"Failed to sort handler registry after overrides (best effort): {e!s}" + ) + + return PriorityApplyResult(persisted=False, applied_in_memory=True) + + @staticmethod + async def is_pr4716_available() -> bool: + try: + await sp.global_get("handler_priority_overrides", {}) + except Exception: + return False + + try: + _ = star_handlers_registry.star_handlers_map + except Exception: + return False + + return True + + +class ConflictDetectionCompatibility: + """Compatibility layer for conflict detection (PR4451).""" + + @staticmethod + async def check_conflicts(plugins: list[str]) -> dict[str, Any] | None: + try: + from astrbot.core.star.conflict_detection import ( + detect_conflicts, # type: ignore + ) + + return await detect_conflicts(plugins) + except ImportError: + return None + except Exception as e: + logger.warning(f"Conflict detection failed: {e!s}") + return None + + @staticmethod + def is_conflict_detection_available() -> bool: + try: + from astrbot.core.star.conflict_detection import ( + detect_conflicts, # type: ignore + ) + + _ = detect_conflicts + return True + except ImportError: + return False diff --git a/astrbot/core/collection/exporter.py b/astrbot/core/collection/exporter.py new file mode 100644 index 000000000..02c023f75 --- /dev/null +++ b/astrbot/core/collection/exporter.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any + +from astrbot.core import logger +from astrbot.core.config.default import VERSION +from astrbot.core.star.star_manager import PluginManager + +from .compatibility import PriorityCompatibility +from .models import CollectionMetadata, CollectionPlugin, PluginCollection +from .sensitive_filter import SensitiveFilter + + +@dataclass(slots=True) +class ExportOptions: + name: str + description: str = "" + author: str = "" + version: str = "1.0.0" + include_configs: bool = True + include_priority: bool = True + exclude_plugins: list[str] | None = None + + +class CollectionExporter: + def __init__(self, plugin_manager: PluginManager) -> None: + self.plugin_manager = plugin_manager + + async def export(self, options: ExportOptions) -> dict[str, Any]: + exclude = set(options.exclude_plugins or []) + + plugins: list[CollectionPlugin] = [] + plugin_configs: dict[str, Any] | None = {} if options.include_configs else None + + for plugin in self.plugin_manager.context.get_all_stars(): + if plugin.name in exclude: + continue + if not plugin.repo: + continue + + if not plugin.name: + continue + + plugins.append( + CollectionPlugin( + name=str(plugin.name), + repo=plugin.repo, + exported_version=plugin.version or "", + display_name=plugin.display_name, + ), + ) + + if options.include_configs and plugin_configs is not None: + cfg = getattr(plugin, "config", None) + if cfg is None: + continue + try: + cfg_dict = dict(cfg) + except Exception as e: + # Log the exception to help diagnose config conversion issues + logger.warning( + f"Failed to convert config for plugin {plugin.name}: {e}" + ) + cfg_dict = {} + + # Some plugins may have a missing/None name in edge cases; avoid using None as dict key. + if not plugin.name: + continue + plugin_configs[str(plugin.name)] = SensitiveFilter.filter_data(cfg_dict) + + created_at = datetime.now(timezone.utc).isoformat() + metadata = CollectionMetadata( + name=options.name, + description=options.description, + author=options.author, + version=options.version, + created_at=created_at, + astrbot_version=VERSION, + plugin_count=len(plugins), + ) + + handler_priority_overrides: dict[str, int] | None = None + if options.include_priority: + handler_priority_overrides = ( + await PriorityCompatibility.get_priority_overrides() + ) + + collection = PluginCollection( + schema_version=PluginCollection.SCHEMA_VERSION, + metadata=metadata, + plugins=plugins, + plugin_configs=plugin_configs, + handler_priority_overrides=handler_priority_overrides, + ) + + payload = collection.to_dict() + PluginCollection.validate_dict(payload) + return payload diff --git a/astrbot/core/collection/importer.py b/astrbot/core/collection/importer.py new file mode 100644 index 000000000..bd6ce6fee --- /dev/null +++ b/astrbot/core/collection/importer.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Any + +from astrbot.core import logger +from astrbot.core.star.star_manager import PluginManager + +from .compatibility import ConflictDetectionCompatibility, PriorityCompatibility +from .models import ImportOptions, PluginCollection + + +class CollectionImporter: + def __init__(self, plugin_manager: PluginManager) -> None: + self.plugin_manager = plugin_manager + + async def preview( + self, collection: PluginCollection, *, import_mode: str + ) -> dict[str, Any]: + installed = self._get_installed_names() + + plugins_to_install = [] + plugins_to_skip = [] + for p in collection.plugins: + if p.name in installed: + plugins_to_skip.append( + {"name": p.name, "repo": p.repo, "status": "installed"} + ) + else: + plugins_to_install.append( + { + "name": p.name, + "repo": p.repo, + "exported_version": p.exported_version, + "status": "not_installed", + }, + ) + + plugins_to_uninstall: list[dict[str, Any]] = [] + if import_mode == "clean": + keep = {p.name for p in collection.plugins} + for p in self.plugin_manager.context.get_all_stars(): + if p.reserved: + continue + if p.name in keep: + continue + plugins_to_uninstall.append({"name": p.name}) + + configs_count = 0 + if isinstance(collection.plugin_configs, dict): + configs_count = len(collection.plugin_configs) + + return { + "metadata": collection.metadata.to_dict(), + "plugins_to_install": plugins_to_install, + "plugins_to_skip": plugins_to_skip, + "plugins_to_uninstall": plugins_to_uninstall, + "configs_count": configs_count, + "has_priority_overrides": bool(collection.handler_priority_overrides), + "conflict_detection_available": ConflictDetectionCompatibility.is_conflict_detection_available(), + } + + def _get_installed_names(self) -> set[str]: + return {p.name for p in self.plugin_manager.context.get_all_stars() if p.name} + + async def _run_install_tasks( + self, tasks: list[Awaitable[dict[str, Any]]] + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + installed_results: list[dict[str, Any]] = [] + failed_results: list[dict[str, Any]] = [] + + raw = await asyncio.gather(*tasks, return_exceptions=True) + for r in raw: + if isinstance(r, asyncio.CancelledError): + raise r + if isinstance(r, BaseException): + failed_results.append( + {"name": "unknown", "status": "error", "message": str(r)} + ) + elif r.get("status") == "ok": + installed_results.append(r) + else: + failed_results.append(r) + + return installed_results, failed_results + + def _build_install_task_factory( + self, *, proxy: str + ) -> Callable[[str, str], Awaitable[dict[str, Any]]]: + sem = asyncio.Semaphore(3) + + async def _install_one(name: str, repo: str) -> dict[str, Any]: + async with sem: + try: + await self.plugin_manager.install_plugin(repo, proxy) + return { + "name": name, + "status": "ok", + "message": "installed", + "exported_version_note": ( + "Recorded for reference only; collection import does not lock plugin version." + ), + } + except Exception as e: + return {"name": name, "status": "error", "message": str(e)} + + return _install_one + + async def _uninstall_for_clean_mode( + self, *, keep_set: set[str], import_mode: str + ) -> dict[str, Any]: + if import_mode != "clean": + return {"ok": [], "failed": [], "skipped": []} + + ok: list[dict[str, Any]] = [] + failed: list[dict[str, Any]] = [] + skipped: list[dict[str, Any]] = [] + + for p in list(self.plugin_manager.context.get_all_stars()): + if p.reserved: + skipped.append( + {"name": p.name, "status": "skipped", "reason": "reserved"} + ) + continue + if not p.name: + skipped.append( + {"name": p.name, "status": "skipped", "reason": "empty_name"} + ) + continue + if p.name in keep_set: + skipped.append({"name": p.name, "status": "skipped", "reason": "kept"}) + continue + + try: + await self.plugin_manager.uninstall_plugin(p.name) + ok.append({"name": p.name, "status": "ok", "message": "uninstalled"}) + except Exception as e: + logger.error(f"Uninstall plugin failed ({p.name}): {e!s}") + failed.append({"name": p.name, "status": "error", "message": str(e)}) + + return {"ok": ok, "failed": failed, "skipped": skipped} + + async def _install_plugins( + self, + collection: PluginCollection, + *, + options: ImportOptions, + installed_before: set[str], + ) -> dict[str, Any]: + ok: list[dict[str, Any]] = [] + failed: list[dict[str, Any]] = [] + skipped: list[dict[str, Any]] = [] + + install_one = self._build_install_task_factory(proxy=options.proxy) + tasks = [] + for p in collection.plugins: + if not p.name: + skipped.append( + {"name": p.name, "status": "skipped", "reason": "empty_name"} + ) + continue + if options.import_mode == "add" and p.name in installed_before: + skipped.append( + {"name": p.name, "status": "skipped", "reason": "already_installed"} + ) + continue + tasks.append(install_one(p.name, p.repo)) + + installed_results, failed_results = await self._run_install_tasks(tasks) + ok.extend(installed_results) + failed.extend(failed_results) + + return {"ok": ok, "failed": failed, "skipped": skipped} + + async def _apply_configs( + self, + collection: PluginCollection, + *, + options: ImportOptions, + installed_before: set[str], + ) -> dict[str, Any]: + if not options.apply_configs or not isinstance(collection.plugin_configs, dict): + return {"ok": [], "failed": [], "skipped": [], "reload_queue": []} + + ok: list[dict[str, Any]] = [] + failed: list[dict[str, Any]] = [] + skipped: list[dict[str, Any]] = [] + reload_queue: list[str] = [] + + for plugin_name, cfg in collection.plugin_configs.items(): + if not isinstance(cfg, dict): + skipped.append( + { + "name": str(plugin_name), + "status": "skipped", + "reason": "invalid_config", + } + ) + continue + + md = self.plugin_manager.context.get_registered_star(plugin_name) + config = getattr(md, "config", None) if md is not None else None + if config is None: + skipped.append( + {"name": plugin_name, "status": "skipped", "reason": "no_config"} + ) + continue + + is_existing = plugin_name in installed_before + if ( + options.import_mode == "add" + and is_existing + and not options.overwrite_existing_configs + ): + skipped.append( + { + "name": plugin_name, + "status": "skipped", + "reason": "existing_config_not_overwritten", + } + ) + continue + + try: + current_cfg = dict(config) + except Exception: + current_cfg = {} + + merged_cfg = {**current_cfg, **cfg} + + try: + config.save_config(merged_cfg) + ok.append( + {"name": plugin_name, "status": "ok", "message": "config_applied"} + ) + if plugin_name not in reload_queue: + reload_queue.append(plugin_name) + except Exception as e: + logger.error(f"Apply config failed ({plugin_name}): {e!s}") + failed.append( + {"name": plugin_name, "status": "error", "message": str(e)} + ) + + return { + "ok": ok, + "failed": failed, + "skipped": skipped, + "reload_queue": reload_queue, + } + + async def _reload_plugins(self, *, reload_queue: list[str]) -> dict[str, Any]: + ok: list[dict[str, Any]] = [] + failed: list[dict[str, Any]] = [] + + for plugin_name in reload_queue: + try: + await self.plugin_manager.reload(plugin_name) + ok.append({"name": plugin_name, "status": "ok", "message": "reloaded"}) + except Exception as e: + logger.error(f"Reload plugin failed ({plugin_name}): {e!s}") + failed.append( + {"name": plugin_name, "status": "error", "message": str(e)} + ) + + return {"ok": ok, "failed": failed} + + async def _apply_priority_overrides( + self, collection: PluginCollection, *, options: ImportOptions + ) -> dict[str, Any]: + if not options.apply_priority or not isinstance( + collection.handler_priority_overrides, dict + ): + return { + "ok": [], + "failed": [], + "skipped": [ + { + "name": "priority", + "status": "skipped", + "reason": "disabled_or_missing", + } + ], + "priority_persisted": False, + "priority_applied_in_memory": False, + "priority_note": "", + } + + apply_result = await PriorityCompatibility.apply_priority_overrides( + collection.handler_priority_overrides, + ) + priority_note = "" + if apply_result.applied_in_memory and not apply_result.persisted: + priority_note = "Priority overrides could not be persisted; applied in memory only for this process." + + return { + "ok": [{"name": "priority", "status": "ok", "message": "applied"}], + "failed": [], + "skipped": [], + "priority_persisted": apply_result.persisted, + "priority_applied_in_memory": apply_result.applied_in_memory, + "priority_note": priority_note, + } + + async def import_collection( + self, collection: PluginCollection, options: ImportOptions + ) -> dict[str, Any]: + if options.import_mode not in {"add", "clean"}: + raise ValueError("import_mode must be 'add' or 'clean'") + + installed_before = self._get_installed_names() + + conflict_report = await ConflictDetectionCompatibility.check_conflicts( + [p.name for p in collection.plugins], + ) + + keep = {p.name for p in collection.plugins if p.name} + uninstall_result = await self._uninstall_for_clean_mode( + keep_set=keep, + import_mode=options.import_mode, + ) + + install_result = await self._install_plugins( + collection, + options=options, + installed_before=installed_before, + ) + + config_result = await self._apply_configs( + collection, + options=options, + installed_before=installed_before, + ) + + reload_queue = list(config_result.get("reload_queue") or []) + reload_result = await self._reload_plugins(reload_queue=reload_queue) + + priority_result = await self._apply_priority_overrides( + collection, + options=options, + ) + + uninstalled = list(uninstall_result.get("ok") or []) + uninstall_failed = [ + i.get("name") + for i in (uninstall_result.get("failed") or []) + if i.get("name") + ] + + configs_applied = len(config_result.get("ok") or []) + configs_failed = [ + {"name": i.get("name"), "message": i.get("message")} + for i in (config_result.get("failed") or []) + if i.get("name") + ] + + result: dict[str, Any] = { + "installed": list(install_result.get("ok") or []), + "failed": list(install_result.get("failed") or []), + "skipped": list(install_result.get("skipped") or []), + "uninstalled": uninstalled, + "uninstall_failed": uninstall_failed, + "configs_applied": configs_applied, + "configs_failed": configs_failed, + "reloaded": reload_queue, + "reload_failed": [ + {"name": i.get("name"), "message": i.get("message")} + for i in (reload_result.get("failed") or []) + if i.get("name") + ], + "priority_persisted": bool(priority_result.get("priority_persisted")), + "priority_applied_in_memory": bool( + priority_result.get("priority_applied_in_memory") + ), + } + + priority_note = str(priority_result.get("priority_note") or "") + if priority_note: + result["priority_note"] = priority_note + if conflict_report is not None: + result["conflicts"] = conflict_report + + return result diff --git a/astrbot/core/collection/models.py b/astrbot/core/collection/models.py new file mode 100644 index 000000000..4f89c22c5 --- /dev/null +++ b/astrbot/core/collection/models.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, ClassVar + +from jsonschema import Draft7Validator + + +@dataclass(slots=True) +class ImportOptions: + import_mode: str = "add" # add | clean + apply_configs: bool = True + overwrite_existing_configs: bool = False + apply_priority: bool = True + proxy: str = "" + + +class CollectionValidationError(ValueError): + """Raised when a plugin collection payload is invalid.""" + + +@dataclass(slots=True) +class CollectionPlugin: + name: str + repo: str + exported_version: str + display_name: str | None = None + + def to_dict(self) -> dict[str, Any]: + data: dict[str, Any] = { + "name": self.name, + "repo": self.repo, + "exported_version": self.exported_version, + } + if self.display_name is not None: + data["display_name"] = self.display_name + return data + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> CollectionPlugin: + exported_version = data.get("exported_version") + if exported_version is None: + # Backward compatibility for schema 1.0 payloads. + exported_version = data.get("version", "") + return cls( + name=str(data.get("name", "")), + repo=str(data.get("repo", "")), + exported_version=str(exported_version or ""), + display_name=( + str(data["display_name"]) + if data.get("display_name") is not None + else None + ), + ) + + +@dataclass(slots=True) +class CollectionMetadata: + name: str + description: str + author: str + version: str + created_at: str + astrbot_version: str + plugin_count: int + + def to_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "description": self.description, + "author": self.author, + "version": self.version, + "created_at": self.created_at, + "astrbot_version": self.astrbot_version, + "plugin_count": self.plugin_count, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> CollectionMetadata: + return cls( + name=str(data.get("name", "")), + description=str(data.get("description", "")), + author=str(data.get("author", "")), + version=str(data.get("version", "")), + created_at=str(data.get("created_at", "")), + astrbot_version=str(data.get("astrbot_version", "")), + plugin_count=int(data.get("plugin_count", 0) or 0), + ) + + +@dataclass(slots=True) +class PluginCollection: + schema_version: str + metadata: CollectionMetadata + plugins: list[CollectionPlugin] + plugin_configs: dict[str, Any] | None = None + handler_priority_overrides: dict[str, int] | None = None + + SCHEMA_VERSION: ClassVar[str] = "1.0" + + JSON_SCHEMA: ClassVar[dict[str, Any]] = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "required": ["schema_version", "metadata", "plugins"], + "additionalProperties": False, + "properties": { + "schema_version": {"type": "string", "enum": ["1.0"]}, + "metadata": { + "type": "object", + "required": [ + "name", + "description", + "author", + "version", + "created_at", + "astrbot_version", + "plugin_count", + ], + "additionalProperties": False, + "properties": { + "name": {"type": "string", "maxLength": 100}, + "description": {"type": "string", "maxLength": 500}, + "author": {"type": "string"}, + "version": {"type": "string"}, + "created_at": {"type": "string", "format": "date-time"}, + "astrbot_version": {"type": "string"}, + "plugin_count": {"type": "integer", "minimum": 0}, + }, + }, + "plugins": { + "type": "array", + "items": { + "type": "object", + "required": ["name", "repo", "exported_version"], + "additionalProperties": False, + "properties": { + "name": {"type": "string"}, + "repo": {"type": "string"}, + "exported_version": {"type": "string"}, + "display_name": {"type": ["string", "null"]}, + }, + }, + }, + "plugin_configs": { + "type": ["object", "null"], + "additionalProperties": True, + }, + "handler_priority_overrides": { + "type": ["object", "null"], + "additionalProperties": {"type": "integer"}, + }, + }, + } + + _validator: ClassVar[Draft7Validator] = Draft7Validator(JSON_SCHEMA) + + def to_dict(self) -> dict[str, Any]: + data: dict[str, Any] = { + "schema_version": self.schema_version, + "metadata": self.metadata.to_dict(), + "plugins": [p.to_dict() for p in self.plugins], + } + if self.plugin_configs is not None: + data["plugin_configs"] = self.plugin_configs + if self.handler_priority_overrides is not None: + data["handler_priority_overrides"] = self.handler_priority_overrides + return data + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> PluginCollection: + cls.validate_dict(data) + + meta = CollectionMetadata.from_dict(data["metadata"]) + plugins = [CollectionPlugin.from_dict(p) for p in (data.get("plugins") or [])] + return cls( + schema_version=str(data.get("schema_version", cls.SCHEMA_VERSION)), + metadata=meta, + plugins=plugins, + plugin_configs=data.get("plugin_configs"), + handler_priority_overrides=data.get("handler_priority_overrides"), + ) + + @classmethod + def validate_dict(cls, data: dict[str, Any]) -> None: + errors = sorted(cls._validator.iter_errors(data), key=lambda e: list(e.path)) + if not errors: + return + parts: list[str] = [] + for err in errors: + loc = "/".join(str(p) for p in err.absolute_path) + prefix = f"{loc}: " if loc else "" + parts.append(prefix + err.message) + raise CollectionValidationError("; ".join(parts)) diff --git a/astrbot/core/collection/sensitive_filter.py b/astrbot/core/collection/sensitive_filter.py new file mode 100644 index 000000000..43d9587d2 --- /dev/null +++ b/astrbot/core/collection/sensitive_filter.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import re +from typing import Any + + +class SensitiveFilter: + """Filter sensitive values from nested config structures. + + Sensitive fields are removed instead of being replaced. + """ + + SENSITIVE_KEYWORDS = [ + "key", + "secret", + "token", + "password", + "credential", + "api_key", + "access_token", + "private_key", + "auth", + "group_id", + "group", + "qq_group", + "guild", + "channel", + "user_id", + "user", + "qq", + "uin", + "openid", + "uid", + "endpoint", + "base_url", + "api_base", + "url", + "uri", + "webhook", + "callback", + "host", + "domain", + ] + + SENSITIVE_KEY_PATTERNS = [ + re.compile(r".*_key$", re.IGNORECASE), + re.compile(r".*_secret$", re.IGNORECASE), + re.compile(r".*_token$", re.IGNORECASE), + re.compile(r".*_password$", re.IGNORECASE), + re.compile(r".*_credential$", re.IGNORECASE), + re.compile(r".*_access_token$", re.IGNORECASE), + re.compile(r".*_private_key$", re.IGNORECASE), + re.compile(r".*_group(_id)?$", re.IGNORECASE), + re.compile(r".*_guild(_id)?$", re.IGNORECASE), + re.compile(r".*_channel(_id)?$", re.IGNORECASE), + re.compile(r".*_user(_id)?$", re.IGNORECASE), + re.compile(r".*_qq$", re.IGNORECASE), + re.compile(r".*_uin$", re.IGNORECASE), + re.compile(r".*_openid$", re.IGNORECASE), + re.compile(r".*_endpoint$", re.IGNORECASE), + re.compile(r".*_base_url$", re.IGNORECASE), + re.compile(r".*_api_base$", re.IGNORECASE), + re.compile(r".*_url$", re.IGNORECASE), + re.compile(r".*_uri$", re.IGNORECASE), + re.compile(r".*_webhook(_url)?$", re.IGNORECASE), + re.compile(r".*_callback(_url)?$", re.IGNORECASE), + re.compile(r".*_host$", re.IGNORECASE), + re.compile(r".*_domain$", re.IGNORECASE), + ] + + SENSITIVE_VALUE_PATTERNS = [ + re.compile(r"^sk-.*", re.IGNORECASE), + re.compile(r"^bearer\s+.+", re.IGNORECASE), + re.compile(r"^https?://.+", re.IGNORECASE), + re.compile(r"^wss?://.+", re.IGNORECASE), + ] + + @classmethod + def is_sensitive_key(cls, key: str) -> bool: + key_lower = key.lower() + if any(k in key_lower for k in cls.SENSITIVE_KEYWORDS): + return True + return any(p.fullmatch(key) is not None for p in cls.SENSITIVE_KEY_PATTERNS) + + @classmethod + def is_sensitive_value(cls, value: Any) -> bool: + if not isinstance(value, str): + return False + s = value.strip() + return any(p.match(s) is not None for p in cls.SENSITIVE_VALUE_PATTERNS) + + @classmethod + def filter_data(cls, data: Any, *, depth: int = 0, max_depth: int = 20) -> Any: + if depth > max_depth: + return data + + if isinstance(data, dict): + filtered: dict[str, Any] = {} + for k, v in data.items(): + key = str(k) + if cls.is_sensitive_key(key): + continue + if cls.is_sensitive_value(v): + continue + filtered[key] = cls.filter_data(v, depth=depth + 1, max_depth=max_depth) + return filtered + + if isinstance(data, list): + return [ + cls.filter_data(v, depth=depth + 1, max_depth=max_depth) for v in data + ] + + if cls.is_sensitive_value(data): + return None + + return data diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 481be2f89..da0cc167b 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -2,6 +2,7 @@ from .backup import BackupRoute from .chat import ChatRoute from .chatui_project import ChatUIProjectRoute +from .collection import CollectionRoute from .command import CommandRoute from .config import ConfigRoute from .conversation import ConversationRoute @@ -25,6 +26,7 @@ "BackupRoute", "ChatRoute", "ChatUIProjectRoute", + "CollectionRoute", "CommandRoute", "ConfigRoute", "ConversationRoute", diff --git a/astrbot/dashboard/routes/collection.py b/astrbot/dashboard/routes/collection.py new file mode 100644 index 000000000..3f9656a15 --- /dev/null +++ b/astrbot/dashboard/routes/collection.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import traceback +from dataclasses import dataclass +from typing import Any + +from quart import request + +from astrbot.core import DEMO_MODE, logger +from astrbot.core.collection.exporter import CollectionExporter, ExportOptions +from astrbot.core.collection.importer import CollectionImporter +from astrbot.core.collection.models import ( + CollectionValidationError, + ImportOptions, + PluginCollection, +) +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.star.star_manager import PluginManager + +from .route import Response, Route, RouteContext + + +@dataclass(slots=True) +class _ImportRequest: + collection: dict[str, Any] + import_mode: str = "add" + apply_configs: bool = True + overwrite_existing_configs: bool = False + apply_priority: bool = True + proxy: str = "" + + +class CollectionRoute(Route): + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + plugin_manager: PluginManager, + ) -> None: + super().__init__(context) + self.core_lifecycle = core_lifecycle + self.plugin_manager = plugin_manager + self.exporter = CollectionExporter(plugin_manager) + self.importer = CollectionImporter(plugin_manager) + + self.routes = { + "/plugin/collection/export": ("GET", self.export_collection_get), + "/plugin/collection/import": ("POST", self.import_collection), + "/plugin/collection/preview": ("POST", self.preview_collection), + "/plugin/collection/validate": ("POST", self.validate_collection), + } + self.register_routes() + + async def export_collection_get(self): + if DEMO_MODE: + return ( + Response() + .error("You are not permitted to do this operation in demo mode") + .__dict__ + ) + + name = str(request.args.get("name") or "").strip() + if not name: + return Response().error("name is required").__dict__ + + options = ExportOptions( + name=name, + description=str(request.args.get("description") or ""), + author=str(request.args.get("author") or ""), + version=str(request.args.get("version") or "1.0.0"), + include_configs=str(request.args.get("include_configs", "true")).lower() + not in {"0", "false", "no"}, + include_priority=str(request.args.get("include_priority", "true")).lower() + not in {"0", "false", "no"}, + exclude_plugins=( + request.args.getlist("exclude_plugins") + if request.args.getlist("exclude_plugins") + else None + ), + ) + + try: + payload = await self.exporter.export(options) + return Response().ok(payload).__dict__ + except Exception as e: + logger.error( + f"/api/plugin/collection/export(GET): {traceback.format_exc()}" + ) + return Response().error(str(e)).__dict__ + + async def _parse_collection(self, body: Any) -> PluginCollection: + if not isinstance(body, dict): + raise CollectionValidationError("collection must be an object") + if "collection" in body and isinstance(body.get("collection"), dict): + body = body["collection"] + if not isinstance(body, dict): + raise CollectionValidationError("collection must be an object") + return PluginCollection.from_dict(body) + + async def validate_collection(self): + try: + data = await request.get_json() + _ = await self._parse_collection(data) + return Response().ok({"valid": True}).__dict__ + except CollectionValidationError as e: + return Response().ok({"valid": False, "errors": [str(e)]}).__dict__ + except Exception as e: + logger.error(f"/api/plugin/collection/validate: {traceback.format_exc()}") + return Response().error(str(e)).__dict__ + + async def preview_collection(self): + try: + data = await request.get_json() + import_mode = "add" + if isinstance(data, dict) and data.get("import_mode") in {"add", "clean"}: + import_mode = str(data.get("import_mode")) + collection = await self._parse_collection(data) + preview = await self.importer.preview(collection, import_mode=import_mode) + return Response().ok(preview).__dict__ + except CollectionValidationError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"/api/plugin/collection/preview: {traceback.format_exc()}") + return Response().error(str(e)).__dict__ + + async def import_collection(self): + if DEMO_MODE: + return ( + Response() + .error("You are not permitted to do this operation in demo mode") + .__dict__ + ) + + try: + data = await request.get_json() + if not isinstance(data, dict): + return Response().error("Invalid request body").__dict__ + + raw_collection = data.get("collection") + req = _ImportRequest( + collection=raw_collection if isinstance(raw_collection, dict) else {}, + import_mode=str(data.get("import_mode") or "add"), + apply_configs=bool(data.get("apply_configs", True)), + overwrite_existing_configs=bool( + data.get("overwrite_existing_configs", False) + ), + apply_priority=bool(data.get("apply_priority", True)), + proxy=str(data.get("proxy") or ""), + ) + + collection = PluginCollection.from_dict(req.collection) + result = await self.importer.import_collection( + collection, + ImportOptions( + import_mode=req.import_mode, + apply_configs=req.apply_configs, + overwrite_existing_configs=req.overwrite_existing_configs, + apply_priority=req.apply_priority, + proxy=req.proxy, + ), + ) + return Response().ok(result).__dict__ + + except CollectionValidationError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"/api/plugin/collection/import: {traceback.format_exc()}") + return Response().error(str(e)).__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 57b8ad741..075a00ecf 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -72,6 +72,11 @@ def __init__( core_lifecycle, core_lifecycle.plugin_manager, ) + self.collection_route = CollectionRoute( + self.context, + core_lifecycle, + core_lifecycle.plugin_manager, + ) self.command_route = CommandRoute(self.context) self.cr = ConfigRoute(self.context, core_lifecycle) self.lr = LogRoute(self.context, core_lifecycle.log_broker) diff --git a/dashboard/src/components/shared/ExtensionCard.vue b/dashboard/src/components/shared/ExtensionCard.vue index 3ad621b29..7fe9b0f74 100644 --- a/dashboard/src/components/shared/ExtensionCard.vue +++ b/dashboard/src/components/shared/ExtensionCard.vue @@ -1,9 +1,13 @@