"""
APEX V18 — TVDataFeedProvider tests.

The tvdatafeed-enhanced package is dependency-injected through
`client_factory` and `interval_map_builder` so tests run offline:
no real TvDatafeed instantiation, no network, no auth.
"""

from __future__ import annotations

import asyncio
import json
import sys
import tempfile
from pathlib import Path

import pandas as pd

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from analysis.tv_data_provider import (
    TVDataFeedProvider, TV_SYMBOL_MAP, _build_interval_map,
)


# ============================================================
# Fakes
# ============================================================

class FakeInterval:
    """Stand-in for tvDatafeed.Interval enum members — opaque tokens."""
    def __init__(self, label: str) -> None:
        self.label = label
    def __repr__(self) -> str:
        return f"FakeInterval({self.label!r})"
    def __eq__(self, other) -> bool:
        return isinstance(other, FakeInterval) and self.label == other.label
    def __hash__(self) -> int:
        return hash(self.label)


FAKE_INTERVALS = {
    "5min":  FakeInterval("5min"),
    "1hour": FakeInterval("1hour"),
    "4hour": FakeInterval("4hour"),
}


class FakeTvClient:
    """Records get_hist calls. Captures init kwargs (username, password,
    token, token_cache_file) so tests can assert on auth wiring."""
    def __init__(self, username=None, password=None, token=None,
                 token_cache_file=None):
        self.username = username
        self.password = password
        self.token = token
        self.token_cache_file = token_cache_file
        self.calls: list[dict] = []
        self.responses: list[pd.DataFrame | None] = []
        self.raise_on_call: Exception | None = None

    def get_hist(self, *, symbol, exchange, interval, n_bars, fut_contract):
        self.calls.append({
            "symbol": symbol, "exchange": exchange,
            "interval": interval, "n_bars": n_bars,
            "fut_contract": fut_contract,
        })
        if self.raise_on_call is not None:
            raise self.raise_on_call
        if not self.responses:
            return None
        return self.responses.pop(0)


def _make_tv_df(rows: int = 5) -> pd.DataFrame:
    """Build a TV-shaped DataFrame: DatetimeIndex 'datetime' + symbol col."""
    idx = pd.date_range("2026-05-01 13:00", periods=rows, freq="5min")
    df = pd.DataFrame({
        "symbol": ["CME_MINI:MES1!"] * rows,
        "open":   [5800.0 + i for i in range(rows)],
        "high":   [5810.0 + i for i in range(rows)],
        "low":    [5795.0 + i for i in range(rows)],
        "close":  [5805.0 + i for i in range(rows)],
        "volume": [100 + i * 10 for i in range(rows)],
    }, index=idx)
    df.index.name = "datetime"
    return df


def _make_provider(
    client: FakeTvClient | None = None,
    *,
    token: str | None = None,
    token_cache_file: str | Path | None = None,
    username: str | None = "user",
    password: str | None = "pass",
) -> tuple[TVDataFeedProvider, FakeTvClient]:
    client = client or FakeTvClient()
    def factory(**kwargs):
        # Mimic the lib by capturing all kwargs on the fake client
        for k, v in kwargs.items():
            setattr(client, k, v)
        return client
    provider = TVDataFeedProvider(
        username=username, password=password,
        token=token, token_cache_file=token_cache_file,
        client_factory=factory,
        interval_map_builder=lambda: dict(FAKE_INTERVALS),
    )
    return provider, client


def _ok(msg: str) -> None:
    print(f"  ok  {msg}")


# ============================================================
# 1. Symbol map sanity
# ============================================================

def test_symbol_map_covers_all_v16_universe():
    """Every symbol in TRADING_HOURS (V16 universe) must have a TV mapping."""
    from core.config_futures import TRADING_HOURS
    missing = [s for s in TRADING_HOURS if s not in TV_SYMBOL_MAP]
    assert missing == [], f"TV_SYMBOL_MAP missing entries for: {missing}"
    _ok("TV_SYMBOL_MAP covers all V16 universe symbols")


# ============================================================
# 2. Happy path: fetch + normalize
# ============================================================

def test_get_bars_happy_path_returns_normalized_df():
    """
    TV returns a DatetimeIndex DataFrame with 'symbol' column;
    provider strips 'symbol', renames index to 'time', sorts ascending.
    """
    provider, client = _make_provider()
    client.responses = [_make_tv_df(rows=5)]

    df = asyncio.run(provider.get_bars("MES", "5min", 200))

    assert list(df.columns) == ["time", "open", "high", "low", "close", "volume"]
    assert "symbol" not in df.columns
    assert len(df) == 5
    # ascending order
    assert df["time"].is_monotonic_increasing
    # last row = most recent bar (matches V16 contract)
    assert df["close"].iloc[-1] == 5809.0

    # Forwarded args to tvdatafeed
    assert len(client.calls) == 1
    call = client.calls[0]
    assert call["symbol"] == "MES"
    assert call["exchange"] == "CME_MINI"
    assert call["interval"] == FAKE_INTERVALS["5min"]
    assert call["n_bars"] == 200
    assert call["fut_contract"] == 1   # front-month continuous
    _ok("happy path: normalized columns + ascending sort + fwd args")


