"""
V18 anti-freeze envelope tests.

Verify that the orchestrator survives an unresponsive broker/provider:
  - `_build_tech` returns None when the market data provider hangs
    longer than `market_data_timeout_seconds`.
  - `_refresh_account_balance` keeps its cached value when the broker
    hangs longer than `broker_call_timeout_seconds`.
  - The maintenance loop emits a heartbeat line on system.log on
    every iteration (even when nothing else is logged), so watchdog.sh
    (STALE_SECONDS=600) does not interpret idle-but-alive ticks as
    freezes.

These tests pin down the regression that caused the bot to be
SIGTERMed every ~10min in live: an SDK WS init call without timeout
froze the scan loop until system.log went stale.

Run:
    cd ~/apex_v16
    python -m pytest tests/test_orchestrator_freeze_guard.py -q
"""

from __future__ import annotations

import asyncio
import logging
import sys
from datetime import datetime, timezone
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from core.config import AccountKind, RunMode, RuntimeConfig
from orchestrator import Orchestrator
from persistence.state_store import SessionState


# ============================================================
# Minimal fakes (no broker, no real provider)
# ============================================================

class HangingProvider:
    """Mimics an SDK whose get_bars never resolves — provokes asyncio.TimeoutError."""

    def __init__(self) -> None:
        self.calls = 0

    async def get_bars(self, symbol, timeframe, n):
        self.calls += 1
        # Sleep WAY past the configured timeout so wait_for cancels us.
        await asyncio.sleep(60.0)


class HangingBroker:
    """get_account_balance never resolves — provokes asyncio.TimeoutError."""

    def __init__(self) -> None:
        self.balance_calls = 0

    async def get_account_balance(self):
        self.balance_calls += 1
        await asyncio.sleep(60.0)


class _MemStore:
    def __init__(self) -> None:
        self._state = None

    def save(self, state):
        self._state = state


class _ListHandler(logging.Handler):
    """Capture log records on a list — used to inspect heartbeat emission."""

    def __init__(self) -> None:
        super().__init__()
        self.records: list[logging.LogRecord] = []

    def emit(self, record: logging.LogRecord) -> None:
        self.records.append(record)


class _CapturingLogger:
    """Mimics LoggerBundle: orchestrator only needs `.system` (stdlib Logger)
    and optional log_* hooks. Tests inspect `.system_handler.records` directly."""

    def __init__(self) -> None:
        self.system = logging.getLogger("test.freeze_guard")
        self.system.setLevel(logging.DEBUG)
        # Reset handlers per test to keep records list clean.
        for h in list(self.system.handlers):
            self.system.removeHandler(h)
        self.system_handler = _ListHandler()
        self.system_handler.setLevel(logging.DEBUG)
        self.system.addHandler(self.system_handler)

    # No-op shims (Orchestrator pings these but tests don't care).
    def log_session_event(self, *a, **k): pass
    def log_trade_opened(self, *a, **k): pass
    def log_trade_closed(self, *a, **k): pass
    def log_error(self, *a, **k): pass


# ============================================================
# Helpers
# ============================================================

def _ok(label: str) -> None:
    print(f"  ok  {label}")


def make_cfg(**over) -> RuntimeConfig:
    cfg = RuntimeConfig(mode=RunMode.PAPER, account=AccountKind.INELIGIBLE)
    cfg.asset_filter = ["MES"]
    cfg.scan_loop_phase_offset_seconds = 0.0
    cfg.manage_loop_interval_seconds = 0
    cfg.maintenance_loop_interval_seconds = 0
    # Tight timeouts so the test completes in <1s.
    cfg.market_data_timeout_seconds = 0.3
    cfg.broker_call_timeout_seconds = 0.3
    cfg.h4_fetch_timeout_seconds = 0.3
    cfg.watchdog_timeout_seconds = 0.3
    cfg.news_sync_timeout_seconds = 0.3
    for k, v in over.items():
        setattr(cfg, k, v)
    return cfg


def make_orch(cfg, *, provider=None, broker=None, logger=None) -> Orchestrator:
    state = SessionState()
    return Orchestrator(
        config=cfg,
        ai_client=None,
        market_data_provider=provider or HangingProvider(),
        state=state,
        store=_MemStore(),
        logger=logger or _CapturingLogger(),
        brain_dispatch={},
        opener=None, closer=None, risk_manager=None,
        broker=broker,
        max_iterations=1,
    )


