"""
Smoke tests for persistence/state_store.py (schema v2).

Covers:
  - Fresh defaults and per-dataclass roundtrip (DailyCounters, BrainCounters,
    BiasEntry, BiasCache, CooldownState, ActiveTrade)
  - SessionState save/load roundtrip with full payload
  - Atomic save: .prev.json backup created on existing file
  - Schema migration v1 (flat) -> v2 (nested)
  - Migration loud failure on unknown / future version
  - load_state RESUME (existing + missing) and FRESH (timestamped archive)
  - auto_daily_reset on stale date in LIVE-style flow
  - Corrupt JSON raises loud RuntimeError

Run:
    cd ~/apex_v16
    python -m tests.test_state_store
"""

from __future__ import annotations

import json
import sys
import tempfile
from datetime import datetime, timedelta, timezone
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from core.contracts import TradeEntry, TradeRuntime, utc_now
from persistence.state_store import (
    ActiveTrade,
    BiasCache,
    BiasEntry,
    BrainCounters,
    CooldownState,
    DailyCounters,
    EntryEvalCache,
    SCHEMA_VERSION,
    SessionState,
    StateLoadMode,
    StateStore,
    _migrate,
    load_state,
)


def _ok(label: str) -> None:
    print(f"  ok  {label}")


def _make_entry(symbol: str = "MES") -> TradeEntry:
    return TradeEntry(
        symbol=symbol,
        brain_name="TF",
        direction="BUY",
        contracts=2,
        entry_price=5800.00,
        sl_price=5797.50,
        tp_price=5805.00,
        opened_at=utc_now(),
        rsi_m5_at_entry=55.0,
        rsi_h1_at_entry=58.0,
        rsi_h4_at_entry=60.0,
        atr_ratio_at_entry=1.1,
        market_structure_at_entry="BULLISH_EXPANSION",
        regime_at_entry="TRENDING",
        h1_compat_at_entry=0.9,
        confidence_at_entry=78,
    )


def _make_runtime() -> TradeRuntime:
    return TradeRuntime(
        minutes_open=12.0,
        progress_pct=0.42,
        net_profit_usd=37.5,
        partial_done=False,
    )


# ============================================================
# 1. Fresh defaults
# ============================================================

def test_fresh_session_state_defaults():
    s = SessionState()
    assert s.schema_version == SCHEMA_VERSION == 3
    assert s.active_trades == {}
    assert isinstance(s.daily, DailyCounters)
    assert isinstance(s.brain, BrainCounters)
    assert isinstance(s.bias_cache, BiasCache)
    assert isinstance(s.entry_eval_cache, EntryEvalCache)
    assert isinstance(s.cooldown, CooldownState)
    assert s.daily.daily_pnl == 0.0
    assert s.daily.daily_loss_hard_stop_hit is False
    assert s.daily.profit_target_hit is False
    assert s.brain.bias_calls_count == 0
    assert s.entry_eval_cache.last_eval == {}
    assert s.session_pnl == 0.0
    assert s.halted is False
    _ok("fresh SessionState defaults are sane (v3 nested)")


# ============================================================
# 2. DailyCounters with explicit V16 flags
# ============================================================

def test_daily_counters_flags_roundtrip():
    dc = DailyCounters(
        date="2026-04-28",
        daily_pnl=-275.0,
        daily_loss_hard_stop_hit=False,
        profit_target_hit=True,
        approved_count=4,
        executed_count=3,
        rejected_count=11,
    )
    d = dc.to_dict()
    dc2 = DailyCounters.from_dict(d)
    assert dc2.profit_target_hit is True
    assert dc2.daily_loss_hard_stop_hit is False
    assert dc2.daily_pnl == -275.0
    assert dc2.executed_count == 3
    _ok("DailyCounters roundtrip preserves V16 explicit flags")


# ============================================================
# 3. BrainCounters with bias_calls_count (V16-new)
# ============================================================

def test_brain_counters_bias_calls_count_roundtrip():
    bc = BrainCounters(
        mr_wins=3, mr_losses=2,
        tf_wins=5, tf_losses=4,
        bias_calls_count=42,
    )
    bc2 = BrainCounters.from_dict(bc.to_dict())
    assert bc2.bias_calls_count == 42
    assert bc2.mr_wins == 3 and bc2.tf_losses == 4
    _ok("BrainCounters roundtrip with bias_calls_count")


# ============================================================
# 4. BiasEntry frozen + roundtrip
# ============================================================

def test_bias_entry_frozen_and_roundtrip():
    be = BiasEntry(
        direction="BUY", confidence=72,
        rationale="EMA20>EMA50, RSI4h=58, structure bullish",
        computed_at=utc_now().isoformat(),
    )
    # Frozen: assignment raises
    raised = False
    try:
        be.confidence = 99   # type: ignore[misc]
    except Exception:
        raised = True
    assert raised, "BiasEntry must be frozen"
    be2 = BiasEntry.from_dict(be.to_dict())
    assert be2 == be
    _ok("BiasEntry frozen + roundtrip equal")