# ============================================================
# 3. Multiple symbols / timeframes
# ============================================================

def test_get_bars_forwards_correct_exchange_per_symbol():
    """6E -> CME, MGC -> COMEX, MCL -> NYMEX, MYM -> CBOT_MINI."""
    cases = [
        ("6E",  "6E",  "CME"),
        ("MGC", "MGC", "COMEX"),
        ("MCL", "MCL", "NYMEX"),
        ("MYM", "MYM", "CBOT_MINI"),
        ("MNQ", "MNQ", "CME_MINI"),
    ]
    for v16, tv_sym, tv_ex in cases:
        provider, client = _make_provider()
        client.responses = [_make_tv_df(rows=2)]
        asyncio.run(provider.get_bars(v16, "1hour", 50))
        assert client.calls[0]["symbol"] == tv_sym
        assert client.calls[0]["exchange"] == tv_ex
        assert client.calls[0]["interval"] == FAKE_INTERVALS["1hour"]
    _ok("symbol-mapping forwards correct (TV ticker, exchange) per asset")


def test_get_bars_forwards_correct_interval_per_timeframe():
    for tf in ("5min", "1hour", "4hour"):
        provider, client = _make_provider()
        client.responses = [_make_tv_df(rows=2)]
        asyncio.run(provider.get_bars("MES", tf, 100))
        assert client.calls[0]["interval"] == FAKE_INTERVALS[tf]
    _ok("timeframe-mapping forwards correct Interval enum")


# ============================================================
# 4. Failure modes -> empty DataFrame, no exception
# ============================================================

def _assert_empty_v16_df(df: pd.DataFrame) -> None:
    assert isinstance(df, pd.DataFrame)
    assert df.empty
    assert list(df.columns) == ["time", "open", "high", "low", "close", "volume"]


def test_unmapped_symbol_returns_empty_df():
    provider, client = _make_provider()
    df = asyncio.run(provider.get_bars("UNKNOWN", "5min", 100))
    _assert_empty_v16_df(df)
    assert client.calls == [], "no TV call should be made for unmapped symbol"
    _ok("unmapped symbol -> empty df, no TV call")


def test_unsupported_timeframe_returns_empty_df():
    provider, client = _make_provider()
    df = asyncio.run(provider.get_bars("MES", "1min", 100))
    _assert_empty_v16_df(df)
    assert client.calls == [], "no TV call should be made for unsupported tf"
    _ok("unsupported timeframe -> empty df, no TV call")


def test_tv_returns_none_yields_empty_df():
    provider, client = _make_provider()
    client.responses = [None]
    df = asyncio.run(provider.get_bars("MES", "5min", 100))
    _assert_empty_v16_df(df)
    assert len(client.calls) == 1
    _ok("TV returns None -> empty df (failure mode)")


def test_tv_returns_empty_df_yields_empty_df():
    provider, client = _make_provider()
    client.responses = [pd.DataFrame()]
    df = asyncio.run(provider.get_bars("MES", "5min", 100))
    _assert_empty_v16_df(df)
    _ok("TV returns empty df -> empty df")


def test_tv_raises_yields_empty_df():
    provider, client = _make_provider()
    client.raise_on_call = ConnectionError("websocket dropped")
    df = asyncio.run(provider.get_bars("MES", "5min", 100))
    _assert_empty_v16_df(df)
    assert len(client.calls) == 1
    _ok("TV raises (network down / auth fail) -> empty df, no propagation")


# ============================================================
# 5. Client lifecycle: instantiated lazily, reused across calls
# ============================================================

def test_client_instantiated_lazily_and_reused():
    factory_calls: list[dict] = []
    client = FakeTvClient()

    def factory(**kwargs):
        factory_calls.append(kwargs)
        return client

    provider = TVDataFeedProvider(
        username="alice", password="secret",
        client_factory=factory,
        interval_map_builder=lambda: dict(FAKE_INTERVALS),
    )
    # No client built yet
    assert factory_calls == []

    client.responses = [_make_tv_df(rows=2), _make_tv_df(rows=3)]
    asyncio.run(provider.get_bars("MES", "5min", 100))
    asyncio.run(provider.get_bars("MNQ", "1hour", 50))

    # Factory called exactly once -> client cached
    assert len(factory_calls) == 1
    assert factory_calls[0]["username"] == "alice"
    assert factory_calls[0]["password"] == "secret"
    assert factory_calls[0]["token"] is None
    _ok("client built lazily, reused across get_bars calls")


# ============================================================
# 6. Normalization edge cases
# ============================================================

