"""AdaEvolve Population component — per-island quality-diversity archives + ring migration +
dynamic island spawning.

One file per component (see scaffold.py). Faithful port of SkyDiscover's ``UnifiedArchive`` +
``CodeDiversity`` plus the island-management half of ``AdaEvolveDatabase`` (the adaptive state —
G signal, UCB bandit — lives in the SelectionPolicy; this store owns topology and admission only).

Admission mirrors the upstream eval-failure gate: children whose scores mark them invalid
(``validity`` exactly 0 or -1, or — when no ``validity`` key exists — the two
``SubprocessEvaluator`` hard-failure shapes: ``combined_score == 0.0`` with ``text_feedback``
starting ``"evaluator error:"`` or ``"timeout after"``) are rejected outright and update
*nothing*, exactly like upstream children that error out and are retried instead of added. A
legitimate 0.0-score child with ordinary diagnostic feedback IS admitted. ``add`` always stamps
``metadata["admitted"]`` (and ``metadata["eval_failed"]`` for gated children) so the policy's
``observe`` and the scaffold's ``after_step`` can read the verdict — the loop discards the return.

Determinism: this store draws no randomness of its own; every random choice (parent sampling,
preset tie-breaks) is driven by the policy's ``rng`` passed in explicitly.
"""
from __future__ import annotations

import functools
import re

from ...components.population import Population
from ...records import Genome

# ---------------------------------------------------------------------------------------------
# CodeDiversity  (skydiscover/search/adaevolve/archive/diversity.py)
# ---------------------------------------------------------------------------------------------

_TOKEN_RE = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*|[0-9]+\.?[0-9]*")
_IMPORT_RE = re.compile(r"^\s*(?:from\s+([\w.]+)|import\s+([\w.]+))", re.MULTILINE)
_DEF_RE = re.compile(r"\bdef\s+(\w+)")
_CLASS_RE = re.compile(r"\bclass\s+(\w+)")
_CALL_RE = re.compile(r"\b(\w+)\.(\w+)\s*\(")
_FLOW_KEYWORDS = ("for", "while", "try", "with", "yield", "async", "lambda")


@functools.lru_cache(maxsize=4096)
def _features(code: str) -> tuple[frozenset, frozenset, int]:
    """Token set (identifiers/numbers, len >= 2), structural features, and length — the three
    ingredients of ``CodeDiversity.distance``."""
    tokens = frozenset(t for t in _TOKEN_RE.findall(code) if len(t) >= 2)
    struct: set[str] = set()
    for m in _IMPORT_RE.finditer(code):
        struct.add(f"import:{m.group(1) or m.group(2)}")
    for m in _DEF_RE.finditer(code):
        struct.add(f"func:{m.group(1)}")
    for m in _CLASS_RE.finditer(code):
        struct.add(f"class:{m.group(1)}")
    for m in _CALL_RE.finditer(code):
        struct.add(f"call:{m.group(1)}.{m.group(2)}")
    for kw in _FLOW_KEYWORDS:  # control-flow pattern flags
        if re.search(rf"\b{kw}\b", code):
            struct.add(f"flow:{kw}")
    return tokens, frozenset(struct), len(code)


def _jaccard_distance(a: frozenset, b: frozenset) -> float:
    union = len(a | b)
    return 1.0 - len(a & b) / union if union else 0.0


def code_distance(code_a: str, code_b: str) -> float:
    """``CodeDiversity.distance``: 0.5*token-Jaccard + 0.3*structural-Jaccard + 0.2*length-diff."""
    tok_a, str_a, len_a = _features(code_a)
    tok_b, str_b, len_b = _features(code_b)
    length_diff = abs(len_a - len_b) / max(len_a, len_b, 1)
    return (0.5 * _jaccard_distance(tok_a, tok_b)
            + 0.3 * _jaccard_distance(str_a, str_b)
            + 0.2 * length_diff)


# ---------------------------------------------------------------------------------------------
# Heterogeneous island presets  (AdaEvolveDatabase.ISLAND_CONFIG_PRESETS — verbatim values)
# ---------------------------------------------------------------------------------------------

ISLAND_CONFIG_PRESETS: list[dict] = [
    {"name": "balanced", "pareto_weight": 0.4, "fitness_weight": 0.3, "novelty_weight": 0.3,
     "elite_ratio": 0.2},
    {"name": "quality", "pareto_weight": 0.2, "fitness_weight": 0.6, "novelty_weight": 0.2,
     "elite_ratio": 0.3},
    {"name": "diversity", "pareto_weight": 0.3, "fitness_weight": 0.2, "novelty_weight": 0.5,
     "elite_ratio": 0.1},
    {"name": "pareto", "pareto_weight": 0.6, "fitness_weight": 0.2, "novelty_weight": 0.2,
     "elite_ratio": 0.2},
    {"name": "exploration", "pareto_weight": 0.2, "fitness_weight": 0.3, "novelty_weight": 0.5,
     "elite_ratio": 0.05},
]


