127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
import httpx
|
|
from typing import Any
|
|
|
|
from app.tools.base import Tool
|
|
|
|
|
|
class BraveSearchTool(Tool):
|
|
name = "brave_search"
|
|
description = "Search the web with Brave Search."
|
|
|
|
def __init__(self, api_key: str) -> None:
|
|
self.api_key = api_key
|
|
|
|
def parameters_schema(self) -> dict[str, Any]:
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "The web search query.",
|
|
},
|
|
"count": {
|
|
"type": "integer",
|
|
"description": "Optional number of results from 1 to 10.",
|
|
"minimum": 1,
|
|
"maximum": 10,
|
|
},
|
|
"mode": {
|
|
"type": "string",
|
|
"description": "Search mode: web or images.",
|
|
"enum": ["web", "images"],
|
|
},
|
|
},
|
|
"required": ["query"],
|
|
"additionalProperties": False,
|
|
}
|
|
|
|
async def run(self, payload: dict[str, Any]) -> dict[str, Any]:
|
|
query = str(payload.get("query", "")).strip()
|
|
count = int(payload.get("count", 5) or 5)
|
|
count = max(1, min(10, count))
|
|
mode = str(payload.get("mode", "web") or "web").strip().lower()
|
|
if mode not in {"web", "images"}:
|
|
mode = "web"
|
|
|
|
if not query:
|
|
return {
|
|
"tool": self.name,
|
|
"status": "error",
|
|
"message": "Query is required.",
|
|
}
|
|
|
|
if not self.api_key:
|
|
return {
|
|
"tool": self.name,
|
|
"status": "error",
|
|
"query": query,
|
|
"message": "Brave Search API key is not configured.",
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
response = await client.get(
|
|
"https://api.search.brave.com/res/v1/images/search"
|
|
if mode == "images"
|
|
else "https://api.search.brave.com/res/v1/web/search",
|
|
headers={
|
|
"Accept": "application/json",
|
|
"Accept-Encoding": "gzip",
|
|
"X-Subscription-Token": self.api_key,
|
|
},
|
|
params={
|
|
"q": query,
|
|
"count": count,
|
|
"search_lang": "en",
|
|
"country": "us",
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
except httpx.HTTPError as exc:
|
|
return {
|
|
"tool": self.name,
|
|
"status": "error",
|
|
"query": query,
|
|
"message": str(exc),
|
|
}
|
|
|
|
payload_json = response.json()
|
|
if mode == "images":
|
|
images = []
|
|
for item in payload_json.get("results", [])[:count]:
|
|
images.append(
|
|
{
|
|
"title": item.get("title", ""),
|
|
"url": item.get("url", ""),
|
|
"source": item.get("source", ""),
|
|
"thumbnail": item.get("thumbnail", {}).get("src", "") if isinstance(item.get("thumbnail"), dict) else "",
|
|
"properties_url": item.get("properties", {}).get("url", "") if isinstance(item.get("properties"), dict) else "",
|
|
}
|
|
)
|
|
return {
|
|
"tool": self.name,
|
|
"status": "ok",
|
|
"mode": mode,
|
|
"query": query,
|
|
"images": images,
|
|
"total_results": len(images),
|
|
}
|
|
|
|
results = []
|
|
for item in payload_json.get("web", {}).get("results", [])[:count]:
|
|
results.append(
|
|
{
|
|
"title": item.get("title", ""),
|
|
"url": item.get("url", ""),
|
|
"description": item.get("description", ""),
|
|
}
|
|
)
|
|
return {
|
|
"tool": self.name,
|
|
"status": "ok",
|
|
"mode": mode,
|
|
"query": query,
|
|
"results": results,
|
|
"total_results": len(results),
|
|
}
|