"""
Phase C1 tests: SIGTERM, DryRunBroker, broker connect retry, account balance.

14 tests:
  SIGTERM (3):
    1. signal handler sets stop_event -> graceful exit
    2. stop during iteration finishes the tick + saves
    3. KeyboardInterrupt path returns 130

  DryRunBroker (5):
    4. delegates get_account_balance to wrapped
    5. delegates positions_get to wrapped
    6. intercepts place_market_bracket (no-op + synthetic IDs)
    7. intercepts close_position (no-op success)
    8. intercepts modify_stop (no-op success)

  Connect retry (2):
    9. succeeds on 2nd attempt
    10. fails after max_attempts

  Account balance (3):
    11. PAPER uses dry_run_balance (no broker)
    12. LIVE uses broker.get_account_balance
    13. LIVE broker failure falls back to dry_run_balance

  Logging (1):
    14. dry-run write events written to brain_log

Run:
    cd ~/apex_v16
    python -m tests.test_orchestrator_phase_c1
"""

from __future__ import annotations

import asyncio
import sys
import tempfile
from pathlib import Path
from typing import Optional

import pandas as pd

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from broker.broker_base import BrokerBase, CancelResult, OrderResult, Position
from broker.dry_run_broker import DryRunBroker
from core.config import RuntimeConfig, RunMode, AccountKind
from core.contracts import EntryDecision
from main import _connect_with_retry
from orchestrator import Orchestrator
from persistence.state_store import SessionState, StateStore


def _ok(label: str) -> None:
    print(f"  ok  {label}")


# ============================================================
# FAKES
# ============================================================

class FakeAI:
    async def ask(self, prompt, temperature=0.2, max_tokens=None):
        from brain.ai_client import AIResponse
        return AIResponse(text=None, error_kind="unknown")
    async def ask_for_decision(self, prompt, max_tokens=600, where=None):
        from brain.ai_client import AIResponse
        return AIResponse(text=None, error_kind="unknown")


class JsonlSink:
    def __init__(self):
        self.events = []
    def write(self, event, **fields):
        self.events.append({"event": event, **fields})


class FakeLogger:
    def __init__(self):
        import logging
        self.brain_log = JsonlSink()
        self.session_log = JsonlSink()
        self.error_log = JsonlSink()
        self.system = logging.getLogger("test.c1")
    def log_session_event(self, event, **fields):
        self.session_log.write(event, **fields)
    def log_error(self, where, error, **extra):
        self.error_log.write("error", where=where, error=error, **extra)
    def log_trade_opened(self, **fields):
        self.brain_log.write("trade_opened", **fields)
    def log_trade_closed(self, **fields):
        self.brain_log.write("trade_closed", **fields)


class FakeProvider:
    async def get_bars(self, symbol, timeframe, n):
        return pd.DataFrame(columns=["open", "high", "low", "close", "volume"])


class _MemoryStore:
    def __init__(self):
        self.saves = 0
    def save(self, state):
        self.saves += 1