# ---------------------------------------------------------------------------------------------
# UnifiedArchive  (skydiscover/search/adaevolve/archive/unified_archive.py) — one per island
# ---------------------------------------------------------------------------------------------


class _IslandArchive:
    """One island's quality-diversity archive (the scalar-mode ``UnifiedArchive``).

    Elite score = ``(fitness_weight + pareto_weight) * fitness_percentile + novelty_weight *
    novelty_percentile`` (Pareto weight folded into fitness in scalar mode — the upstream default
    1.0/0.4/0.0 makes it effectively pure fitness ranking). Novelty = mean code-distance to the
    ``k_neighbors`` nearest neighbours. At capacity, deterministic crowding: evict the member most
    similar to the newcomer among non-protected members iff the newcomer's elite score is
    *strictly* higher; protected = top ``max(1, n*elite_ratio)`` by elite score plus the
    fitness-best. The comparison is the reference's exact two-sided rule: the newcomer is scored
    with ``_compute_elite_score_for_new`` (against the EXISTING members only — strictly-fitter /
    strictly-lower-novelty counts), the victim with the cached archive-only ``_compute_elite_score``
    over the n members (newcomer excluded). A child marginally fitter than its most-similar member
    with nothing in between therefore TIES and is rejected, exactly like upstream.
    """

    def __init__(self, capacity: int = 20, k_neighbors: int = 5, elite_ratio: float = 0.2,
                 fitness_weight: float = 1.0, novelty_weight: float = 0.0,
                 pareto_weight: float = 0.4, preset_name: str = "balanced"):
        self.capacity = max(1, int(capacity))
        self.k_neighbors = int(k_neighbors)
        self.elite_ratio = float(elite_ratio)
        self.fitness_weight = float(fitness_weight)
        self.novelty_weight = float(novelty_weight)
        self.pareto_weight = float(pareto_weight)
        self.preset_name = preset_name
        self.members: list[Genome] = []

    # ---- novelty + elite score ---------------------------------------------------------------
    def _novelty(self, genome: Genome, pool: list[Genome]) -> float:
        """Mean distance to the k nearest neighbours (``_compute_novelty``); 1.0 if alone."""
        dists = sorted(code_distance(genome.content, o.content) for o in pool if o.id != genome.id)
        if not dists:
            return 1.0
        k = min(self.k_neighbors, len(dists))
        return sum(dists[:k]) / k

    def _member_elite_scores(self) -> dict[str, float]:
        """``_compute_elite_score`` for every CURRENT member — the reference's cached archive-only
        scores (newcomer excluded): fitness percentile = ``1 - rank/max(n-1, 1)``; novelty
        percentile = fraction of the archive's novelties strictly lower (denominator n)."""
        n = len(self.members)
        denom = max(n - 1, 1)
        by_fitness = sorted(self.members, key=lambda g: g.fitness, reverse=True)
        fitness_pct = {g.id: 1.0 - i / denom for i, g in enumerate(by_fitness)}
        novelty = {g.id: self._novelty(g, self.members) for g in self.members}
        all_novelties = list(novelty.values())
        novelty_pct = {
            g.id: sum(1 for v in all_novelties if v < novelty[g.id]) / max(n, 1)
            for g in self.members}
        w_fit = self.fitness_weight + self.pareto_weight  # scalar mode folds pareto into fitness
        return {g.id: w_fit * fitness_pct[g.id] + self.novelty_weight * novelty_pct[g.id]
                for g in self.members}

    def _elite_score_for_new(self, genome: Genome) -> float:
        """``_compute_elite_score_for_new``: the newcomer scored against the EXISTING members only.
        fitness percentile = ``1 - (#members strictly fitter)/max(n-1, 1)`` clamped to [0, 1];
        novelty percentile = ``(#members with strictly lower novelty)/n``."""
        n = len(self.members)
        w_fit = self.fitness_weight + self.pareto_weight
        if n == 0:
            return w_fit + self.novelty_weight  # max score (all components = 1.0)
        better = sum(1 for g in self.members if g.fitness > genome.fitness)
        fitness_pct = max(0.0, min(1.0, 1.0 - better / max(n - 1, 1)))
        novelty = self._novelty(genome, self.members)
        existing = [self._novelty(g, self.members) for g in self.members]
        novelty_pct = max(0.0, min(1.0, sum(1 for v in existing if v < novelty) / n))
        return w_fit * fitness_pct + self.novelty_weight * novelty_pct

    # ---- admission -----------------------------------------------------------------------------
    def add(self, genome: Genome) -> tuple[bool, Genome | None]:
        """``UnifiedArchive.add``: insert under capacity; at capacity replace the most-similar
        non-protected member iff the newcomer's ``_compute_elite_score_for_new`` strictly beats the
        victim's archive-only elite score, else reject. Returns ``(admitted, evicted_or_None)``."""
        if len(self.members) < self.capacity:
            self.members.append(genome)
            return True, None
        member_scores = self._member_elite_scores()
        victim = self._eviction_candidate(genome, member_scores)
        if victim is not None and self._elite_score_for_new(genome) > member_scores[victim.id]:
            self.members.remove(victim)
            self.members.append(genome)
            return True, victim
        return False, None

    def _eviction_candidate(self, genome: Genome, scores: dict[str, float]) -> Genome | None:
        """``_find_eviction_candidate``: the member *most similar to the newcomer* (deterministic
        crowding) among the non-protected set."""
        n = len(self.members)
        by_elite = sorted(self.members, key=lambda g: scores[g.id], reverse=True)
        protected = {g.id for g in by_elite[: max(1, int(n * self.elite_ratio))]}
        protected.add(max(self.members, key=lambda g: g.fitness).id)  # fitness-best always survives
        candidates = [g for g in self.members if g.id not in protected]
        if not candidates:
            return None
        return min(candidates, key=lambda g: code_distance(genome.content, g.content))

    # ---- sampling views (driven by the policy's rng) -------------------------------------------
    def get_top_programs(self, n: int | None = None) -> list[Genome]:
        """``get_top_programs``: sorted by fitness; default n = top ~20%, capped at 10, floor 1."""
        ordered = sorted(self.members, key=lambda g: g.fitness, reverse=True)
        if n is None:
            n = max(1, min(10, len(ordered) // 5))
        return ordered[:n]

    def sample_exploitation(self, rng) -> Genome:
        return rng.choice(self.get_top_programs())  # uniform among the island's top programs

    def sample_top_quartile(self, rng) -> Genome:
        """``_sample_top``: uniform among the top ``max(1, n//4)`` by fitness — the
        exploitation-MODE parent pool (distinct from ``get_top_programs``, which only feeds the
        balanced coin's exploit branch, as in the reference)."""
        ordered = sorted(self.members, key=lambda g: g.fitness, reverse=True)
        return rng.choice(ordered[: max(1, len(ordered) // 4)])

    def sample_exploration(self, rng) -> Genome:
        """Roulette proportional to novelty (``sample_parent("exploration")``)."""
        weights = [max(self._novelty(g, self.members), 0.001) for g in self.members]
        total = sum(weights)
        r = rng.random() * total
        upto = 0.0
        for g, w in zip(self.members, weights):
            upto += w
            if upto >= r:
                return g
        return self.members[-1]

    def sample_other_context_programs(self, parent: Genome, n: int) -> list[Genome]:
        """``sample_other_context_programs``: from the top half of the archive (at least 2n) by
        fitness, return the n programs *most distant from the parent*."""
        pool_size = max(2 * n, len(self.members) // 2)
        ordered = sorted(self.members, key=lambda g: g.fitness, reverse=True)
        candidates = [g for g in ordered[:pool_size] if g.id != parent.id]
        candidates.sort(key=lambda g: code_distance(parent.content, g.content), reverse=True)
        return candidates[:n]


# ---------------------------------------------------------------------------------------------
# Archipelago  (the island-management half of AdaEvolveDatabase)
# ---------------------------------------------------------------------------------------------


class AdaEvolveArchipelago(Population):
    """``num_islands`` quality-diversity :class:`_IslandArchive` islands with ring migration,
    dynamic spawning, and a genealogy map for sibling context.

    Island routing: a genome's island lives in ``metadata["island"]``; since ``Genome.child``
    copies metadata, children inherit the parent's island automatically (the policy re-stamps it
    each select). The very first ``add`` (the task seed) seeds *all* islands with copies — the
    open mirror of upstream's seed-every-island-with-``p0`` bootstrap.
    """

    def __init__(self, num_islands: int = 2, population_size: int = 20, k_neighbors: int = 5,
                 archive_elite_ratio: float = 0.2, fitness_weight: float = 1.0,
                 novelty_weight: float = 0.0, pareto_weight: float = 0.4,
                 migration_count: int = 5):
        self.population_size = int(population_size)
        self.k_neighbors = int(k_neighbors)
        self.migration_count = int(migration_count)
        # initial islands use the *config* weights and count as "balanced" for preset bookkeeping
        self.archives: list[_IslandArchive] = [
            _IslandArchive(capacity=population_size, k_neighbors=k_neighbors,
                           elite_ratio=archive_elite_ratio, fitness_weight=fitness_weight,
                           novelty_weight=novelty_weight, pareto_weight=pareto_weight)
            for _ in range(max(1, int(num_islands)))
        ]
        self._children: dict[str, list[Genome]] = {}  # genealogy: parent id -> admitted children

    @property
    def num_islands(self) -> int:
        return len(self.archives)

    # ---- admission ------------------------------------------------------------------------------
    @staticmethod
    def _is_eval_failure(genome: Genome) -> bool:
        """The open mirror of the upstream eval-failure gate (``validity in (0, -1)`` OR a zero
        ``combined_score`` with an error present): such children are retried upstream, never added.
        Without a ``validity`` key the fallback fires ONLY on the two ``SubprocessEvaluator``
        hard-failure shapes (exec error / timeout) — a legitimate 0.0-score child with ordinary
        diagnostic feedback is admitted."""
        validity = genome.scores.get("validity")
        if validity is not None:
            try:
                return float(validity) in (0.0, -1.0)  # reference: metrics["validity"] in (0, -1)
            except (TypeError, ValueError):
                return True
        # SubprocessEvaluator hard failures emit exactly combined_score=0.0 + a text_feedback
        # starting "evaluator error:" or "timeout after" (components/evaluator.py:69-76).
        feedback = genome.artifacts.get("text_feedback")
        return (genome.fitness == 0.0 and isinstance(feedback, str)
                and (feedback.startswith("evaluator error:")
                     or feedback.startswith("timeout after")))

    def add(self, genome: Genome) -> bool:
        # strip inherited bookkeeping flags (Genome.child copies the parent's metadata wholesale)
        for stale in ("admitted", "eval_failed", "migrated_from", "migrated_to", "seed_copy"):
            genome.metadata.pop(stale, None)

        # bootstrap: the very first add seeds EVERY island with a copy of the seed program
        if not any(a.members for a in self.archives):
            genome.metadata.update(island=0, admitted=True)
            self.archives[0].members.append(genome)
            for i in range(1, self.num_islands):
                copy = Genome(content=genome.content, scores=dict(genome.scores),
                              metadata={"island": i, "admitted": True, "seed_copy": True},
                              artifacts=dict(genome.artifacts))
                self.archives[i].members.append(copy)
            return True

        # eval-failure gate: rejected children update nothing (no archive slot, no G, no UCB)
        if self._is_eval_failure(genome):
            genome.metadata.update(admitted=False, eval_failed=True)
            return False

        isl = genome.metadata.get("island")
        if not isinstance(isl, int) or not (0 <= isl < self.num_islands):
            isl = 0
        genome.metadata["island"] = isl
        admitted, evicted = self.archives[isl].add(genome)
        genome.metadata["admitted"] = admitted
        if evicted is not None:
            self._forget(evicted)
        if admitted and genome.parent_id:  # genealogy tracked only on successful add
            self._children.setdefault(genome.parent_id, []).append(genome)
        return admitted

    def _forget(self, genome: Genome) -> None:
        """Genealogy cleanup on eviction (mirrors the archive's evict-time bookkeeping)."""
        self._children.pop(genome.id, None)
        if genome.parent_id and genome.parent_id in self._children:
            self._children[genome.parent_id] = [g for g in self._children[genome.parent_id]
                                                if g.id != genome.id]

    # ---- queries ---------------------------------------------------------------------------------
    def query(self, spec: dict | None = None) -> list[Genome]:
        """Sorted-by-fitness views: ``{"island": i}`` → that island; default → all islands;
        ``{"top": n, ...}`` truncates."""
        spec = spec or {}
        if "island" in spec:
            members = list(self.archives[spec["island"] % self.num_islands].members)
        else:
            members = self.all()
        members.sort(key=lambda g: g.fitness, reverse=True)
        top = spec.get("top")
        return members[:top] if top else members

    def all(self) -> list[Genome]:
        return [g for archive in self.archives for g in archive.members]

    def island_top(self, island: int) -> list[Genome]:
        return self.archives[island % self.num_islands].get_top_programs()

    def sample_parent(self, island: int, mode: str, rng) -> Genome:
        """Per-mode parent draw: exploitation → uniform among the top QUARTILE (reference
        ``_sample_top``); exploration → novelty-proportional roulette; balanced → 50/50 coin
        between ``get_top_programs`` (top ~20% capped 10) and the roulette — two distinct
        exploit pools, exactly as in the reference."""
        archive = self.archives[island % self.num_islands]
        if mode == "exploitation":
            return archive.sample_top_quartile(rng)
        if mode == "exploration":
            return archive.sample_exploration(rng)
        if rng.random() < 0.5:  # balanced: coin flip between the two
            return archive.sample_exploitation(rng)
        return archive.sample_exploration(rng)

    def context_candidates(self, island: int, parent: Genome, n: int) -> list[Genome]:
        return self.archives[island % self.num_islands].sample_other_context_programs(parent, n)

    def global_top(self, n: int, exclude: set[str] | None = None) -> list[Genome]:
        """Top fitness across ALL islands (``_sample_global_top``), excluding given ids."""
        exclude = exclude or set()
        members = [g for g in self.all() if g.id not in exclude]
        members.sort(key=lambda g: g.fitness, reverse=True)
        return members[:n]

    def children_of(self, parent_id: str, limit: int = 5, island: int | None = None) -> list[Genome]:
        """``get_children``: the most recent ``limit`` admitted children of a parent. When
        ``island`` is given, only children admitted on THAT island are returned (the reference
        consults the current island's archive genealogy only)."""
        kids = self._children.get(parent_id, [])
        if island is not None:
            kids = [g for g in kids if g.metadata.get("island") == island]
        return list(kids)[-limit:]

    def preset_usage(self) -> dict[str, int]:
        """How many islands use each of the five presets (initial islands count as balanced)."""
        usage = {p["name"]: 0 for p in ISLAND_CONFIG_PRESETS}
        for archive in self.archives:
            if archive.preset_name in usage:
                usage[archive.preset_name] += 1
        return usage

    # ---- topology changes (called by the policy at end of iteration) -----------------------------
    def spawn_island(self, preset: dict) -> int:
        """``_spawn_island``: a fresh archive configured by ``preset``, seeded with copies of the
        top ``min(5, n)`` programs by fitness across all existing islands. Returns the new index."""
        new_idx = self.num_islands
        archive = _IslandArchive(capacity=self.population_size, k_neighbors=self.k_neighbors,
                                 elite_ratio=float(preset.get("elite_ratio", 0.2)),
                                 fitness_weight=float(preset.get("fitness_weight", 0.3)),
                                 novelty_weight=float(preset.get("novelty_weight", 0.3)),
                                 pareto_weight=float(preset.get("pareto_weight", 0.4)),
                                 preset_name=str(preset.get("name", "balanced")))
        for prog in self.global_top(5):
            copy = Genome(content=prog.content, parent_id=prog.id, lineage=prog.lineage,
                          scores=dict(prog.scores),
                          metadata={"island": new_idx, "admitted": True, "seed_copy": True},
                          artifacts=dict(prog.artifacts))
            archive.members.append(copy)
        self.archives.append(archive)
        return new_idx

    def migrate(self) -> list[Genome]:
        """Ring migration (``_migrate_archives``): each island copies its top ``migration_count``
        programs to ``(k+1) mod K``, skipping exact-duplicate solutions already at the destination.
        Migrants are copies with new ids + ``migrated_from``/``migrated_to`` metadata; the archive
        may still reject them. Returns the *admitted* migrants so the policy can update each
        receiving island's adaptive state (``receive_external_improvement``) — never UCB stats."""
        if self.num_islands < 2:
            return []
        admitted_migrants: list[Genome] = []
        # each source's top list is computed INSIDE the loop, AFTER earlier sources' migrants were
        # admitted — the reference's sequential semantics, where a strong program can hop multiple
        # islands within one migration event (src 0→1, then on through 1→2)
        for src in range(self.num_islands):
            programs = self.archives[src].get_top_programs(self.migration_count)
            dest = (src + 1) % self.num_islands
            dest_archive = self.archives[dest]
            for prog in programs:
                if any(m.content == prog.content for m in dest_archive.members):
                    continue  # skip exact-duplicate solutions
                migrant = Genome(content=prog.content, parent_id=prog.id, lineage=prog.lineage,
                                 scores=dict(prog.scores),
                                 metadata={**prog.metadata, "island": dest, "admitted": True,
                                           "migrated_from": src, "migrated_to": dest},
                                 artifacts=dict(prog.artifacts))
                admitted, evicted = dest_archive.add(migrant)
                migrant.metadata["admitted"] = admitted
                if evicted is not None:
                    self._forget(evicted)
                if admitted:
                    admitted_migrants.append(migrant)
        return admitted_migrants
