"""
APEX V16 — State persistence.

Single point of truth for saving/loading bot state.
Replaces V15's opaque v15_state.json with:
  - Schema versioning (explicit migration on schema change)
  - Pre-save backup (.prev.json kept always)
  - File scoped per (mode, account) — never mix paper with live
  - Atomic write (write temp + fsync + rename + dir-fsync)
  - load_state(store, mode) helper with FRESH timestamped archive

Schema v2 (current) — sub-dataclasses, replaces v1 flat layout:
  - active_trades: dict[symbol -> ActiveTrade(entry, runtime)]
  - daily:        DailyCounters (per-day P&L, halt flags, counts; UTC date)
  - brain:        BrainCounters (per-Brain wins/losses + bias_calls_count)
  - bias_cache:   BiasCache (per-symbol BiasEntry frozen snapshots)
  - cooldown:     CooldownState (per-symbol last_close + cooldown_until)
  - session_pnl, halted, halt_reason, metadata at top level
"""

from __future__ import annotations

import json
import os
import shutil
import tempfile
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Optional

from core.contracts import TradeEntry, TradeRuntime, utc_now


SCHEMA_VERSION = 3


# ============================================================
# SUB-DATACLASSES
# ============================================================

@dataclass
class ActiveTrade:
    """Pair of TradeEntry (immutable) and TradeRuntime (mutable)."""
    entry: TradeEntry
    runtime: TradeRuntime


@dataclass
class DailyCounters:
    """
    Per-day counters. Reset at UTC midnight (auto when load_state runs
    with auto_daily_reset=True and state.daily.date != today UTC).

    daily_loss_hard_stop_hit and profit_target_hit are explicit V16 flags.
    V15 collapsed both into a single `halted` kill switch — V16 separates
    them so RiskManager can decouple "halted (no entries)" from
    "profit target reached (allow risk-reduced trades)" without
    recomputing from daily_pnl at runtime.
    """
    date: str = field(default_factory=lambda: utc_now().date().isoformat())
    daily_pnl: float = 0.0
    daily_loss_hard_stop_hit: bool = False
    profit_target_hit: bool = False
    approved_count: int = 0
    executed_count: int = 0
    rejected_count: int = 0

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict) -> "DailyCounters":
        return cls(
            date=data.get("date", utc_now().date().isoformat()),
            daily_pnl=data.get("daily_pnl", 0.0),
            daily_loss_hard_stop_hit=data.get("daily_loss_hard_stop_hit", False),
            profit_target_hit=data.get("profit_target_hit", False),
            approved_count=data.get("approved_count", 0),
            executed_count=data.get("executed_count", 0),
            rejected_count=data.get("rejected_count", 0),
        )


@dataclass
class BrainCounters:
    """
    Per-Brain win/loss counters + AI cost analytics.

    bias_calls_count is V16-new: V15 did not persist AI-bias call counts
    (only an in-memory cache). Tracked here for Anthropic per-token cost
    analytics across sessions.
    """
    mr_wins: int = 0
    mr_losses: int = 0
    tf_wins: int = 0
    tf_losses: int = 0
    bias_calls_count: int = 0

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict) -> "BrainCounters":
        return cls(
            mr_wins=data.get("mr_wins", 0),
            mr_losses=data.get("mr_losses", 0),
            tf_wins=data.get("tf_wins", 0),
            tf_losses=data.get("tf_losses", 0),
            bias_calls_count=data.get("bias_calls_count", 0),
        )


@dataclass(frozen=True)
class BiasEntry:
    """
    Frozen snapshot of an AI bias decision at computed_at.

    The state store persists the snapshot only — it does NOT enforce
    freshness. The consumer (BrainBias) MUST check (utc_now - computed_at)
    against its freshness policy (e.g. >1h → recompute) before using a
    cached entry. Bias is time-sensitive: a 4-hour-old bias on a
    fast-moving session is misleading.
    """
    direction: str        # "BUY" | "SELL" | "NEUTRAL"
    confidence: int       # 0-100
    rationale: str
    computed_at: str      # ISO datetime UTC

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict) -> "BiasEntry":
        return cls(
            direction=data["direction"],
            confidence=data["confidence"],
            rationale=data.get("rationale", ""),
            computed_at=data["computed_at"],
        )