def test_normalize_drops_symbol_column():
    df_in = _make_tv_df(rows=3)
    out = TVDataFeedProvider._normalize(df_in)
    assert "symbol" not in out.columns
    _ok("normalize: 'symbol' column dropped")


def test_normalize_sorts_ascending():
    df_in = _make_tv_df(rows=4)
    df_in = df_in.iloc[::-1]  # descending
    out = TVDataFeedProvider._normalize(df_in)
    assert out["time"].is_monotonic_increasing
    _ok("normalize: descending input -> ascending output")


def test_normalize_handles_unnamed_index():
    """If TV ever returns an unnamed index (older lib versions), the
    fallback rename via 'index' must still produce a 'time' column."""
    df_in = _make_tv_df(rows=2)
    df_in.index.name = None
    out = TVDataFeedProvider._normalize(df_in)
    assert "time" in out.columns
    assert list(out.columns) == ["time", "open", "high", "low", "close", "volume"]
    _ok("normalize: unnamed index -> 'time' column produced")


# ============================================================
# 7. Real interval_map (only if tvdatafeed-enhanced is installed)
# ============================================================

def test_real_interval_map_has_three_entries():
    """Smoke: when the package IS installed, the real builder produces
    exactly the V16 timeframes. Skipped if package missing."""
    try:
        m = _build_interval_map()
    except ImportError:
        print("  skip  tvdatafeed-enhanced not installed; skipped real builder")
        return
    assert set(m.keys()) == {"5min", "1hour", "4hour"}
    _ok("real interval map: 3 V16 timeframes mapped to Interval enum")


# ============================================================
# 8. TV_TOKEN auth path
# ============================================================

def test_tv_token_pre_written_to_cache_file():
    """When TV_TOKEN provided, provider pre-writes it to the cache file
    so the lib's _load_token() picks it up (skips CAPTCHA login)."""
    with tempfile.TemporaryDirectory() as tmp:
        cache = Path(tmp) / "tv_token.json"
        provider, client = _make_provider(
            token="JWT_TOKEN_FROM_BROWSER",
            token_cache_file=cache,
            username=None, password=None,
        )
        client.responses = [_make_tv_df(rows=2)]
        asyncio.run(provider.get_bars("MES", "5min", 100))

        assert cache.exists()
        data = json.loads(cache.read_text())
        assert data == {"token": "JWT_TOKEN_FROM_BROWSER"}
        # Forwarded to the lib via factory
        assert client.token == "JWT_TOKEN_FROM_BROWSER"
        assert client.token_cache_file == str(cache)
    _ok("TV_TOKEN: pre-written to cache file + forwarded to lib")


def test_anonymous_mode_when_no_creds_no_token():
    """No token, no creds -> anonymous; no cache file written."""
    with tempfile.TemporaryDirectory() as tmp:
        cache = Path(tmp) / "tv_token.json"
        provider, client = _make_provider(
            token=None, token_cache_file=cache,
            username=None, password=None,
        )
        client.responses = [_make_tv_df(rows=2)]
        asyncio.run(provider.get_bars("MES", "5min", 100))

        assert not cache.exists(), "no token -> no cache write"
        assert client.token is None
    _ok("anonymous mode: no token, no creds -> no cache, factory called clean")


def test_default_token_cache_path_is_apex_scoped():
    """Default cache path lives under ~/.apex_v18/, not the lib's
    ~/.tv_token.json — avoids clobbering other TV tools."""
    from analysis.tv_data_provider import DEFAULT_TOKEN_CACHE
    s = str(DEFAULT_TOKEN_CACHE)
    assert ".apex_v18" in s
    assert s.endswith("tv_token.json")
    _ok("default cache path scoped under ~/.apex_v18/")


# ============================================================
# Runner
# ============================================================

def main() -> int:
    print("test_tv_data_provider.py")
    test_symbol_map_covers_all_v16_universe()
    test_get_bars_happy_path_returns_normalized_df()
    test_get_bars_forwards_correct_exchange_per_symbol()
    test_get_bars_forwards_correct_interval_per_timeframe()
    test_unmapped_symbol_returns_empty_df()
    test_unsupported_timeframe_returns_empty_df()
    test_tv_returns_none_yields_empty_df()
    test_tv_returns_empty_df_yields_empty_df()
    test_tv_raises_yields_empty_df()
    test_client_instantiated_lazily_and_reused()
    test_normalize_drops_symbol_column()
    test_normalize_sorts_ascending()
    test_normalize_handles_unnamed_index()
    test_real_interval_map_has_three_entries()
    test_tv_token_pre_written_to_cache_file()
    test_anonymous_mode_when_no_creds_no_token()
    test_default_token_cache_path_is_apex_scoped()
    print("ALL TESTS PASSED")
    return 0


if __name__ == "__main__":
    sys.exit(main())
