"""
APEX V16 — Risk Manager.

Single source of truth for "can the bot open this specific trade right now?".

V15 had risk checks scattered across topstep_guards_status, correlation_block,
budget_pre_check, halt, _on_sl_hit, _check_daily_reset. V16 centralizes
all pre-trade gates here. Sizing-related budget pre-checks live in
trading/sizing.py and propagate via SizingDecision.skip; risk_manager
respects that result instead of re-deriving it.

Public API:
  - RiskManager(config, state, logger)
  - .check_entry(*, entry, sizing, symbol, active_trades, now_utc)
                                                    -> RiskCheckResult
  - .halt(reason)
  - .clear_halt(why)
  - .update_daily_pnl(delta_usd, *, is_win, brain)
  - .register_sl_hit(symbol, *, now_utc)            (cooldown + consecutive_sl)
  - .register_tp_hit(symbol)                         (resets consecutive_sl)

The RiskManager:
  - reads state.* and active_trades, never mutates active_trades.
  - mutates state.daily.*, state.brain.*, state.cooldown.*, state.halted.
  - does NOT call state_store.save() — orchestrator owns persistence
    (V15-BUG-9 discipline preserved; risk_manager mutates, orchestrator
    saves after each tick).

Halt semantics:
  - halt() sets state.halted=True. Persists until daily reset
    (load_state(auto_daily_reset=True)) or clear_halt().
  - profit_target REJECTS entries but does NOT halt — the bot stays
    alive to manage open positions (V15-parity, riga 666-668).
  - daily_loss_hard_stop AUTO-HALTS (V15-parity, riga 655-658).
  - daily_loss_soft_stop REJECTS entries but does NOT halt — the bot
    can resume opening if daily_pnl recovers above soft (V15-parity,
    riga 661-663).
"""

from __future__ import annotations

from dataclasses import asdict, dataclass
from datetime import datetime, time as dtime, timedelta, timezone
from typing import Optional

from core.contracts import EntryDecision, RiskRule, utc_now
from trading.sizing import SizingDecision

try:
    from core.config_futures import MAX_CONTRACTS_PER_TRADE
except ImportError:
    MAX_CONTRACTS_PER_TRADE = {}


# ============================================================
# CORRELATION GROUPS — V15-aligned (config_futures.py:432-443)
# ============================================================
# V15 bundles all FX into a single FX_MAJORS group because the dollar-index
# correlation makes them all sensitive to the same macro driver (USD strength).
# Splitting them by individual currency would be arbitrary.
#
# Module-level constant: pure lookup, not runtime config. ENABLE/DISABLE
# is gated by config.enable_correlation, not by group composition.

CORRELATION_GROUPS: dict[str, list[str]] = {
    "EQUITY_INDICES": ["MES", "MNQ", "MYM", "ES", "NQ", "YM"],
    "METALS":         ["MGC", "GC"],
    "ENERGY":         ["MCL", "CL"],
    "FX_MAJORS":      ["6E", "6B", "6A", "6J", "6C"],
}


# ============================================================
# RESULT TYPES (split core/audit, mirrors SizingDecision/SizingAudit)
# ============================================================

@dataclass(frozen=True)
class RiskAudit:
    """
    Diagnostic context emitted with every check_entry() result.

    Persisted to brain_log.jsonl for forensic review and calibration.
    Populated for both approvals and rejections so we can compute
    "how close was this approval to a gate" post-hoc.
    """
    daily_pnl: float
    daily_loss_hard_stop: float
    daily_loss_soft_stop: float
    daily_profit_target: float
    daily_remaining_budget: float        # |hard_stop - daily_pnl|
    pending_risk_usd: float              # SizingDecision.real_risk_usd
    risk_vs_budget_pct: float            # pending / remaining (1.0 if remaining=0)
    open_trades_count: int
    consecutive_sl_count: int
    cooldown_until: Optional[str]        # ISO datetime UTC, or None
    correlation_blocker: Optional[str]   # symbol of blocking open trade
    now_utc: str


@dataclass(frozen=True)
class RiskCheckResult:
    """
    Output of RiskManager.check_entry().
      approved: True iff trade may proceed.
      rule:     stable RiskRule code (machine-readable).
      reason:   human-readable explanation (logged).
      audit:    RiskAudit context.
    """
    approved: bool
    rule: str
    reason: str
    audit: RiskAudit

    def audit_dict(self) -> dict:
        return asdict(self.audit)


# ============================================================
# RISK MANAGER
# ============================================================