# ============================================================
# 5. BiasCache multi-symbol
# ============================================================

def test_bias_cache_multi_symbol_roundtrip():
    now = utc_now().isoformat()
    cache = BiasCache(entries={
        "MES": BiasEntry("BUY", 70, "trend up", now),
        "MNQ": BiasEntry("SELL", 65, "trend down", now),
    })
    cache2 = BiasCache.from_dict(cache.to_dict())
    assert set(cache2.entries) == {"MES", "MNQ"}
    assert cache2.entries["MES"].direction == "BUY"
    assert cache2.entries["MNQ"].confidence == 65
    _ok("BiasCache multi-symbol roundtrip")


# ============================================================
# 6. CooldownState roundtrip
# ============================================================

def test_cooldown_state_roundtrip():
    now = utc_now().isoformat()
    later = (utc_now() + timedelta(minutes=20)).isoformat()
    cs = CooldownState(
        last_trade_close_at={"MES": now, "6E": now},
        cooldown_until={"MES": later},
    )
    cs2 = CooldownState.from_dict(cs.to_dict())
    assert cs2.last_trade_close_at["MES"] == now
    assert cs2.cooldown_until["MES"] == later
    assert "6E" not in cs2.cooldown_until
    _ok("CooldownState roundtrip preserves both maps")


# ============================================================
# 7. SessionState empty roundtrip via StateStore
# ============================================================

def test_session_state_roundtrip_empty():
    with tempfile.TemporaryDirectory() as tmp:
        store = StateStore(Path(tmp) / "state.json")
        s = SessionState()
        store.save(s)
        s2 = store.load()
        assert s2 is not None
        assert s2.schema_version == SCHEMA_VERSION
        assert s2.active_trades == {}
        assert s2.daily.date == s.daily.date
    _ok("empty SessionState roundtrip via disk")


# ============================================================
# 8. SessionState with active trade roundtrip
# ============================================================

def test_session_state_roundtrip_with_active_trade():
    with tempfile.TemporaryDirectory() as tmp:
        store = StateStore(Path(tmp) / "state.json")
        s = SessionState()
        s.active_trades["MES"] = ActiveTrade(_make_entry(), _make_runtime())
        s.daily.daily_pnl = 125.0
        s.daily.executed_count = 1
        s.brain.tf_wins = 1
        s.brain.bias_calls_count = 7
        s.bias_cache.entries["MES"] = BiasEntry(
            "BUY", 75, "test", utc_now().isoformat()
        )
        s.cooldown.last_trade_close_at["MES"] = utc_now().isoformat()
        s.entry_eval_cache.last_eval["MES"] = 1714378500.0
        s.entry_eval_cache.last_eval["6E"]  = 1714378200.0
        s.session_pnl = 125.0
        store.save(s)
        s2 = store.load()
        assert s2 is not None
        assert "MES" in s2.active_trades
        assert s2.active_trades["MES"].entry.contracts == 2
        assert s2.active_trades["MES"].runtime.partial_done is False
        assert s2.daily.daily_pnl == 125.0
        assert s2.daily.executed_count == 1
        assert s2.brain.tf_wins == 1
        assert s2.brain.bias_calls_count == 7
        assert s2.bias_cache.entries["MES"].confidence == 75
        assert "MES" in s2.cooldown.last_trade_close_at
        assert s2.entry_eval_cache.last_eval["MES"] == 1714378500.0
        assert s2.entry_eval_cache.last_eval["6E"]  == 1714378200.0
        assert s2.session_pnl == 125.0
    _ok("SessionState full roundtrip preserves all nested fields")


# ============================================================
# 8b. EntryEvalCache standalone roundtrip + v3 migration step
# ============================================================

def test_entry_eval_cache_roundtrip():
    cache = EntryEvalCache(last_eval={"MES": 1714000000.0, "6E": 1714000300.0})
    cache2 = EntryEvalCache.from_dict(cache.to_dict())
    assert cache2.last_eval == cache.last_eval
    _ok("EntryEvalCache roundtrip preserves per-symbol candle_time map")


