134 lines
4.2 KiB
Python
134 lines
4.2 KiB
Python
from collections.abc import Iterator
|
|
from contextlib import contextmanager
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import Boolean, DateTime, Integer, String, Text, create_engine, select
|
|
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, sessionmaker
|
|
|
|
from app.config import get_settings
|
|
|
|
|
|
DEFAULT_SETTINGS = {
|
|
"terminal_mode": "3",
|
|
"search_provider": "brave",
|
|
"ollama_base_url": "http://127.0.0.1:11434",
|
|
"default_model": "qwen3.5:4b",
|
|
}
|
|
|
|
DEFAULT_TOOLS = {
|
|
"brave_search": True,
|
|
"searxng_search": False,
|
|
"web_fetch": True,
|
|
"apple_notes": True,
|
|
"files": True,
|
|
"terminal": True,
|
|
}
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
class SettingORM(Base):
|
|
__tablename__ = "settings"
|
|
|
|
key: Mapped[str] = mapped_column(String(100), primary_key=True)
|
|
value: Mapped[str] = mapped_column(Text, nullable=False)
|
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
|
|
|
|
|
|
class ToolStateORM(Base):
|
|
__tablename__ = "tool_states"
|
|
|
|
name: Mapped[str] = mapped_column(String(100), primary_key=True)
|
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
|
|
|
|
|
|
class AuthorizedUserORM(Base):
|
|
__tablename__ = "authorized_users"
|
|
|
|
telegram_user_id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
|
username: Mapped[str | None] = mapped_column(String(255))
|
|
display_name: Mapped[str | None] = mapped_column(String(255))
|
|
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
|
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
|
|
|
|
|
|
class MemoryItemORM(Base):
|
|
__tablename__ = "memory_items"
|
|
|
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
content: Mapped[str] = mapped_column(Text, nullable=False)
|
|
kind: Mapped[str] = mapped_column(String(50), nullable=False, default="message")
|
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
|
|
|
|
|
|
class AuditLogORM(Base):
|
|
__tablename__ = "audit_logs"
|
|
|
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
category: Mapped[str] = mapped_column(String(50), nullable=False)
|
|
message: Mapped[str] = mapped_column(Text, nullable=False)
|
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
|
|
|
|
|
|
class SecretORM(Base):
|
|
__tablename__ = "secrets"
|
|
|
|
key: Mapped[str] = mapped_column(String(100), primary_key=True)
|
|
value: Mapped[str] = mapped_column(Text, nullable=False)
|
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False)
|
|
|
|
|
|
settings = get_settings()
|
|
engine = create_engine(
|
|
settings.db_url,
|
|
connect_args={"check_same_thread": False} if settings.db_url.startswith("sqlite") else {},
|
|
)
|
|
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
|
|
|
|
|
def init_db() -> None:
|
|
Base.metadata.create_all(bind=engine)
|
|
with session_scope() as session:
|
|
_seed_defaults(session)
|
|
|
|
|
|
def _seed_defaults(session: Session) -> None:
|
|
for key, value in DEFAULT_SETTINGS.items():
|
|
if session.get(SettingORM, key) is None:
|
|
session.add(SettingORM(key=key, value=value))
|
|
|
|
for name, enabled in DEFAULT_TOOLS.items():
|
|
if session.get(ToolStateORM, name) is None:
|
|
session.add(ToolStateORM(name=name, enabled=enabled))
|
|
|
|
|
|
def get_session() -> Iterator[Session]:
|
|
session = SessionLocal()
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
@contextmanager
|
|
def session_scope() -> Iterator[Session]:
|
|
session = SessionLocal()
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
finally:
|
|
session.close()
|
|
|
|
|
|
def list_recent_logs(session: Session, limit: int = 10) -> list[str]:
|
|
stmt = select(AuditLogORM).order_by(AuditLogORM.created_at.desc(), AuditLogORM.id.desc()).limit(limit)
|
|
return [row.message for row in session.scalars(stmt)]
|
|
|