"""
APEX V16/V18 — Dashboard state writer.

Atomic JSON snapshot emitted every maintenance tick to feed an
external dashboard / TUI / web view. The file is overwritten in place
(no rotation, no history): a single consumer reads the latest tick.

Design constraints:
  - Fail-open: a write error MUST NOT break the orchestrator loop. All
    public entry points are wrapped; on failure we log a warning via
    orchestrator._sys_log and return.
  - Atomic: write to a temp file in the same directory, fsync, rename.
    A reader can mmap/open the target file without ever seeing partial
    JSON.
  - Read-only on orchestrator state: the writer never mutates anything
    on the orchestrator or on the SessionState; it only reads fields.

Snapshot layout (top-level keys):
  version, generated_at, mode
  session    : balance / equity / unrealized / session_pnl / halts / iter
  daily      : DailyCounters serialised (date + counters)
  brain_stats: MR / TF wins-losses-trades-win_rate + bias_calls_count
  asset_stats: per-symbol trades/wins/losses/pnl/win_rate derived from
               balance_confirmed + partial_close events in brain_log.jsonl
  asset_stats_today: same shape as asset_stats but filtered to events
               on the active session day (daily.date prefix on ts).
  trade_history: chronological list of realized P&L points (ts/symbol/
                pnl_usd/cumulative_pnl/kind). Each balance_confirmed
                yields one point with pnl = balance_post -
                prev_balance_post; each partial_close yields one point
                with pnl = partial_pnl_usd. The first balance_confirmed
                event is the baseline and is not rendered as a point.
  trade_history_daily: trade_history rolled up per UTC calendar day —
                one entry per day with pnl_day + cumulative_pnl (gross
                and net), in ascending date order. Feeds the dashboard
                equity curve (one chart point per day).
  positions  : list of open trades (entry+runtime fields per symbol)
  events     : last N events merged from all JSONL streams (desc by ts)
  feed       : per-stream tail (candle/brain/order/risk/sys)

JSONL stream -> dashboard stream mapping:
  regime_log.jsonl   -> "candle"
  brain_log.jsonl    -> "brain"
  trade_log.jsonl    -> "order"
  session_log.jsonl  -> "risk"  (risk_manager logs rejections here)
  error_log.jsonl    -> "sys"
"""

from __future__ import annotations

import json
import os
import tempfile
from collections import deque
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional


# Round-trip commission per contract, USD. Used to derive net P&L from
# the gross deltas surfaced by balance_confirmed events. Fallback of 1
# contract is the safer minimum for legacy events that predate the
# `contracts` field (V18 emitters now always supply it).
COMMISSION_RT: dict[str, float] = {
    "6A": 6.48, "6B": 6.48, "6C": 6.48,
    "6E": 6.48, "6J": 6.48,
    "MGC": 2.48, "MCL": 2.08,
    "MES": 1.48, "MNQ": 1.48, "MYM": 1.48,
}
_DEFAULT_CONTRACTS = 1


def _parse_int(value) -> Optional[int]:
    if value is None:
        return None
    try:
        return int(value)
    except (TypeError, ValueError):
        return None


def _commission_for(symbol: str, contracts: Optional[int]) -> float:
    rate = COMMISSION_RT.get(symbol)
    if rate is None:
        return 0.0
    n = int(contracts) if contracts else _DEFAULT_CONTRACTS
    if n <= 0:
        n = _DEFAULT_CONTRACTS
    return rate * n


# ============================================================
# Helpers
# ============================================================

def _utc_iso() -> str:
    return datetime.now(timezone.utc).isoformat()


def _tail_jsonl(path: Path, n: int) -> list[dict]:
    """Return up to `n` most-recent valid JSON records from a JSONL file.

    Best-effort: missing file, OS errors, and malformed lines are
    silently skipped. Newest line ends up last in the returned list
    (file order preserved).
    """
    if n <= 0 or not path.exists():
        return []
    try:
        with path.open("r", encoding="utf-8") as f:
            lines = deque(f, maxlen=n)
    except OSError:
        return []
    out: list[dict] = []
    for line in lines:
        line = line.strip()
        if not line:
            continue
        try:
            rec = json.loads(line)
        except json.JSONDecodeError:
            continue
        if isinstance(rec, dict):
            out.append(rec)
    return out


