You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
orquestrador/app/services/tools/tool_registry.py

137 lines
4.8 KiB
Python

import inspect
from typing import Callable, Dict, List
from fastapi import HTTPException
from sqlalchemy.orm import Session
from app.models.tool_model import ToolDefinition
from app.repositories.tool_repository import ToolRepository
from app.services.tools.handlers import (
agendar_revisao,
avaliar_veiculo_troca,
cancelar_agendamento_revisao,
cancelar_pedido,
editar_data_revisao,
listar_agendamentos_revisao,
listar_pedidos,
consultar_estoque,
consultar_frota_aluguel,
abrir_locacao_aluguel,
registrar_devolucao_aluguel,
registrar_multa_aluguel,
registrar_pagamento_aluguel,
realizar_pedido,
validar_cliente_venda,
)
HANDLERS: Dict[str, Callable] = {
"consultar_estoque": consultar_estoque,
"consultar_frota_aluguel": consultar_frota_aluguel,
"validar_cliente_venda": validar_cliente_venda,
"avaliar_veiculo_troca": avaliar_veiculo_troca,
"agendar_revisao": agendar_revisao,
"listar_agendamentos_revisao": listar_agendamentos_revisao,
"cancelar_agendamento_revisao": cancelar_agendamento_revisao,
"editar_data_revisao": editar_data_revisao,
"cancelar_pedido": cancelar_pedido,
"listar_pedidos": listar_pedidos,
"realizar_pedido": realizar_pedido,
"abrir_locacao_aluguel": abrir_locacao_aluguel,
"registrar_devolucao_aluguel": registrar_devolucao_aluguel,
"registrar_pagamento_aluguel": registrar_pagamento_aluguel,
"registrar_multa_aluguel": registrar_multa_aluguel,
}
# Registry em memoria das tools disponiveis para o orquestrador.
class ToolRegistry:
def __init__(self, db: Session, extra_handlers: Dict[str, Callable] | None = None):
"""Carrega tools do banco e registra apenas as que possuem handler conhecido."""
self._tools = []
available_handlers = dict(HANDLERS)
if extra_handlers:
available_handlers.update(extra_handlers)
repo = ToolRepository(db)
db_tools = repo.get_all()
for db_tool in db_tools:
handler = available_handlers.get(db_tool.name)
if not handler:
continue
self.register_tool(
name=db_tool.name,
description=db_tool.description,
parameters=db_tool.parameters,
handler=handler,
)
def register_tool(self, name, description, parameters, handler):
"""Registra uma tool em memoria para uso pelo orquestrador."""
self._tools.append(
ToolDefinition(
name=name,
description=description,
parameters=parameters,
handler=handler,
)
)
def get_tools(self) -> List[ToolDefinition]:
"""Retorna a lista atual de tools registradas."""
return self._tools
async def execute(self, name: str, arguments: dict, user_id: int | None = None):
"""Executa a tool solicitada pelo modelo com os argumentos extraidos."""
tool = next((t for t in self._tools if t.name == name), None)
if not tool:
raise HTTPException(
status_code=400,
detail={
"code": "tool_not_found",
"message": f"Tool {name} nao encontrada.",
"retryable": False,
},
)
call_args = dict(arguments or {})
signature = inspect.signature(tool.handler)
if user_id is not None and "user_id" in signature.parameters:
call_args["user_id"] = user_id
supported_args = {
key: value
for key, value in call_args.items()
if key in signature.parameters
}
missing_required = [
parameter.name
for parameter in signature.parameters.values()
if parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
and parameter.default is inspect._empty
and parameter.name not in supported_args
]
if missing_required:
raise HTTPException(
status_code=400,
detail={
"code": "invalid_tool_arguments",
"message": f"Argumentos obrigatorios ausentes para a tool {name}: {', '.join(missing_required)}.",
"retryable": True,
"field": missing_required[0],
},
)
try:
return await tool.handler(**supported_args)
except TypeError as exc:
raise HTTPException(
status_code=400,
detail={
"code": "invalid_tool_arguments",
"message": f"Argumentos invalidos para a tool {name}.",
"retryable": True,
},
) from exc