"""AdaEvolve Memory component — Level-3 paradigm (solution-tactics) state.

One file per component (see scaffold.py). Faithful port of SkyDiscover's ``ParadigmTracker``
(search/adaevolve/paradigm/tracker.py) behind the :class:`~galapagos.components.memory.Memory`
interface. Per the reference code (not the paper), stagnation is the *windowed improvement rate*:
the binary window (size 10) of "did this admitted child set a new global best" is full, no active
non-exhausted paradigm exists, and the rate is below 0.12 — not "G ≤ τ_M for all islands".

Memory interface mapping:

* ``read()``                      → the active paradigm's "## BREAKTHROUGH IDEA" injection block
  (or ``""``), consumed by the PromptBuilder.
* ``read({"spec": "tried"})``     → the formatted previously-tried ideas with SUCCESS/FAILED
  outcomes, fed back to the paradigm generator.
* ``write(_, kind="improvement", improved=..., best_score=...)`` → one improvement-window tick
  (the open mirror of ``record_improvement``).
* ``write(_, kind="paradigms", paradigms=[...], best_score=...)`` → install a fresh paradigm set
  (``set_paradigms``), archiving outgoing used paradigms with their outcome.

The rotation/usage methods (:meth:`use_paradigm`, :meth:`has_active_paradigm`,
:meth:`is_stagnating`) are concrete extras called by the scaffold.
"""
from __future__ import annotations

from ...components.memory import Memory

_PARADIGM_BLOCK = """## BREAKTHROUGH IDEA - IMPLEMENT THIS

The search has stagnated globally. You MUST implement this breakthrough idea:

**IDEA:** {idea}

**HOW TO IMPLEMENT:**
{description}

**TARGET METRIC:** {what_to_optimize}

**CAUTIONS:** {cautions}

**APPROACH TYPE:** {approach_type}

**CRITICAL:**
- You MUST implement the breakthrough idea
- Ensure the paradigm is actually used in your implementation (not just mentioned in comments)
- Correctness is essential - your implementation must be correct and functional
- Verify output format matches evaluator requirements
- Make purposeful changes that implement the idea
- Test your implementation logic carefully"""


class AdaEvolveParadigmMemory(Memory):
    """``ParadigmTracker``: bounded binary improvement window, paradigm rotation with bounded
    uses, and an outcome-annotated history of tried ideas."""

    def __init__(self, window_size: int = 10, improvement_threshold: float = 0.12,
                 max_uses: int = 2, max_tried: int = 10, num_to_generate: int = 3):
        self.window_size = int(window_size)
        self.improvement_threshold = float(improvement_threshold)
        self.max_uses = int(max_uses)
        self.max_tried = int(max_tried)
        self.num_to_generate = int(num_to_generate)
        self.improvement_history: list[float] = []      # bounded binary window
        self.active_paradigms: list[dict] = []
        self.usage_counts: list[int] = []
        self.current_index = 0
        self.tried_paradigms: list[str] = []            # bounded outcome history
        self.best_score_at_generation = 0.0
        self.best_score_during = 0.0

    # ---- Memory interface --------------------------------------------------------------------
    def read(self, spec: dict | None = None) -> str:
        if spec and spec.get("spec") == "tried":
            return "\n".join(self.tried_paradigms)
        paradigm = self.current_paradigm()
        if paradigm is None:
            return ""
        return _PARADIGM_BLOCK.format(
            idea=paradigm.get("idea", ""),
            description=paradigm.get("description", ""),
            what_to_optimize=paradigm.get("what_to_optimize", ""),
            cautions=paradigm.get("cautions", ""),
            approach_type=paradigm.get("approach_type", ""),
        )

    def write(self, knowledge: str, **meta) -> None:
        kind = meta.get("kind")
        if kind == "improvement":
            self.record_improvement(bool(meta.get("improved")), float(meta.get("best_score", 0.0)))
        elif kind == "paradigms":
            self.set_paradigms(list(meta.get("paradigms") or []),
                               float(meta.get("best_score", 0.0)))

    # ---- ParadigmTracker port ------------------------------------------------------------------
    def record_improvement(self, improved: bool, current_best_score: float) -> None:
        """One tick of the binary window — called once per admitted non-migrant child."""
        self.improvement_history.append(1.0 if improved else 0.0)
        if len(self.improvement_history) > self.window_size:
            self.improvement_history.pop(0)
        if self.active_paradigms:  # track the best score reached during the paradigms' tenure
            self.best_score_during = max(self.best_score_during, current_best_score)

    def get_improvement_rate(self) -> float:
        if not self.improvement_history:
            return 0.0
        return sum(self.improvement_history) / len(self.improvement_history)

    def is_stagnating(self) -> bool:
        """Trigger condition: full window AND no active non-exhausted paradigm AND rate below the
        threshold (the reference-code divergence from the paper's G-based trigger)."""
        return (len(self.improvement_history) >= self.window_size
                and not self.has_active_paradigm()
                and self.get_improvement_rate() < self.improvement_threshold)

    def has_active_paradigm(self) -> bool:
        """True iff a non-exhausted paradigm exists (auto-rotates past exhausted ones)."""
        if not self.active_paradigms:
            return False
        if self.usage_counts[self.current_index] >= self.max_uses:
            self._rotate()
        return self.usage_counts[self.current_index] < self.max_uses

    def current_paradigm(self) -> dict | None:
        if not self.has_active_paradigm():
            return None
        return self.active_paradigms[self.current_index]

    def use_paradigm(self) -> None:
        """Increment the current paradigm's usage, then rotate round-robin to the next
        non-exhausted paradigm."""
        if not self.has_active_paradigm():
            return
        self.usage_counts[self.current_index] += 1
        self._rotate()

    def _rotate(self) -> None:
        """Round-robin past exhausted paradigms (may stay put if only the current has uses left)."""
        n = len(self.active_paradigms)
        for offset in range(1, n + 1):
            j = (self.current_index + offset) % n
            if self.usage_counts[j] < self.max_uses:
                self.current_index = j
                return

    def set_paradigms(self, paradigms: list[dict], current_best_score: float) -> None:
        """Install a fresh set; archive outgoing *used* paradigms with outcome SUCCESS iff the
        best score rose > 0.001 during their tenure (bounded history of ``max_tried``)."""
        improvement = self.best_score_during - self.best_score_at_generation
        for i, paradigm in enumerate(self.active_paradigms):
            if self.usage_counts[i] > 0:
                outcome = "SUCCESS" if improvement > 0.001 else "FAILED"
                self.tried_paradigms.append(
                    f"{outcome}: {paradigm.get('approach_type', '?')} - {paradigm.get('idea', '?')}"
                    f" (improvement: {improvement:+.4f})")
        self.tried_paradigms = self.tried_paradigms[-self.max_tried:]
        self.active_paradigms = list(paradigms)
        self.usage_counts = [0] * len(self.active_paradigms)
        self.current_index = 0
        self.best_score_at_generation = current_best_score
        self.best_score_during = current_best_score

    def get_previously_tried_ideas(self) -> list[str]:
        return list(self.tried_paradigms)
