"""
Cache candle-keyed: assert che build_tech_snapshot non rifaccia H1+H4
fetch per la stessa M5 candle key.
"""
from __future__ import annotations

import asyncio
import sys
from pathlib import Path

import pandas as pd

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from analysis.market_data import FakeMarketDataProvider
from analysis.tech_snapshot import (
    _TECH_CACHE,
    build_tech_snapshot,
    clear_tech_cache,
)


def _make_bars(n: int, last_ts: int = 1735689600) -> pd.DataFrame:
    """OHLCV bars con timestamp incrementali (M5)."""
    times = [last_ts - (n - 1 - i) * 300 for i in range(n)]
    return pd.DataFrame({
        "time":   times,
        "open":   [5800.0 + i * 0.25 for i in range(n)],
        "high":   [5805.0 + i * 0.25 for i in range(n)],
        "low":    [5795.0 + i * 0.25 for i in range(n)],
        "close":  [5802.0 + i * 0.25 for i in range(n)],
        "volume": [1000 + i for i in range(n)],
    })


def test_cache_hit_skips_h1_h4_fetch():
    clear_tech_cache()
    bars = {
        ("MES", "5min"):  _make_bars(200),
        ("MES", "1hour"): _make_bars(100),
        ("MES", "4hour"): _make_bars(50),
    }
    provider = FakeMarketDataProvider(bars)
    asyncio.run(build_tech_snapshot(
        symbol="MES", provider=provider,
        tick_size=0.25, tick_value=1.25,
    ))
    first_call_count = len(provider.calls)
    assert first_call_count == 3, \
        f"first call: 3 fetches expected, got {first_call_count}"

    asyncio.run(build_tech_snapshot(
        symbol="MES", provider=provider,
        tick_size=0.25, tick_value=1.25,
    ))
    second_call_delta = len(provider.calls) - first_call_count
    assert second_call_delta == 1, \
        f"cache hit must fetch only M5 (1 call), got {second_call_delta}"
    print("OK: second call hits cache, only M5 probe re-fetched")


def test_cache_miss_when_candle_time_changes():
    clear_tech_cache()
    bars_t0 = {
        ("MES", "5min"):  _make_bars(200, last_ts=1735689600),
        ("MES", "1hour"): _make_bars(100),
        ("MES", "4hour"): _make_bars(50),
    }
    provider1 = FakeMarketDataProvider(bars_t0)
    asyncio.run(build_tech_snapshot(
        symbol="MES", provider=provider1,
        tick_size=0.25, tick_value=1.25,
    ))
    bars_t1 = {
        ("MES", "5min"):  _make_bars(200, last_ts=1735689900),
        ("MES", "1hour"): _make_bars(100),
        ("MES", "4hour"): _make_bars(50),
    }
    provider2 = FakeMarketDataProvider(bars_t1)
    asyncio.run(build_tech_snapshot(
        symbol="MES", provider=provider2,
        tick_size=0.25, tick_value=1.25,
    ))
    assert len(provider2.calls) == 3, \
        f"new candle → full re-fetch (3 calls), got {len(provider2.calls)}"
    print("OK: new candle_time forces full re-fetch")


def test_lru_evicts_oldest():
    clear_tech_cache()
    from analysis import tech_snapshot as ts_mod
    saved_max = ts_mod._TECH_CACHE_MAX
    ts_mod._TECH_CACHE_MAX = 3
    try:
        for i in range(5):
            bars = {
                ("MES", "5min"):  _make_bars(200, last_ts=1735689600 + i * 300),
                ("MES", "1hour"): _make_bars(100),
                ("MES", "4hour"): _make_bars(50),
            }
            asyncio.run(build_tech_snapshot(
                symbol="MES", provider=FakeMarketDataProvider(bars),
                tick_size=0.25, tick_value=1.25,
            ))
        assert len(_TECH_CACHE) == 3, \
            f"LRU cap=3, got {len(_TECH_CACHE)}"
        keys = list(_TECH_CACHE.keys())
        assert keys[0][1] == 1735689600 + 2 * 300, \
            f"oldest evicted, expected key starts at +600s; got {keys[0]}"
    finally:
        ts_mod._TECH_CACHE_MAX = saved_max
    print("OK: LRU evicts oldest entries")


if __name__ == "__main__":
    test_cache_hit_skips_h1_h4_fetch()
    test_cache_miss_when_candle_time_changes()
    test_lru_evicts_oldest()
    print("ALL 3 TESTS PASSED")