@dataclass
class BiasCache:
    """Per-symbol cache of frozen BiasEntry snapshots. See BiasEntry."""
    entries: dict[str, BiasEntry] = field(default_factory=dict)

    def to_dict(self) -> dict:
        return {"entries": {s: e.to_dict() for s, e in self.entries.items()}}

    @classmethod
    def from_dict(cls, data: dict) -> "BiasCache":
        raw = data.get("entries", {})
        return cls(entries={s: BiasEntry.from_dict(e) for s, e in raw.items()})


@dataclass
class EntryEvalCache:
    """
    Per-symbol last AI-entry-evaluated M5 candle time (UTC unix seconds).

    Used by Brain.evaluate_entry to skip the AI call when it has
    already been made on the current M5 candle for this symbol.
    Persists across restart — same-candle dedup survives crashes.

    Update timing (orchestrator): set last_eval[symbol] = candle_time
    when EntryEvalResult.evaluated_candle_time is populated. The brain
    populates that only when the AI either responded (any parse
    outcome) or returned a permanent error (credit/invalid). Transient
    AI failures (unknown/overload) leave evaluated_candle_time=None
    so the next iteration retries.
    """
    last_eval: dict[str, float] = field(default_factory=dict)

    def to_dict(self) -> dict:
        return {"last_eval": dict(self.last_eval)}

    @classmethod
    def from_dict(cls, data: dict) -> "EntryEvalCache":
        return cls(last_eval=dict(data.get("last_eval", {})))


@dataclass
class CooldownState:
    """
    Per-symbol post-trade cooldown tracking.
      last_trade_close_at: ISO datetime of most recent close per symbol.
      cooldown_until: optional explicit cooldown expiry per symbol
                      (e.g. enforced after consecutive losses).
    """
    last_trade_close_at: dict[str, str] = field(default_factory=dict)
    cooldown_until: dict[str, str] = field(default_factory=dict)

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict) -> "CooldownState":
        return cls(
            last_trade_close_at=dict(data.get("last_trade_close_at", {})),
            cooldown_until=dict(data.get("cooldown_until", {})),
        )


# ============================================================
# SESSION STATE
# ============================================================

@dataclass
class SessionState:
    """
    Full bot state, saved atomically to disk.
    Schema is versioned; bump SCHEMA_VERSION on breaking changes
    and add migration path in _migrate().
    """
    schema_version: int = SCHEMA_VERSION
    started_at: str = field(default_factory=lambda: utc_now().isoformat())
    saved_at: str = field(default_factory=lambda: utc_now().isoformat())

    # Active trades keyed by symbol
    active_trades: dict[str, ActiveTrade] = field(default_factory=dict)

    # Nested sub-dataclasses
    daily: DailyCounters = field(default_factory=DailyCounters)
    brain: BrainCounters = field(default_factory=BrainCounters)
    bias_cache: BiasCache = field(default_factory=BiasCache)
    entry_eval_cache: EntryEvalCache = field(default_factory=EntryEvalCache)
    cooldown: CooldownState = field(default_factory=CooldownState)

    # Top-level (survive daily reset)
    session_pnl: float = 0.0
    halted: bool = False
    halt_reason: str = ""

    # Free-form metadata for debugging
    metadata: dict = field(default_factory=dict)

    # ------------------------------------------------------------
    # SERIALIZATION
    # ------------------------------------------------------------

    def to_dict(self) -> dict:
        return {
            "schema_version": self.schema_version,
            "started_at": self.started_at,
            "saved_at": self.saved_at,
            "active_trades": {
                symbol: {
                    "entry": at.entry.to_dict(),
                    "runtime": at.runtime.to_dict(),
                }
                for symbol, at in self.active_trades.items()
            },
            "daily": self.daily.to_dict(),
            "brain": self.brain.to_dict(),
            "bias_cache": self.bias_cache.to_dict(),
            "entry_eval_cache": self.entry_eval_cache.to_dict(),
            "cooldown": self.cooldown.to_dict(),
            "session_pnl": self.session_pnl,
            "halted": self.halted,
            "halt_reason": self.halt_reason,
            "metadata": self.metadata,
        }

    @classmethod
    def from_dict(cls, data: dict) -> "SessionState":
        data = _migrate(data)

        active_trades = {}
        for symbol, at_data in data.get("active_trades", {}).items():
            active_trades[symbol] = ActiveTrade(
                entry=TradeEntry.from_dict(at_data["entry"]),
                runtime=TradeRuntime.from_dict(at_data["runtime"]),
            )

        return cls(
            schema_version=data.get("schema_version", SCHEMA_VERSION),
            started_at=data.get("started_at", utc_now().isoformat()),
            saved_at=data.get("saved_at", utc_now().isoformat()),
            active_trades=active_trades,
            daily=DailyCounters.from_dict(data.get("daily", {})),
            brain=BrainCounters.from_dict(data.get("brain", {})),
            bias_cache=BiasCache.from_dict(data.get("bias_cache", {})),
            entry_eval_cache=EntryEvalCache.from_dict(data.get("entry_eval_cache", {})),
            cooldown=CooldownState.from_dict(data.get("cooldown", {})),
            session_pnl=data.get("session_pnl", 0.0),
            halted=data.get("halted", False),
            halt_reason=data.get("halt_reason", ""),
            metadata=data.get("metadata", {}),
        )


