You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
316 lines
11 KiB
Python
316 lines
11 KiB
Python
import os
|
|
import unittest
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import patch
|
|
|
|
os.environ.setdefault("DEBUG", "false")
|
|
|
|
from app.services.flows.order_flow import OrderFlowMixin
|
|
from app.services.orchestration.conversation_policy import ConversationPolicy
|
|
from app.services.orchestration.entity_normalizer import EntityNormalizer
|
|
from app.services.tools.handlers import _parse_data_hora_revisao
|
|
|
|
|
|
class FakeState:
|
|
def __init__(self, entries=None, contexts=None):
|
|
self.entries = entries or {}
|
|
self.contexts = contexts or {}
|
|
|
|
def get_entry(self, bucket: str, user_id: int | None, *, expire: bool = False):
|
|
if user_id is None:
|
|
return None
|
|
return self.entries.get(bucket, {}).get(user_id)
|
|
|
|
def set_entry(self, bucket: str, user_id: int | None, value: dict):
|
|
if user_id is None:
|
|
return
|
|
self.entries.setdefault(bucket, {})[user_id] = value
|
|
|
|
def pop_entry(self, bucket: str, user_id: int | None):
|
|
if user_id is None:
|
|
return None
|
|
return self.entries.get(bucket, {}).pop(user_id, None)
|
|
|
|
def get_user_context(self, user_id: int | None):
|
|
if user_id is None:
|
|
return None
|
|
return self.contexts.get(user_id)
|
|
|
|
|
|
class FakeService:
|
|
def __init__(self, state):
|
|
self.state = state
|
|
self.normalizer = EntityNormalizer()
|
|
|
|
def _is_affirmative_message(self, text: str) -> bool:
|
|
normalized = self.normalizer.normalize_text(text).strip().rstrip(".!?,;:")
|
|
return normalized in {"sim", "pode", "ok", "confirmo", "aceito", "fechado", "pode sim", "tenho", "tenho sim"}
|
|
|
|
def _get_user_context(self, user_id: int | None):
|
|
return self.state.get_user_context(user_id)
|
|
|
|
|
|
class FakeRegistry:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
async def execute(self, tool_name: str, arguments: dict, user_id: int | None = None):
|
|
self.calls.append((tool_name, arguments, user_id))
|
|
if tool_name == "realizar_pedido":
|
|
vehicle_map = {
|
|
1: ("Honda Civic 2021", 51524.0),
|
|
2: ("Toyota Corolla 2020", 58476.0),
|
|
}
|
|
modelo_veiculo, valor_veiculo = vehicle_map[arguments["vehicle_id"]]
|
|
return {
|
|
"numero_pedido": "PED-TESTE-123",
|
|
"status": "Ativo",
|
|
"modelo_veiculo": modelo_veiculo,
|
|
"valor_veiculo": valor_veiculo,
|
|
}
|
|
return {
|
|
"numero_pedido": arguments["numero_pedido"],
|
|
"status": "Cancelado",
|
|
"motivo": arguments["motivo"],
|
|
}
|
|
|
|
|
|
class OrderFlowHarness(OrderFlowMixin):
|
|
def __init__(self, state, registry):
|
|
self.state = state
|
|
self.registry = registry
|
|
self.normalizer = EntityNormalizer()
|
|
|
|
def _get_user_context(self, user_id: int | None):
|
|
return self.state.get_user_context(user_id)
|
|
|
|
def _normalize_intents(self, data) -> dict:
|
|
return self.normalizer.normalize_intents(data)
|
|
|
|
def _normalize_cancel_order_fields(self, data) -> dict:
|
|
return self.normalizer.normalize_cancel_order_fields(data)
|
|
|
|
def _normalize_order_fields(self, data) -> dict:
|
|
return self.normalizer.normalize_order_fields(data)
|
|
|
|
def _normalize_text(self, text: str) -> str:
|
|
return self.normalizer.normalize_text(text)
|
|
|
|
def _http_exception_detail(self, exc) -> str:
|
|
return str(exc)
|
|
|
|
def _fallback_format_tool_result(self, tool_name: str, tool_result) -> str:
|
|
if tool_name == "realizar_pedido":
|
|
return (
|
|
f"Pedido criado com sucesso.\n"
|
|
f"Numero: {tool_result['numero_pedido']}\n"
|
|
f"Veiculo: {tool_result['modelo_veiculo']}\n"
|
|
f"Valor: R$ {tool_result['valor_veiculo']:.2f}"
|
|
)
|
|
return (
|
|
f"Pedido {tool_result['numero_pedido']} atualizado.\n"
|
|
f"Status: {tool_result['status']}\n"
|
|
f"Motivo: {tool_result['motivo']}"
|
|
)
|
|
|
|
def _try_prefill_order_cpf_from_user_profile(self, user_id: int | None, payload: dict) -> None:
|
|
return None
|
|
|
|
def _load_vehicle_by_id(self, vehicle_id: int) -> dict | None:
|
|
for context in self.state.contexts.values():
|
|
for item in context.get("last_stock_results", []):
|
|
if int(item["id"]) == int(vehicle_id):
|
|
return dict(item)
|
|
return None
|
|
|
|
|
|
class ConversationAdjustmentsTests(unittest.TestCase):
|
|
def test_defer_flow_cancel_when_order_cancel_draft_waits_for_reason(self):
|
|
state = FakeState(
|
|
entries={
|
|
"pending_cancel_order_drafts": {
|
|
7: {
|
|
"payload": {"numero_pedido": "PED-20260305120000-ABC123"},
|
|
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
|
}
|
|
}
|
|
}
|
|
)
|
|
policy = ConversationPolicy(service=FakeService(state))
|
|
|
|
self.assertTrue(policy.should_defer_flow_cancellation_control("desisti", user_id=7))
|
|
self.assertFalse(policy.should_defer_flow_cancellation_control("cancelar fluxo atual", user_id=7))
|
|
|
|
def test_normalize_datetime_connector_accepts_as_com_acento(self):
|
|
normalizer = EntityNormalizer()
|
|
|
|
self.assertEqual(
|
|
normalizer.normalize_datetime_connector("10/03/2026 às 09:00"),
|
|
"10/03/2026 09:00",
|
|
)
|
|
|
|
def test_parse_review_datetime_accepts_as_com_acento(self):
|
|
parsed = _parse_data_hora_revisao("10/03/2026 às 09:00")
|
|
|
|
self.assertEqual(parsed, datetime(2026, 3, 10, 9, 0))
|
|
|
|
|
|
class CancelOrderFlowTests(unittest.IsolatedAsyncioTestCase):
|
|
async def test_cancel_order_flow_consumes_free_text_reason(self):
|
|
state = FakeState(
|
|
entries={
|
|
"pending_cancel_order_drafts": {
|
|
42: {
|
|
"payload": {"numero_pedido": "PED-20260305120000-ABC123"},
|
|
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
|
}
|
|
}
|
|
}
|
|
)
|
|
registry = FakeRegistry()
|
|
flow = OrderFlowHarness(state=state, registry=registry)
|
|
|
|
response = await flow._try_collect_and_cancel_order(
|
|
message="desisti",
|
|
user_id=42,
|
|
extracted_fields={},
|
|
intents={},
|
|
)
|
|
|
|
self.assertEqual(len(registry.calls), 1)
|
|
tool_name, arguments, tool_user_id = registry.calls[0]
|
|
self.assertEqual(tool_name, "cancelar_pedido")
|
|
self.assertEqual(tool_user_id, 42)
|
|
self.assertEqual(arguments["numero_pedido"], "PED-20260305120000-ABC123")
|
|
self.assertEqual(arguments["motivo"], "desisti")
|
|
self.assertIn("Status: Cancelado", response)
|
|
self.assertIsNone(state.get_entry("pending_cancel_order_drafts", 42))
|
|
|
|
async def test_cancel_order_flow_still_requests_reason_when_message_is_too_short(self):
|
|
state = FakeState(
|
|
entries={
|
|
"pending_cancel_order_drafts": {
|
|
42: {
|
|
"payload": {"numero_pedido": "PED-20260305120000-ABC123"},
|
|
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
|
}
|
|
}
|
|
}
|
|
)
|
|
registry = FakeRegistry()
|
|
flow = OrderFlowHarness(state=state, registry=registry)
|
|
|
|
response = await flow._try_collect_and_cancel_order(
|
|
message="ok",
|
|
user_id=42,
|
|
extracted_fields={},
|
|
intents={},
|
|
)
|
|
|
|
self.assertEqual(registry.calls, [])
|
|
self.assertIn("o motivo do cancelamento", response)
|
|
self.assertIsNotNone(state.get_entry("pending_cancel_order_drafts", 42))
|
|
|
|
async def test_cancel_order_flow_does_not_override_active_order_creation_draft(self):
|
|
state = FakeState(
|
|
entries={
|
|
"pending_order_drafts": {
|
|
42: {
|
|
"payload": {},
|
|
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
|
}
|
|
}
|
|
}
|
|
)
|
|
registry = FakeRegistry()
|
|
flow = OrderFlowHarness(state=state, registry=registry)
|
|
|
|
response = await flow._try_collect_and_cancel_order(
|
|
message="2",
|
|
user_id=42,
|
|
extracted_fields={},
|
|
intents={"order_cancel": True},
|
|
)
|
|
|
|
self.assertIsNone(response)
|
|
self.assertEqual(registry.calls, [])
|
|
|
|
|
|
class CreateOrderFlowWithVehicleTests(unittest.IsolatedAsyncioTestCase):
|
|
async def test_order_flow_requests_vehicle_selection_from_last_stock_results(self):
|
|
state = FakeState(
|
|
contexts={
|
|
10: {
|
|
"generic_memory": {},
|
|
"last_stock_results": [
|
|
{"id": 1, "modelo": "Honda Civic 2021", "categoria": "sedan", "preco": 51524.0},
|
|
{"id": 2, "modelo": "Toyota Corolla 2020", "categoria": "hatch", "preco": 58476.0},
|
|
],
|
|
"selected_vehicle": None,
|
|
}
|
|
}
|
|
)
|
|
registry = FakeRegistry()
|
|
flow = OrderFlowHarness(state=state, registry=registry)
|
|
|
|
response = await flow._try_collect_and_create_order(
|
|
message="Quero fazer um pedido",
|
|
user_id=10,
|
|
extracted_fields={},
|
|
intents={"order_create": True},
|
|
)
|
|
|
|
self.assertIn("escolha primeiro qual veiculo", response.lower())
|
|
self.assertIn("Honda Civic 2021", response)
|
|
self.assertEqual(registry.calls, [])
|
|
|
|
async def test_order_flow_creates_order_with_selected_vehicle_from_list_index(self):
|
|
state = FakeState(
|
|
entries={
|
|
"pending_order_drafts": {
|
|
10: {
|
|
"payload": {"cpf": "12345678909"},
|
|
"expires_at": datetime.utcnow() + timedelta(minutes=30),
|
|
}
|
|
}
|
|
},
|
|
contexts={
|
|
10: {
|
|
"generic_memory": {"cpf": "12345678909"},
|
|
"last_stock_results": [
|
|
{"id": 1, "modelo": "Honda Civic 2021", "categoria": "sedan", "preco": 51524.0},
|
|
{"id": 2, "modelo": "Toyota Corolla 2020", "categoria": "hatch", "preco": 58476.0},
|
|
],
|
|
"selected_vehicle": None,
|
|
}
|
|
},
|
|
)
|
|
registry = FakeRegistry()
|
|
flow = OrderFlowHarness(state=state, registry=registry)
|
|
|
|
async def fake_hydrate_mock_customer_from_cpf(cpf: str, user_id: int | None = None):
|
|
return {"cpf": cpf, "user_id": user_id}
|
|
|
|
with patch(
|
|
"app.services.flows.order_flow.hydrate_mock_customer_from_cpf",
|
|
new=fake_hydrate_mock_customer_from_cpf,
|
|
):
|
|
response = await flow._try_collect_and_create_order(
|
|
message="2",
|
|
user_id=10,
|
|
extracted_fields={},
|
|
intents={},
|
|
)
|
|
|
|
self.assertEqual(len(registry.calls), 1)
|
|
tool_name, arguments, tool_user_id = registry.calls[0]
|
|
self.assertEqual(tool_name, "realizar_pedido")
|
|
self.assertEqual(tool_user_id, 10)
|
|
self.assertEqual(arguments["vehicle_id"], 2)
|
|
self.assertEqual(arguments["cpf"], "12345678909")
|
|
self.assertIn("Veiculo: Toyota Corolla 2020", response)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|