"""Beam Search Population component — a fixed-width beam over a keep-all store, with diversity-aware
pruning and search-tree depth tracking.

One file per component (see scaffold.py). Faithful port of the admission side of SkyDiscover's
``BeamSearchDatabase``: ``add`` / ``_update_beam`` / ``_prune_beam`` / ``_diverse_selection`` /
``_get_program_score``. The shared fitness notion (``score``) and the diversity metric
(``_code_distance``) live here and are imported by the SelectionPolicy — mirroring how the OpenEvolve
scaffold keeps ``_fitness`` in its population module.
"""
from __future__ import annotations

import math

from ...components.population import Population
from ...records import Genome


def _code_distance(a: str, b: str, n: int = 3) -> float:
    """Normalised character n-gram Jaccard *distance* in ``[0, 1]`` (1 == completely different).

    SkyDiscover's ``_solution_distance``: a cheap, embedding-free diversity signal between programs.
    """
    if not a or not b:
        return 1.0
    ga = {a[i : i + n] for i in range(len(a) - n + 1)}
    gb = {b[i : i + n] for i in range(len(b) - n + 1)}
    if not ga and not gb:
        return 0.0
    union = len(ga | gb)
    return 1.0 - (len(ga & gb) / union if union else 0.0)


class BeamPopulation(Population):
    """``BeamSearchDatabase`` store: a keep-all dict plus a fixed-width ``beam`` subset of ids.

    Every program is retained (so inspirations can be drawn from the global top), and a beam of at
    most ``beam_width`` ids is maintained: on each add the program joins the beam, and an over-full
    beam is pruned to the best members — by fitness, or by a greedy fitness+diversity objective when
    ``diversity_weight > 0`` (``_prune_beam`` / ``_diverse_selection``). Search-tree ``depth`` (hops
    from the seed along parent links) is tracked and can be penalised in the fitness used for pruning
    and selection via ``depth_penalty`` (``_get_program_score``).
    """

    def __init__(self, beam_width: int = 5, diversity_weight: float = 0.3, depth_penalty: float = 0.0):
        self.beam_width = max(1, int(beam_width))
        self.diversity_weight = diversity_weight
        self.depth_penalty = depth_penalty
        self._members: dict[str, Genome] = {}
        self.beam: list[str] = []
        self.depth: dict[str, int] = {}

    # ---- shared fitness  (``_get_program_score``) ---------------------------
    def score(self, genome: Genome) -> float:
        """Fitness (``combined_score``) with an optional exponential depth penalty."""
        base = genome.fitness
        if base == float("-inf"):
            base = 0.0
        if self.depth_penalty > 0:
            base *= math.exp(-self.depth_penalty * self.depth.get(genome.id, 0))
        return base

    # ---- admission + beam maintenance  (``add`` / ``_update_beam``) ---------
    def add(self, genome: Genome) -> bool:
        # SkyDiscover's default controller validity-gates each child and NEVER calls database.add for
        # a failed generation (default_discovery_controller._run_iteration → _process_iteration_result
        # returns early on result.error). Mirror that: drop an eval-invalid child entirely so it never
        # pollutes the keep-all store, the beam, the inspirations, or top-N ranking. Keyed on `is
        # False` (not falsy) so the seed — whose metadata never carries 'valid' — is exempt, matching
        # the ungated seed/from-scratch path.
        if genome.metadata.get("valid") is False:
            return False
        self._members[genome.id] = genome
        parent = genome.parent_id
        self.depth[genome.id] = self.depth[parent] + 1 if parent in self.depth else 0
        genome.metadata["depth"] = self.depth[genome.id]
        if genome.id not in self.beam:
            self.beam.append(genome.id)
        if len(self.beam) > self.beam_width:
            self._prune()
        # SkyDiscover stores EVERY successfully-evaluated child (keep-all) and tracks best / recent
        # over all of them, independent of beam membership — only the *parent frontier* is the bounded
        # beam (get_best_program reads the keep-all store, not the beam). So report admission for every
        # stored (valid) child, not just beam survivors: this lets the base loop record it in
        # state.recent (Previous-Attempts/Focus sourcing then matches get_statistics' recent_programs,
        # not just beam survivors) and promote it to best. Best-tracking stays correct because the base
        # loop holds state.best by reference and updates it on child.fitness > prev_best, independent of
        # beam membership (mirroring get_best_program reading the keep-all store).
        return True

    def _prune(self) -> None:
        items = [self._members[i] for i in self.beam if i in self._members]
        if self.diversity_weight > 0:
            self.beam = self._diverse(items, self.beam_width)
        else:
            items.sort(key=self.score, reverse=True)
            self.beam = [g.id for g in items[: self.beam_width]]

    def _diverse(self, candidates: list[Genome], k: int) -> list[str]:
        """Greedy max ``(1-w)·fitness + w·min_distance_to_selected`` selection (``_diverse_selection``)."""
        if len(candidates) <= k:
            return [g.id for g in candidates]
        remaining = sorted(candidates, key=self.score, reverse=True)
        selected = [remaining.pop(0)]                        # always keep the best
        while len(selected) < k and remaining:
            best_i, best_val = 0, float("-inf")
            for i, g in enumerate(remaining):
                diversity = min(_code_distance(g.content, s.content) for s in selected)
                combined = (1 - self.diversity_weight) * self.score(g) + self.diversity_weight * diversity
                if combined > best_val:
                    best_val, best_i = combined, i
            selected.append(remaining.pop(best_i))
        return [g.id for g in selected]

    # ---- queries ------------------------------------------------------------
    def beam_members(self) -> list[Genome]:
        return sorted((self._members[i] for i in self.beam if i in self._members),
                      key=self.score, reverse=True)

    def query(self, spec: dict | None = None) -> list[Genome]:
        spec = spec or {}
        members = self.beam_members() if spec.get("beam") else \
            sorted(self._members.values(), key=self.score, reverse=True)
        top = spec.get("top")
        return members[:top] if top else members

    def all(self) -> list[Genome]:
        return list(self._members.values())
