Files
wiseclaw/backend/app/db.py

134 lines
4.2 KiB
Python

from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime
from sqlalchemy import Boolean, DateTime, Integer, String, Text, create_engine, select
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, sessionmaker
from app.config import get_settings
DEFAULT_SETTINGS = {
"terminal_mode": "3",
"search_provider": "brave",
"ollama_base_url": "http://127.0.0.1:11434",
"default_model": "qwen3.5:4b",
}
DEFAULT_TOOLS = {
"brave_search": True,
"searxng_search": False,
"web_fetch": True,
"apple_notes": True,
"files": True,
"terminal": True,
}
class Base(DeclarativeBase):
pass
class SettingORM(Base):
__tablename__ = "settings"
key: Mapped[str] = mapped_column(String(100), primary_key=True)
value: Mapped[str] = mapped_column(Text, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
class ToolStateORM(Base):
__tablename__ = "tool_states"
name: Mapped[str] = mapped_column(String(100), primary_key=True)
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
class AuthorizedUserORM(Base):
__tablename__ = "authorized_users"
telegram_user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
username: Mapped[str | None] = mapped_column(String(255))
display_name: Mapped[str | None] = mapped_column(String(255))
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
class MemoryItemORM(Base):
__tablename__ = "memory_items"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
content: Mapped[str] = mapped_column(Text, nullable=False)
kind: Mapped[str] = mapped_column(String(50), nullable=False, default="message")
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
class AuditLogORM(Base):
__tablename__ = "audit_logs"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
category: Mapped[str] = mapped_column(String(50), nullable=False)
message: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
class SecretORM(Base):
__tablename__ = "secrets"
key: Mapped[str] = mapped_column(String(100), primary_key=True)
value: Mapped[str] = mapped_column(Text, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
settings = get_settings()
engine = create_engine(
settings.db_url,
connect_args={"check_same_thread": False} if settings.db_url.startswith("sqlite") else {},
)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
def init_db() -> None:
Base.metadata.create_all(bind=engine)
with session_scope() as session:
_seed_defaults(session)
def _seed_defaults(session: Session) -> None:
for key, value in DEFAULT_SETTINGS.items():
if session.get(SettingORM, key) is None:
session.add(SettingORM(key=key, value=value))
for name, enabled in DEFAULT_TOOLS.items():
if session.get(ToolStateORM, name) is None:
session.add(ToolStateORM(name=name, enabled=enabled))
def get_session() -> Iterator[Session]:
session = SessionLocal()
try:
yield session
finally:
session.close()
@contextmanager
def session_scope() -> Iterator[Session]:
session = SessionLocal()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def list_recent_logs(session: Session, limit: int = 10) -> list[str]:
stmt = select(AuditLogORM).order_by(AuditLogORM.created_at.desc(), AuditLogORM.id.desc()).limit(limit)
return [row.message for row in session.scalars(stmt)]