import asyncio import threading import time import unittest from datetime import datetime from types import SimpleNamespace from unittest.mock import patch from fastapi import HTTPException from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from app.db.mock_database import MockBase from app.db.mock_models import ReviewSchedule from app.services.domain import review_service class ReviewLockingQuery: def __init__(self, results=None): self.results = list(results or []) def filter(self, *args, **kwargs): return self def first(self): if self.results: return self.results.pop(0) return None class ReviewLockingSession: def __init__(self, *, query_results=None, lock_acquired=1): self.query_instance = ReviewLockingQuery(query_results) self.lock_acquired = lock_acquired self.execute_calls = [] self.added = [] self.committed = False self.closed = False self.refreshed = [] def query(self, model): if model is review_service.ReviewSchedule: return self.query_instance raise AssertionError(f"unexpected model query: {model}") def execute(self, statement, params=None): sql_text = str(statement) self.execute_calls.append((sql_text, params)) if "GET_LOCK" in sql_text: return SimpleNamespace(scalar=lambda: self.lock_acquired) if "RELEASE_LOCK" in sql_text: return SimpleNamespace(scalar=lambda: 1) raise AssertionError(f"unexpected execute call: {sql_text}") def add(self, item): self.added.append(item) def commit(self): self.committed = True def refresh(self, item): self.refreshed.append(item) def close(self): self.closed = True class ReviewServiceLockingTests(unittest.IsolatedAsyncioTestCase): def _build_threadsafe_session_local(self): engine = create_engine( "sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) MockBase.metadata.create_all(bind=engine) self.addCleanup(engine.dispose) return SessionLocal def test_acquire_review_slot_lock_returns_conflict_when_slot_is_busy(self): session = ReviewLockingSession(lock_acquired=0) with self.assertRaises(HTTPException) as ctx: review_service._acquire_review_slot_lock( session, requested_dt=datetime(2026, 3, 18, 9, 0), ) self.assertEqual(ctx.exception.status_code, 409) self.assertEqual(ctx.exception.detail["code"], "review_slot_busy") self.assertTrue(any("GET_LOCK" in sql for sql, _ in session.execute_calls)) async def test_agendar_revisao_uses_slot_lock_and_releases_after_success(self): session = ReviewLockingSession(query_results=[None, None]) with patch.object(review_service, "SessionMockLocal", return_value=session): result = await review_service.agendar_revisao( placa="ABC1234", data_hora="18/03/2026 09:00", modelo="Onix", ano=2022, km=15000, revisao_previa_concessionaria=False, user_id=7, ) self.assertTrue(any("GET_LOCK" in sql for sql, _ in session.execute_calls)) self.assertTrue(any("RELEASE_LOCK" in sql for sql, _ in session.execute_calls)) self.assertTrue(session.committed) self.assertEqual(len(session.added), 1) self.assertEqual(result["status"], "agendado") self.assertTrue(session.closed) async def test_editar_data_revisao_releases_slot_lock_when_conflict_is_detected(self): current_schedule = ReviewSchedule( id=1, protocolo="REV-20260318-AAAA1111", user_id=7, placa="ABC1234", data_hora=datetime(2026, 3, 18, 9, 0), status="agendado", ) conflicting_schedule = ReviewSchedule( id=2, protocolo="REV-20260319-BBBB2222", user_id=8, placa="XYZ9876", data_hora=datetime(2026, 3, 19, 10, 0), status="agendado", ) session = ReviewLockingSession(query_results=[current_schedule, conflicting_schedule]) with patch.object(review_service, "SessionMockLocal", return_value=session): with self.assertRaises(HTTPException) as ctx: await review_service.editar_data_revisao( protocolo=current_schedule.protocolo, nova_data_hora="19/03/2026 10:00", user_id=7, ) self.assertTrue(any("GET_LOCK" in sql for sql, _ in session.execute_calls)) self.assertTrue(any("RELEASE_LOCK" in sql for sql, _ in session.execute_calls)) self.assertEqual(ctx.exception.status_code, 409) self.assertEqual(ctx.exception.detail["code"], "review_schedule_conflict") self.assertFalse(session.committed) self.assertTrue(session.closed) async def test_agendar_revisao_allows_single_success_under_race(self): SessionLocal = self._build_threadsafe_session_local() attempts = 4 start_barrier = threading.Barrier(attempts) slot_locks: dict[str, threading.Lock] = {} slot_locks_guard = threading.Lock() def _acquire_slot_lock(db, *, requested_dt, timeout_seconds=5, field_name="data_hora"): lock_name = review_service._review_slot_lock_name(requested_dt) with slot_locks_guard: slot_lock = slot_locks.setdefault(lock_name, threading.Lock()) acquired = slot_lock.acquire(timeout=timeout_seconds) if not acquired: review_service.raise_tool_http_error( status_code=409, code="review_slot_busy", message="Outro atendimento esta finalizando este horario de revisao. Tente novamente.", retryable=True, field=field_name, ) db.info.setdefault("_test_review_slot_locks", {})[lock_name] = slot_lock time.sleep(0.05) return lock_name def _release_slot_lock(db, lock_name): if not lock_name: return held_lock = db.info.get("_test_review_slot_locks", {}).pop(lock_name, None) if held_lock and held_lock.locked(): held_lock.release() def _sync_schedule_review(): start_barrier.wait(timeout=5) return asyncio.run( review_service.agendar_revisao( placa="ABC1234", data_hora="18/03/2026 09:00", modelo="Onix", ano=2022, km=15000, revisao_previa_concessionaria=False, user_id=7, ) ) with patch.object(review_service, "SessionMockLocal", SessionLocal), patch.object( review_service, "_acquire_review_slot_lock", side_effect=_acquire_slot_lock, ), patch.object( review_service, "_release_review_slot_lock", side_effect=_release_slot_lock, ): results = await asyncio.gather( *[asyncio.to_thread(_sync_schedule_review) for _ in range(attempts)], return_exceptions=True, ) successes = [result for result in results if isinstance(result, dict)] conflict_codes = {"review_schedule_conflict", "review_slot_busy"} conflicts = [ result for result in results if isinstance(result, HTTPException) and isinstance(result.detail, dict) and result.detail.get("code") in conflict_codes ] unexpected = [result for result in results if result not in successes and result not in conflicts] self.assertEqual(len(successes), 1) self.assertEqual(len(conflicts), attempts - 1) self.assertEqual(unexpected, []) db = SessionLocal() try: schedules = db.query(ReviewSchedule).all() self.assertEqual(len(schedules), 1) self.assertEqual(schedules[0].status, "agendado") self.assertEqual(schedules[0].placa, "ABC1234") finally: db.close() if __name__ == "__main__": unittest.main()