diff --git a/.env.example b/.env.example index 37526c6..ccd1365 100644 --- a/.env.example +++ b/.env.example @@ -4,7 +4,7 @@ GOOGLE_PROJECT_ID=id_do_seu_projeto GOOGLE_LOCATION=loc_do_seu_projeto -VERTEX_MODEL_NAME=gemini-2.5-flash +VERTEX_MODEL_NAME=gemini-2.5-pro # ============================================ # CONFIGURACOES DO BANCO DE DADOS (MYSQL - TOOLS) diff --git a/app/core/settings.py b/app/core/settings.py index 6ab82ee..c3a3883 100644 --- a/app/core/settings.py +++ b/app/core/settings.py @@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings class Settings(BaseSettings): google_project_id: str google_location: str = "us-central1" - vertex_model_name: str = "gemini-2.5-flash" + vertex_model_name: str = "gemini-2.5-pro" # Tools database (MySQL) db_host: str = "127.0.0.1" diff --git a/app/services/ai/llm_service.py b/app/services/ai/llm_service.py index 69a95a9..7d88e01 100644 --- a/app/services/ai/llm_service.py +++ b/app/services/ai/llm_service.py @@ -27,7 +27,7 @@ class LLMService: LLMService._vertex_initialized = True configured = settings.vertex_model_name.strip() - fallback_models = ["gemini-2.5-flash", "gemini-2.0-flash-001", "gemini-1.5-pro"] + fallback_models = ["gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash-001"] self.model_names = [configured] + [m for m in fallback_models if m != configured] def build_vertex_tools(self, tools: List[ToolDefinition]) -> Optional[List[Tool]]: @@ -73,6 +73,30 @@ class LLMService: LLMService._models[model_name] = model return model + def _extract_response_payload(self, response) -> Dict[str, Any]: + candidate = response.candidates[0] if getattr(response, "candidates", None) else None + content = getattr(candidate, "content", None) + parts = list(getattr(content, "parts", None) or []) + + tool_call = None + text_parts: list[str] = [] + for part in parts: + function_call = getattr(part, "function_call", None) + if function_call is not None and tool_call is None: + tool_call = { + "name": function_call.name, + "arguments": dict(function_call.args), + } + text_value = getattr(part, "text", None) + if isinstance(text_value, str) and text_value.strip(): + text_parts.append(text_value) + + response_text = "\n".join(text_parts).strip() or None + return { + "response": response_text, + "tool_call": tool_call, + } + async def generate_response( self, message: str, @@ -106,21 +130,7 @@ class LLMService: ) from last_error raise RuntimeError("Falha ao gerar resposta no Vertex AI.") - part = response.candidates[0].content.parts[0] - - if part.function_call: - return { - "response": None, - "tool_call": { - "name": part.function_call.name, - "arguments": dict(part.function_call.args), - }, - } - - return { - "response": response.text, - "tool_call": None, - } + return self._extract_response_payload(response) async def warmup(self) -> None: """Preaquece conexao/modelo para reduzir latencia da primeira requisicao real.""" diff --git a/tests/test_llm_service.py b/tests/test_llm_service.py new file mode 100644 index 0000000..5e3d765 --- /dev/null +++ b/tests/test_llm_service.py @@ -0,0 +1,59 @@ +import os +import unittest +from types import SimpleNamespace + +os.environ.setdefault("DEBUG", "false") + +from app.services.ai.llm_service import LLMService + + +class LLMServiceResponseParsingTests(unittest.TestCase): + def test_extract_response_payload_supports_text_and_function_call_in_same_candidate(self): + service = LLMService.__new__(LLMService) + response = SimpleNamespace( + candidates=[ + SimpleNamespace( + content=SimpleNamespace( + parts=[ + SimpleNamespace(text="Legal! Buscando carros de ate 70 mil para voce.", function_call=None), + SimpleNamespace( + text=None, + function_call=SimpleNamespace( + name="consultar_estoque", + args={"preco_max": 70000.0}, + ), + ), + ] + ) + ) + ] + ) + + payload = service._extract_response_payload(response) + + self.assertEqual(payload["response"], "Legal! Buscando carros de ate 70 mil para voce.") + self.assertEqual( + payload["tool_call"], + { + "name": "consultar_estoque", + "arguments": {"preco_max": 70000.0}, + }, + ) + + def test_extract_response_payload_handles_text_only_candidate_without_response_text_accessor(self): + service = LLMService.__new__(LLMService) + response = SimpleNamespace( + candidates=[ + SimpleNamespace( + content=SimpleNamespace( + parts=[ + SimpleNamespace(text="Resposta simples", function_call=None), + ] + ) + ) + ] + ) + + payload = service._extract_response_payload(response) + + self.assertEqual(payload, {"response": "Resposta simples", "tool_call": None})