You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
267 lines
10 KiB
Python
267 lines
10 KiB
Python
import importlib
|
|
import inspect
|
|
import json
|
|
import logging
|
|
from typing import Callable, Dict, List
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models.tool_model import ToolDefinition
|
|
from app.repositories.tool_repository import ToolRepository
|
|
from app.services.tools.handlers import (
|
|
agendar_revisao,
|
|
avaliar_veiculo_troca,
|
|
cancelar_agendamento_revisao,
|
|
cancelar_pedido,
|
|
editar_data_revisao,
|
|
listar_agendamentos_revisao,
|
|
listar_pedidos,
|
|
consultar_estoque,
|
|
consultar_frota_aluguel,
|
|
abrir_locacao_aluguel,
|
|
registrar_devolucao_aluguel,
|
|
registrar_pagamento_aluguel,
|
|
realizar_pedido,
|
|
validar_cliente_venda,
|
|
)
|
|
from shared.contracts import (
|
|
GENERATED_TOOL_ENTRYPOINT,
|
|
GENERATED_TOOLS_PACKAGE,
|
|
ToolParameterType,
|
|
ToolRuntimePublicationManifest,
|
|
get_generated_tool_publication_manifest_path,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
HANDLERS: Dict[str, Callable] = {
|
|
"consultar_estoque": consultar_estoque,
|
|
"consultar_frota_aluguel": consultar_frota_aluguel,
|
|
"validar_cliente_venda": validar_cliente_venda,
|
|
"avaliar_veiculo_troca": avaliar_veiculo_troca,
|
|
"agendar_revisao": agendar_revisao,
|
|
"listar_agendamentos_revisao": listar_agendamentos_revisao,
|
|
"cancelar_agendamento_revisao": cancelar_agendamento_revisao,
|
|
"editar_data_revisao": editar_data_revisao,
|
|
"cancelar_pedido": cancelar_pedido,
|
|
"listar_pedidos": listar_pedidos,
|
|
"realizar_pedido": realizar_pedido,
|
|
"abrir_locacao_aluguel": abrir_locacao_aluguel,
|
|
"registrar_devolucao_aluguel": registrar_devolucao_aluguel,
|
|
"registrar_pagamento_aluguel": registrar_pagamento_aluguel,
|
|
}
|
|
|
|
_PARAMETER_SCHEMA_TYPE_MAPPING = {
|
|
ToolParameterType.STRING: "string",
|
|
ToolParameterType.INTEGER: "integer",
|
|
ToolParameterType.NUMBER: "number",
|
|
ToolParameterType.BOOLEAN: "boolean",
|
|
ToolParameterType.OBJECT: "object",
|
|
ToolParameterType.ARRAY: "array",
|
|
}
|
|
|
|
|
|
class GeneratedToolCoreBoundaryViolation(RuntimeError):
|
|
"""Raised when a generated tool attempts to reuse or point at core runtime code."""
|
|
|
|
|
|
class ToolRegistry:
|
|
"""Registry em memoria das tools disponiveis para o orquestrador."""
|
|
|
|
def __init__(self, db: Session, extra_handlers: Dict[str, Callable] | None = None):
|
|
self._tools = []
|
|
available_handlers = dict(HANDLERS)
|
|
if extra_handlers:
|
|
available_handlers.update(extra_handlers)
|
|
repo = ToolRepository(db)
|
|
db_tools = repo.get_all()
|
|
for db_tool in db_tools:
|
|
handler = available_handlers.get(db_tool.name)
|
|
if not handler:
|
|
continue
|
|
self.register_tool(
|
|
name=db_tool.name,
|
|
description=db_tool.description,
|
|
parameters=db_tool.parameters,
|
|
handler=handler,
|
|
)
|
|
self._load_generated_tool_publications_from_snapshot()
|
|
|
|
def register_tool(self, name, description, parameters, handler):
|
|
"""Registra uma tool em memoria para uso pelo orquestrador."""
|
|
if self._is_generated_handler(handler):
|
|
self._ensure_generated_tool_boundary(name=name, handler=handler)
|
|
self._append_tool_definition(
|
|
name=name,
|
|
description=description,
|
|
parameters=parameters,
|
|
handler=handler,
|
|
)
|
|
|
|
def register_generated_tool(self, name, description, parameters, handler):
|
|
"""Registra uma tool gerada apenas quando ela respeita o pacote isolado do runtime."""
|
|
self._ensure_generated_tool_boundary(name=name, handler=handler)
|
|
self._append_tool_definition(
|
|
name=name,
|
|
description=description,
|
|
parameters=parameters,
|
|
handler=handler,
|
|
)
|
|
|
|
def _load_generated_tool_publications_from_snapshot(self) -> None:
|
|
manifest_path = get_generated_tool_publication_manifest_path()
|
|
if not manifest_path.exists():
|
|
return
|
|
|
|
try:
|
|
manifest_payload = json.loads(manifest_path.read_text(encoding="utf-8-sig"))
|
|
manifest = ToolRuntimePublicationManifest.model_validate(manifest_payload)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"Falha ao carregar snapshot local de tools publicadas em %s: %s",
|
|
manifest_path,
|
|
exc,
|
|
)
|
|
return
|
|
|
|
for envelope in manifest.publications:
|
|
published_tool = envelope.published_tool
|
|
try:
|
|
importlib.invalidate_caches()
|
|
module = importlib.import_module(published_tool.implementation_module)
|
|
handler = getattr(module, published_tool.implementation_callable)
|
|
self.register_generated_tool(
|
|
name=published_tool.tool_name,
|
|
description=published_tool.description,
|
|
parameters=self._build_generated_parameter_schema(published_tool.parameters),
|
|
handler=handler,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"Falha ao registrar tool publicada '%s' a partir do snapshot local %s: %s",
|
|
published_tool.tool_name,
|
|
manifest_path,
|
|
exc,
|
|
)
|
|
|
|
@staticmethod
|
|
def _build_generated_parameter_schema(parameters) -> dict:
|
|
properties: dict[str, dict] = {}
|
|
required: list[str] = []
|
|
for parameter in parameters or ():
|
|
parameter_type = parameter.parameter_type
|
|
schema = {
|
|
"type": _PARAMETER_SCHEMA_TYPE_MAPPING[parameter_type],
|
|
"description": parameter.description,
|
|
}
|
|
if parameter_type == ToolParameterType.OBJECT:
|
|
schema["additionalProperties"] = True
|
|
elif parameter_type == ToolParameterType.ARRAY:
|
|
schema["items"] = {"type": "string"}
|
|
properties[parameter.name] = schema
|
|
if parameter.required:
|
|
required.append(parameter.name)
|
|
return {
|
|
"type": "object",
|
|
"properties": properties,
|
|
"required": required,
|
|
}
|
|
|
|
def _append_tool_definition(self, *, name, description, parameters, handler):
|
|
self._tools.append(
|
|
ToolDefinition(
|
|
name=name,
|
|
description=description,
|
|
parameters=parameters,
|
|
handler=handler,
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def _is_generated_handler(handler: Callable) -> bool:
|
|
module_name = str(getattr(handler, "__module__", "") or "").strip()
|
|
return module_name.startswith(f"{GENERATED_TOOLS_PACKAGE}.")
|
|
|
|
def _ensure_generated_tool_boundary(self, *, name: str, handler: Callable) -> None:
|
|
normalized_name = str(name or "").strip().lower()
|
|
if normalized_name in HANDLERS:
|
|
raise GeneratedToolCoreBoundaryViolation(
|
|
f"Tool gerada '{normalized_name}' nao pode sobrescrever um handler do catalogo core."
|
|
)
|
|
|
|
if any(str(tool.name or "").strip().lower() == normalized_name for tool in self._tools):
|
|
raise GeneratedToolCoreBoundaryViolation(
|
|
f"Tool gerada '{normalized_name}' nao pode sobrescrever uma tool ja registrada no runtime."
|
|
)
|
|
|
|
module_name = str(getattr(handler, "__module__", "") or "").strip()
|
|
if not module_name.startswith(f"{GENERATED_TOOLS_PACKAGE}."):
|
|
raise GeneratedToolCoreBoundaryViolation(
|
|
f"Tools geradas so podem ser carregadas do pacote isolado '{GENERATED_TOOLS_PACKAGE}.*'."
|
|
)
|
|
|
|
handler_name = str(getattr(handler, "__name__", "") or "").strip()
|
|
if handler_name != GENERATED_TOOL_ENTRYPOINT:
|
|
raise GeneratedToolCoreBoundaryViolation(
|
|
f"Tools geradas precisam expor o entrypoint governado '{GENERATED_TOOL_ENTRYPOINT}'."
|
|
)
|
|
|
|
def get_tools(self) -> List[ToolDefinition]:
|
|
"""Retorna a lista atual de tools registradas."""
|
|
return self._tools
|
|
|
|
async def execute(self, name: str, arguments: dict, user_id: int | None = None):
|
|
"""Executa a tool solicitada pelo modelo com os argumentos extraidos."""
|
|
tool = next((t for t in self._tools if t.name == name), None)
|
|
|
|
if not tool:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"code": "tool_not_found",
|
|
"message": f"Tool {name} nao encontrada.",
|
|
"retryable": False,
|
|
},
|
|
)
|
|
|
|
call_args = dict(arguments or {})
|
|
signature = inspect.signature(tool.handler)
|
|
if user_id is not None and "user_id" in signature.parameters:
|
|
call_args["user_id"] = user_id
|
|
|
|
supported_args = {
|
|
key: value
|
|
for key, value in call_args.items()
|
|
if key in signature.parameters
|
|
}
|
|
missing_required = [
|
|
parameter.name
|
|
for parameter in signature.parameters.values()
|
|
if parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
|
|
and parameter.default is inspect._empty
|
|
and parameter.name not in supported_args
|
|
]
|
|
if missing_required:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"code": "invalid_tool_arguments",
|
|
"message": f"Argumentos obrigatorios ausentes para a tool {name}: {', '.join(missing_required)}.",
|
|
"retryable": True,
|
|
"field": missing_required[0],
|
|
},
|
|
)
|
|
|
|
try:
|
|
return await tool.handler(**supported_args)
|
|
except TypeError as exc:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"code": "invalid_tool_arguments",
|
|
"message": f"Argumentos invalidos para a tool {name}.",
|
|
"retryable": True,
|
|
},
|
|
) from exc
|