class _FullFakeBroker(BrokerBase):
    """Concrete BrokerBase with all abstract methods stubbed."""
    name = "FakeBroker"
    def __init__(self, balance: float = 50_000.0):
        self.calls: list[tuple[str, dict]] = []
        self._balance = balance
        self.fail_balance = False
        self.connected = False

    async def connect(self):
        self.calls.append(("connect", {}))
        self.connected = True
        return True
    async def disconnect(self):
        self.calls.append(("disconnect", {}))
        self.connected = False
    async def is_connected(self):
        return self.connected
    async def get_last_price(self, symbol):
        self.calls.append(("get_last_price", {"symbol": symbol}))
        return 5800.0
    async def positions_get(self, symbol=None):
        self.calls.append(("positions_get", {"symbol": symbol}))
        return []
    async def pending_orders(self, symbol=None):
        self.calls.append(("pending_orders", {"symbol": symbol}))
        return []
    async def recent_trades(self, symbol=None, since=None, limit=50):
        self.calls.append(("recent_trades", {
            "symbol": symbol, "since": since, "limit": limit,
        }))
        return []
    async def place_market_bracket(self, symbol, direction, contracts,
                                    sl_price, tp_price, **kw):
        self.calls.append(("place_market_bracket", {
            "symbol": symbol, "direction": direction, "contracts": contracts,
            "sl_price": sl_price, "tp_price": tp_price,
        }))
        return OrderResult(success=True, entry_price=5800.0,
                           sl_price=sl_price, tp_price=tp_price,
                           entry_id="REAL-E1", stop_id="REAL-S1", target_id="REAL-T1")
    async def place_stop_order(self, symbol, side, contracts, stop_price):
        self.calls.append(("place_stop_order", {
            "symbol": symbol, "side": side, "contracts": contracts,
            "stop_price": stop_price,
        }))
        return OrderResult(success=True, sl_price=stop_price,
                           stop_id="REAL-S2", entry_id="REAL-S2")
    async def place_limit_order(self, symbol, side, contracts, limit_price):
        self.calls.append(("place_limit_order", {
            "symbol": symbol, "side": side, "contracts": contracts,
            "limit_price": limit_price,
        }))
        return OrderResult(success=True, tp_price=limit_price,
                           target_id="REAL-T2", entry_id="REAL-T2")
    async def cancel_order(self, symbol, order_id):
        self.calls.append(("cancel_order", {"symbol": symbol, "order_id": order_id}))
        return CancelResult(success=True, order_id=str(order_id))
    async def cancel_all_for_symbol(self, symbol):
        self.calls.append(("cancel_all_for_symbol", {"symbol": symbol}))
        return 0
    async def close_position(self, symbol, contracts=None):
        self.calls.append(("close_position", {"symbol": symbol, "contracts": contracts}))
        return OrderResult(success=True)
    async def partial_close_via_opposite_order(
        self, symbol, direction, contracts_to_close, residual_contracts,
        new_sl_price, new_tp_price, old_stop_order_id, old_target_order_id,
    ):
        self.calls.append(("partial_close_via_opposite_order", {
            "symbol": symbol, "direction": direction,
            "contracts_to_close": contracts_to_close,
            "residual_contracts": residual_contracts,
            "new_sl_price": new_sl_price, "new_tp_price": new_tp_price,
            "old_stop": old_stop_order_id, "old_target": old_target_order_id,
        }))
        return OrderResult(
            success=True, entry_id="REAL-CLOSE", stop_id="REAL-S2", target_id="REAL-T2",
            sl_price=new_sl_price, tp_price=new_tp_price,
        )
    async def modify_stop(self, symbol, order_id, new_sl_price):
        self.calls.append(("modify_stop", {
            "symbol": symbol, "order_id": order_id, "new_sl_price": new_sl_price,
        }))
        return OrderResult(success=True, sl_price=new_sl_price)
    async def get_account_balance(self):
        self.calls.append(("get_account_balance", {}))
        if self.fail_balance:
            raise RuntimeError("balance fetch failed")
        return self._balance
    async def fetch_bars(self, symbol, timeframe, n):
        self.calls.append(("fetch_bars", {
            "symbol": symbol, "timeframe": timeframe, "n": n,
        }))
        return pd.DataFrame(columns=["open", "high", "low", "close", "volume"])


# ============================================================
# FIXTURES
# ============================================================

def make_config(**over) -> RuntimeConfig:
    cfg = RuntimeConfig(mode=RunMode.PAPER, account=AccountKind.INELIGIBLE)
    cfg.asset_filter = ["MES"]
    cfg.loop_sleep_seconds = 0
    cfg.scan_loop_phase_offset_seconds = 0.0
    cfg.manage_loop_interval_seconds = 0
    cfg.maintenance_loop_interval_seconds = 0
    for k, v in over.items():
        setattr(cfg, k, v)
    return cfg


from datetime import datetime as _dt, timezone as _tz
FIXED_DAY_UTC: _dt = _dt(2026, 4, 28, 14, 0, 0, tzinfo=_tz.utc)


def make_orch(*, broker=None, max_iterations=1, mode=RunMode.PAPER) -> tuple:
    cfg = make_config()
    cfg.mode = mode
    state = SessionState()
    logger = FakeLogger()
    orch = Orchestrator(
        config=cfg, ai_client=FakeAI(),
        market_data_provider=FakeProvider(),
        state=state, store=_MemoryStore(), logger=logger,
        brain_dispatch={}, broker=broker,
        now_utc_provider=lambda: FIXED_DAY_UTC,
        max_iterations=max_iterations,
    )
    return orch, logger


# ============================================================
# 1-3. SIGTERM
# ============================================================

def test_signal_handler_sets_stop_event():
    """orchestrator.stop() sets the internal _stop_event used by run()."""
    orch, _ = make_orch()
    assert not orch._stop_event.is_set()
    orch.stop()
    assert orch._stop_event.is_set()
    _ok("stop(): _stop_event set, run() will exit at next sleep boundary")


