You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
import inspect
|
|
from typing import Callable, Dict, List
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models.tool_model import ToolDefinition
|
|
from app.repositories.tool_repository import ToolRepository
|
|
from app.services.tools.handlers import (
|
|
agendar_revisao,
|
|
avaliar_veiculo_troca,
|
|
cancelar_agendamento_revisao,
|
|
cancelar_pedido,
|
|
editar_data_revisao,
|
|
listar_agendamentos_revisao,
|
|
listar_pedidos,
|
|
consultar_estoque,
|
|
realizar_pedido,
|
|
validar_cliente_venda,
|
|
)
|
|
|
|
|
|
HANDLERS: Dict[str, Callable] = {
|
|
"consultar_estoque": consultar_estoque,
|
|
"validar_cliente_venda": validar_cliente_venda,
|
|
"avaliar_veiculo_troca": avaliar_veiculo_troca,
|
|
"agendar_revisao": agendar_revisao,
|
|
"listar_agendamentos_revisao": listar_agendamentos_revisao,
|
|
"cancelar_agendamento_revisao": cancelar_agendamento_revisao,
|
|
"editar_data_revisao": editar_data_revisao,
|
|
"cancelar_pedido": cancelar_pedido,
|
|
"listar_pedidos": listar_pedidos,
|
|
"realizar_pedido": realizar_pedido,
|
|
}
|
|
|
|
|
|
# Registry em memoria das tools disponiveis para o orquestrador.
|
|
class ToolRegistry:
|
|
|
|
def __init__(self, db: Session, extra_handlers: Dict[str, Callable] | None = None):
|
|
"""Carrega tools do banco e registra apenas as que possuem handler conhecido."""
|
|
self._tools = []
|
|
available_handlers = dict(HANDLERS)
|
|
if extra_handlers:
|
|
available_handlers.update(extra_handlers)
|
|
repo = ToolRepository(db)
|
|
db_tools = repo.get_all()
|
|
for db_tool in db_tools:
|
|
handler = available_handlers.get(db_tool.name)
|
|
if not handler:
|
|
continue
|
|
self.register_tool(
|
|
name=db_tool.name,
|
|
description=db_tool.description,
|
|
parameters=db_tool.parameters,
|
|
handler=handler,
|
|
)
|
|
|
|
def register_tool(self, name, description, parameters, handler):
|
|
"""Registra uma tool em memoria para uso pelo orquestrador."""
|
|
self._tools.append(
|
|
ToolDefinition(
|
|
name=name,
|
|
description=description,
|
|
parameters=parameters,
|
|
handler=handler,
|
|
)
|
|
)
|
|
|
|
def get_tools(self) -> List[ToolDefinition]:
|
|
"""Retorna a lista atual de tools registradas."""
|
|
return self._tools
|
|
|
|
async def execute(self, name: str, arguments: dict, user_id: int | None = None):
|
|
"""Executa a tool solicitada pelo modelo com os argumentos extraidos."""
|
|
tool = next((t for t in self._tools if t.name == name), None)
|
|
|
|
if not tool:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"code": "tool_not_found",
|
|
"message": f"Tool {name} nao encontrada.",
|
|
"retryable": False,
|
|
},
|
|
)
|
|
|
|
call_args = dict(arguments or {})
|
|
if user_id is not None and "user_id" in inspect.signature(tool.handler).parameters:
|
|
call_args["user_id"] = user_id
|
|
|
|
return await tool.handler(**call_args)
|