diff --git a/app/services/flows/order_flow.py b/app/services/flows/order_flow.py index 508d890..e1bc7e7 100644 --- a/app/services/flows/order_flow.py +++ b/app/services/flows/order_flow.py @@ -79,6 +79,7 @@ class OrderFlowMixin: if budget: generic_memory["orcamento_max"] = int(round(budget)) context.setdefault("shared_memory", {})["orcamento_max"] = int(round(budget)) + self._save_user_context(user_id=user_id, context=context) def _try_prefill_order_cpf_from_memory(self, user_id: int | None, payload: dict) -> None: if user_id is None or payload.get("cpf"): @@ -142,6 +143,7 @@ class OrderFlowMixin: context["last_stock_results"] = sanitized if sanitized: context["selected_vehicle"] = None + self._save_user_context(user_id=user_id, context=context) def _store_selected_vehicle(self, user_id: int | None, vehicle: dict | None) -> None: if user_id is None: @@ -150,6 +152,7 @@ class OrderFlowMixin: if not context: return context["selected_vehicle"] = dict(vehicle) if isinstance(vehicle, dict) else None + self._save_user_context(user_id=user_id, context=context) def _vehicle_to_payload(self, vehicle: dict) -> dict: return { @@ -244,6 +247,7 @@ class OrderFlowMixin: return context["last_stock_results"] = [] context["selected_vehicle"] = None + self._save_user_context(user_id=user_id, context=context) def _match_vehicle_from_message_index(self, message: str, stock_results: list[dict]) -> dict | None: tokens = [token for token in re.findall(r"\d+", str(message or "")) if token.isdigit()] diff --git a/app/services/orchestration/conversation_policy.py b/app/services/orchestration/conversation_policy.py index 7af354f..022c409 100644 --- a/app/services/orchestration/conversation_policy.py +++ b/app/services/orchestration/conversation_policy.py @@ -17,6 +17,11 @@ class ConversationPolicy: def __init__(self, service: "OrquestradorService"): self.service = service + def _save_context(self, user_id: int | None, context: dict | None) -> None: + if user_id is None or not isinstance(context, dict): + return + self.service._save_user_context(user_id=user_id, context=context) + def _decision_action(self, turn_decision: dict | None) -> str: return str((turn_decision or {}).get("action") or "").strip().lower() @@ -72,6 +77,7 @@ class ConversationPolicy: "created_at": datetime.utcnow().isoformat(), } ) + self._save_context(user_id=user_id, context=context) # Transforma as entidades extraídas de um pedido em uma memória temporária pronta para usar quando esse pedido for processado. @@ -109,7 +115,9 @@ class ConversationPolicy: queue = context.setdefault("order_queue", []) if not queue: return None - return queue.pop(0) + popped = queue.pop(0) + self._save_context(user_id=user_id, context=context) + return popped @@ -196,6 +204,7 @@ class ConversationPolicy: queued_count += 1 context["active_domain"] = first["domain"] context["generic_memory"] = self.build_order_memory_seed(user_id=user_id, order=first) + self._save_context(user_id=user_id, context=context) queue_notice = self.render_queue_notice(queued_count) return first["message"], queue_notice, None @@ -226,6 +235,7 @@ class ConversationPolicy: ], "expires_at": datetime.utcnow() + timedelta(minutes=PENDING_ORDER_SELECTION_TTL_MINUTES), } + self._save_context(user_id=user_id, context=context) # Cria o texto de escolha para o usuário. @@ -437,10 +447,12 @@ class ConversationPolicy: return None if pending.get("expires_at") and pending["expires_at"] < datetime.utcnow(): context["pending_order_selection"] = None + self._save_context(user_id=user_id, context=context) return None orders = pending.get("orders") or [] if len(orders) < 2: context["pending_order_selection"] = None + self._save_context(user_id=user_id, context=context) return None decision_action = self._decision_action(turn_decision) @@ -459,6 +471,7 @@ class ConversationPolicy: if selected_index is None: if self.looks_like_fresh_operational_request(message, turn_decision=turn_decision): context["pending_order_selection"] = None + self._save_context(user_id=user_id, context=context) return None return self.render_order_selection_prompt(orders) @@ -480,6 +493,7 @@ class ConversationPolicy: selected_memory = dict(selected_order.get("memory_seed") or {}) if selected_memory: context["generic_memory"] = selected_memory + self._save_context(user_id=user_id, context=context) next_response = await self.service.handle_message(str(selected_order.get("message") or ""), user_id=user_id) return f"{intro}\n{next_response}" @@ -564,6 +578,7 @@ class ConversationPolicy: "memory_seed": dict(next_order.get("memory_seed") or self.service._new_tab_memory(user_id=user_id)), "expires_at": datetime.utcnow() + timedelta(minutes=15), } + self._save_context(user_id=user_id, context=context) transition = self.build_next_order_transition(next_order["domain"]) return ( f"{base_response}\n\n" @@ -622,6 +637,7 @@ class ConversationPolicy: return None if pending_switch.get("expires_at") and pending_switch["expires_at"] < datetime.utcnow(): context["pending_switch"] = None + self._save_context(user_id=user_id, context=context) return None queued_message = str(pending_switch.get("queued_message") or "").strip() if not queued_message: @@ -630,6 +646,7 @@ class ConversationPolicy: decision_action = self._decision_action(turn_decision) if self.service._is_negative_message(message) and decision_action != "continue_queue": context["pending_switch"] = None + self._save_context(user_id=user_id, context=context) return "Tudo bem. Mantive o proximo pedido fora da fila por enquanto." if not ( self.is_continue_queue_message(message, turn_decision=turn_decision) @@ -643,6 +660,7 @@ class ConversationPolicy: refreshed = self.service._get_user_context(user_id) if refreshed is not None: refreshed["generic_memory"] = memory_seed + self._save_context(user_id=user_id, context=refreshed) transition = self.build_next_order_transition(target_domain) next_response = await self.service.handle_message(queued_message, user_id=user_id) return f"{transition}\n{next_response}" @@ -681,6 +699,7 @@ class ConversationPolicy: context["generic_memory"] = self.service._new_tab_memory(user_id=user_id) context["pending_order_selection"] = None context["pending_switch"] = None + self._save_context(user_id=user_id, context=context) # Controla a confirmação de “você quer mesmo sair deste assunto e ir para outro?”. @@ -698,17 +717,20 @@ class ConversationPolicy: if pending_switch: if pending_switch["expires_at"] < datetime.utcnow(): context["pending_switch"] = None + self._save_context(user_id=user_id, context=context) elif self.is_context_switch_confirmation(message, turn_decision=turn_decision): if self.service._is_affirmative_message(message) or self._decision_domain(turn_decision) == pending_switch["target_domain"]: target_domain = pending_switch["target_domain"] self.apply_domain_switch(user_id=user_id, target_domain=target_domain) return self.render_context_switched_message(target_domain=target_domain) context["pending_switch"] = None + self._save_context(user_id=user_id, context=context) return "Perfeito, vamos continuar no fluxo atual." pending_order_selection = context.get("pending_order_selection") if pending_order_selection and pending_order_selection.get("expires_at") < datetime.utcnow(): context["pending_order_selection"] = None + self._save_context(user_id=user_id, context=context) current_domain = context.get("active_domain", "general") if target_domain_hint == "general" or target_domain_hint == current_domain: @@ -721,6 +743,7 @@ class ConversationPolicy: "target_domain": target_domain_hint, "expires_at": datetime.utcnow() + timedelta(minutes=15), } + self._save_context(user_id=user_id, context=context) return self.render_context_switch_confirmation(source_domain=current_domain, target_domain=target_domain_hint) @@ -729,6 +752,7 @@ class ConversationPolicy: context = self.service._get_user_context(user_id) if context and domain_hint != "general": context["active_domain"] = domain_hint + self._save_context(user_id=user_id, context=context) # Serve para exibir o nome do domínio em mensagens para o usuário. diff --git a/app/services/orchestration/conversation_state_repository.py b/app/services/orchestration/conversation_state_repository.py index 7df8b10..e1694a6 100644 --- a/app/services/orchestration/conversation_state_repository.py +++ b/app/services/orchestration/conversation_state_repository.py @@ -11,6 +11,10 @@ class ConversationStateRepository(ABC): def get_user_context(self, user_id: int | None) -> dict | None: pass + @abstractmethod + def save_user_context(self, user_id: int | None, context: dict) -> None: + pass + @abstractmethod def get_entry(self, bucket: str, user_id: int | None, *, expire: bool = False) -> dict | None: pass diff --git a/app/services/orchestration/conversation_state_store.py b/app/services/orchestration/conversation_state_store.py index 2f12d68..0608109 100644 --- a/app/services/orchestration/conversation_state_store.py +++ b/app/services/orchestration/conversation_state_store.py @@ -47,6 +47,11 @@ class ConversationStateStore(ConversationStateRepository): return None return context + def save_user_context(self, user_id: int | None, context: dict) -> None: + if user_id is None or not isinstance(context, dict): + return + self.user_contexts[user_id] = context + def get_entry(self, bucket: str, user_id: int | None, *, expire: bool = False) -> dict | None: if user_id is None: return None diff --git a/app/services/orchestration/orquestrador_service.py b/app/services/orchestration/orquestrador_service.py index 3018a8a..0a283e8 100644 --- a/app/services/orchestration/orquestrador_service.py +++ b/app/services/orchestration/orquestrador_service.py @@ -552,6 +552,7 @@ class OrquestradorService(ReviewFlowMixin, OrderFlowMixin): context["pending_switch"] = None context["last_stock_results"] = [] context["selected_vehicle"] = None + self._save_user_context(user_id=user_id, context=context) def _clear_pending_order_navigation(self, user_id: int | None) -> int: context = self._get_user_context(user_id) @@ -566,6 +567,7 @@ class OrquestradorService(ReviewFlowMixin, OrderFlowMixin): context["order_queue"] = [] context["pending_switch"] = None context["pending_order_selection"] = None + self._save_user_context(user_id=user_id, context=context) return dropped def _cancel_active_flow(self, user_id: int | None) -> str: @@ -581,6 +583,7 @@ class OrquestradorService(ReviewFlowMixin, OrderFlowMixin): self._reset_pending_order_states(user_id=user_id) context["pending_switch"] = None + self._save_user_context(user_id=user_id, context=context) if had_flow: return f"Fluxo atual de {self._domain_label(active_domain)} cancelado." return "Nao havia fluxo em andamento para cancelar." @@ -603,6 +606,7 @@ class OrquestradorService(ReviewFlowMixin, OrderFlowMixin): refreshed = self._get_user_context(user_id) if refreshed is not None: refreshed["generic_memory"] = memory_seed + self._save_user_context(user_id=user_id, context=refreshed) transition = self._build_next_order_transition(target_domain) next_response = await self.handle_message(queued_message, user_id=user_id) return f"{transition}\n{next_response}" @@ -617,6 +621,7 @@ class OrquestradorService(ReviewFlowMixin, OrderFlowMixin): refreshed = self._get_user_context(user_id) if refreshed is not None: refreshed["generic_memory"] = memory_seed + self._save_user_context(user_id=user_id, context=refreshed) transition = self._build_next_order_transition(target_domain) next_response = await self.handle_message(str(next_order.get("message") or ""), user_id=user_id) return f"{transition}\n{next_response}" @@ -668,6 +673,11 @@ class OrquestradorService(ReviewFlowMixin, OrderFlowMixin): def _get_user_context(self, user_id: int | None) -> dict | None: return self.state.get_user_context(user_id) + def _save_user_context(self, user_id: int | None, context: dict | None) -> None: + if user_id is None or not isinstance(context, dict): + return + self.state.save_user_context(user_id=user_id, context=context) + def _extract_generic_memory_fields(self, llm_generic_fields: dict | None = None) -> dict: extracted: dict = {} llm_fields = llm_generic_fields or {} @@ -704,6 +714,7 @@ class OrquestradorService(ReviewFlowMixin, OrderFlowMixin): # Campos novos entram e campos repetidos sobrescrevem valor antigo. context["generic_memory"].update(fields) context.setdefault("shared_memory", {}).update(fields) + self._save_user_context(user_id=user_id, context=context) def _capture_tool_result_context( self, @@ -738,6 +749,7 @@ class OrquestradorService(ReviewFlowMixin, OrderFlowMixin): context["last_stock_results"] = sanitized if sanitized: context["selected_vehicle"] = None + self._save_user_context(user_id=user_id, context=context) async def _maybe_build_stock_suggestion_response( self, diff --git a/app/services/orchestration/redis_state_repository.py b/app/services/orchestration/redis_state_repository.py index 514bef2..fe13f60 100644 --- a/app/services/orchestration/redis_state_repository.py +++ b/app/services/orchestration/redis_state_repository.py @@ -58,6 +58,16 @@ class RedisConversationStateRepository(ConversationStateRepository): return None return context + def save_user_context(self, user_id: int | None, context: dict) -> None: + if user_id is None or not isinstance(context, dict): + return + payload = dict(context) + ttl_seconds = self._ttl_from_entry(payload) + if ttl_seconds is None: + payload["expires_at"] = datetime.utcnow().replace(microsecond=0) + self._minutes_delta(self.default_ttl_minutes) + ttl_seconds = self.default_ttl_minutes * 60 + self._save(self._bucket_key("user_contexts", user_id), payload, ttl_seconds=ttl_seconds) + def get_entry(self, bucket: str, user_id: int | None, *, expire: bool = False) -> dict | None: if user_id is None: return None diff --git a/tests/test_conversation_adjustments.py b/tests/test_conversation_adjustments.py index 84d6f3f..882c44f 100644 --- a/tests/test_conversation_adjustments.py +++ b/tests/test_conversation_adjustments.py @@ -40,6 +40,11 @@ class FakeState: return None return self.contexts.get(user_id) + def save_user_context(self, user_id: int | None, context: dict): + if user_id is None: + return + self.contexts[user_id] = context + class FakeService: def __init__(self, state): @@ -53,6 +58,11 @@ class FakeService: def _get_user_context(self, user_id: int | None): return self.state.get_user_context(user_id) + def _save_user_context(self, user_id: int | None, context: dict | None) -> None: + if user_id is None or not isinstance(context, dict): + return + self.state.save_user_context(user_id, context) + class FakeRegistry: def __init__(self): @@ -108,6 +118,11 @@ class OrderFlowHarness(OrderFlowMixin): def _get_user_context(self, user_id: int | None): return self.state.get_user_context(user_id) + def _save_user_context(self, user_id: int | None, context: dict | None) -> None: + if user_id is None or not isinstance(context, dict): + return + self.state.save_user_context(user_id, context) + def _normalize_intents(self, data) -> dict: return self.normalizer.normalize_intents(data) diff --git a/tests/test_turn_decision_contract.py b/tests/test_turn_decision_contract.py index 7184e28..a381263 100644 --- a/tests/test_turn_decision_contract.py +++ b/tests/test_turn_decision_contract.py @@ -48,6 +48,11 @@ class FakeState: return None return self.contexts.get(user_id) + def save_user_context(self, user_id: int | None, context: dict): + if user_id is None: + return + self.contexts[user_id] = context + class FakeToolExecutor: def __init__(self, result=None): @@ -83,6 +88,11 @@ class FakePolicyService: return None return self.state.contexts.get(user_id) + def _save_user_context(self, user_id: int | None, context: dict | None) -> None: + if user_id is None or not isinstance(context, dict): + return + self.state.save_user_context(user_id, context) + def _new_tab_memory(self, user_id: int | None): return {}