"""
APEX V16 — State / broker reconciliation.

Root-cause solution for V15-BUG-9 (state-broker desync). Applies a
5-case decision matrix between `state.active_trades` and the broker's
view of the world (positions + pending orders), at startup and
periodically during the run.

Decision matrix
---------------
For each (state, broker) pair on a symbol:

  (i)   state OPEN  + broker OPEN  + sizes match
        -> OK, no action.

  (ii)  state OPEN  + broker FLAT
        -> position closed broker-side while bot was down. Recover
           realized P&L from `broker.recent_trades` if the trade matches
           by symbol + contracts + post-opened_at timestamp; classify
           win/loss and call risk_manager.update_daily_pnl + register_*.
           Either way, drop the trade from state.active_trades.

  (iii) state FLAT  + broker OPEN
        -> position exists broker-side that V16 didn't open. Conservative
           policy: LOG-ONLY (do NOT auto-adopt). Operator decides whether
           to flatten manually or restart with a state file that includes
           the trade. Auto-adopt would inject an unsafe TradeEntry with
           zero indicators-at-entry (V15-BUG-2/5 territory).

  (iv)  state OPEN  + broker OPEN  + sizes mismatch
        -> partial fill / partial close happened broker-side. Update
           runtime contracts and (best-effort) flag inconsistency.
           For V16 conservative path: LOG-ONLY warning; treat the trade
           as still open with state-recorded size.

  (v)   broker has pending SL/TP orders without a matching position
        ("naked orphan orders") -> cancel or LOG-ONLY (current policy:
        log-only, V15 _watchdog_naked_positions parity).

Why log-only for ambiguous cases:
  V16 is stricter than V15 on auto-adopt. V15 sometimes inferred
  positions from broker state; that masked the real bug (state save
  failures) and produced "phantom" entries with no indicators.
  V16 prefers explicit operator action over silent auto-recovery.

Watchdog cadence:
  Orchestrator calls `watchdog_naked_positions()` every N iterations
  (V15 parity: roughly every 30s). Cheaper than full reconcile_startup
  because it only checks for naked positions (case v), which is the
  only one that can newly emerge mid-run without a state mutation.
"""

from __future__ import annotations

import logging
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional

from broker.broker_base import BrokerBase, ClosedTrade, Order, Position
from core.contracts import BrainName, Direction

if TYPE_CHECKING:
    from persistence.state_store import SessionState
    from trading.risk_manager import RiskManager


log = logging.getLogger("reconciliation")


# ============================================================
# REPORTS — return values for orchestrator/main observability
# ============================================================

@dataclass
class ReconciliationReport:
    """Outcome of a reconcile_startup() call. Fields are counts by case."""
    case_i_ok: list[str] = field(default_factory=list)
    case_ii_state_open_broker_flat: list[str] = field(default_factory=list)
    case_ii_recovered_via_history: list[str] = field(default_factory=list)
    case_iii_state_flat_broker_open: list[str] = field(default_factory=list)
    case_iv_size_mismatch: list[str] = field(default_factory=list)
    case_v_naked_orders: list[str] = field(default_factory=list)
    pnl_recovered_usd: float = 0.0


@dataclass
class NakedPositionsReport:
    """Outcome of a watchdog_naked_positions() call."""
    naked_symbols: list[str] = field(default_factory=list)


# ============================================================
# RECONCILER
# ============================================================