# ============================================================
# MIGRATIONS
# ============================================================

def _migrate_v1_to_v2(data: dict) -> dict:
    """
    Schema v1 (flat) → v2 (nested sub-dataclasses).

    Field mapping:
      approved_count, executed_count, rejected_count → daily.*
      daily_pnl, daily_pnl_date                      → daily.daily_pnl, daily.date
      mr_wins, mr_losses, tf_wins, tf_losses         → brain.*
      session_pnl, halted, halt_reason, metadata     → unchanged (top-level)
      active_trades                                  → unchanged

    daily_loss_hard_stop_hit / profit_target_hit / bias_calls_count are
    new in v2 — defaulted to False/0 for migrated v1 data.
    """
    return {
        "schema_version": 2,
        "started_at": data.get("started_at"),
        "saved_at": data.get("saved_at"),
        "active_trades": data.get("active_trades", {}),
        "daily": {
            "date": data.get("daily_pnl_date", utc_now().date().isoformat()),
            "daily_pnl": data.get("daily_pnl", 0.0),
            "daily_loss_hard_stop_hit": False,
            "profit_target_hit": False,
            "approved_count": data.get("approved_count", 0),
            "executed_count": data.get("executed_count", 0),
            "rejected_count": data.get("rejected_count", 0),
        },
        "brain": {
            "mr_wins": data.get("mr_wins", 0),
            "mr_losses": data.get("mr_losses", 0),
            "tf_wins": data.get("tf_wins", 0),
            "tf_losses": data.get("tf_losses", 0),
            "bias_calls_count": 0,
        },
        "bias_cache": {"entries": {}},
        "cooldown": {"last_trade_close_at": {}, "cooldown_until": {}},
        "session_pnl": data.get("session_pnl", 0.0),
        "halted": data.get("halted", False),
        "halt_reason": data.get("halt_reason", ""),
        "metadata": data.get("metadata", {}),
    }


def _migrate_v2_to_v3(data: dict) -> dict:
    """
    Schema v2 -> v3: add entry_eval_cache (per-symbol same-candle dedup
    for evaluate_entry AI calls). Migration is purely additive: existing
    fields untouched, new field defaulted to empty cache.
    """
    out = dict(data)
    out["schema_version"] = 3
    out["entry_eval_cache"] = {"last_eval": {}}
    return out


def _migrate(data: dict) -> dict:
    """
    Migrate older state schemas to current SCHEMA_VERSION.
    Loud failure on unknown schema (better than silent corruption).
    """
    version = data.get("schema_version", 0)

    if version == SCHEMA_VERSION:
        return data

    if version > SCHEMA_VERSION:
        raise ValueError(
            f"State file schema_version={version} is newer than "
            f"this code's SCHEMA_VERSION={SCHEMA_VERSION}. "
            f"Refusing to load (would corrupt data)."
        )

    if version == 1:
        data = _migrate_v1_to_v2(data)
        version = 2

    if version == 2:
        data = _migrate_v2_to_v3(data)
        version = 3

    if version == SCHEMA_VERSION:
        return data

    raise ValueError(
        f"Unknown migration path from schema_version={version} "
        f"to {SCHEMA_VERSION}. Add migration step in _migrate()."
    )


# ============================================================
# STATE STORE
# ============================================================

