"""Best-of-N SelectionPolicy component — reuse one parent for N attempts, then switch to global best.

Faithful port of SkyDiscover's ``BestOfNDatabase``: the parent-reuse counter is advanced **only by a
validly-scored child** (SkyDiscover increments it in ``add()``, which runs only for non-error
results — parse/eval failures never get added and so never spend the budget). A failed attempt is a
free retry on the same parent. The attempt-counted sibling (``best_of_n_attempts``) instead spends
one unit of the budget per *selection*, valid or not — see :class:`BestOfNAttemptsPolicy`.
"""
from __future__ import annotations

from ...components.selection import SelectionPolicy
from ...records import Genome, RunState, Selection


def _safe_score(g: Genome) -> float:
    """SkyDiscover ``BestOfNDatabase.sample``'s ``safe_score``: ``combined_score`` if numeric, else
    ``-inf``. The parent argmax uses THIS (combined_score-only), whereas the inspiration top-pool ranks
    by ``Genome.fitness`` (combined_score-else-mean = SkyDiscover ``get_score``) — mirroring the two
    distinct SkyDiscover ranking functions exactly."""
    s = g.scores.get("combined_score")
    return float(s) if isinstance(s, (int, float)) and not isinstance(s, bool) else float("-inf")


class BestOfNPolicy(SelectionPolicy):
    """Port of ``BestOfNDatabase.sample``: reuse the same parent until ``n`` **valid** children have
    been produced from it, then commit to the current global best and repeat.

    The parent-reuse counter (``current_parent_id`` / ``uses``) is the policy's adaptive state, and
    it advances in :meth:`observe` (one per valid child), not in :meth:`select`. Inspirations are a
    random sample from the top pool, excluding the parent, drawn fresh each step regardless of the
    reuse cycle. All randomness flows through ``self.rng``.
    """

    def __init__(self, seed: int = 0, n: int = 5, num_inspirations: int = 4):
        super().__init__(seed)
        self.n = max(1, int(n))
        self.num_inspirations = num_inspirations
        self.current_parent_id: str | None = None
        self.uses = 0

    def _choose_parent(self, members: list[Genome]) -> Genome:
        """Reuse the current parent, or commit to the global best once the budget is spent / it is
        gone. Folds SkyDiscover's ``current_parent_id not in programs`` guard into ``current is None``."""
        current = next((g for g in members if g.id == self.current_parent_id), None)
        if current is None or self.uses >= self.n:
            parent = max(members, key=_safe_score)   # commit to the current global best (safe_score)
            self.current_parent_id = parent.id
            self.uses = 0
        else:
            parent = current                                 # reuse for another attempt
        return parent

    def _inspirations(self, population, parent: Genome) -> list[Genome]:
        """Random sample from the top pool, excluding the parent, re-sampled each step."""
        pool = population.query({"top": max(2 * self.num_inspirations, 10)})
        candidates = [g for g in pool if g.id != parent.id]
        k = min(self.num_inspirations, len(candidates))
        return self.rng.sample(candidates, k) if k else []

    def select(self, population, state: RunState | None = None) -> Selection:
        members = population.all()
        if not members:
            raise RuntimeError("cannot select from an empty population")
        parent = self._choose_parent(members)
        return Selection(parent=parent, inspirations=self._inspirations(population, parent), pool=members)

    def observe(self, genome: Genome, state: RunState | None = None) -> None:
        # SkyDiscover advances the reuse counter in add(), which runs only for a validly-scored child.
        # Count one valid child against the current parent's budget; failures are free retries.
        if genome.metadata.get("valid", True):
            self.uses += 1