# ============================================================
# Tests
# ============================================================

async def test_build_tech_returns_none_when_provider_hangs():
    """_build_tech wraps build_tech_snapshot with wait_for; hang → None."""
    cfg = make_cfg()
    provider = HangingProvider()
    logger = _CapturingLogger()
    orch = make_orch(cfg, provider=provider, logger=logger)

    result = await orch._build_tech("MES")
    assert result is None, "hang must be converted to None, not propagated"
    # Provider was queried at least once (no silent skip).
    assert provider.calls >= 1
    # Timeout warning surfaced on system.log.
    msgs = [r.getMessage() for r in logger.system_handler.records]
    assert any("build_tech_snapshot timed out" in m for m in msgs), msgs
    _ok("_build_tech: provider hang -> None + system.log warning")


async def test_fetch_h4_returns_none_on_timeout():
    """_fetch_h4 returns None when provider.get_bars hangs."""
    cfg = make_cfg()
    provider = HangingProvider()
    logger = _CapturingLogger()
    orch = make_orch(cfg, provider=provider, logger=logger)

    result = await orch._fetch_h4("MES")
    assert result is None
    msgs = [r.getMessage() for r in logger.system_handler.records]
    assert any("H4 fetch timed out" in m for m in msgs), msgs
    _ok("_fetch_h4: provider hang -> None + system.log warning")


async def test_refresh_balance_keeps_cached_on_timeout():
    """_refresh_account_balance falls back to dry_run_balance on hang."""
    cfg = make_cfg()
    broker = HangingBroker()
    logger = _CapturingLogger()
    orch = make_orch(cfg, broker=broker, logger=logger)

    await orch._refresh_account_balance()
    # Fallback: dry_run_balance assigned, no crash.
    assert orch._cached_account_balance == cfg.dry_run_balance
    msgs = [r.getMessage() for r in logger.system_handler.records]
    assert any("get_account_balance timed out" in m for m in msgs), msgs
    _ok("_refresh_account_balance: broker hang -> dry_run_balance fallback")


async def test_maintenance_heartbeat_emitted_each_tick():
    """One '[maintenance] heartbeat' line per maintenance iteration."""
    cfg = make_cfg()
    logger = _CapturingLogger()
    # No broker, no provider needed: maintenance loop only logs heartbeat
    # and runs balance/news/watchdog which all short-circuit (broker=None).
    orch = make_orch(cfg, broker=None, logger=logger)
    orch.max_iterations = 3

    await orch.run()

    heartbeats = [
        r for r in logger.system_handler.records
        if "[maintenance] heartbeat" in r.getMessage()
    ]
    assert len(heartbeats) >= 3, (
        f"expected >=3 heartbeats over 3 iterations, got {len(heartbeats)}: "
        f"{[r.getMessage() for r in heartbeats]}"
    )
    _ok(f"maintenance heartbeat: {len(heartbeats)} lines over 3 iterations")


async def test_maintenance_heartbeat_can_be_disabled():
    """Setting maintenance_heartbeat_enabled=False suppresses the line."""
    cfg = make_cfg()
    cfg.maintenance_heartbeat_enabled = False
    logger = _CapturingLogger()
    orch = make_orch(cfg, broker=None, logger=logger)
    orch.max_iterations = 2

    await orch.run()

    heartbeats = [
        r for r in logger.system_handler.records
        if "[maintenance] heartbeat" in r.getMessage()
    ]
    assert heartbeats == [], heartbeats
    _ok("maintenance heartbeat: disabled flag silences the line")


# ============================================================
# RUN
# ============================================================

async def _async_main() -> int:
    print("test_orchestrator_freeze_guard.py")
    await test_build_tech_returns_none_when_provider_hangs()
    await test_fetch_h4_returns_none_on_timeout()
    await test_refresh_balance_keeps_cached_on_timeout()
    await test_maintenance_heartbeat_emitted_each_tick()
    await test_maintenance_heartbeat_can_be_disabled()
    print("ALL TESTS PASSED")
    return 0


def main() -> int:
    return asyncio.run(_async_main())


if __name__ == "__main__":
    sys.exit(main())