class RiskManager:
    """Centralized pre-trade risk gates + lifecycle hooks."""

    # Cooldown durations after stop-loss hits (V15 parity, riga 2128-2133)
    COOLDOWN_FIRST_SL_SECONDS      = 1800   # 30 min
    COOLDOWN_CONSECUTIVE_SL_SECONDS = 7200  # 2 h, triggers when prior_consecutive >= 2

    def __init__(self, config, state, logger=None) -> None:
        self.config = config
        self.state = state
        self.logger = logger

    # ============================================================
    # MAIN GATE
    # ============================================================

    def check_entry(
        self,
        *,
        entry: EntryDecision,
        sizing: SizingDecision,
        symbol: str,
        active_trades: dict,
        now_utc: Optional[datetime] = None,
    ) -> RiskCheckResult:
        """
        Evaluate all pre-trade gates. First failing gate wins.
        Order: cheap binary checks first, calendar before intraday,
        budget arithmetic last.
        """
        now_utc = now_utc or utc_now()
        audit = self._build_audit(
            symbol=symbol,
            sizing=sizing,
            active_trades=active_trades,
            now_utc=now_utc,
        )

        # 1. SIZING_SKIP — sizing already determined this is not viable
        if sizing.skip:
            return self._reject(
                RiskRule.SIZING_SKIP,
                f"sizing skip: {sizing.reason}",
                audit,
            )

        # 2. HALTED — kill switch
        if self.state.halted:
            return self._reject(
                RiskRule.HALTED,
                self.state.halt_reason or "halted",
                audit,
            )

        # 3. DAILY_LOSS_HARD_STOP_HIT — auto-halt + reject
        hard = self.config.daily_loss_hard_stop
        if self.state.daily.daily_pnl <= hard:
            self.state.daily.daily_loss_hard_stop_hit = True
            self.halt(
                f"Daily loss hard stop: ${self.state.daily.daily_pnl:.2f} <= ${hard:.2f}"
            )
            return self._reject(
                RiskRule.DAILY_LOSS_HARD_STOP_HIT,
                f"daily P&L ${self.state.daily.daily_pnl:.2f} <= hard ${hard:.2f}",
                audit,
            )

        # 4. DAILY_LOSS_SOFT_STOP_HIT — block, NO halt (V15 parity, riga 661)
        soft = self.config.daily_loss_soft_stop
        if self.state.daily.daily_pnl <= soft:
            return self._reject(
                RiskRule.DAILY_LOSS_SOFT_STOP_HIT,
                f"daily P&L ${self.state.daily.daily_pnl:.2f} <= soft ${soft:.2f}",
                audit,
            )

        # 5. DAILY_PROFIT_TARGET_REACHED — block, NO halt (V15 parity, riga 666)
        target = self.config.daily_profit_target
        if self.state.daily.daily_pnl >= target:
            self.state.daily.profit_target_hit = True
            return self._reject(
                RiskRule.DAILY_PROFIT_TARGET_REACHED,
                f"daily target ${self.state.daily.daily_pnl:+.0f} >= ${target:.0f}",
                audit,
            )

        # 6. MAX_OPEN_TRADES_REACHED — concurrent positions cap (V15 parity)
        if len(active_trades) >= self.config.max_open_trades_total:
            return self._reject(
                RiskRule.MAX_OPEN_TRADES_REACHED,
                f"{len(active_trades)}/{self.config.max_open_trades_total} open positions",
                audit,
            )

        # 7. MAX_DAILY_TRADES_REACHED — executions today (V16-new)
        if self.state.daily.executed_count >= self.config.max_daily_trades:
            return self._reject(
                RiskRule.MAX_DAILY_TRADES_REACHED,
                f"{self.state.daily.executed_count}/{self.config.max_daily_trades} trades today",
                audit,
            )

        # 8. LAST_FRIDAY_CUTOFF — calendar-driven cutoff first
        if self._is_last_friday_of_month(now_utc):
            cutoff_h = self.config.last_friday_of_month_cutoff_utc_hour
            if now_utc.hour >= cutoff_h:
                return self._reject(
                    RiskRule.LAST_FRIDAY_CUTOFF,
                    f"last Friday of month, after {cutoff_h:02d}:00 UTC",
                    audit,
                )

        # 9. FORCE_FLAT_TIME_REACHED — intraday cutoff
        ff = dtime(
            self.config.force_flat_utc_hour,
            self.config.force_flat_utc_minute,
            tzinfo=timezone.utc,
        )
        if now_utc.timetz() >= ff:
            return self._reject(
                RiskRule.FORCE_FLAT_TIME_REACHED,
                f"after {ff.strftime('%H:%M')} UTC force-flat cutoff",
                audit,
            )

        # 10. COOLDOWN_ACTIVE — post-SL anti-overtrading
        cooldown_until_iso = self.state.cooldown.cooldown_until.get(symbol)
        if cooldown_until_iso:
            try:
                cu = datetime.fromisoformat(cooldown_until_iso)
                if now_utc < cu:
                    return self._reject(
                        RiskRule.COOLDOWN_ACTIVE,
                        f"cooldown until {cooldown_until_iso}",
                        audit,
                    )
                # Expired -> clear silently
                del self.state.cooldown.cooldown_until[symbol]
            except (TypeError, ValueError):
                # Corrupted entry — drop
                del self.state.cooldown.cooldown_until[symbol]

        # 11. CORRELATION_BLOCKED — only when enabled
        if self.config.enable_correlation:
            blocker = self._correlation_blocker(symbol, active_trades)
            if blocker is not None:
                # Re-emit audit with blocker filled in
                audit = self._with_correlation(audit, blocker)
                return self._reject(
                    RiskRule.CORRELATION_BLOCKED,
                    f"{symbol} blocked: {blocker} already open (correlated)",
                    audit,
                )

        # 12. MAX_CONTRACTS_EXCEEDED — sanity check on sizing.contracts.
        # sizing.py already clamps to MAX_CONTRACTS_PER_TRADE[symbol]; this
        # is a fail-safe in case of a sizing regression.
        max_ct = MAX_CONTRACTS_PER_TRADE.get(
            symbol, MAX_CONTRACTS_PER_TRADE.get("default", 1)
        )
        if sizing.contracts > max_ct:
            return self._reject(
                RiskRule.MAX_CONTRACTS_EXCEEDED,
                f"{sizing.contracts} > cap {max_ct} for {symbol}",
                audit,
            )

        # 13. MAX_RISK_VS_DAILY_BUDGET_EXCEEDED — final budget gate
        cap_pct = self.config.max_risk_vs_daily_budget
        # remaining_budget could be 0 if daily_pnl == hard_stop exactly
        # (already caught at gate 3, but defensive guard for negative remaining)
        if audit.daily_remaining_budget > 0:
            allowed = audit.daily_remaining_budget * cap_pct
            if sizing.real_risk_usd > allowed:
                return self._reject(
                    RiskRule.MAX_RISK_VS_DAILY_BUDGET_EXCEEDED,
                    (
                        f"risk ${sizing.real_risk_usd:.2f} > "
                        f"{cap_pct*100:.0f}% of remaining ${audit.daily_remaining_budget:.2f} "
                        f"(allowed ${allowed:.2f})"
                    ),
                    audit,
                )

        return RiskCheckResult(
            approved=True,
            rule=RiskRule.OK.value,
            reason="",
            audit=audit,
        )

    # ============================================================
    # LIFECYCLE HOOKS — called by trade_closer / orchestrator
    # ============================================================

    def halt(self, reason: str) -> None:
        """Activate kill switch. No new entries until daily reset or clear_halt."""
        if self.state.halted:
            return
        self.state.halted = True
        self.state.halt_reason = reason
        if self.logger is not None:
            try:
                self.logger.log_session_event("halt", reason=reason)
                self.logger.system.warning(f"HALT: {reason}")
            except Exception:
                pass

    def clear_halt(self, why: str = "manual") -> None:
        """Clear halt flag. Use sparingly; daily reset handles the common case."""
        if not self.state.halted:
            return
        old = self.state.halt_reason
        self.state.halted = False
        self.state.halt_reason = ""
        if self.logger is not None:
            try:
                self.logger.log_session_event("halt_cleared", why=why, was=old)
            except Exception:
                pass

    def update_daily_pnl(
        self,
        delta_usd: float,
        *,
        is_win: bool,
        brain: str,
    ) -> None:
        """
        Apply a P&L delta and update brain win/loss counters in one call.
        Brain string is the stable BrainName value ("TF" or "MR").
        """
        self.state.daily.daily_pnl += delta_usd
        self.state.session_pnl += delta_usd
        if brain == "TF":
            if is_win:
                self.state.brain.tf_wins += 1
            else:
                self.state.brain.tf_losses += 1
        elif brain == "MR":
            if is_win:
                self.state.brain.mr_wins += 1
            else:
                self.state.brain.mr_losses += 1

    def register_sl_hit(
        self,
        symbol: str,
        *,
        now_utc: Optional[datetime] = None,
    ) -> None:
        """
        Anti-overtrading post-SL (V15 _on_sl_hit, riga 2117).
          - prior consecutive_sl >= 2 -> 2h cooldown (circuit breaker)
          - else                       -> 30min cooldown
        Increments state.cooldown.last_trade_close_at and consecutive_sl.
        """
        now_utc = now_utc or utc_now()
        prior = self.state.cooldown.last_trade_close_at  # type-shorthand
        # consecutive_sl tracked alongside last_trade_close_at via a sidecar
        # key in metadata? No — store explicit map on cooldown via a private
        # convention: we keep it inline in last_trade_close_at as ISO string,
        # and a separate counter map in state.metadata. For symmetry with
        # CooldownState shape we keep counter on metadata to avoid migrating
        # the schema again. See BACKLOG: promote to first-class CooldownState
        # field if usage stabilizes.
        counters = self.state.metadata.setdefault("consecutive_sl_count", {})
        prev = int(counters.get(symbol, 0))
        new_count = prev + 1
        counters[symbol] = new_count

        if prev >= 2:
            seconds = self.COOLDOWN_CONSECUTIVE_SL_SECONDS
            label = "2h"
        else:
            seconds = self.COOLDOWN_FIRST_SL_SECONDS
            label = "30min"

        cooldown_until = (now_utc + timedelta(seconds=seconds)).isoformat()
        self.state.cooldown.cooldown_until[symbol] = cooldown_until
        self.state.cooldown.last_trade_close_at[symbol] = now_utc.isoformat()

        if self.logger is not None:
            try:
                self.logger.system.info(
                    f"[SL] {symbol}: SL #{new_count} -> cooldown {label} until {cooldown_until}"
                )
            except Exception:
                pass

    def register_tp_hit(self, symbol: str) -> None:
        """Reset consecutive_sl counter for symbol. Cooldown not extended."""
        counters = self.state.metadata.setdefault("consecutive_sl_count", {})
        counters[symbol] = 0
        self.state.cooldown.last_trade_close_at[symbol] = utc_now().isoformat()
        # Do NOT touch cooldown_until — TP doesn't trigger a cooldown.
        # If a stale cooldown exists for this symbol we leave it as-is;
        # the gate evaluator clears expired entries naturally.

    # ============================================================
    # INTERNAL HELPERS
    # ============================================================

    def _consecutive_sl_for(self, symbol: str) -> int:
        return int(self.state.metadata.get("consecutive_sl_count", {}).get(symbol, 0))

    def _build_audit(
        self,
        *,
        symbol: str,
        sizing: SizingDecision,
        active_trades: dict,
        now_utc: datetime,
    ) -> RiskAudit:
        hard = self.config.daily_loss_hard_stop
        remaining = abs(hard - self.state.daily.daily_pnl)
        risk_vs_budget = (
            sizing.real_risk_usd / remaining if remaining > 0 else 1.0
        )
        cooldown_until = self.state.cooldown.cooldown_until.get(symbol)
        return RiskAudit(
            daily_pnl=self.state.daily.daily_pnl,
            daily_loss_hard_stop=hard,
            daily_loss_soft_stop=self.config.daily_loss_soft_stop,
            daily_profit_target=self.config.daily_profit_target,
            daily_remaining_budget=remaining,
            pending_risk_usd=sizing.real_risk_usd,
            risk_vs_budget_pct=risk_vs_budget,
            open_trades_count=len(active_trades),
            consecutive_sl_count=self._consecutive_sl_for(symbol),
            cooldown_until=cooldown_until,
            correlation_blocker=None,
            now_utc=now_utc.isoformat(),
        )

    @staticmethod
    def _with_correlation(audit: RiskAudit, blocker: str) -> RiskAudit:
        """Return a new RiskAudit with correlation_blocker set."""
        d = asdict(audit)
        d["correlation_blocker"] = blocker
        return RiskAudit(**d)

    def _reject(
        self,
        rule: RiskRule,
        reason: str,
        audit: RiskAudit,
    ) -> RiskCheckResult:
        return RiskCheckResult(
            approved=False,
            rule=rule.value,
            reason=reason,
            audit=audit,
        )

    @staticmethod
    def _correlation_blocker(
        symbol: str,
        active_trades: dict,
    ) -> Optional[str]:
        """
        Returns the symbol of a correlated open trade that blocks `symbol`,
        or None if no group-level blocker applies.
        """
        for members in CORRELATION_GROUPS.values():
            if symbol not in members:
                continue
            for other in members:
                if other != symbol and other in active_trades:
                    return other
        return None

    @staticmethod
    def _is_last_friday_of_month(now_utc: datetime) -> bool:
        """True iff `now_utc` falls on the last Friday of its month."""
        if now_utc.weekday() != 4:   # 4 = Friday
            return False
        return (now_utc + timedelta(days=7)).month != now_utc.month