def test_migrate_v2_to_v3_adds_entry_eval_cache():
    """v2 payload (no entry_eval_cache) migrates to v3 with empty cache."""
    v2_payload = {
        "schema_version": 2,
        "started_at": "2026-04-29T08:00:00+00:00",
        "saved_at":   "2026-04-29T08:30:00+00:00",
        "active_trades": {},
        "daily": {
            "date": "2026-04-29", "daily_pnl": 0.0,
            "daily_loss_hard_stop_hit": False, "profit_target_hit": False,
            "approved_count": 0, "executed_count": 0, "rejected_count": 0,
        },
        "brain": {
            "mr_wins": 1, "mr_losses": 2, "tf_wins": 3, "tf_losses": 4,
            "bias_calls_count": 5,
        },
        "bias_cache": {"entries": {}},
        "cooldown": {"last_trade_close_at": {}, "cooldown_until": {}},
        "session_pnl": 12.0, "halted": False, "halt_reason": "",
        "metadata": {},
    }
    s = SessionState.from_dict(v2_payload)
    assert s.schema_version == 3
    assert s.entry_eval_cache.last_eval == {}
    # Pre-existing v2 fields preserved
    assert s.brain.tf_wins == 3
    assert s.brain.bias_calls_count == 5
    assert s.session_pnl == 12.0
    _ok("_migrate_v2_to_v3 adds empty entry_eval_cache, preserves v2 data")


# ============================================================
# 9. Atomic save: .prev.json backup created on existing file
# ============================================================

def test_save_creates_prev_backup_on_existing_file():
    with tempfile.TemporaryDirectory() as tmp:
        store = StateStore(Path(tmp) / "state.json")
        s = SessionState()
        s.session_pnl = 10.0
        store.save(s)            # first save -> no .prev yet
        assert not store.backup_file.exists()
        s.session_pnl = 20.0
        store.save(s)            # second save -> .prev keeps the previous content
        assert store.backup_file.exists()
        with store.backup_file.open() as f:
            prev = json.load(f)
        assert prev["session_pnl"] == 10.0
    _ok("save() writes .prev.json backup of previous file before overwrite")


# ============================================================
# 10. Migration v1 -> v2: flat fields land in nested sub-dataclasses
# ============================================================

def test_migrate_v1_to_v2_preserves_all_data():
    v1_payload = {
        "schema_version": 1,
        "started_at": "2026-04-28T08:00:00+00:00",
        "saved_at": "2026-04-28T15:00:00+00:00",
        "active_trades": {},
        "approved_count": 4,
        "executed_count": 3,
        "rejected_count": 11,
        "mr_wins": 2, "mr_losses": 1,
        "tf_wins": 5, "tf_losses": 3,
        "session_pnl": 88.5,
        "daily_pnl": 88.5,
        "daily_pnl_date": "2026-04-28",
        "halted": False,
        "halt_reason": "",
        "metadata": {"note": "v1 sample"},
    }
    s = SessionState.from_dict(v1_payload)
    # _migrate chains v1 -> v2 -> v3
    assert s.schema_version == 3
    # daily.* received the v1 flat fields
    assert s.daily.daily_pnl == 88.5
    assert s.daily.date == "2026-04-28"
    assert s.daily.approved_count == 4
    assert s.daily.executed_count == 3
    assert s.daily.rejected_count == 11
    # New v2 flags default to False
    assert s.daily.daily_loss_hard_stop_hit is False
    assert s.daily.profit_target_hit is False
    # brain.* received the v1 flat fields
    assert s.brain.mr_wins == 2 and s.brain.mr_losses == 1
    assert s.brain.tf_wins == 5 and s.brain.tf_losses == 3
    assert s.brain.bias_calls_count == 0           # new in v2
    # New v2 caches default empty
    assert s.bias_cache.entries == {}
    assert s.cooldown.last_trade_close_at == {}
    # New v3 cache also defaults empty
    assert s.entry_eval_cache.last_eval == {}
    # Top-level survived
    assert s.session_pnl == 88.5
    assert s.metadata == {"note": "v1 sample"}
    _ok("_migrate_v1_to_v2 maps flat fields into nested sub-dataclasses")


# ============================================================
# 11. Migration loud failure: unknown version
# ============================================================

def test_migrate_unknown_version_raises():
    raised = False
    try:
        _migrate({"schema_version": 0})
    except ValueError as e:
        raised = "Unknown migration path" in str(e)
    assert raised
    _ok("_migrate raises ValueError on unknown schema_version=0")


# ============================================================
# 12. Migration loud failure: future version
# ============================================================

def test_migrate_future_version_raises():
    raised = False
    try:
        _migrate({"schema_version": SCHEMA_VERSION + 5})
    except ValueError as e:
        raised = "newer than" in str(e)
    assert raised
    _ok("_migrate raises ValueError on schema_version > SCHEMA_VERSION")


# ============================================================
# 13. load_state RESUME loads existing
# ============================================================

def test_load_state_resume_loads_existing():
    with tempfile.TemporaryDirectory() as tmp:
        store = StateStore(Path(tmp) / "state.json")
        s = SessionState()
        s.session_pnl = 75.0
        s.brain.tf_wins = 4
        store.save(s)
        s2 = load_state(store, mode=StateLoadMode.RESUME)
        assert s2.session_pnl == 75.0
        assert s2.brain.tf_wins == 4
    _ok("load_state RESUME loads existing state from disk")