def test_stop_during_iteration_finishes_tick_then_saves():
    """run() finishes the current tick + persists state before returning."""
    async def runner():
        orch, _ = make_orch(max_iterations=100)   # would loop forever
        # Stop after 50ms; orch should still complete at least 1 iteration + save
        async def stopper():
            await asyncio.sleep(0.05)
            orch.stop()
        await asyncio.gather(orch.run(), stopper())
        return orch
    orch = asyncio.run(runner())
    assert orch._iteration >= 1, "must complete at least one iteration"
    assert orch.store.saves >= 1, "must save at least once"
    _ok("stop(): tick finishes + save executes before exit")


def test_keyboard_interrupt_path_returns_130():
    """Synthetic KeyboardInterrupt in a loop -> exit code 130."""
    async def runner():
        orch, _ = make_orch(max_iterations=100)
        # Patch _scan_loop to raise KeyboardInterrupt — gather catches it
        # and run() maps to 130.
        async def _ki_loop():
            raise KeyboardInterrupt
        orch._scan_loop = _ki_loop   # type: ignore[assignment]
        rc = await orch.run()
        return rc
    rc = asyncio.run(runner())
    assert rc == 130, f"KeyboardInterrupt path must return 130, got {rc}"
    _ok("KeyboardInterrupt: orchestrator returns exit code 130")


# ============================================================
# 4-8. DryRunBroker delegate vs intercept
# ============================================================

def test_dry_run_broker_delegates_get_account_balance():
    real = _FullFakeBroker(balance=42_500.0)
    dry = DryRunBroker(real)
    bal = asyncio.run(dry.get_account_balance())
    assert bal == 42_500.0
    assert any(c[0] == "get_account_balance" for c in real.calls), \
        "must reach the wrapped broker"
    _ok("DryRunBroker.get_account_balance: delegates to wrapped broker")


def test_dry_run_broker_delegates_positions_get():
    real = _FullFakeBroker()
    dry = DryRunBroker(real)
    asyncio.run(dry.positions_get("MES"))
    assert any(c[0] == "positions_get" for c in real.calls)
    _ok("DryRunBroker.positions_get: delegates to wrapped broker")


def test_dry_run_broker_intercepts_place_market_bracket():
    """place_market_bracket on dry returns synthetic IDs and DOES NOT call wrapped."""
    real = _FullFakeBroker()
    logger = FakeLogger()
    dry = DryRunBroker(real, logger=logger)
    res = asyncio.run(dry.place_market_bracket(
        symbol="MES", direction="BUY", contracts=2,
        sl_price=5790.0, tp_price=5810.0,
    ))
    assert res.success is True
    assert res.entry_id and res.entry_id.startswith("DRY-place-E-"), res.entry_id
    assert not any(c[0] == "place_market_bracket" for c in real.calls), \
        "wrapped broker must NOT be called for write ops"
    # The wrapped get_last_price IS called (best-effort enrich) -> that's read, allowed
    assert any(e["event"] == "dry_place_market_bracket" for e in logger.brain_log.events)
    _ok("DryRunBroker.place_market_bracket: intercepted, synthetic DRY- IDs, log emitted")


def test_dry_run_broker_intercepts_close_position():
    real = _FullFakeBroker()
    dry = DryRunBroker(real)
    res = asyncio.run(dry.close_position("MES"))
    assert res.success is True
    assert not any(c[0] == "close_position" for c in real.calls)
    _ok("DryRunBroker.close_position: intercepted (wrapped NOT called)")


def test_dry_run_broker_intercepts_modify_stop():
    real = _FullFakeBroker()
    dry = DryRunBroker(real)
    res = asyncio.run(dry.modify_stop("MES", "S1", new_sl_price=5795.0))
    assert res.success is True
    assert res.sl_price == 5795.0
    assert not any(c[0] == "modify_stop" for c in real.calls)
    _ok("DryRunBroker.modify_stop: intercepted, returns synthetic success")


# ============================================================
# 9-10. Connect retry
# ============================================================

def test_connect_with_retry_succeeds_on_2nd_attempt():
    """First attempt returns False, second True -> overall True."""
    class _Flaky:
        def __init__(self):
            self.attempts = 0
        async def connect(self):
            self.attempts += 1
            return self.attempts >= 2
    broker = _Flaky()
    logger = FakeLogger()
    ok = asyncio.run(_connect_with_retry(
        broker, logger=logger, max_attempts=3, delays_seconds=(0, 0, 0),
    ))
    assert ok is True
    assert broker.attempts == 2
    _ok("_connect_with_retry: succeeds on 2nd attempt")


