"""
functions/hierarchical.py
-------------------------
Functions for hierarchical / multi-level sampling.
All inputs come from request payloads; no shared state.

The data flows as: receive rows → apply mask → group by levels → sample → return.
Excluded data from the hierarchy is NOT dropped — we receive everything,
do the grouping/counting on everything, and return only sampled rows.
"""
from __future__ import annotations
from typing import Any


# ── helpers ──────────────────────────────────────────────────────────────────

def _path_key(parts: list[str]) -> str:
    return "||".join(str(x) for x in parts)


def _apply_mask(data: list[dict], mask_rules: list[dict]) -> list[dict]:
    """
    Apply conditional mask rules to filter rows.
    Supported ops: ==, !=, in, not in, contains, not contains, >, >=, <, <=
    """
    if not mask_rules:
        return data

    result = []
    for row in data:
        passes = True
        for rule in mask_rules:
            col = rule.get("col")
            op = (rule.get("op") or "").lower().strip()
            val = rule.get("value") or rule.get("values")

            if not col or not op or col not in row:
                continue

            cell = str(row.get(col, "")).strip()

            try:
                if op in ("==", "eq", "equals"):
                    passes = passes and (cell == str(val).strip())
                elif op in ("!=", "neq", "not equals"):
                    passes = passes and (cell != str(val).strip())
                elif op in ("in", "isin"):
                    vals = val if isinstance(val, list) else [val]
                    passes = passes and (cell in [str(v).strip() for v in vals])
                elif op in ("not in", "notin"):
                    vals = val if isinstance(val, list) else [val]
                    passes = passes and (cell not in [str(v).strip() for v in vals])
                elif op == "contains":
                    vals = val if isinstance(val, list) else [val]
                    passes = passes and any(str(v).lower() in cell.lower() for v in vals)
                elif op in ("ncontains", "not contains", "notcontains"):
                    vals = val if isinstance(val, list) else [val]
                    passes = passes and not any(str(v).lower() in cell.lower() for v in vals)
                elif op == ">":
                    try:
                        passes = passes and (float(row.get(col, 0)) > float(val))
                    except Exception:
                        pass
                elif op == ">=":
                    try:
                        passes = passes and (float(row.get(col, 0)) >= float(val))
                    except Exception:
                        pass
                elif op == "<":
                    try:
                        passes = passes and (float(row.get(col, 0)) < float(val))
                    except Exception:
                        pass
                elif op == "<=":
                    try:
                        passes = passes and (float(row.get(col, 0)) <= float(val))
                    except Exception:
                        pass
            except Exception:
                pass

        if passes:
            result.append(row)

    return result


def _apply_selected_filters(
    data: list[dict],
    selected: dict[str, list[str]],
    levels: list[str],
) -> list[dict]:
    """Filter rows to match selected values for given levels."""
    out = data
    for lvl in levels:
        vals = selected.get(lvl) or []
        if vals:
            val_set = {str(v) for v in vals}
            out = [r for r in out if str(r.get(lvl, "")) in val_set]
    return out


def _group_by_levels(data: list[dict], levels: list[str]) -> dict[str, list[dict]]:
    """Group rows by the full path key of the given levels."""
    groups: dict[str, list[dict]] = {}
    for row in data:
        key = _path_key([str(row.get(lvl, "")) for lvl in levels])
        groups.setdefault(key, []).append(row)
    return groups


# ── public functions ──────────────────────────────────────────────────────────

