Files
wiseclaw/backend/app/llm/ollama_client.py

324 lines
13 KiB
Python

import asyncio
import json
from typing import Any
import httpx
from httpx import HTTPError, HTTPStatusError, ReadTimeout
from app.models import ModelProvider, OllamaStatus
class OllamaClient:
def __init__(self, base_url: str, provider: ModelProvider = "local", api_key: str = "") -> None:
self.base_url = base_url.rstrip("/")
self.provider = provider
self.api_key = api_key
async def health(self) -> bool:
try:
await self._fetch_models()
except HTTPError:
return False
return True
async def status(self, model: str) -> OllamaStatus:
if self.provider == "zai" and not self.api_key.strip():
return OllamaStatus(
reachable=False,
provider=self.provider,
base_url=self.base_url,
model=model,
message="Z.AI API key is not configured.",
)
try:
installed_models = await self._fetch_models()
except HTTPError as exc:
return OllamaStatus(
reachable=False,
provider=self.provider,
base_url=self.base_url,
model=model,
message=f"LLM endpoint unreachable: {exc}",
)
has_model = model in installed_models
return OllamaStatus(
reachable=True,
provider=self.provider,
base_url=self.base_url,
model=model,
installed_models=installed_models,
message="Model found." if has_model else "LLM endpoint reachable but model is not installed.",
)
async def chat(self, model: str, system_prompt: str, user_message: str) -> str:
result = await self.chat_completion(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
],
)
if result["tool_calls"]:
raise HTTPError("Chat completion requested tools in plain chat mode.")
payload = result["content"].strip()
if not payload:
raise HTTPError("Chat completion returned empty content.")
return payload
async def chat_completion(
self,
model: str,
messages: list[dict[str, object]],
tools: list[dict[str, Any]] | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> dict[str, Any]:
self._ensure_provider_ready()
if self.provider == "zai":
return await self._anthropic_chat_completion(model, messages, tools)
payload: dict[str, Any] = {
"model": model,
"messages": messages,
"temperature": 0.3,
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = tool_choice or "auto"
endpoint = f"{self.base_url}/chat/completions" if self.provider == "zai" else f"{self.base_url}/v1/chat/completions"
try:
async with httpx.AsyncClient(timeout=180.0) as client:
response = await self._post_with_retry(client, endpoint, payload)
except ReadTimeout as exc:
raise HTTPError("LLM request timed out after 180 seconds.") from exc
data = response.json()
choices = data.get("choices", [])
if not choices:
raise HTTPError("Chat completion returned no choices.")
message = choices[0].get("message", {})
content = message.get("content", "")
if isinstance(content, list):
text_parts = [part.get("text", "") for part in content if isinstance(part, dict)]
content = "".join(text_parts)
tool_calls = []
for call in message.get("tool_calls", []) or []:
function = call.get("function", {})
raw_arguments = function.get("arguments", "{}")
try:
arguments = json.loads(raw_arguments) if isinstance(raw_arguments, str) else raw_arguments
except json.JSONDecodeError:
arguments = {"raw": raw_arguments}
tool_calls.append(
{
"id": call.get("id", ""),
"name": function.get("name", ""),
"arguments": arguments,
}
)
return {
"content": str(content or ""),
"tool_calls": tool_calls,
"message": message,
}
async def _anthropic_chat_completion(
self,
model: str,
messages: list[dict[str, object]],
tools: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
system_prompt, anthropic_messages = self._to_anthropic_messages(messages)
payload: dict[str, Any] = {
"model": model,
"max_tokens": 2048,
"messages": anthropic_messages,
}
if system_prompt:
payload["system"] = system_prompt
anthropic_tools = self._to_anthropic_tools(tools or [])
if anthropic_tools:
payload["tools"] = anthropic_tools
try:
async with httpx.AsyncClient(timeout=180.0) as client:
response = await self._post_with_retry(client, f"{self.base_url}/v1/messages", payload)
except ReadTimeout as exc:
raise HTTPError("LLM request timed out after 180 seconds.") from exc
data = response.json()
blocks = data.get("content", []) or []
text_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
for block in blocks:
if not isinstance(block, dict):
continue
block_type = block.get("type")
if block_type == "text":
text_parts.append(str(block.get("text", "")))
if block_type == "tool_use":
tool_calls.append(
{
"id": str(block.get("id", "")),
"name": str(block.get("name", "")),
"arguments": block.get("input", {}) if isinstance(block.get("input"), dict) else {},
}
)
return {
"content": "".join(text_parts).strip(),
"tool_calls": tool_calls,
"message": data,
}
async def _fetch_models(self) -> list[str]:
self._ensure_provider_ready()
async with httpx.AsyncClient(timeout=5.0) as client:
if self.provider == "zai":
response = await client.get(f"{self.base_url}/v1/models", headers=self._headers())
if response.is_success:
payload = response.json()
return [item.get("id", "") for item in payload.get("data", []) if item.get("id")]
return ["glm-4.7", "glm-5"]
response = await client.get(f"{self.base_url}/api/tags")
if response.is_success:
payload = response.json()
if isinstance(payload, dict) and "models" in payload:
return [item.get("name", "") for item in payload.get("models", []) if item.get("name")]
response = await client.get(f"{self.base_url}/v1/models")
response.raise_for_status()
payload = response.json()
return [item.get("id", "") for item in payload.get("data", []) if item.get("id")]
def _headers(self) -> dict[str, str]:
if self.provider != "zai":
return {}
return {
"x-api-key": self.api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json",
}
def _ensure_provider_ready(self) -> None:
if self.provider == "zai" and not self.api_key.strip():
raise HTTPError("Z.AI API key is not configured.")
async def _post_with_retry(
self,
client: httpx.AsyncClient,
endpoint: str,
payload: dict[str, Any],
) -> httpx.Response:
delays = [0.0, 1.5, 4.0]
last_exc: HTTPStatusError | None = None
for attempt, delay in enumerate(delays, start=1):
if delay > 0:
await asyncio.sleep(delay)
response = await client.post(endpoint, json=payload, headers=self._headers())
try:
response.raise_for_status()
return response
except HTTPStatusError as exc:
last_exc = exc
if response.status_code != 429 or attempt == len(delays):
raise self._translate_status_error(exc) from exc
if last_exc is not None:
raise self._translate_status_error(last_exc) from last_exc
raise HTTPError("LLM request failed.")
def _translate_status_error(self, exc: HTTPStatusError) -> HTTPError:
status = exc.response.status_code
if status == 429:
provider = "Z.AI" if self.provider == "zai" else "LLM endpoint"
return HTTPError(f"{provider} rate limit reached. Please wait a bit and try again.")
if status == 401:
provider = "Z.AI" if self.provider == "zai" else "LLM endpoint"
return HTTPError(f"{provider} authentication failed. Check the configured API key.")
if status == 404:
return HTTPError("Configured LLM endpoint path was not found.")
return HTTPError(f"LLM request failed with HTTP {status}.")
def _to_anthropic_tools(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
anthropic_tools: list[dict[str, Any]] = []
for tool in tools:
function = tool.get("function", {}) if isinstance(tool, dict) else {}
if not isinstance(function, dict):
continue
anthropic_tools.append(
{
"name": str(function.get("name", "")),
"description": str(function.get("description", "")),
"input_schema": function.get("parameters", {"type": "object", "properties": {}}),
}
)
return [tool for tool in anthropic_tools if tool["name"]]
def _to_anthropic_messages(self, messages: list[dict[str, object]]) -> tuple[str, list[dict[str, object]]]:
system_parts: list[str] = []
anthropic_messages: list[dict[str, object]] = []
for message in messages:
role = str(message.get("role", "user"))
if role == "system":
content = str(message.get("content", "")).strip()
if content:
system_parts.append(content)
continue
if role == "tool":
content = str(message.get("content", ""))
tool_use_id = str(message.get("tool_call_id", ""))
tool_result_block = {
"type": "tool_result",
"tool_use_id": tool_use_id,
"content": content,
}
if anthropic_messages and anthropic_messages[-1]["role"] == "user":
existing = anthropic_messages[-1]["content"]
if isinstance(existing, list):
existing.append(tool_result_block)
continue
anthropic_messages.append({"role": "user", "content": [tool_result_block]})
continue
content_blocks: list[dict[str, object]] = []
content = message.get("content", "")
if isinstance(content, str) and content.strip():
content_blocks.append({"type": "text", "text": content})
raw_tool_calls = message.get("tool_calls", [])
if isinstance(raw_tool_calls, list):
for call in raw_tool_calls:
if not isinstance(call, dict):
continue
function = call.get("function", {})
if not isinstance(function, dict):
continue
arguments = function.get("arguments", {})
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
content_blocks.append(
{
"type": "tool_use",
"id": str(call.get("id", "")),
"name": str(function.get("name", "")),
"input": arguments if isinstance(arguments, dict) else {},
}
)
if not content_blocks:
continue
anthropic_messages.append({"role": "assistant" if role == "assistant" else "user", "content": content_blocks})
return "\n\n".join(part for part in system_parts if part), anthropic_messages