def test_connect_with_retry_fails_after_max_attempts():
    class _Always:
        def __init__(self):
            self.attempts = 0
        async def connect(self):
            self.attempts += 1
            raise RuntimeError("network down")
    broker = _Always()
    logger = FakeLogger()
    ok = asyncio.run(_connect_with_retry(
        broker, logger=logger, max_attempts=3, delays_seconds=(0, 0, 0),
    ))
    assert ok is False
    assert broker.attempts == 3
    _ok("_connect_with_retry: fails after 3 attempts, returns False")


# ============================================================
# 11-13. Account balance resolution
# ============================================================

def test_resolve_account_balance_paper_uses_dry_run_balance():
    cfg = make_config(dry_run_balance=75_000.0)
    state = SessionState()
    orch = Orchestrator(
        config=cfg, ai_client=FakeAI(),
        market_data_provider=FakeProvider(),
        state=state, store=_MemoryStore(), logger=FakeLogger(),
        brain_dispatch={}, broker=None,
        max_iterations=1,
    )
    asyncio.run(orch._refresh_account_balance())
    assert orch._resolve_account_balance() == 75_000.0
    _ok("PAPER (no broker): _resolve_account_balance returns config.dry_run_balance")


def test_resolve_account_balance_live_uses_broker():
    cfg = make_config(dry_run_balance=100_000.0)
    cfg.mode = RunMode.LIVE
    state = SessionState()
    broker = _FullFakeBroker(balance=87_654.0)
    orch = Orchestrator(
        config=cfg, ai_client=FakeAI(),
        market_data_provider=FakeProvider(),
        state=state, store=_MemoryStore(), logger=FakeLogger(),
        brain_dispatch={}, broker=broker,
        max_iterations=1,
    )
    asyncio.run(orch._refresh_account_balance())
    assert orch._resolve_account_balance() == 87_654.0
    _ok("LIVE: _refresh -> broker.get_account_balance, _resolve uses real value")


def test_resolve_account_balance_live_failure_falls_back():
    cfg = make_config(dry_run_balance=100_000.0)
    cfg.mode = RunMode.LIVE
    state = SessionState()
    broker = _FullFakeBroker()
    broker.fail_balance = True
    orch = Orchestrator(
        config=cfg, ai_client=FakeAI(),
        market_data_provider=FakeProvider(),
        state=state, store=_MemoryStore(), logger=FakeLogger(),
        brain_dispatch={}, broker=broker,
        max_iterations=1,
    )
    asyncio.run(orch._refresh_account_balance())
    # First failure -> fallback to dry_run_balance
    assert orch._resolve_account_balance() == 100_000.0
    _ok("LIVE: broker balance fetch fails -> fallback to dry_run_balance")


# ============================================================
# 14. Logging: dry-run events
# ============================================================

def test_dry_run_logs_events_to_brain_log():
    real = _FullFakeBroker()
    logger = FakeLogger()
    dry = DryRunBroker(real, logger=logger)
    asyncio.run(dry.place_market_bracket("MES", "BUY", 1, 5790.0, 5810.0))
    asyncio.run(dry.modify_stop("MES", "S1", 5795.0))
    asyncio.run(dry.close_position("MES"))
    events = [e["event"] for e in logger.brain_log.events]
    for expected in (
        "dry_place_market_bracket",
        "dry_modify_stop",
        "dry_close_position",
    ):
        assert expected in events, f"missing event: {expected}"
    _ok("DryRunBroker: 3 write methods log structured events to brain_log")


# ============================================================
# MAIN
# ============================================================

def main() -> int:
    print("test_orchestrator_phase_c1.py")
    test_signal_handler_sets_stop_event()
    test_stop_during_iteration_finishes_tick_then_saves()
    test_keyboard_interrupt_path_returns_130()
    test_dry_run_broker_delegates_get_account_balance()
    test_dry_run_broker_delegates_positions_get()
    test_dry_run_broker_intercepts_place_market_bracket()
    test_dry_run_broker_intercepts_close_position()
    test_dry_run_broker_intercepts_modify_stop()
    test_connect_with_retry_succeeds_on_2nd_attempt()
    test_connect_with_retry_fails_after_max_attempts()
    test_resolve_account_balance_paper_uses_dry_run_balance()
    test_resolve_account_balance_live_uses_broker()
    test_resolve_account_balance_live_failure_falls_back()
    test_dry_run_logs_events_to_brain_log()
    print("ALL 14 TESTS PASSED")
    return 0


if __name__ == "__main__":
    sys.exit(main())
