"""OpenEvolve Population component — per-island MAP-Elites grids over the behaviour
descriptors (complexity, diversity) + a global elite archive + lazy ring migration.

One file per component (see scaffold.py). Faithful port of OpenEvolve ProgramDatabase.
"""
from __future__ import annotations

from ...components.population import Population
from ...records import Genome, RunState

def _fitness(genome: Genome, feature_dimensions: list[str]) -> float:
    """Fitness = ``combined_score`` if present, else the mean of numeric metrics EXCLUDING the
    MAP-Elites feature dimensions (so the QD descriptors never pollute the fitness comparison).

    This is the open mirror of ``get_fitness_score``: it diverges from the shared ``Genome.fitness`` by
    (a) the feature-dimension exclusion above, and (b) the empty floor — when no numeric metric exists
    at all this returns ``0.0`` (matching OpenEvolve ``safe_numeric_average`` / ``get_fitness_score``),
    whereas ``Genome.fitness`` floors at ``-inf`` (SkyDiscover "un-scored = worst" semantics). This is
    openevolve-local on purpose: it keeps openevolve's MAP-Elites cell competition / archive ranking
    faithful to OpenEvolve without changing the shared fitness used by topk/best_of_n.
    """
    scores = genome.scores
    if "combined_score" in scores:
        try:
            return float(scores["combined_score"])
        except (TypeError, ValueError):
            pass
    nums = [float(v) for k, v in scores.items()
            if k not in feature_dimensions and isinstance(v, (int, float)) and not isinstance(v, bool)]
    if not nums:
        nums = [float(v) for v in scores.values()
                if isinstance(v, (int, float)) and not isinstance(v, bool)]
    return sum(nums) / len(nums) if nums else 0.0   # OpenEvolve get_fitness_score floor (NOT -inf)


# ---------------------------------------------------------------------------------------------
# Population — MAP-Elites over islands  (openevolve/database.py::ProgramDatabase)
# ---------------------------------------------------------------------------------------------