class Reconciler:
    """
    Stateless coordinator that runs the 5-case matrix between
    `SessionState.active_trades` and the broker's open positions /
    pending orders / recent trade history.
    """

    def __init__(
        self,
        broker: BrokerBase,
        state: "SessionState",
        risk_manager: Optional["RiskManager"] = None,
        logger=None,
    ) -> None:
        """
        Args:
            broker:       a connected BrokerBase.
            state:        the live SessionState (mutated in case ii).
            risk_manager: required for case (ii) P&L recovery; if None,
                          case (ii) drops the trade without updating
                          daily counters (degraded mode).
            logger:       LoggerBundle. Reconciliation events go to brain_log.
        """
        self._broker = broker
        self._state = state
        self._risk = risk_manager
        self._logger = logger

    # ============================================================
    # PUBLIC: STARTUP
    # ============================================================

    async def reconcile_startup(
        self, *, post_reconnect: bool = False,
    ) -> ReconciliationReport:
        """
        Full reconciliation between state and broker. Run once at
        boot, after broker.connect, before orchestrator.run(). Also
        run after mid-loop reconnect (C2b), with post_reconnect=True
        so the events log distinguishes the two phases.

        Returns a ReconciliationReport. Mutates state.active_trades for
        case (ii) only (drop the trade after attempting P&L recovery).

        Idempotent: a second invocation finds no remaining case (ii)
        ghosts and is effectively a no-op for state mutation.
        """
        report = ReconciliationReport()
        self._post_reconnect = post_reconnect

        positions = await self._broker.positions_get()
        pending = await self._broker.pending_orders()

        # Index broker view by symbol for O(1) lookups. positions_get
        # returns user-facing symbol; pending_orders returns the raw
        # contractId, so we index pending differently and probe with
        # _SYMBOL_TAG_MAP-style matching.
        broker_pos_by_symbol: dict[str, Position] = {
            p.symbol: p for p in positions
        }

        state_symbols = set(self._state.active_trades.keys())
        broker_symbols = set(broker_pos_by_symbol.keys())

        # ----- case (i), (ii), (iv) on symbols in state ----------
        for symbol in list(state_symbols):
            at = self._state.active_trades[symbol]
            entry = at.entry

            if symbol in broker_pos_by_symbol:
                bp = broker_pos_by_symbol[symbol]
                if bp.contracts == entry.contracts:
                    report.case_i_ok.append(symbol)
                    self._log("recon_case_i_ok", symbol=symbol,
                              contracts=bp.contracts)
                else:
                    report.case_iv_size_mismatch.append(symbol)
                    self._log("recon_case_iv_size_mismatch",
                              symbol=symbol,
                              state_contracts=entry.contracts,
                              broker_contracts=bp.contracts)
            else:
                # Case (ii): state OPEN, broker FLAT.
                report.case_ii_state_open_broker_flat.append(symbol)
                recovered_pnl = await self._recover_case_ii(symbol)
                if recovered_pnl is not None:
                    report.case_ii_recovered_via_history.append(symbol)
                    report.pnl_recovered_usd += recovered_pnl
                # Drop the orphan from state regardless of recovery.
                self._state.active_trades.pop(symbol, None)

        # ----- case (iii) on symbols only on broker --------------
        for symbol in broker_symbols - state_symbols:
            bp = broker_pos_by_symbol[symbol]
            report.case_iii_state_flat_broker_open.append(symbol)
            self._log("recon_case_iii_state_flat_broker_open",
                      symbol=symbol,
                      broker_contracts=bp.contracts,
                      direction=bp.direction)

        # ----- case (v) naked orders -----------------------------
        report.case_v_naked_orders = self._find_naked_orders(positions, pending)
        for symbol in report.case_v_naked_orders:
            self._log("recon_case_v_naked_order", symbol=symbol)

        self._log("recon_startup_done",
                  post_reconnect=getattr(self, "_post_reconnect", False),
                  case_i=len(report.case_i_ok),
                  case_ii=len(report.case_ii_state_open_broker_flat),
                  case_ii_recovered=len(report.case_ii_recovered_via_history),
                  case_iii=len(report.case_iii_state_flat_broker_open),
                  case_iv=len(report.case_iv_size_mismatch),
                  case_v=len(report.case_v_naked_orders),
                  pnl_recovered_usd=report.pnl_recovered_usd)
        return report

    # ============================================================
    # PUBLIC: WATCHDOG
    # ============================================================

    async def watchdog_naked_positions(self) -> NakedPositionsReport:
        """
        Lightweight periodic check: scan broker positions + pending orders
        for naked positions (position without protective SL). V15 parity
        (_watchdog_naked_positions, riga 2035-2115).

        Currently LOG-ONLY: V15 chose not to auto-flatten because operator
        review is safer on naked positions (a missed SL is rare and usually
        a transient broker quirk). V16 calibration may decide to escalate.
        """
        positions = await self._broker.positions_get()
        if not positions:
            return NakedPositionsReport()
        pending = await self._broker.pending_orders()
        naked = self._find_naked_orders(positions, pending)
        for symbol in naked:
            self._log("watchdog_naked_position", symbol=symbol)
        return NakedPositionsReport(naked_symbols=list(naked))

    # ============================================================
    # internals
    # ============================================================

    async def _recover_case_ii(self, symbol: str) -> Optional[float]:
        """
        Try to fetch the close P&L from broker history and apply it to
        risk_manager counters. Returns the recovered pnl_usd if a match
        was found, None otherwise.

        Match heuristic (V15 _fetch_last_close_price parity):
          - symbol contractId tag matches
          - profitAndLoss != 0
          - closed_at >= entry.opened_at (allow 5min skew either way)
          - take the most recent matching trade
        """
        if self._risk is None:
            return None

        at = self._state.active_trades.get(symbol)
        if at is None:
            return None

        try:
            since_dt = at.entry.opened_at
            if since_dt.tzinfo is None:
                since_dt = since_dt.replace(tzinfo=timezone.utc)
            trades = await self._broker.recent_trades(
                symbol=symbol, since=since_dt, limit=50,
            )
        except Exception as e:
            log.warning(f"_recover_case_ii({symbol}) recent_trades raised: {e}")
            return None

        if not trades:
            self._log("recon_case_ii_no_history_match", symbol=symbol)
            return None

        # Already filtered by symbol + pnl != 0 + sorted desc by closed_at
        # in TopstepXBroker.recent_trades. Take the most recent.
        latest = trades[0]
        pnl = float(latest.pnl_usd)

        brain = at.entry.brain_name
        if brain not in (BrainName.TF.value, BrainName.MR.value):
            brain = BrainName.TF.value  # safe default; log records the actual
        is_win = pnl > 0

        self._risk.update_daily_pnl(pnl, is_win=is_win, brain=brain)
        if is_win:
            self._risk.register_tp_hit(symbol)
        else:
            self._risk.register_sl_hit(symbol)

        self._log("recon_case_ii_recovered",
                  symbol=symbol,
                  pnl_usd=pnl,
                  is_win=is_win,
                  brain=brain,
                  exit_price=latest.exit_price,
                  closed_at=latest.closed_at)
        return pnl

    @staticmethod
    def _find_naked_orders(
        positions: list[Position],
        pending: list[Order],
    ) -> list[str]:
        """
        V15 parity (_watchdog_naked_positions riga 2080-2105):
        a position is "naked" if it has no STOP order on its contractId.
        Missing TP is not flagged.
        """
        protection_per_contract: dict[str, set[str]] = defaultdict(set)
        for o in pending:
            protection_per_contract[o.symbol].add(o.kind)

        naked: list[str] = []
        for p in positions:
            # Position.symbol is the user-facing short symbol; pending
            # orders are keyed by contractId. We probe by tag containment.
            has_stop = any(
                _contract_matches_symbol(p.symbol, cid)
                and "STOP" in kinds
                for cid, kinds in protection_per_contract.items()
            )
            if not has_stop:
                naked.append(p.symbol)
        return naked

    def _log(self, event: str, **fields) -> None:
        if self._logger is not None:
            try:
                self._logger.brain_log.write(event, **fields)
                return
            except Exception:
                pass
        log.info("[recon] %s %s", event, fields)


# ============================================================
# helpers
# ============================================================

# Mirror of TopstepXBroker._SYMBOL_TAG_MAP. Duplicated rather than imported
# to keep reconciliation broker-agnostic.
_SYMBOL_TAG_MAP: dict[str, str] = {
    "6B": "BP6", "6E": "EU6", "6A": "A6", "6J": "JY6", "6C": "CA6",
    "MES": "ES",  "MNQ": "NQ", "MYM": "YM", "MGC": "MGC", "MCL": "MCL",
}


def _contract_matches_symbol(symbol: str, contract_id: str) -> bool:
    if not symbol or not contract_id:
        return False
    if contract_id == symbol:
        return True
    tag = _SYMBOL_TAG_MAP.get(symbol, symbol)
    return tag in contract_id