# ============================================================
# 14. load_state RESUME returns fresh when file missing
# ============================================================

def test_load_state_resume_fresh_when_missing():
    with tempfile.TemporaryDirectory() as tmp:
        store = StateStore(Path(tmp) / "state.json")
        assert not store.exists()
        s = load_state(store, mode=StateLoadMode.RESUME)
        assert s.schema_version == SCHEMA_VERSION
        assert s.session_pnl == 0.0
    _ok("load_state RESUME returns fresh SessionState when no file")


# ============================================================
# 15. load_state FRESH archives existing to .deleted-<ts>.json
# ============================================================

def test_load_state_fresh_archives_existing_file():
    with tempfile.TemporaryDirectory() as tmp:
        store = StateStore(Path(tmp) / "state.json")
        s = SessionState()
        s.session_pnl = 999.0
        store.save(s)
        assert store.exists()
        s2 = load_state(store, mode=StateLoadMode.FRESH)
        # Old file moved away
        assert not store.exists()
        # Fresh state returned
        assert s2.session_pnl == 0.0
        # Archive present and contains old payload
        archives = list(Path(tmp).glob("state.deleted-*.json"))
        assert len(archives) == 1, f"expected 1 archive, found {archives}"
        with archives[0].open() as f:
            archived = json.load(f)
        assert archived["session_pnl"] == 999.0
    _ok("load_state FRESH archives existing file to .deleted-<UTC>.json")


# ============================================================
# 16. load_state auto_daily_reset on stale date
# ============================================================

def test_load_state_auto_daily_reset_on_stale_date():
    with tempfile.TemporaryDirectory() as tmp:
        store = StateStore(Path(tmp) / "state.json")
        s = SessionState()
        # Force "yesterday" in daily counters and profit_target_hit, halted
        yesterday = (utc_now() - timedelta(days=1)).date().isoformat()
        s.daily = DailyCounters(
            date=yesterday,
            daily_pnl=-500.0,
            profit_target_hit=True,
            executed_count=8,
        )
        s.halted = True
        s.halt_reason = "yesterday_loss_stop"
        s.session_pnl = -500.0
        s.active_trades["MES"] = ActiveTrade(_make_entry(), _make_runtime())
        store.save(s)

        s2 = load_state(store, mode=StateLoadMode.RESUME, auto_daily_reset=True)
        today = utc_now().date().isoformat()
        # Daily counters reset
        assert s2.daily.date == today
        assert s2.daily.daily_pnl == 0.0
        assert s2.daily.profit_target_hit is False
        assert s2.daily.executed_count == 0
        # Halt cleared (daily-scoped)
        assert s2.halted is False
        assert s2.halt_reason == ""
        # session_pnl preserved (NOT daily-scoped)
        assert s2.session_pnl == -500.0
        # active_trades preserved
        assert "MES" in s2.active_trades
    _ok("load_state auto_daily_reset zeros daily + halt, preserves session_pnl/trades")


# ============================================================
# 17. Corrupt JSON loud failure
# ============================================================

def test_load_corrupt_json_raises():
    with tempfile.TemporaryDirectory() as tmp:
        path = Path(tmp) / "state.json"
        path.write_text("{ not valid json", encoding="utf-8")
        store = StateStore(path)
        raised = False
        try:
            store.load()
        except RuntimeError as e:
            raised = "corrupt" in str(e).lower()
        assert raised
    _ok("load() raises RuntimeError on corrupt JSON (loud failure)")


# ============================================================
# MAIN
# ============================================================

def main() -> int:
    print("test_state_store.py")
    test_fresh_session_state_defaults()
    test_daily_counters_flags_roundtrip()
    test_brain_counters_bias_calls_count_roundtrip()
    test_bias_entry_frozen_and_roundtrip()
    test_bias_cache_multi_symbol_roundtrip()
    test_cooldown_state_roundtrip()
    test_session_state_roundtrip_empty()
    test_session_state_roundtrip_with_active_trade()
    test_entry_eval_cache_roundtrip()
    test_migrate_v2_to_v3_adds_entry_eval_cache()
    test_save_creates_prev_backup_on_existing_file()
    test_migrate_v1_to_v2_preserves_all_data()
    test_migrate_unknown_version_raises()
    test_migrate_future_version_raises()
    test_load_state_resume_loads_existing()
    test_load_state_resume_fresh_when_missing()
    test_load_state_fresh_archives_existing_file()
    test_load_state_auto_daily_reset_on_stale_date()
    test_load_corrupt_json_raises()
    print("ALL 19 TESTS PASSED")
    return 0


if __name__ == "__main__":
    sys.exit(main())