class MapElitesIslandsPopulation(Population):
    """``ProgramDatabase``: ``num_islands`` islands, each a MAP-Elites grid over
    ``feature_dimensions``, plus a global elite archive and ring migration.

    Differs from the generic :class:`~galapagos.components.population.IslandPopulation` in three
    faithfulness-load-bearing ways:

    * **All programs are retained** (``self.programs``) — not only one elite per cell. The
      exploitation tier samples from a global ``archive`` and the random tier samples from every
      program, so the store cannot be a pure one-per-cell grid.
    * **Feature scaling is global running min/max** over the raw descriptor values (code length and
      a token-distance novelty measure), exactly like ``_calculate_feature_coords`` →
      ``_scale_feature_value`` (``minmax``), rather than fixed bucket boundaries.
    * **Migration is lazy on generation counters** — ``max(island_generations) -
      last_migration_generation >= migration_interval`` — and copies the top ``migration_rate`` of
      each island to its TWO ring neighbours ``(i±1) % n``, marking migrants so they never re-migrate.
    """

    def __init__(self, num_islands: int = 5, archive_size: int = 100, population_size: int = 1000,
                 feature_dimensions: list[str] | None = None, feature_bins: int = 10,
                 migration_interval: int = 50, migration_rate: float = 0.1,
                 diversity_reference_size: int = 20):
        self.num_islands = num_islands
        self.archive_size = archive_size
        self.population_size = population_size
        self.feature_dimensions = list(feature_dimensions or ["complexity", "diversity"])
        # OpenEvolve floors the bin count so the grid can hold the archive (database.py __init__):
        # feature_bins = max(configured, ceil(archive_size ** (1/num_dims))).
        ndims = max(1, len(self.feature_dimensions))
        self.feature_bins = max(feature_bins, int(archive_size ** (1.0 / ndims) + 0.99))
        self.migration_interval = migration_interval
        self.migration_rate = migration_rate
        self.diversity_reference_size = diversity_reference_size

        # core stores (mirror ProgramDatabase fields)
        self.programs: dict[str, Genome] = {}
        self.island_feature_maps: list[dict[str, str]] = [dict() for _ in range(num_islands)]
        self.islands: list[set[str]] = [set() for _ in range(num_islands)]
        self.archive: set[str] = set()
        self.best_program_id: str | None = None
        self.island_best_programs: list[str | None] = [None] * num_islands

        # island scheduling / migration state
        self.current_island = 0
        self.island_generations = [0] * num_islands
        self.last_migration_generation = 0

        # feature scaling state: per-dimension running min/max  (minmax method)
        self.feature_stats: dict[str, dict[str, float]] = {}
        # diversity reference set (a small sample of program codes), as in the reference
        self.diversity_reference: list[str] = []

        self._rr = -1  # round-robin fallback when a genome has no island

    # ---- admission --------------------------------------------------------------------------
    def add(self, genome: Genome) -> bool:
        self.programs[genome.id] = genome

        # route to island: explicit metadata > inherit parent's island > current island
        isl = genome.metadata.get("island")
        if not isinstance(isl, int) or not (0 <= isl < self.num_islands):
            parent = self.programs.get(genome.parent_id) if genome.parent_id else None
            if parent is not None and isinstance(parent.metadata.get("island"), int):
                isl = parent.metadata["island"]
            else:
                isl = self.current_island
        isl %= self.num_islands
        genome.metadata["island"] = isl

        # MAP-Elites cell coordinates (also refreshes the diversity reference set)
        coords = self._feature_coords(genome)
        genome.metadata["cell"] = coords
        feature_key = "-".join(str(c) for c in coords)

        grid = self.island_feature_maps[isl]
        if feature_key not in grid or grid[feature_key] not in self.programs:
            grid[feature_key] = genome.id
        elif self._fit(genome) > self._fit(self.programs[grid[feature_key]]):
            # cell improved: hand off archive membership, drop incumbent from island set
            if grid[feature_key] in self.archive:
                self.archive.discard(grid[feature_key])
                self.archive.add(genome.id)
            self.islands[isl].discard(grid[feature_key])
            grid[feature_key] = genome.id
        # else: the cell keeps its fitter elite, but this genome is still stored in self.programs and the
        # island set — OpenEvolve's database.add ALWAYS adds to self.programs (losing a MAP-Elites cell is
        # not rejection), so it stays a sampleable parent/inspiration and a best candidate.

        self.islands[isl].add(genome.id)
        self._update_archive(genome)
        self._enforce_population_limit(exclude=genome.id)
        self._update_best(genome)
        self._update_island_best(genome, isl)
        return True   # always stored in self.programs (admitted), mirroring OpenEvolve's database.add

    # ---- queries ----------------------------------------------------------------------------
    def query(self, spec: dict | None = None) -> list[Genome]:
        """Sorted-by-fitness views used by the selection policy.

        ``{"island": i}`` → that island's programs; ``{"archive": True}`` → the elite archive;
        ``{"top": n, ...}`` truncates. Default → all programs.
        """
        spec = spec or {}
        if spec.get("archive"):
            members = [self.programs[p] for p in self.archive if p in self.programs]
        elif "island" in spec:
            members = [self.programs[p] for p in self.islands[spec["island"]] if p in self.programs]
        else:
            members = list(self.programs.values())
        members.sort(key=self._fit, reverse=True)
        top = spec.get("top")
        return members[:top] if top else members

    def all(self) -> list[Genome]:
        return list(self.programs.values())

    def best(self) -> Genome | None:
        if self.best_program_id and self.best_program_id in self.programs:
            return self.programs[self.best_program_id]
        members = self.all()
        return max(members, key=self._fit) if members else None

    # ---- island scheduling / migration  (called by the scaffold/policy) ---------------------
    def next_island(self) -> int:
        self.current_island = (self.current_island + 1) % self.num_islands
        return self.current_island

    def increment_island_generation(self, island_idx: int) -> None:
        self.island_generations[island_idx % self.num_islands] += 1

    def should_migrate(self) -> bool:
        return (max(self.island_generations) - self.last_migration_generation) >= self.migration_interval

    def migrate(self) -> None:
        """Ring migration: copy the top ``migration_rate`` of each island to BOTH neighbours
        ``(i±1) % n``; copies are marked ``migrant`` so they never migrate again (avoids the
        exponential-duplication failure mode documented in the reference)."""
        if self.num_islands < 2:
            return
        for i, island in enumerate(self.islands):
            members = sorted((self.programs[p] for p in island if p in self.programs),
                             key=self._fit, reverse=True)
            if not members:
                continue
            num_to_migrate = max(1, int(len(members) * self.migration_rate))
            targets = [(i + 1) % self.num_islands, (i - 1) % self.num_islands]
            for migrant in members[:num_to_migrate]:
                if migrant.metadata.get("migrant"):
                    continue
                for tgt in targets:
                    if any(self.programs[p].content == migrant.content
                           for p in self.islands[tgt] if p in self.programs):
                        continue
                    copy = Genome(content=migrant.content, parent_id=migrant.id,
                                  lineage=migrant.lineage, scores=dict(migrant.scores),
                                  metadata={**migrant.metadata, "island": tgt, "migrant": True},
                                  artifacts=dict(migrant.artifacts))
                    self.add(copy)
        self.last_migration_generation = max(self.island_generations)

    # ---- checkpoint / resume (mirror ProgramDatabase.save/load scheduling+scaling state) ----
    def state_dict(self) -> dict:
        """Persist the island-scheduling + feature-scaling state so a resume continues migration on
        the saved cadence and keeps the same minmax binning (the population genomes themselves are
        round-tripped separately via population.jsonl)."""
        return {
            "current_island": self.current_island,
            "island_generations": list(self.island_generations),
            "last_migration_generation": self.last_migration_generation,
            "feature_stats": {k: {"min": v["min"], "max": v["max"]} for k, v in self.feature_stats.items()},
            "diversity_reference": list(self.diversity_reference),
        }

    def load_state_dict(self, st: dict) -> None:
        if not isinstance(st, dict):
            return
        self.current_island = int(st.get("current_island", self.current_island))
        ig = st.get("island_generations")
        if isinstance(ig, list) and len(ig) == self.num_islands:
            self.island_generations = [int(x) for x in ig]
        self.last_migration_generation = int(st.get("last_migration_generation", self.last_migration_generation))
        fs = st.get("feature_stats")
        if isinstance(fs, dict):
            self.feature_stats = {k: {"min": float(v["min"]), "max": float(v["max"])}
                                  for k, v in fs.items() if isinstance(v, dict) and "min" in v and "max" in v}
        dr = st.get("diversity_reference")
        if isinstance(dr, list):
            self.diversity_reference = [str(x) for x in dr]

    # ---- internals --------------------------------------------------------------------------
    def _fit(self, genome: Genome) -> float:
        return _fitness(genome, self.feature_dimensions)

    def _feature_coords(self, genome: Genome) -> tuple[int, ...]:
        """Per-dimension bin index, mirroring ``_calculate_feature_coords``.

        For each ``dim``: a custom evaluator metric of that name wins; else the built-ins
        ``complexity`` = ``len(code)`` and ``diversity`` = avg fast-code-distance to a reference
        sample. Each raw value updates a running min/max and is minmax-scaled to ``[0, 1]``, then
        ``bin = int(scaled * bins)`` clamped to ``[0, bins-1]``.
        """
        coords: list[int] = []
        for dim in self.feature_dimensions:
            if dim in genome.scores and isinstance(genome.scores[dim], (int, float)):
                coords.append(self._bin(dim, float(genome.scores[dim])))
                continue
            if dim == "diversity" and len(self.programs) < 2:
                coords.append(0)   # OpenEvolve assigns the diversity bin 0 directly (feature_stats left
                continue           # untouched) when <2 programs — avoids pinning the running min at a
                                   # fake 0.0 that would skew every later diversity bin (database.py:867-868)
            if dim == "complexity":
                raw = float(len(genome.content))
            elif dim == "diversity":
                raw = self._diversity(genome)
            elif dim == "score":
                raw = self._fit(genome) if genome.scores else 0.0
            else:
                raw = 0.0
            coords.append(self._bin(dim, raw))
        return tuple(coords)

    def _bin(self, dim: str, value: float) -> int:
        stats = self.feature_stats.setdefault(dim, {"min": value, "max": value})
        stats["min"] = min(stats["min"], value)
        stats["max"] = max(stats["max"], value)
        lo, hi = stats["min"], stats["max"]
        scaled = 0.5 if hi == lo else (value - lo) / (hi - lo)
        scaled = min(1.0, max(0.0, scaled))
        return max(0, min(self.feature_bins - 1, int(scaled * self.feature_bins)))

    def _diversity(self, genome: Genome) -> float:
        """Avg ``_fast_code_diversity`` against the diversity reference set (OpenEvolve
        ``_calculate_diversity``)."""
        ref = self._build_diversity_reference()
        diffs = [self._fast_code_diversity(genome.content, r) for r in ref if r != genome.content]
        return sum(diffs) / len(diffs) if diffs else 0.0

    def _build_diversity_reference(self) -> list[str]:
        """OpenEvolve ``_update_diversity_reference_set``: all programs when <= reference_size, else a
        greedy max-min-diversity (farthest-point) selection. We start from the first-inserted program
        (deterministic) instead of OpenEvolve's random first element — the greedy max-min dominates the
        composition, and keeping the start deterministic avoids gratuitous run-to-run variance."""
        progs = list(self.programs.values())
        if len(progs) <= self.diversity_reference_size:
            self.diversity_reference = [p.content for p in progs]
            return self.diversity_reference
        remaining = progs[:]
        selected = [remaining.pop(0)]
        while len(selected) < self.diversity_reference_size and remaining:
            best_i, best_min = 0, -1.0
            for i, cand in enumerate(remaining):
                mn = min(self._fast_code_diversity(cand.content, s.content) for s in selected)
                if mn > best_min:
                    best_min, best_i = mn, i
            selected.append(remaining.pop(best_i))
        self.diversity_reference = [p.content for p in selected]
        return self.diversity_reference

    @staticmethod
    def _fast_code_diversity(code1: str, code2: str) -> float:
        if code1 == code2:
            return 0.0
        length_diff = abs(len(code1) - len(code2))
        line_diff = abs(code1.count("\n") - code2.count("\n"))
        char_diff = len(set(code1).symmetric_difference(set(code2)))
        return length_diff * 0.1 + line_diff * 10 + char_diff * 0.5

    def _update_archive(self, genome: Genome) -> None:
        if len(self.archive) < self.archive_size:
            self.archive.add(genome.id)
            return
        valid = [self.programs[p] for p in self.archive if p in self.programs]
        self.archive = {p.id for p in valid}
        if len(self.archive) < self.archive_size:
            self.archive.add(genome.id)
            return
        worst = min(valid, key=self._fit)
        if self._fit(genome) > self._fit(worst):
            self.archive.discard(worst.id)
            self.archive.add(genome.id)

    def _enforce_population_limit(self, exclude: str | None = None) -> None:
        if len(self.programs) <= self.population_size:
            return
        num_remove = len(self.programs) - self.population_size
        ordered = sorted(self.programs.values(), key=self._fit)  # worst first
        protected = {self.best_program_id, exclude} - {None}
        removed = 0
        for g in ordered:
            if removed >= num_remove:
                break
            if g.id in protected:
                continue
            del self.programs[g.id]
            for grid in self.island_feature_maps:
                for k in [k for k, v in grid.items() if v == g.id]:
                    del grid[k]
            for island in self.islands:
                island.discard(g.id)
            self.archive.discard(g.id)
            removed += 1

    def _update_best(self, genome: Genome) -> None:
        if self.best_program_id is None or self.best_program_id not in self.programs:
            self.best_program_id = genome.id
        elif self._fit(genome) > self._fit(self.programs[self.best_program_id]):
            self.best_program_id = genome.id

    def _update_island_best(self, genome: Genome, isl: int) -> None:
        cur = self.island_best_programs[isl]
        if cur is None or cur not in self.programs or self._fit(genome) > self._fit(self.programs[cur]):
            self.island_best_programs[isl] = genome.id
