"""Beam Search SelectionPolicy component — pick the parent from the beam by one of four strategies.

The sampling side of SkyDiscover's ``BeamSearchDatabase`` (``sample`` + ``_select_best`` /
``_select_stochastic`` / ``_select_round_robin`` / ``_select_diversity_weighted``). Reads the beam
maintained by :class:`~galapagos.scaffolds.beam_search.population.BeamPopulation` and shares its
fitness (``population.score``) and diversity metric (``_code_distance``).
"""
from __future__ import annotations

import math

from ...components.selection import SelectionPolicy
from ...records import Genome, RunState, Selection
from .population import _code_distance


class BeamSelectionPolicy(SelectionPolicy):
    """Draw the parent from the current beam by ``strategy``; inspirations are the global top
    programs, excluding the parent.

    Strategies: ``best`` (highest score), ``round_robin`` (cycle the beam in score order),
    ``stochastic`` (softmax over score at ``temperature``), ``diversity_weighted`` (softmax over
    ``(1-w)·score + w·avg_distance_to_recently_expanded``). The recently-expanded parents are the
    policy's adaptive state; all randomness flows through ``self.rng``.
    """

    def __init__(self, seed: int = 0, strategy: str = "diversity_weighted", temperature: float = 1.0,
                 diversity_weight: float = 0.3, num_inspirations: int = 4):
        super().__init__(seed)
        self.strategy = strategy
        self.temperature = temperature
        self.diversity_weight = diversity_weight
        self.num_inspirations = num_inspirations
        self.expanded: list[Genome] = []   # recently expanded parents (diversity_weighted lookback)
        self._rr = 0

    def select(self, population, state: RunState | None = None) -> Selection:
        beam = population.beam_members()
        if not beam:
            best = population.best()
            if best is None:
                raise RuntimeError("cannot select from an empty population")
            beam = [best]
        parent = self._pick(beam, population)
        self.expanded.append(parent)
        self.expanded = self.expanded[-50:]
        # inspirations = global top-N by RAW combined_score. SkyDiscover's get_top_programs (context
        # programs) ranks by the base get_score with NO depth penalty — the depth penalty applies only
        # to beam pruning and parent selection (population.score). Rank by genome.fitness here so a
        # non-zero beam_depth_penalty never leaks into context selection. (Identical to ranking by
        # population.score when beam_depth_penalty == 0, the default.)
        top = sorted(population.all(), key=lambda g: g.fitness, reverse=True)[: self.num_inspirations + 1]
        inspirations = [g for g in top if g.id != parent.id][: self.num_inspirations]
        return Selection(parent=parent, inspirations=inspirations, pool=population.all())

    def _pick(self, beam: list[Genome], population) -> Genome:
        if len(beam) == 1 or self.strategy == "best":
            return max(beam, key=population.score)
        if self.strategy == "round_robin":
            ordered = sorted(beam, key=population.score, reverse=True)
            g = ordered[self._rr % len(ordered)]
            self._rr += 1
            return g
        if self.strategy == "diversity_weighted":
            return self._diversity_weighted(beam, population)
        return self._softmax_pick(beam, [population.score(g) for g in beam])  # "stochastic" / unknown

    def _diversity_weighted(self, beam: list[Genome], population) -> Genome:
        recent = self.expanded[-10:]
        if not recent:
            return self._softmax_pick(beam, [population.score(g) for g in beam])
        combined = []
        for g in beam:
            diversity = sum(_code_distance(g.content, r.content) for r in recent) / len(recent)
            combined.append((1 - self.diversity_weight) * population.score(g) + self.diversity_weight * diversity)
        return self._softmax_pick(beam, combined)

    def _softmax_pick(self, beam: list[Genome], scores: list[float]) -> Genome:
        if self.temperature <= 0:
            return beam[scores.index(max(scores))]
        m = max(scores)
        weights = [math.exp((s - m) / self.temperature) for s in scores]
        total = sum(weights) or 1.0
        r = self.rng.random() * total
        upto = 0.0
        for g, w in zip(beam, weights):
            upto += w
            if upto >= r:
                return g
        return beam[-1]
