You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
from typing import Any, Literal
|
|
|
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
|
|
|
|
# Esse modulo define o contrato estruturado esperado do modelo por turno.
|
|
TurnDomain = Literal["review", "sales", "general"]
|
|
TurnIntent = Literal[
|
|
"review_schedule",
|
|
"review_list",
|
|
"review_cancel",
|
|
"review_reschedule",
|
|
"order_create",
|
|
"order_list",
|
|
"order_cancel",
|
|
"inventory_search",
|
|
"conversation_reset",
|
|
"queue_continue",
|
|
"discard_queue",
|
|
"cancel_active_flow",
|
|
"general",
|
|
]
|
|
TurnAction = Literal[
|
|
"collect_review_schedule",
|
|
"collect_review_management",
|
|
"collect_order_create",
|
|
"collect_order_cancel",
|
|
"ask_missing_fields",
|
|
"answer_user",
|
|
"call_tool",
|
|
"clear_context",
|
|
"continue_queue",
|
|
"discard_queue",
|
|
"cancel_active_flow",
|
|
]
|
|
|
|
|
|
class DecisionEntities(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
# As entidades continuam separadas por tipo de fluxo para facilitar
|
|
# compatibilidade com os mixins e validadores tecnicos atuais.
|
|
generic_memory: dict[str, Any] = Field(default_factory=dict)
|
|
review_fields: dict[str, Any] = Field(default_factory=dict)
|
|
review_management_fields: dict[str, Any] = Field(default_factory=dict)
|
|
order_fields: dict[str, Any] = Field(default_factory=dict)
|
|
cancel_order_fields: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class TurnDecision(BaseModel):
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
# O modelo decide a intencao, o dominio e a acao do turno.
|
|
intent: TurnIntent = "general"
|
|
domain: TurnDomain = "general"
|
|
action: TurnAction = "answer_user"
|
|
entities: DecisionEntities = Field(default_factory=DecisionEntities)
|
|
missing_fields: list[str] = Field(default_factory=list)
|
|
selection_index: int | None = None
|
|
tool_name: str | None = None
|
|
tool_arguments: dict[str, Any] = Field(default_factory=dict)
|
|
response_to_user: str | None = None
|
|
|
|
@model_validator(mode="after")
|
|
def validate_contract(self):
|
|
if self.action == "ask_missing_fields":
|
|
if not self.missing_fields or not str(self.response_to_user or "").strip():
|
|
raise ValueError("ask_missing_fields exige missing_fields e response_to_user")
|
|
if self.action == "call_tool" and not str(self.tool_name or "").strip():
|
|
raise ValueError("call_tool exige tool_name")
|
|
if self.selection_index is not None and self.selection_index < 0:
|
|
raise ValueError("selection_index deve ser maior ou igual a zero")
|
|
return self
|