def options_for_level(
    data: list[dict],
    depth_levels: list[str],
    level: str,
    selected: dict[str, list[str]],
    mask_rules: list[dict],
    min_count: int = 0,
) -> list[dict]:
    """
    Return available options for a given hierarchy level.
    Parent level must be selected before child options are returned.
    """
    if level not in depth_levels:
        return []

    idx = depth_levels.index(level)

    # Require parent selection
    if idx > 0:
        parent = depth_levels[idx - 1]
        if not (selected.get(parent) or []):
            return []

    # Apply mask and parent filters
    base = _apply_mask(data, mask_rules)
    base = _apply_selected_filters(base, selected, depth_levels[:idx])

    if not any(level in row for row in base):
        return []

    # Group by all levels up to and including current
    group_cols = depth_levels[:idx + 1]
    groups: dict[str, dict] = {}

    for row in base:
        path_parts = [str(row.get(col, "")) for col in group_cols]
        key = _path_key(path_parts)
        if key not in groups:
            groups[key] = {
                "value": str(row.get(level, "")),
                "count": 0,
                "path": " | ".join(path_parts),
                "key": key,
            }
        groups[key]["count"] += 1

    out = sorted(groups.values(), key=lambda x: x["count"], reverse=True)

    # Apply min_count only on last level
    if idx == len(depth_levels) - 1:
        out = [x for x in out if x["count"] >= int(min_count)]

    return out


def preview_last_level(
    data: list[dict],
    depth_levels: list[str],
    selected: dict[str, list[str]],
    mask_rules: list[dict],
    min_count: int,
) -> list[dict]:
    """
    Return count-per-group for the last depth level,
    flagging which groups meet min_count.
    """
    if not depth_levels:
        return []

    last = depth_levels[-1]
    base = _apply_mask(data, mask_rules)
    base = _apply_selected_filters(base, selected, depth_levels[:-1])

    groups: dict[str, dict] = {}
    for row in base:
        path_parts = [str(row.get(col, "")) for col in depth_levels]
        key = _path_key(path_parts)
        if key not in groups:
            groups[key] = {
                "group": str(row.get(last, "")),
                "count": 0,
                "path": " | ".join(path_parts),
                "key": key,
            }
        groups[key]["count"] += 1

    out = sorted(groups.values(), key=lambda x: x["count"], reverse=True)
    for item in out:
        item["eligible"] = item["count"] >= int(min_count)

    return out


def hierarchical_sample(
    data: list[dict],
    depth_levels: list[str],
    selected: dict[str, list[str]],
    mask_rules: list[dict],
    min_count: int,
    sample_count: int,
    quota: dict[str, int] | None = None,
    random_state: int = 42,
) -> list[dict]:
    """
    Draw samples from each selected last-level group.

    - mask_rules apply to narrow the population
    - selected filters apply hierarchically (except the last level which uses quota/sample_count)
    - quota: {full_path_key: n} overrides sample_count per group
    - Groups below min_count are skipped
    - Returns only sampled rows
    """
    import random

    quota = quota or {}
    rng = random.Random(random_state)

    if not depth_levels:
        return []

    last = depth_levels[-1]
    base = _apply_mask(data, mask_rules)
    parent_levels = depth_levels[:-1]
    base = _apply_selected_filters(base, selected, parent_levels)

    # Group by full path
    groups: dict[str, list[dict]] = {}
    for row in base:
        path_parts = [str(row.get(col, "")) for col in depth_levels]
        key = _path_key(path_parts)
        groups.setdefault(key, []).append(row)

    selected_last_keys = {str(x) for x in (selected.get(last) or [])}
    parts: list[dict] = []

    for group_key, group_rows in groups.items():
        # Only sample selected last-level groups (if any selection made)
        if selected_last_keys and group_key not in selected_last_keys:
            continue

        if len(group_rows) < int(min_count):
            continue

        n = int(quota.get(group_key, sample_count))
        n = max(0, min(len(group_rows), n))
        if n == 0:
            continue

        parts.extend(rng.sample(group_rows, n))

    # Sort by depth levels (best-effort on string values)
    parts.sort(key=lambda r: tuple(str(r.get(lvl, "")) for lvl in depth_levels))

    return parts


def count_effective_rows(data: list[dict], mask_rules: list[dict]) -> dict:
    """Return total vs. mask-filtered row counts."""
    total = len(data)
    filtered = len(_apply_mask(data, mask_rules))
    return {"total_rows": total, "filtered_rows": filtered}