class StateStore:
    """
    Atomic file-based state persistence.

    Usage:
        store = StateStore(config.state_file)
        state = load_state(store, mode=StateLoadMode.RESUME, auto_daily_reset=True)
        # ... mutate state ...
        store.save(state)
    """

    def __init__(self, state_file: Path) -> None:
        self.state_file = Path(state_file)
        self.backup_file = self.state_file.with_suffix(".prev.json")

    # ------------------------------------------------------------
    # LOAD
    # ------------------------------------------------------------

    def load(self) -> Optional[SessionState]:
        """
        Load state from disk. Returns None if file doesn't exist.
        Raises on corrupt file or unknown schema (loud failure).
        """
        if not self.state_file.exists():
            return None

        try:
            with self.state_file.open("r", encoding="utf-8") as f:
                data = json.load(f)
        except json.JSONDecodeError as e:
            raise RuntimeError(
                f"State file {self.state_file} is corrupt: {e}. "
                f"Backup at {self.backup_file} may be usable."
            ) from e

        return SessionState.from_dict(data)

    def load_or_fresh(self) -> SessionState:
        """Load state, or return a fresh SessionState if no file exists."""
        loaded = self.load()
        return loaded if loaded is not None else SessionState()

    # ------------------------------------------------------------
    # SAVE — atomic + fsync
    # ------------------------------------------------------------

    def save(self, state: SessionState) -> None:
        """
        Atomic durable save:
          1. If state_file exists, copy it to backup_file (.prev.json)
          2. Write new state to temp file (same dir for atomic rename)
          3. flush + fsync the temp file (durable on disk)
          4. rename temp -> state_file (atomic on POSIX)
          5. fsync the parent directory (durable rename)

        Updates state.saved_at timestamp before serialization.
        """
        state.saved_at = utc_now().isoformat()

        if self.state_file.exists():
            shutil.copy2(self.state_file, self.backup_file)

        self.state_file.parent.mkdir(parents=True, exist_ok=True)
        fd, tmp_path = tempfile.mkstemp(
            prefix=self.state_file.stem + ".",
            suffix=".tmp",
            dir=str(self.state_file.parent),
        )
        try:
            with open(fd, "w", encoding="utf-8") as f:
                json.dump(state.to_dict(), f, indent=2, ensure_ascii=False)
                f.flush()
                os.fsync(f.fileno())

            Path(tmp_path).replace(self.state_file)

            # fsync the parent directory so the rename is durable across crash.
            # Best-effort: not portable everywhere (Windows raises on dir fsync).
            try:
                dir_fd = os.open(str(self.state_file.parent), os.O_RDONLY)
                try:
                    os.fsync(dir_fd)
                finally:
                    os.close(dir_fd)
            except OSError:
                pass

        except Exception:
            try:
                Path(tmp_path).unlink(missing_ok=True)
            except Exception:
                pass
            raise

    # ------------------------------------------------------------
    # UTILITIES
    # ------------------------------------------------------------

    def reset(self) -> None:
        """Delete state file and .prev backup. Use with --fresh-start."""
        self.state_file.unlink(missing_ok=True)
        self.backup_file.unlink(missing_ok=True)

    def exists(self) -> bool:
        return self.state_file.exists()


# ============================================================
# LOAD MODES — explicit fresh-vs-resume API
# ============================================================

class StateLoadMode(str, Enum):
    """How load_state should treat existing state file."""
    RESUME = "RESUME"   # Load existing, fresh if missing.
    FRESH = "FRESH"     # Archive existing to .deleted-<UTC>.json, return fresh.


def load_state(
    store: StateStore,
    *,
    mode: StateLoadMode = StateLoadMode.RESUME,
    auto_daily_reset: bool = False,
) -> SessionState:
    """
    Load state with explicit mode and optional daily reset.

    Modes:
      RESUME — Load existing state. Fresh SessionState if file missing.
      FRESH  — If state file exists, archive it to
               <state_file>.deleted-<UTC timestamp>.json then drop it
               (and the .prev backup). Return a fresh SessionState.
               Caller is responsible for the first save().

    auto_daily_reset:
      When True (typically LIVE runs), if state.daily.date != today UTC
      reset DailyCounters to a fresh one for today AND clear halted/halt_reason
      (halts are daily-scoped). session_pnl and active_trades are preserved
      (active_trades survive across days; closure logic resets them).
      V15 parity: V15 reset daily counters at UTC midnight via the
      session loop. V16 makes the reset declarative at load time.
    """
    if mode == StateLoadMode.FRESH:
        if store.state_file.exists():
            ts = utc_now().strftime("%Y%m%dT%H%M%SZ")
            archive = store.state_file.with_suffix(f".deleted-{ts}.json")
            shutil.copy2(store.state_file, archive)
            store.state_file.unlink()
            store.backup_file.unlink(missing_ok=True)
        return SessionState()

    state = store.load_or_fresh()

    if auto_daily_reset:
        today = utc_now().date().isoformat()
        if state.daily.date != today:
            state.daily = DailyCounters(date=today)
            state.halted = False
            state.halt_reason = ""

    return state