# ============================================================
# DashboardWriter
# ============================================================

class DashboardWriter:
    """Atomically writes a dashboard JSON snapshot per maintenance tick."""

    # Default output: ~/apex_v16/dashboard_state.json
    DEFAULT_OUTPUT = Path("~/apex_v16/dashboard_state.json")

    # Per-stream JSONL file under config.log_dir
    _FEED_FILES: dict[str, str] = {
        "candle": "regime_log.jsonl",
        "brain":  "brain_log.jsonl",
        "order":  "trade_log.jsonl",
        "risk":   "session_log.jsonl",
        "sys":    "error_log.jsonl",
    }

    def __init__(
        self,
        output_path: Optional[Path | str] = None,
        *,
        recent_events_limit: int = 50,
        feed_limit: int = 20,
    ) -> None:
        path = Path(output_path) if output_path is not None else self.DEFAULT_OUTPUT
        self.output_path = path.expanduser()
        self.recent_events_limit = int(recent_events_limit)
        self.feed_limit = int(feed_limit)

    # ------------------------------------------------------------
    # PUBLIC
    # ------------------------------------------------------------

    def write(self, orchestrator) -> None:
        """Build + atomically replace the snapshot file. Fail-open."""
        try:
            snapshot = self._build(orchestrator)
        except Exception as e:
            self._warn(orchestrator, f"build failed: {e}")
            return
        try:
            self._atomic_write(snapshot)
        except Exception as e:
            self._warn(orchestrator, f"write failed: {e}")

    # ------------------------------------------------------------
    # BUILD
    # ------------------------------------------------------------

    def _build(self, orch) -> dict:
        state = getattr(orch, "state", None)
        config = getattr(orch, "config", None)
        log_dir = self._log_dir(config)

        balance = float(getattr(orch, "_cached_account_balance", 0.0) or 0.0)
        session_pnl = float(getattr(state, "session_pnl", 0.0) or 0.0)
        daily = getattr(state, "daily", None)
        # Daily P&L source ranking:
        #   1. Broker-reported RP&L for the CME session day, cached on the
        #      orchestrator by _refresh_daily_rpnl. Matches what TopstepX
        #      shows in its UI (sum of Trade.profitAndLoss since 17:00 CT).
        #   2. In-memory daily counter (state.daily.daily_pnl), persisted
        #      across restarts. Used in PAPER mode and as a fallback if
        #      the broker call hasn't succeeded yet on this run.
        counter_pnl = float(getattr(daily, "daily_pnl", 0.0) or 0.0)
        broker_rpnl = getattr(orch, "_cached_daily_rpnl", None)
        if broker_rpnl is not None:
            daily_pnl = float(broker_rpnl)
            daily_pnl_source = "broker_rpnl"
        else:
            daily_pnl = counter_pnl
            daily_pnl_source = "counter"
        unrealized = self._sum_unrealized(state)
        # Equity = realised (already in cached balance on LIVE/DRY) + open P&L.
        equity = balance + unrealized

        daily_date = str(getattr(daily, "date", "") or "")
        closed = self._closed_trades_section(log_dir, daily_date)

        return {
            "version": 1,
            "generated_at": _utc_iso(),
            "mode": self._mode(config),
            "session": {
                "balance": round(balance, 2),
                "equity": round(equity, 2),
                "unrealized_pnl_usd": round(unrealized, 2),
                "session_pnl": round(session_pnl, 2),
                "daily_pnl": round(daily_pnl, 2),
                "daily_pnl_source": daily_pnl_source,
                "daily_pnl_counter": round(counter_pnl, 2),
                "daily_commissions": closed["daily_commissions"],
                "daily_pnl_net": round(daily_pnl - closed["daily_commissions"], 2),
                "halted": bool(getattr(state, "halted", False)),
                "halt_reason": getattr(state, "halt_reason", "") or "",
                "broker_degraded": bool(getattr(orch, "_broker_degraded", False)),
                "degraded_since": self._dt_iso(getattr(orch, "_degraded_since", None)),
                "iteration": int(getattr(orch, "_iteration", 0) or 0),
                "maintenance_iter": int(getattr(orch, "_maintenance_iter", 0) or 0),
            },
            "daily": self._daily_dict(daily),
            "brain_stats": self._brain_stats(getattr(state, "brain", None)),
            "asset_stats": closed["asset_stats"],
            "asset_stats_today": closed["asset_stats_today"],
            "trade_history": closed["trade_history"],
            "trade_history_daily": closed["trade_history_daily"],
            "positions": self._positions(state),
            "events": self._recent_events(log_dir),
            "feed": self._raw_feed(log_dir),
        }

    # ------------------------------------------------------------
    # BUILD helpers
    # ------------------------------------------------------------

    @staticmethod
    def _mode(config) -> str:
        mode = getattr(config, "mode", None)
        if mode is None:
            return ""
        return getattr(mode, "value", str(mode))

    @staticmethod
    def _log_dir(config) -> Optional[Path]:
        log_dir = getattr(config, "log_dir", None)
        return Path(log_dir) if log_dir else None

    @staticmethod
    def _dt_iso(value) -> Optional[str]:
        if value is None:
            return None
        if isinstance(value, datetime):
            return value.isoformat()
        return str(value)

    @staticmethod
    def _sum_unrealized(state) -> float:
        if state is None:
            return 0.0
        total = 0.0
        for at in getattr(state, "active_trades", {}).values():
            rt = getattr(at, "runtime", None)
            if rt is None:
                continue
            total += float(getattr(rt, "net_profit_usd", 0.0) or 0.0)
        return total

    @staticmethod
    def _daily_dict(daily) -> dict:
        if daily is None:
            return {}
        return {
            "date": getattr(daily, "date", ""),
            "daily_pnl": float(getattr(daily, "daily_pnl", 0.0) or 0.0),
            "approved_count": int(getattr(daily, "approved_count", 0) or 0),
            "executed_count": int(getattr(daily, "executed_count", 0) or 0),
            "rejected_count": int(getattr(daily, "rejected_count", 0) or 0),
            "daily_loss_hard_stop_hit":
                bool(getattr(daily, "daily_loss_hard_stop_hit", False)),
            "profit_target_hit":
                bool(getattr(daily, "profit_target_hit", False)),
        }

    @staticmethod
    def _brain_stats(brain) -> dict:
        def _rate(w: int, l: int) -> Optional[float]:
            tot = w + l
            return round(w / tot, 4) if tot > 0 else None

        if brain is None:
            return {
                "mr": {"wins": 0, "losses": 0, "trades": 0, "win_rate": None},
                "tf": {"wins": 0, "losses": 0, "trades": 0, "win_rate": None},
                "bias_calls_count": 0,
            }
        mr_w = int(getattr(brain, "mr_wins", 0) or 0)
        mr_l = int(getattr(brain, "mr_losses", 0) or 0)
        tf_w = int(getattr(brain, "tf_wins", 0) or 0)
        tf_l = int(getattr(brain, "tf_losses", 0) or 0)
        return {
            "mr": {
                "wins": mr_w, "losses": mr_l,
                "trades": mr_w + mr_l, "win_rate": _rate(mr_w, mr_l),
            },
            "tf": {
                "wins": tf_w, "losses": tf_l,
                "trades": tf_w + tf_l, "win_rate": _rate(tf_w, tf_l),
            },
            "bias_calls_count": int(getattr(brain, "bias_calls_count", 0) or 0),
        }

    @classmethod
    def _closed_trades_section(
        cls, log_dir: Optional[Path], daily_date: str,
    ) -> dict:
        """Single-pass source of closed-trade stats from brain_log.jsonl.

        trade_log.jsonl is never written in V18 — realized P&L surfaces
        as two events in brain_log.jsonl:
          * `balance_confirmed`: emitted by _handle_exit and
            _check_external_close after a FULL close (broker balance
            snapshot, pnl reconstructed from consecutive deltas).
          * `partial_close`: emitted by _handle_partial_close on a 50%
            partial fill — broker balance is NOT snapshotted here, so
            the row carries `partial_pnl_usd` directly. Walking both
            event types in ts order, we keep a single rolling
            `prev_balance` cursor and treat each partial pnl as if it
            updated the cursor by the same amount the broker did
            (gross delta).

        We read the file once (it's large; ~26 MB / 90 k lines in
        production) and pass the parsed event list to asset_stats,
        asset_stats_today and trade_history.

        `daily_date` is the active CME session day (YYYY-MM-DD) — used
        both to filter asset_stats_today and to sum daily_commissions.
        """
        events = cls._realized_events(log_dir)
        history = cls._trade_history(events)
        daily_comm = 0.0
        if daily_date:
            for row in history:
                if str(row.get("ts", "")).startswith(daily_date):
                    daily_comm += float(row.get("commission_usd", 0.0) or 0.0)
        return {
            "asset_stats": cls._asset_stats(events),
            "asset_stats_today": cls._asset_stats(events, ts_prefix=daily_date),
            "trade_history": history,
            "trade_history_daily": cls._trade_history_daily(history),
            "daily_commissions": round(daily_comm, 2),
        }

    @staticmethod
    def _realized_events(log_dir: Optional[Path]) -> list[dict]:
        """Parse `balance_confirmed` + `partial_close` events.

        Returns a single list sorted ascending by ts, with a normalized
        schema:
          {ts, symbol, kind, balance_post, partial_pnl_usd, contracts}
        where `kind` is "full" for balance_confirmed (balance_post set,
        partial_pnl_usd None) and "partial" for partial_close
        (balance_post None, partial_pnl_usd set). Fail-open: missing
        file / OS errors → []; malformed lines and events with no
        numeric pnl source are skipped.
        """
        if log_dir is None:
            return []
        path = log_dir / "brain_log.jsonl"
        if not path.exists():
            return []
        events: list[dict] = []
        try:
            with path.open("r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    try:
                        rec = json.loads(line)
                    except json.JSONDecodeError:
                        continue
                    if not isinstance(rec, dict):
                        continue
                    ev_name = rec.get("event")
                    ts = rec.get("ts")
                    if not ts:
                        continue
                    symbol = rec.get("symbol", "") or ""

                    if ev_name == "balance_confirmed":
                        bp = rec.get("balance_post")
                        if bp is None:
                            continue
                        try:
                            bp_f = float(bp)
                        except (TypeError, ValueError):
                            continue
                        events.append({
                            "ts": ts,
                            "symbol": symbol,
                            "kind": "full",
                            "balance_post": bp_f,
                            "partial_pnl_usd": None,
                            "contracts": _parse_int(rec.get("contracts")),
                        })
                    elif ev_name == "partial_close":
                        pp = rec.get("partial_pnl_usd")
                        if pp is None:
                            continue
                        try:
                            pp_f = float(pp)
                        except (TypeError, ValueError):
                            continue
                        events.append({
                            "ts": ts,
                            "symbol": symbol,
                            "kind": "partial",
                            "balance_post": None,
                            "partial_pnl_usd": pp_f,
                            "contracts": _parse_int(rec.get("contracts_closed")),
                        })
        except OSError:
            return []
        events.sort(key=lambda r: r["ts"])
        return events

    @staticmethod
    def _asset_stats(events: list[dict], *, ts_prefix: str = "") -> dict:
        """Per-symbol aggregates from realized events.

        The first `full` event seeds the rolling `prev_balance` cursor
        and is not counted (no prior balance to diff against). Each
        subsequent full event yields one trade with pnl = balance_post
        - prev_balance; each partial event yields one trade with pnl =
        partial_pnl_usd (and advances prev_balance by the same amount
        so the next full delta doesn't double-count the realized half).

        `ts_prefix` filters which events count toward stats (used for
        the per-session-day view). Events before the prefix still
        advance the rolling cursor — they just don't accumulate.
        Commission is estimated from COMMISSION_RT and the closed
        contracts count on the event.
        """
        stats: dict[str, dict] = {}
        prev: Optional[float] = None
        for ev in events:
            symbol = ev["symbol"]
            kind = ev["kind"]
            if kind == "full":
                if prev is None:
                    prev = ev["balance_post"]
                    continue
                pnl = ev["balance_post"] - prev
                prev = ev["balance_post"]
            else:  # partial
                if prev is None:
                    # No baseline yet — skip; we can't anchor cumulative
                    # accounting without a prior full event.
                    continue
                pnl = ev["partial_pnl_usd"]
                prev += pnl
            if ts_prefix and not str(ev["ts"]).startswith(ts_prefix):
                continue
            if not symbol:
                continue
            bucket = stats.setdefault(
                symbol,
                {"trades": 0, "wins": 0, "losses": 0,
                 "pnl": 0.0, "commission_usd": 0.0},
            )
            bucket["trades"] += 1
            if pnl > 0:
                bucket["wins"] += 1
            elif pnl < 0:
                bucket["losses"] += 1
            bucket["pnl"] += pnl
            bucket["commission_usd"] += _commission_for(symbol, ev.get("contracts"))
        for bucket in stats.values():
            trades = bucket["trades"]
            wins = bucket["wins"]
            bucket["pnl"] = round(bucket["pnl"], 2)
            bucket["commission_usd"] = round(bucket["commission_usd"], 2)
            bucket["pnl_net"] = round(bucket["pnl"] - bucket["commission_usd"], 2)
            bucket["win_rate"] = (
                round(wins / trades, 4) if trades > 0 else None
            )
        return stats

    @staticmethod
    def _trade_history(events: list[dict]) -> list[dict]:
        """Equity-curve points from realized events.

        The first `full` event seeds the rolling balance cursor and is
        not emitted as a point. Each subsequent full event becomes one
        row with pnl_usd = balance_post - prev; each partial event
        becomes one row with pnl_usd = partial_pnl_usd. The cursor is
        advanced by pnl on both kinds, so cumulative_pnl matches the
        sum of all rows.
        """
        rows: list[dict] = []
        prev: Optional[float] = None
        cumulative = 0.0
        cumulative_net = 0.0
        for ev in events:
            kind = ev["kind"]
            if kind == "full":
                if prev is None:
                    prev = ev["balance_post"]
                    continue
                pnl = ev["balance_post"] - prev
                prev = ev["balance_post"]
            else:  # partial
                if prev is None:
                    continue
                pnl = ev["partial_pnl_usd"]
                prev += pnl
            cumulative += pnl
            symbol = ev["symbol"]
            commission = _commission_for(symbol, ev.get("contracts"))
            pnl_net = pnl - commission
            cumulative_net += pnl_net
            rows.append({
                "ts": ev["ts"],
                "symbol": symbol,
                "kind": kind,
                "pnl_usd": round(pnl, 2),
                "cumulative_pnl": round(cumulative, 2),
                "commission_usd": round(commission, 2),
                "pnl_net": round(pnl_net, 2),
                "cumulative_pnl_net": round(cumulative_net, 2),
            })
        return rows

    @staticmethod
    def _trade_history_daily(rows: list[dict]) -> list[dict]:
        """Roll up trade_history rows into one entry per UTC calendar day.

        ts is ISO-8601 with UTC offset, so the first 10 chars are the
        YYYY-MM-DD date in UTC. The chart consumes this directly: one
        x-axis tick per day, no per-trade clutter.
        """
        buckets: dict[str, dict] = {}
        for row in rows:
            ts = str(row.get("ts", ""))
            date = ts[:10]
            if len(date) != 10:
                continue
            b = buckets.setdefault(date, {"pnl_day": 0.0, "pnl_day_net": 0.0})
            b["pnl_day"] += float(row.get("pnl_usd", 0.0) or 0.0)
            b["pnl_day_net"] += float(row.get("pnl_net", 0.0) or 0.0)
        out: list[dict] = []
        cumulative = 0.0
        cumulative_net = 0.0
        for date in sorted(buckets):
            b = buckets[date]
            cumulative += b["pnl_day"]
            cumulative_net += b["pnl_day_net"]
            out.append({
                "date": date,
                "pnl_day": round(b["pnl_day"], 2),
                "cumulative_pnl": round(cumulative, 2),
                "pnl_day_net": round(b["pnl_day_net"], 2),
                "cumulative_pnl_net": round(cumulative_net, 2),
            })
        return out

    @staticmethod
    def _positions(state) -> list[dict]:
        if state is None:
            return []
        out: list[dict] = []
        for symbol, at in getattr(state, "active_trades", {}).items():
            entry = getattr(at, "entry", None)
            runtime = getattr(at, "runtime", None)
            if entry is None or runtime is None:
                continue
            opened_at = getattr(entry, "opened_at", None)
            current_sl = float(getattr(runtime, "current_sl_price", 0.0) or 0.0)
            sl_at_entry = float(getattr(entry, "sl_price", 0.0) or 0.0)
            effective_sl = current_sl if current_sl > 0 else sl_at_entry
            out.append({
                "symbol": symbol,
                "brain": getattr(entry, "brain_name", ""),
                "direction": getattr(entry, "direction", ""),
                "contracts": int(getattr(entry, "contracts", 0) or 0),
                "entry_price": float(getattr(entry, "entry_price", 0.0) or 0.0),
                "sl_price": sl_at_entry,
                "current_sl_price": effective_sl,
                "tp_price": float(getattr(entry, "tp_price", 0.0) or 0.0),
                "opened_at": (
                    opened_at.isoformat()
                    if isinstance(opened_at, datetime) else
                    (str(opened_at) if opened_at is not None else "")
                ),
                "minutes_open": float(getattr(runtime, "minutes_open", 0.0) or 0.0),
                "progress_pct": float(getattr(runtime, "progress_pct", 0.0) or 0.0),
                "net_profit_usd":
                    float(getattr(runtime, "net_profit_usd", 0.0) or 0.0),
                "partial_done": bool(getattr(runtime, "partial_done", False)),
                "partial_pnl_usd":
                    float(getattr(runtime, "partial_pnl_usd", 0.0) or 0.0),
                "last_brain_action":
                    getattr(runtime, "last_brain_action", "") or "",
                "last_brain_reason":
                    getattr(runtime, "last_brain_reason", "") or "",
                "is_paper": bool(getattr(entry, "is_paper", False)),
                "confidence_at_entry":
                    int(getattr(entry, "confidence_at_entry", 0) or 0),
            })
        return out

    def _recent_events(self, log_dir: Optional[Path]) -> list[dict]:
        if log_dir is None or self.recent_events_limit <= 0:
            return []
        merged: list[dict] = []
        for stream, fname in self._FEED_FILES.items():
            for rec in _tail_jsonl(log_dir / fname, self.recent_events_limit):
                tagged = dict(rec)
                tagged.setdefault("stream", stream)
                merged.append(tagged)
        # ts is ISO-8601 UTC -> lexicographic sort == chronological sort
        merged.sort(key=lambda r: r.get("ts", ""), reverse=True)
        return merged[: self.recent_events_limit]

    def _raw_feed(self, log_dir: Optional[Path]) -> dict:
        feed: dict[str, list[dict]] = {k: [] for k in self._FEED_FILES}
        if log_dir is None or self.feed_limit <= 0:
            return feed
        for stream, fname in self._FEED_FILES.items():
            # Newest-first for the dashboard.
            feed[stream] = list(reversed(
                _tail_jsonl(log_dir / fname, self.feed_limit)
            ))
        return feed

    # ------------------------------------------------------------
    # ATOMIC WRITE
    # ------------------------------------------------------------

    def _atomic_write(self, snapshot: dict) -> None:
        self.output_path.parent.mkdir(parents=True, exist_ok=True)
        fd, tmp_path = tempfile.mkstemp(
            prefix=self.output_path.stem + ".",
            suffix=".tmp",
            dir=str(self.output_path.parent),
        )
        try:
            with open(fd, "w", encoding="utf-8") as f:
                json.dump(snapshot, f, ensure_ascii=False, default=str)
                f.flush()
                try:
                    os.fsync(f.fileno())
                except OSError:
                    pass
            Path(tmp_path).replace(self.output_path)
        except Exception:
            try:
                Path(tmp_path).unlink(missing_ok=True)
            except Exception:
                pass
            raise

    # ------------------------------------------------------------
    # LOG (best-effort)
    # ------------------------------------------------------------

    @staticmethod
    def _warn(orch, msg: str) -> None:
        try:
            orch._sys_log.warning("[dashboard] %s", msg)
        except Exception:
            pass
