"""
scheinman.py — Scheinman Boolean minimization (Sum-of-Products).

stdlib only: dataclasses, argparse, sys, typing
"""
from __future__ import annotations

import sys
from dataclasses import dataclass, field
from typing import Optional


# ---------------------------------------------------------------------------
# Data model
# ---------------------------------------------------------------------------

@dataclass(eq=False)
class Implicant:
    value: int   # bit pattern (don't-care bits should be 0 after normalisation)
    mask: int    # 1 = care bit (fixed), 0 = don't-care bit
    absorbed: bool = False
    is_dc: bool = False  # True if this was originally a don't-care term
    outputs: frozenset = field(default_factory=frozenset)  # output membership

    def covers(self, minterm: int) -> bool:
        """True if this implicant covers the given minterm."""
        return (minterm & self.mask) == (self.value & self.mask)

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Implicant):
            return NotImplemented
        return (self.value & self.mask) == (other.value & other.mask) and self.mask == other.mask

    def __hash__(self) -> int:
        return hash((self.value & self.mask, self.mask))

    def __repr__(self) -> str:
        return f"Implicant(value={self.value:b}, mask={self.mask:b}, absorbed={self.absorbed})"


@dataclass
class Node:
    depth: int
    n: int
    implicants: list[Implicant]
    left_implicants: list[Implicant]
    right_implicants: list[Implicant]
    matched_implicants: list[Implicant]
    pre_match_absorbed_ids: frozenset = field(default_factory=frozenset)
    left: Optional[Node] = field(default=None)
    right: Optional[Node] = field(default=None)
    matched: Optional[Node] = field(default=None)
    ancestor_right_mask: int = 0


@dataclass
class InputRowRecord:
    """One displayable row for an implicant column."""
    display_value: int        # value shown (cleared for right-branch, original for left/input)
    original_value: int       # pre-clearing value (same as display_value for left/input rows)
    outputs: frozenset
    is_dc: bool
    absorbed_at_depth: bool   # matched AT this node (not at an ancestor)
    mset: Optional[str] = None  # minterm-set annotation, e.g. '{1,3}', computed by build_trace
    mask: int = 0  # PI mask (which bits are fixed); used by visualize.py for PI labels at last var


@dataclass
class MatchPairRecord:
    """One matched pair at a tree node."""
    left_value: int
    left_value_raw: int        # left_imp.value before arm-masking (for PI labels)
    right_value_cleared: int
    right_value_original: int  # right_value_cleared | (1 << bit_pos)
    left_outputs: frozenset
    right_outputs: frozenset
    shared_outputs: frozenset  # empty frozenset in single-output mode
    is_kept: bool              # False only in multi-output when intersection is empty


@dataclass
class TreeStepRecord:
    """All display data for one node in the Scheinman tree."""
    step_index: int            # 1-based, DFS order (matched first, then left, then right)
    depth: int
    bit_pos: int               # n - 1 - depth
    path_label: str            # '' for root, 'Match(−)', 'Left', 'Right', etc.
    n_outputs: int
    input_rows: list           # list[InputRowRecord], sorted by display_value
    left_rows: list            # list[InputRowRecord]
    right_rows: list           # list[InputRowRecord], display_value is cleared value
    match_pairs: list          # list[MatchPairRecord]
    child_step_left: Optional[int]
    child_step_right: Optional[int]
    child_step_matched: Optional[int]


@dataclass
class PIChartRecord:
    prime_implicants: list     # list[Implicant], sorted by group size desc then value
    minterms: list             # list[int], sorted
    essential_pis: frozenset   # subset of prime_implicants covered by exactly one PI
    selected_pis: list         # list[Implicant], the final cover


@dataclass
class MinimizationTrace:
    n: int
    n_outputs: int
    on_set: list               # list[int]
    dont_cares: list           # list[int]
    tree_steps: list           # list[TreeStepRecord], DFS order
    pi_chart: PIChartRecord


# ---------------------------------------------------------------------------
# Tree building
# ---------------------------------------------------------------------------

def build_tree(
    implicants: list[Implicant],
    n: int,
    depth: int = 0,
    overlapping: bool = True,
    ancestor_right_mask: int = 0,
) -> Optional[Node]:
    """
    Recursively build the Scheinman binary tree.

    At each level we partition on bit_pos = n-1-depth (MSB-first).
    Items with that bit = 0 go left; items with that bit = 1 go right
    (with the bit cleared in value so values are comparable across branches).
    """
    # --- Base cases ---
    if depth == n:
        return None
    if len(implicants) == 0:
        return None
    # Prune early if all absorbed, but NOT at the root (depth > 0)
    if depth > 0 and all(imp.absorbed for imp in implicants):
        return None

    bit_pos = n - 1 - depth

    # --- Partition ---
    left_list: list[Implicant] = []
    # right_pairs holds (bit-cleared copy, original_index) for matching comparison only
    right_pairs: list[tuple[Implicant, int]] = []
    # right_originals holds references to the original right-branch items for recursion
    right_originals: list[Implicant] = []

    for idx, imp in enumerate(implicants):
        if imp.value & (1 << bit_pos):
            # Bit is 1 — create a value-cleared copy for cross-branch comparison.
            # The mask is UNCHANGED so matching works correctly (compare same-bit-position values).
            cleared_imp = Implicant(
                value=imp.value & ~(1 << bit_pos),
                mask=imp.mask,
                absorbed=imp.absorbed,
                is_dc=imp.is_dc,
                outputs=imp.outputs,
            )
            right_pairs.append((cleared_imp, idx))
            right_originals.append(imp)
        else:
            # Bit is 0 — use reference to the original object
            left_list.append(imp)

    right_list = [rp[0] for rp in right_pairs]

    # Snapshot which items are absorbed BEFORE this depth's match scan.
    # Used by build_trace() so the input list can mark previously-absorbed items.
    pre_match_absorbed = frozenset(id(imp) for imp in implicants if imp.absorbed)

    # --- Match scan O(n²) ---
    matched_implicants: list[Implicant] = []

    for i, left_imp in enumerate(left_list):
        for j, right_imp in enumerate(right_list):
            if left_imp.value == right_imp.value:
                # Build merged implicant: value from left (bit=0), mask with bit_pos cleared
                orig_right_imp = implicants[right_pairs[j][1]]

                # Multi-output: only keep matches where at least one output benefits
                shared_outputs = left_imp.outputs & orig_right_imp.outputs
                if left_imp.outputs and orig_right_imp.outputs and not shared_outputs:
                    # Both have output tags but no overlap — skip this match
                    continue

                # Use intersection if both have tags; otherwise fall through (single-output mode)
                effective_outputs = shared_outputs

                matched_imp = Implicant(
                    value=left_imp.value,
                    mask=left_imp.mask & ~(1 << bit_pos),
                    absorbed=left_imp.absorbed and orig_right_imp.absorbed,
                    outputs=effective_outputs,
                )

                # Mark both originals absorbed, but only when the matched implicant
                # covers ALL outputs that the original serves.  In multi-output mode
                # the intersection may be a strict subset of one or both operands'
                # output tags, meaning the originals are still prime implicants for
                # the uncovered outputs.
                if not left_imp.outputs or effective_outputs >= left_imp.outputs:
                    left_list[i].absorbed = True
                if not orig_right_imp.outputs or effective_outputs >= orig_right_imp.outputs:
                    orig_right_imp.absorbed = True  # the real original in implicants[]

                matched_implicants.append(matched_imp)
                break  # each left matches at most one right

    # --- CRITICAL: build matched subtree FIRST ---
    matched_node = build_tree(matched_implicants, n, depth + 1, overlapping,
                              ancestor_right_mask=ancestor_right_mask)

    # --- Filter for non-overlapping mode ---
    # Right subtree recursion uses ORIGINAL right-branch items (not bit-cleared copies).
    # This preserves correct bit values so that PIs produced within the right subtree
    # (covering the bit=1 subspace) have correct value/mask combinations.
    if not overlapping:
        left_recurse = [imp for imp in left_list if not imp.absorbed]
        right_recurse = [imp for imp in right_originals if not imp.absorbed]
    else:
        left_recurse = left_list
        right_recurse = right_originals

    # --- Recurse into left and right ---
    left_node = build_tree(left_recurse, n, depth + 1, overlapping,
                           ancestor_right_mask=ancestor_right_mask)
    right_node = build_tree(right_recurse, n, depth + 1, overlapping,
                            ancestor_right_mask=ancestor_right_mask | (1 << bit_pos))

    return Node(
        depth=depth,
        n=n,
        implicants=implicants,
        left_implicants=left_list,
        right_implicants=right_list,
        matched_implicants=matched_implicants,
        pre_match_absorbed_ids=pre_match_absorbed,
        left=left_node,
        right=right_node,
        matched=matched_node,
        ancestor_right_mask=ancestor_right_mask,
    )


# ---------------------------------------------------------------------------
# Prime implicant extraction
# ---------------------------------------------------------------------------

def extract_prime_implicants(
    root: Optional[Node],
    minterms: list[int],
    n: int,
) -> list[Implicant]:
    """
    Depth-first traversal; collect all non-absorbed matched implicants.
    Adds fallback single-literal implicants for any uncovered on-set minterm.
    """
    pis: set[Implicant] = set()

    def dfs(node: Optional[Node]) -> None:
        if node is None:
            return
        for imp in node.matched_implicants:
            if not imp.absorbed:
                pis.add(imp)
        dfs(node.left)
        dfs(node.right)
        dfs(node.matched)

    dfs(root)

    # Fallback: ensure every on-set minterm is covered
    for m in minterms:
        if not any(pi.covers(m) for pi in pis):
            pis.add(Implicant(value=m, mask=(1 << n) - 1))

    return list(pis)


# ---------------------------------------------------------------------------
# PI chart
# ---------------------------------------------------------------------------

def build_pi_chart(
    prime_implicants: list[Implicant],
    minterms: list[int],
) -> dict[int, list[Implicant]]:
    """Map each on-set minterm to the list of PIs that cover it."""
    return {m: [pi for pi in prime_implicants if pi.covers(m)] for m in minterms}


# ---------------------------------------------------------------------------
# Cover selection
# ---------------------------------------------------------------------------

def select_cover(
    pi_chart: dict[int, list[Implicant]],
    prime_implicants: list[Implicant],
) -> list[Implicant]:
    """
    Phase 1: pick essential PIs (minterms covered by exactly one PI).
    Phase 2: greedy pick — PI covering the most uncovered minterms.
    """
    uncovered: set[int] = set(pi_chart.keys())
    selected: list[Implicant] = []

    # Phase 1 — essential PIs
    changed = True
    while changed:
        changed = False
        for m in list(uncovered):
            covering = [pi for pi in pi_chart[m] if pi in set(prime_implicants)]
            # Filter to PIs that exist (all do, but keep the pattern clean)
            if len(covering) == 1:
                essential = covering[0]
                if essential not in selected:
                    selected.append(essential)
                    changed = True
                # Remove all minterms covered by this essential PI
                uncovered -= {mt for mt in uncovered if essential.covers(mt)}
                break  # restart loop since uncovered changed

    # Phase 2 — greedy
    remaining_pis = [pi for pi in prime_implicants if pi not in selected]
    while uncovered:
        best_pi = max(
            remaining_pis,
            key=lambda pi: sum(1 for m in uncovered if pi.covers(m)),
            default=None,
        )
        if best_pi is None:
            break
        selected.append(best_pi)
        remaining_pis.remove(best_pi)
        uncovered -= {m for m in uncovered if best_pi.covers(m)}

    return selected


# ---------------------------------------------------------------------------
# Trace builder — collects all display data for visualize.py
# ---------------------------------------------------------------------------

def build_trace(
    minterms: list[int],
    dont_cares: list[int],
    n: int,
    overlapping: bool = True,
    functions: list[list[int]] | None = None,
) -> MinimizationTrace:
    """
    Run the full Scheinman pipeline and collect all display data into a
    MinimizationTrace (dataclasses only — no matplotlib/rendering).

    Single-output mode: pass minterms, dont_cares, n.
    Multi-output mode:  pass functions=[[...], ...]; minterms is ignored.
    """
    n_outputs = 1 if functions is None else len(functions)
    max_val = 1 << n
    full_mask = max_val - 1

    # ── Build initial implicant list (mirrors build_visualize_payload logic) ──
    if functions is None:
        all_implicants: list[Implicant] = []
        for m in sorted(minterms):
            all_implicants.append(
                Implicant(value=m, mask=full_mask, absorbed=False, is_dc=False)
            )
        for d in sorted(dont_cares):
            all_implicants.append(
                Implicant(value=d, mask=full_mask, absorbed=True, is_dc=True)
            )
        all_implicants.sort(key=lambda imp: imp.value)
        on_set = list(minterms)
        dc_list = list(dont_cares)
    else:
        dc_list_raw = list(dont_cares)
        all_on_set: set[int] = set()
        for fn in functions:
            all_on_set.update(fn)
        all_dc_only = {d for d in dc_list_raw if d not in all_on_set}

        all_implicants = []
        for m in sorted(all_on_set | all_dc_only):
            on_out = frozenset(i for i, fn in enumerate(functions) if m in fn)
            dc_out = frozenset(
                i for i in range(len(functions))
                if m in dc_list_raw and m not in functions[i]
            )
            if on_out:
                all_implicants.append(
                    Implicant(value=m, mask=full_mask,
                              absorbed=False, is_dc=False, outputs=on_out)
                )
            elif dc_out:
                all_implicants.append(
                    Implicant(value=m, mask=full_mask,
                              absorbed=True, is_dc=True, outputs=dc_out)
                )
        on_set = sorted(all_on_set)
        dc_list = dc_list_raw

    if not all_implicants or not on_set:
        # Build empty PI chart so callers always get a valid MinimizationTrace
        empty_chart = PIChartRecord(
            prime_implicants=[],
            minterms=[],
            essential_pis=frozenset(),
            selected_pis=[],
        )
        return MinimizationTrace(
            n=n, n_outputs=n_outputs,
            on_set=on_set, dont_cares=dc_list,
            tree_steps=[], pi_chart=empty_chart,
        )

    # ── Build tree ────────────────────────────────────────────────────────────
    root = build_tree(all_implicants, n, 0, overlapping)
    if root is None:
        empty_chart = PIChartRecord(
            prime_implicants=[],
            minterms=[],
            essential_pis=frozenset(),
            selected_pis=[],
        )
        return MinimizationTrace(
            n=n, n_outputs=n_outputs,
            on_set=on_set, dont_cares=dc_list,
            tree_steps=[], pi_chart=empty_chart,
        )

    # ── DFS walk: assign step indices (matched → left → right) ───────────────
    dfs_order: list[tuple[Node, str]] = []

    def _collect(node: Optional[Node], path: str) -> None:
        if node is None:
            return
        dfs_order.append((node, path))
        _collect(node.matched, (path + ' → ' if path else '') + 'Match(−)')
        _collect(node.left,    (path + ' → ' if path else '') + 'Left')
        _collect(node.right,   (path + ' → ' if path else '') + 'Right')

    _collect(root, '')

    # Map node object id → 1-based step index
    node_to_step: dict[int, int] = {
        id(nd): idx + 1 for idx, (nd, _) in enumerate(dfs_order)
    }

    # ── Helper: mset annotation (mirrors _mset_annotation in visualize.py) ───
    def _mset(imp: Implicant) -> Optional[str]:
        covered = [m for m in on_set if imp.covers(m)]
        return ('{' + ','.join(str(m) for m in sorted(covered)) + '}'
                if len(covered) > 1 else None)

    # ── Build TreeStepRecord for each node ───────────────────────────────────
    tree_steps: list[TreeStepRecord] = []

    for nd, path in dfs_order:
        depth = nd.depth
        bit_pos = n - 1 - depth
        arm = nd.ancestor_right_mask  # bits cleared by all ancestor right-branches
        matched_cleared_values: set[int] = {mi.value for mi in nd.matched_implicants}

        # input_rows: mark items absorbed BEFORE this depth's match scan
        # (don't-cares + items matched at an ancestor depth)
        input_rows: list[InputRowRecord] = [
            InputRowRecord(
                display_value=imp.value & ~arm,
                original_value=imp.value & ~arm,
                outputs=imp.outputs,
                is_dc=imp.is_dc,
                absorbed_at_depth=id(imp) in nd.pre_match_absorbed_ids,
                mset=_mset(imp),
                mask=imp.mask,
            )
            for imp in sorted(nd.implicants, key=lambda x: x.value & ~arm)
        ]

        # left_rows: absorbed_at_depth iff value in matched_cleared_values
        left_rows: list[InputRowRecord] = [
            InputRowRecord(
                display_value=imp.value & ~arm,
                original_value=imp.value,          # pre-arm value, used for PI labels
                outputs=imp.outputs,
                is_dc=imp.is_dc,
                absorbed_at_depth=imp.value in matched_cleared_values,
                mset=_mset(imp),
                mask=imp.mask,
            )
            for imp in sorted(nd.left_implicants, key=lambda x: x.value & ~arm)
        ]

        # right_rows: display the bit-cleared copy value with ancestor bits masked;
        # original_value restores the cleared bit to recover the full original value
        right_rows: list[InputRowRecord] = [
            InputRowRecord(
                display_value=imp.value & ~arm,
                original_value=imp.value | (1 << bit_pos),  # full original value, for PI labels
                outputs=imp.outputs,
                is_dc=imp.is_dc,
                absorbed_at_depth=imp.value in matched_cleared_values,
                mset=_mset(imp),
                mask=imp.mask,
            )
            for imp in sorted(nd.right_implicants, key=lambda x: x.value & ~arm)
        ]

        # match_pairs: pair left items with same-value right items
        right_by_value: dict[int, Implicant] = {}
        for r in nd.right_implicants:
            if r.value not in right_by_value:
                right_by_value[r.value] = r

        match_pairs: list[MatchPairRecord] = []
        for left_imp in nd.left_implicants:
            if left_imp.value not in right_by_value:
                continue
            right_imp = right_by_value[left_imp.value]

            if left_imp.outputs and right_imp.outputs:
                shared = left_imp.outputs & right_imp.outputs
                is_kept = bool(shared)
            else:
                shared = frozenset()
                is_kept = True

            match_pairs.append(MatchPairRecord(
                left_value=left_imp.value & ~arm,
                left_value_raw=left_imp.value,
                right_value_cleared=right_imp.value & ~arm,
                right_value_original=(right_imp.value & ~arm) | (1 << bit_pos),
                left_outputs=left_imp.outputs,
                right_outputs=right_imp.outputs,
                shared_outputs=shared,
                is_kept=is_kept,
            ))

        step_idx = node_to_step[id(nd)]
        tree_steps.append(TreeStepRecord(
            step_index=step_idx,
            depth=depth,
            bit_pos=bit_pos,
            path_label=path,
            n_outputs=n_outputs,
            input_rows=input_rows,
            left_rows=left_rows,
            right_rows=right_rows,
            match_pairs=match_pairs,
            child_step_left=node_to_step.get(id(nd.left)),
            child_step_right=node_to_step.get(id(nd.right)),
            child_step_matched=node_to_step.get(id(nd.matched)),
        ))

    # ── PI chart ──────────────────────────────────────────────────────────────
    pis = extract_prime_implicants(root, on_set, n)

    # Multi-output only: partially-absorbed items at depth n-1 never appear in
    # matched_implicants (they were absorbed for some outputs but not all), yet
    # they are valid leaf-level PIs for their remaining outputs.  Collect them
    # here for the visualization PI chart without touching extract_prime_implicants
    # (which feeds minimize_multi and must stay clean).
    if functions is not None:
        _existing = {(pi.value & pi.mask, pi.mask) for pi in pis}

        def _add_leaf_pis(node: Optional[Node]) -> None:
            if node is None:
                return
            if node.depth == n - 1:
                for imp in node.left_implicants:
                    if not imp.absorbed:
                        key = (imp.value & imp.mask, imp.mask)
                        if key not in _existing:
                            pis.append(Implicant(value=imp.value, mask=imp.mask,
                                                 absorbed=False, is_dc=imp.is_dc,
                                                 outputs=imp.outputs))
                            _existing.add(key)
                for imp in node.right_implicants:
                    if not imp.absorbed:
                        restored = imp.value | 1  # bit_pos=0 at depth n-1
                        key = (restored & imp.mask, imp.mask)
                        if key not in _existing:
                            pis.append(Implicant(value=restored, mask=imp.mask,
                                                 absorbed=False, is_dc=imp.is_dc,
                                                 outputs=imp.outputs))
                            _existing.add(key)
            _add_leaf_pis(node.left)
            _add_leaf_pis(node.right)
            _add_leaf_pis(node.matched)

        _add_leaf_pis(root)

    pi_chart_raw = build_pi_chart(pis, on_set)
    selected = select_cover(pi_chart_raw, pis)

    sorted_pis = sorted(
        pis,
        key=lambda p: (-bin(~p.mask & full_mask).count('1'), p.value & p.mask),
    )
    sorted_minterms = sorted(on_set)

    essential_pis: set[Implicant] = set()
    for m in sorted_minterms:
        covering = [pi for pi in pis if pi.covers(m)]
        if len(covering) == 1:
            essential_pis.add(covering[0])

    pi_chart_record = PIChartRecord(
        prime_implicants=sorted_pis,
        minterms=sorted_minterms,
        essential_pis=frozenset(essential_pis),
        selected_pis=selected,
    )

    return MinimizationTrace(
        n=n,
        n_outputs=n_outputs,
        on_set=on_set,
        dont_cares=dc_list,
        tree_steps=tree_steps,
        pi_chart=pi_chart_record,
    )


# ---------------------------------------------------------------------------
# Top-level single-output minimizer
# ---------------------------------------------------------------------------

def minimize(
    minterms: list[int],
    dont_cares: list[int],
    n: int,
    overlapping: bool = True,
) -> list[Implicant]:
    """
    Minimize a Boolean function to a minimal SOP expression.

    Parameters
    ----------
    minterms:   on-set indices
    dont_cares: don't-care indices
    n:          number of variables
    overlapping: if False, use non-overlapping Scheinman variant
    """
    max_val = 1 << n

    # Validate
    for m in minterms:
        if not (0 <= m < max_val):
            raise ValueError(f"Minterm {m} out of range [0, {max_val})")
    for d in dont_cares:
        if not (0 <= d < max_val):
            raise ValueError(f"Don't-care {d} out of range [0, {max_val})")
    if set(minterms) & set(dont_cares):
        raise ValueError("Overlap between minterms and don't-cares")

    # Trivial cases
    if len(minterms) == 0:
        return []
    if set(minterms) | set(dont_cares) == set(range(max_val)):
        return [Implicant(0, 0)]

    full_mask = max_val - 1

    # Build initial implicant list (sorted by value for determinism)
    all_implicants: list[Implicant] = []
    for m in sorted(minterms):
        all_implicants.append(Implicant(value=m, mask=full_mask, absorbed=False, is_dc=False))
    for d in sorted(dont_cares):
        all_implicants.append(Implicant(value=d, mask=full_mask, absorbed=True, is_dc=True))
    all_implicants.sort(key=lambda imp: imp.value)

    root = build_tree(all_implicants, n, 0, overlapping)
    pis = extract_prime_implicants(root, minterms, n)
    chart = build_pi_chart(pis, minterms)
    selected = select_cover(chart, pis)
    return selected


# ---------------------------------------------------------------------------
# Top-level multiple-output minimizer
# ---------------------------------------------------------------------------

def minimize_multi(
    functions: list[list[int]],
    dont_cares: list[list[int]] | list[int],
    n: int,
    overlapping: bool = True,
) -> dict[int, list[Implicant]]:
    """
    Minimize multiple Boolean functions simultaneously using the Scheinman
    multiple-output method.

    Parameters
    ----------
    functions:   List of minterm lists, one per output function.
                 e.g. [[1,5,7,9,13,15], [0,1,2,3,5,7,15], [7,8,9,10,11,13,15]]
    dont_cares:  Either a single shared list of don't-care indices, or a list of
                 per-output don't-care lists.
    n:           Number of variables.
    overlapping: If False, use non-overlapping variant.

    Returns
    -------
    dict mapping output index -> list of selected prime implicants for that output.
    """
    max_val = 1 << n
    num_outputs = len(functions)

    # --- Normalize dont_cares ---
    if not dont_cares or isinstance(dont_cares[0], int):
        # Flat list of ints (possibly empty) — shared across all outputs
        dc_per_fn: list[list[int]] = [list(dont_cares)] * num_outputs  # type: ignore[arg-type]
    else:
        # List of lists — one per output
        dc_per_fn = [list(dc) for dc in dont_cares]  # type: ignore[arg-type]

    # --- Validate ---
    for i, (fn, dc) in enumerate(zip(functions, dc_per_fn)):
        for m in fn:
            if not (0 <= m < max_val):
                raise ValueError(f"Function {i}: minterm {m} out of range [0, {max_val})")
        for d in dc:
            if not (0 <= d < max_val):
                raise ValueError(f"Function {i}: don't-care {d} out of range [0, {max_val})")
        if set(fn) & set(dc):
            raise ValueError(f"Function {i}: overlap between minterms and don't-cares")

    # --- Trivial: any function with 0 minterms produces empty cover ---
    result: dict[int, list[Implicant]] = {}

    # Collect all unique values that appear as either on-set or dc in any function
    all_on_set: set[int] = set()
    all_dc_only: set[int] = set()
    for i, fn in enumerate(functions):
        all_on_set.update(fn)
    for i, dc in enumerate(dc_per_fn):
        for d in dc:
            if d not in all_on_set:
                all_dc_only.add(d)

    all_values: set[int] = all_on_set | all_dc_only

    full_mask = max_val - 1

    # --- Build combined implicant list ---
    combined: list[Implicant] = []
    for m in sorted(all_values):
        # Determine output membership for on-set
        on_set_outputs = frozenset(i for i, fn in enumerate(functions) if m in fn)
        # Determine which functions list this as a don't-care
        dc_outputs = frozenset(i for i, dc in enumerate(dc_per_fn) if m in dc)

        if on_set_outputs:
            # Minterm is in at least one on-set — not absorbed
            outputs = on_set_outputs
            combined.append(Implicant(
                value=m,
                mask=full_mask,
                absorbed=False,
                is_dc=False,
                outputs=outputs,
            ))
        elif dc_outputs:
            # Only a don't-care — absorbed
            combined.append(Implicant(
                value=m,
                mask=full_mask,
                absorbed=True,
                is_dc=True,
                outputs=dc_outputs,
            ))

    # Handle outputs with zero minterms trivially
    for i, fn in enumerate(functions):
        if len(fn) == 0:
            result[i] = []

    # If ALL outputs are trivially empty, return early
    if len(result) == num_outputs:
        return result

    # --- Build tree over the combined implicant set ---
    root = build_tree(combined, n, 0, overlapping)

    # --- Extract all PIs from the combined tree ---
    # The all_on_set_minterms list is used for fallback coverage — pass the full union
    all_pis_raw = extract_prime_implicants(root, list(all_on_set), n)

    # --- Expand PI output tags for global sharing ---
    # A PI tagged only for output i by the tree may also be valid for output j if
    # every minterm the PI covers is either in on_set[j] or in dont_cares[j].
    # Expanding tags this way lets the per-output cover selector reuse PIs across
    # outputs, reducing the total unique PI count (gate sharing).
    expanded_pis: list[Implicant] = []
    for pi in all_pis_raw:
        new_outputs = set(pi.outputs)
        for i_exp, fn_exp in enumerate(functions):
            allowed = set(fn_exp) | all_dc_only
            # PI is valid for output i_exp iff every minterm it covers is in allowed
            pi_covered = {m for m in range(max_val) if pi.covers(m)}
            if pi_covered.issubset(allowed):
                new_outputs.add(i_exp)
        expanded_pis.append(
            Implicant(
                value=pi.value,
                mask=pi.mask,
                absorbed=pi.absorbed,
                is_dc=pi.is_dc,
                outputs=frozenset(new_outputs),
            )
        )

    # --- Sharing-aware per-output cover selection ---
    # Global greedy: at each step pick the (output, PI) pair that gives the
    # best combined score — favouring PIs that cover minterms across the most
    # outputs simultaneously.  This reduces the total unique PI count (gate sharing).
    globally_selected: set[Implicant] = set()

    # Build per-output PI lists with fallback
    per_output_pis: dict[int, list[Implicant]] = {}
    for i, fn in enumerate(functions):
        if i in result:
            continue
        output_pis = [pi for pi in expanded_pis if i in pi.outputs]
        for m in fn:
            if not any(pi.covers(m) for pi in output_pis):
                output_pis.append(Implicant(
                    value=m,
                    mask=full_mask,
                    absorbed=False,
                    outputs=frozenset({i}),
                ))
        per_output_pis[i] = output_pis

    # Phase 1: per-output essential PIs
    per_output_uncovered: dict[int, set[int]] = {}
    per_output_selected: dict[int, list[Implicant]] = {}
    for i, pis in per_output_pis.items():
        chart_i = build_pi_chart(pis, list(functions[i]))
        uncovered: set[int] = set(functions[i])
        selected_i: list[Implicant] = []
        changed = True
        while changed:
            changed = False
            for m in list(uncovered):
                covering = [pi for pi in chart_i[m]]
                if len(covering) == 1:
                    ess = covering[0]
                    if ess not in selected_i:
                        selected_i.append(ess)
                        globally_selected.add(ess)
                        changed = True
                    uncovered -= {mt for mt in uncovered if ess.covers(mt)}
                    break
        per_output_uncovered[i] = uncovered
        per_output_selected[i] = selected_i

    # Phase 2: global greedy — pick the PI that covers the most total uncovered
    # minterms across ALL outputs simultaneously; break ties by outputs-shared count
    remaining: dict[int, set[int]] = {i: set(u) for i, u in per_output_uncovered.items()}

    while any(remaining[i] for i in remaining):
        best_pi: Optional[Implicant] = None
        best_score: tuple = (-1, -1, -1)
        best_output: int = -1

        for i, uncov in remaining.items():
            if not uncov:
                continue
            for pi in per_output_pis[i]:
                if pi in per_output_selected[i]:
                    continue
                # How many minterms does this PI cover in output i?
                local_coverage = sum(1 for m in uncov if pi.covers(m))
                if local_coverage == 0:
                    continue
                # How many total uncovered minterms does it cover across all outputs?
                total_coverage = sum(
                    sum(1 for m in remaining[j] if pi in per_output_pis[j] and pi.covers(m))
                    for j in remaining
                )
                # How many distinct outputs does this PI provide NEW coverage for?
                outputs_covered = sum(
                    1 for j in remaining
                    if pi in per_output_pis[j] and any(pi.covers(m) for m in remaining[j])
                )
                already_selected = 1 if pi in globally_selected else 0
                # Prefer: already selected > outputs covered (sharing) > total coverage
                score = (already_selected, outputs_covered, total_coverage)
                if score > best_score:
                    best_score = score
                    best_pi = pi
                    best_output = i

        if best_pi is None:
            break

        # Apply best_pi to ALL outputs where it provides coverage
        globally_selected.add(best_pi)
        for i in list(remaining.keys()):
            if best_pi in per_output_pis[i] and any(best_pi.covers(m) for m in remaining[i]):
                per_output_selected[i].append(best_pi)
                remaining[i] -= {m for m in remaining[i] if best_pi.covers(m)}

    for i in per_output_pis:
        result[i] = per_output_selected[i]

    return result


# ---------------------------------------------------------------------------
# Expression formatter
# ---------------------------------------------------------------------------

def format_expression(
    implicants: list[Implicant],
    n: int,
    use_unicode: bool = True,
    var_names: Optional[list[str]] = None,
) -> str:
    """
    Convert a list of prime implicants to a human-readable SOP string.

    Uses Unicode combining overline (U+0304) for complements when use_unicode=True,
    otherwise uses '~' prefix.
    """
    if var_names is None:
        var_names = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")

    if not implicants:
        return "0"

    terms: list[str] = []
    for imp in implicants:
        if imp.mask == 0:
            terms.append("1")
            continue
        literals: list[str] = []
        for depth in range(n):
            bit_pos = n - 1 - depth
            if imp.mask & (1 << bit_pos) == 0:
                # Don't-care bit — skip
                continue
            var = var_names[depth]
            if imp.value & (1 << bit_pos):
                literals.append(var)
            else:
                if use_unicode:
                    literals.append(var + "̄")  # combining macron
                else:
                    literals.append("~" + var)
        product = "".join(literals) if literals else "1"
        terms.append(product)

    # Sort for determinism: more literals first, then lexicographic
    def sort_key(term: str) -> tuple[int, str]:
        # Count actual variable characters (letters), not decorators
        letter_count = sum(1 for c in term if c.isalpha())
        return (-letter_count, term)

    terms.sort(key=sort_key)

    if len(terms) == 1 and terms[0] == "1":
        return "1"

    return " + ".join(terms)


# ---------------------------------------------------------------------------
# Semantic verification helper
# ---------------------------------------------------------------------------

def _verify_cover(
    selected: list[Implicant],
    minterms: list[int],
    n: int,
    dont_cares: Optional[list[int]] = None,
) -> tuple[bool, str]:
    dc = set(dont_cares or [])
    covered: set[int] = set()
    for pi in selected:
        for m in range(1 << n):
            if pi.covers(m):
                covered.add(m)
    required = set(minterms)
    if not required.issubset(covered):
        return False, f"Missing: {required - covered}"
    extra = covered - required - dc
    if extra:
        return False, f"Covers extra minterms not in on-set or dc-set: {extra}"
    return True, "OK"


# ---------------------------------------------------------------------------
# Test suite — single output
# ---------------------------------------------------------------------------

def _run_tests() -> tuple[int, int]:
    tests = [
        # (name, minterms, dont_cares, n, overlapping, expected_term_count, extra_check)
        # ── Original tests (from Scheinman 1962 PDF examples) ──────────────
        ("2-var f(0,1,2)",           [0, 1, 2],                          [],      2, True,  None, None),
        ("4-var 12-minterm",         [1, 2, 3, 4, 6, 7, 8, 9, 11, 12, 13, 14], [], 4, True, 4, None),
        ("4-var 11-minterm",         [0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 15], [],    4, True,  3,    None),
        ("4-var with don't-cares",   [2, 3, 10, 11, 12, 13, 14, 15],      [1, 6, 7], 4, True, 2, None),
        ("5-var",                    [1, 2, 6, 7, 9, 13, 14, 15, 17, 22, 23, 25, 29, 30, 31], [], 5, True, 4, None),
        ("non-overlapping 2-var",    [0, 1, 2],                          [],      2, False, None, None),
        ("single minterm",           [5],                                 [],      3, True,  1,    None),
        ("constant 0",               [],                                  [],      4, True,  0,    None),
        ("constant 1",               list(range(16)),                     [],      4, True,  1,    None),

        # ── Additional tests (sourced from online references) ────────────
        # SO-1  geeksforgeeks.org/digital-logic/quine-mccluskey-method/
        #   Expected: B'C' + A'D' + AD  (3 terms, verified)
        ("GeeksForGeeks 4-var",
         [0, 1, 2, 4, 6, 8, 9, 11, 13, 15], [], 4, True, 3, None),

        # SO-2  tutorialspoint.com/digital-electronics/quine-mccluskey-tabular-method.htm
        #   Expected: A'B'C' + A'BD + ACD'  (3 terms, verified)
        ("Tutorialspoint 4-var",
         [0, 1, 5, 7, 10, 14], [], 4, True, 3, None),

        # SO-3  elprocus.com/quine-mccluskey-method/
        #   Expected: CD' + AB' + AC  (3 terms, verified)
        ("Elprocus 4-var",
         [2, 6, 8, 9, 10, 11, 14, 15], [], 4, True, 3, None),

        # SO-4  en.wikipedia.org/wiki/Quine-McCluskey_algorithm
        #   Two equally-minimal covers: BC'D'+AC+AB' or BC'D'+AD'+AC  (3 terms each)
        ("Wikipedia QMC 4-var with don't-cares",
         [4, 8, 10, 11, 12, 15], [9, 14], 4, True, 3, None),

        # SO-5  askfilo.com (QMC tabular method worked example)
        #   Expected: A'C'D + BC' + BD'  (3 terms, verified)
        ("Askfilo 4-var with don't-cares",
         [1, 5, 6, 12, 13, 14], [2, 4], 4, True, 3, None),

        # ── From course notes (docs/1.pdf, docs/2.pdf) ──────────────────────────────
        # CN-1  Overlapping f(1,3,5,6,7): C covers {1,3,5,7}, AB covers {6,7} → C + AB
        ("2.pdf p1 3-var overlapping f(1,3,5,6,7)",
         [1, 3, 5, 6, 7], [], 3, True, 2, None),

        # CN-2  Same function non-overlapping: minterm 7 owned by exactly one PI → 2 terms
        ("2.pdf p1 3-var non-overlapping f(1,3,5,6,7)",
         [1, 3, 5, 6, 7], [], 3, False, 2, None),

        # CN-3  Trivial all-group: {12,13,14,15} → AB (1 term, A=1 B=1 CD-free)
        ("1.pdf p5 4-var f(12,13,14,15) → AB",
         [12, 13, 14, 15], [], 4, True, 1, None),

        # CN-4  D covers 8 of 9; {12} merges with {13} → ABĈ → 2 terms total
        ("2.pdf p3 4-var f(1,3,5,7,9,11,12,13,15)",
         [1, 3, 5, 7, 9, 11, 12, 13, 15], [], 4, True, 2, None),

        # CN-5  Mixed density, no known term count — semantic check only
        ("2.pdf p4 4-var f(1,2,4,7,8,9,11,12,13,14)",
         [1, 2, 4, 7, 8, 9, 11, 12, 13, 14], [], 4, True, None, None),

        # CN-6  n=6: {0-3}→ĀBĈD̄, {60-63}→ABCD → 2 terms (E,F free in each group)
        ("2.pdf p5 6-var f(0,1,2,3,60,61,62,63)",
         [0, 1, 2, 3, 60, 61, 62, 63], [], 6, True, 2, None),

        # CN-7  5-var arithmetic cell (2.pdf p12-14) — semantic only, both modes
        ("2.pdf p12 5-var S arithmetic overlapping",
         [4, 5, 6, 7, 12, 13, 14, 15, 17, 18, 20, 23, 24, 27, 29, 30], [], 5, True, None, None),

        ("2.pdf p14 5-var S arithmetic non-overlapping",
         [4, 5, 6, 7, 12, 13, 14, 15, 17, 18, 20, 23, 24, 27, 29, 30], [], 5, False, None, None),
    ]

    passes = 0
    failures = 0

    for idx, (name, minterms, dont_cares, n, overlapping, expected_terms, _extra) in enumerate(tests, 1):
        try:
            selected = minimize(minterms, dont_cares, n, overlapping=overlapping)
            expr_str = format_expression(selected, n, use_unicode=True)
            ok, msg = _verify_cover(selected, minterms, n, dont_cares)

            issues: list[str] = []
            if not ok:
                issues.append(msg)
            if expected_terms is not None and len(selected) != expected_terms:
                issues.append(f"expected {expected_terms} terms, got {len(selected)}")

            if issues:
                print(f"[FAIL] Test {idx}: {name} -> {expr_str}  ({'; '.join(issues)})")
                failures += 1
            else:
                print(f"[PASS] Test {idx}: {name} -> {expr_str}")
                passes += 1

        except Exception as exc:  # noqa: BLE001
            import traceback
            print(f"[FAIL] Test {idx}: {name} -> EXCEPTION: {exc}")
            traceback.print_exc()
            failures += 1

    return passes, failures


# ---------------------------------------------------------------------------
# Test suite — multiple output
# ---------------------------------------------------------------------------

def _run_tests_multi() -> tuple[int, int]:
    passes = 0
    failures = 0
    test_num = 0

    def run_test(name: str, functions: list[list[int]], dont_cares, n: int, checks) -> None:
        nonlocal passes, failures, test_num
        test_num += 1
        try:
            result = minimize_multi(functions, dont_cares, n)
            issues: list[str] = []

            # Semantic correctness: each output must cover its minterms exactly
            all_dc: set[int] = set()
            if dont_cares and isinstance(dont_cares[0], int):
                all_dc = set(dont_cares)
            for i, fn in enumerate(functions):
                cover = result.get(i, [])
                ok, msg = _verify_cover(cover, fn, n, list(all_dc))
                if not ok:
                    issues.append(f"output {i}: {msg}")

            # Run any extra checks
            for check_fn in checks:
                check_issue = check_fn(result, functions, n)
                if check_issue:
                    issues.append(check_issue)

            exprs = {i: format_expression(result.get(i, []), n) for i in range(len(functions))}
            summary = "  ".join(f"f{i}={exprs[i]}" for i in range(len(functions)))

            if issues:
                print(f"[FAIL] Multi-Test {test_num}: {name} -> {summary}  ({'; '.join(issues)})")
                failures += 1
            else:
                print(f"[PASS] Multi-Test {test_num}: {name} -> {summary}")
                passes += 1

        except Exception as exc:  # noqa: BLE001
            import traceback
            print(f"[FAIL] Multi-Test {test_num}: {name} -> EXCEPTION: {exc}")
            traceback.print_exc()
            failures += 1

    # ------------------------------------------------------------------
    # Test 1 — PDF Example 1: 4-input, 3-output
    # ------------------------------------------------------------------
    fa = [1, 5, 7, 9, 13, 15]
    fb = [0, 1, 2, 3, 5, 7, 15]
    fg = [7, 8, 9, 10, 11, 13, 15]

    def check_t1_unique_pis(result, functions, n):
        all_pis: set[Implicant] = set()
        for cover in result.values():
            all_pis.update(cover)
        # Individual minimization would give at most 7; shared should be <=7
        if len(all_pis) > 5:
            return f"unique PIs = {len(all_pis)}, expected <=5 (global sharing)"
        return None

    run_test(
        "PDF Example 1 (4-var, 3-output)",
        [fa, fb, fg],
        [],
        4,
        [check_t1_unique_pis],
    )

    # ------------------------------------------------------------------
    # Test 2 — Full Adder Sum & Carry (3-input, 2-output)
    # ------------------------------------------------------------------
    fsum = [1, 2, 4, 7]
    fcarry = [3, 5, 6, 7]

    run_test(
        "Full Adder Sum & Carry (3-var, 2-output)",
        [fsum, fcarry],
        [],
        3,
        [],
    )

    # ------------------------------------------------------------------
    # Test 3 — 2-output, 3-variable: B-bar and C-bar
    # ------------------------------------------------------------------
    f0_t3 = [0, 1, 4, 5]  # B=0 in 3-variable space (ABC: A=0/1, B=0, C=0/1)
    f1_t3 = [0, 2, 4, 6]  # C=0 in 3-variable space

    def check_t3_sharing(result, functions, n):
        # Minterms {0, 4} appear in both: verify both outputs get correct cover
        for i, fn in enumerate(functions):
            for m in fn:
                if not any(pi.covers(m) for pi in result[i]):
                    return f"output {i} misses minterm {m}"
        return None

    run_test(
        "B-bar and C-bar (3-var, 2-output)",
        [f0_t3, f1_t3],
        [],
        3,
        [check_t3_sharing],
    )

    # ------------------------------------------------------------------
    # Test 4 — Disjoint outputs (no sharing possible)
    # ------------------------------------------------------------------
    f0_t4 = [0, 1]  # A-bar in 2-variable space
    f1_t4 = [2, 3]  # A in 2-variable space

    def check_t4_no_sharing(result, functions, n):
        pis_0 = set(result[0])
        pis_1 = set(result[1])
        shared = pis_0 & pis_1
        if shared:
            return f"unexpected shared PIs between disjoint outputs: {shared}"
        return None

    def check_t4_two_terms(result, functions, n):
        all_pis: set[Implicant] = set()
        for cover in result.values():
            all_pis.update(cover)
        if len(all_pis) != 2:
            return f"expected 2 unique terms for disjoint outputs, got {len(all_pis)}"
        return None

    run_test(
        "Disjoint outputs A-bar and A (2-var, 2-output)",
        [f0_t4, f1_t4],
        [],
        2,
        [check_t4_no_sharing, check_t4_two_terms],
    )

    # ------------------------------------------------------------------
    # Test 5 — Complete sharing: identical functions
    # ------------------------------------------------------------------
    f_shared = [1, 3, 5, 7]  # C in 3-variable space (LSB = 1)

    def check_t5_single_shared(result, functions, n):
        all_pis: set[Implicant] = set()
        for cover in result.values():
            all_pis.update(cover)
        if len(all_pis) != 1:
            return f"identical outputs should share 1 unique PI, got {len(all_pis)}"
        return None

    run_test(
        "Identical outputs (3-var, 2-output, complete sharing)",
        [f_shared, f_shared],
        [],
        3,
        [check_t5_single_shared],
    )

    # ------------------------------------------------------------------
    # Test 6 — Backward compatibility: single output via minimize_multi
    # ------------------------------------------------------------------
    minterms_t6 = [0, 1, 2]
    n_t6 = 2

    def check_t6_compat(result, functions, n):
        single = minimize(minterms_t6, [], n_t6)
        multi = result[0]
        # Compare by expression string (canonical form)
        expr_single = format_expression(single, n)
        expr_multi = format_expression(multi, n)
        if expr_single != expr_multi:
            return f"single={expr_single!r} != multi={expr_multi!r}"
        return None

    run_test(
        "Backward compat: single output via minimize_multi",
        [minterms_t6],
        [],
        n_t6,
        [check_t6_compat],
    )

    # ------------------------------------------------------------------
    # Test 7 — Shared PI AB across two outputs (constructed, verified)
    #   F1 = {8,9,12,13,14,15}  minimal: AB + AC̄  (2 terms)
    #   F2 = {3,7,12,13,14,15}  minimal: ĀCD + AB  (2 terms)
    #   AB = {12,13,14,15} must appear in covers of both outputs
    # ------------------------------------------------------------------
    f0_t7 = [8, 9, 12, 13, 14, 15]
    f1_t7 = [3, 7, 12, 13, 14, 15]

    def check_t7_ab_shared(result, functions, n):
        # The PI AB (value=12, mask=12) must appear in both output covers
        ab_in_f0 = any(pi.value & pi.mask == 12 and pi.mask == 12 for pi in result[0])
        ab_in_f1 = any(pi.value & pi.mask == 12 and pi.mask == 12 for pi in result[1])
        if not ab_in_f0:
            return "PI AB missing from f0 cover"
        if not ab_in_f1:
            return "PI AB missing from f1 cover"
        return None

    def check_t7_term_count(result, functions, n):
        issues = []
        for i, expected in enumerate([2, 2]):
            if len(result[i]) != expected:
                issues.append(f"f{i}: expected {expected} terms, got {len(result[i])}")
        return "; ".join(issues) or None

    run_test(
        "Shared PI AB (4-var, 2-output, constructed)",
        [f0_t7, f1_t7],
        [],
        4,
        [check_t7_ab_shared, check_t7_term_count],
    )

    # ------------------------------------------------------------------
    # Test 8 — 2×2-bit multiplier (4-var, 4-output)
    #   Inputs AB × CD, output WXYZ (4-bit product)
    #   W = {15}            → ABCD        (1 term)
    #   X = {10,11,14}      → AB̄C + ACD̄  (2 terms)
    #   Y = {6,7,9,11,13,14}→ 4 terms
    #   Z = {5,7,13,15}     → BD          (1 term, shared with nothing)
    #   Source: MIT 6.111 tutorial synthesis examples
    # ------------------------------------------------------------------
    f_W = [15]
    f_X = [10, 11, 14]
    f_Y = [6, 7, 9, 11, 13, 14]
    f_Z = [5, 7, 13, 15]

    def check_t8_term_counts(result, functions, n):
        issues = []
        for i, (expected, label) in enumerate([(1, 'W'), (2, 'X'), (4, 'Y'), (1, 'Z')]):
            if len(result[i]) != expected:
                issues.append(f"f{i}({label}): expected {expected} terms, got {len(result[i])}")
        return "; ".join(issues) or None

    def check_t8_z_is_bd(result, functions, n):
        # Z = BD: value=5 (0101), mask=5 (bits B and D set)
        cover_z = result[3]
        if len(cover_z) != 1:
            return None  # term count check handles this
        pi = cover_z[0]
        if pi.value & pi.mask != 5 or pi.mask != 5:
            return f"Z expected BD (value&mask=5, mask=5), got value={pi.value} mask={pi.mask}"
        return None

    run_test(
        "2×2-bit multiplier (4-var, 4-output, MIT 6.111)",
        [f_W, f_X, f_Y, f_Z],
        [],
        4,
        [check_t8_term_counts, check_t8_z_is_bd],
    )

    # ------------------------------------------------------------------
    # Test 9 — 2-bit binary adder carry and LSB outputs (4-var, 2-output)
    #   AB + CD → carry X and sum LSB Z
    #   X (carry-out) = {7,10,11,13,14,15} → 3 terms
    #   Z (sum LSB)   = {1,3,4,6,9,11,12,14} → B'D + BD' (XOR, 2 terms)
    #   Source: MIT 6.111 tutorial synthesis examples
    # ------------------------------------------------------------------
    f_X2 = [7, 10, 11, 13, 14, 15]
    f_Z2 = [1, 3, 4, 6, 9, 11, 12, 14]

    def check_t9_term_counts(result, functions, n):
        issues = []
        for i, (expected, label) in enumerate([(3, 'carry'), (2, 'LSB')]):
            if len(result[i]) != expected:
                issues.append(f"f{i}({label}): expected {expected} terms, got {len(result[i])}")
        return "; ".join(issues) or None

    def check_t9_lsb_is_xor(result, functions, n):
        # Z LSB = B'D + BD': two 2-literal PIs that together implement XOR(B,D)
        cover_z = result[1]
        if len(cover_z) != 2:
            return None  # term count check handles this
        # Each PI must have mask with exactly 2 care bits (B and D, i.e. bits 2 and 0)
        for pi in cover_z:
            if pi.mask != 5:   # mask = 0101 = 5 means B and D are care bits
                return f"LSB PI has unexpected mask {pi.mask}, expected 5 (B and D)"
        return None

    run_test(
        "2-bit adder carry+LSB (4-var, 2-output, MIT 6.111)",
        [f_X2, f_Z2],
        [],
        4,
        [check_t9_term_counts, check_t9_lsb_is_xor],
    )

    return passes, failures


# ---------------------------------------------------------------------------
# Test suite — build_trace() structural tests
# ---------------------------------------------------------------------------

def _run_trace_tests() -> tuple[int, int]:
    """Verify build_trace() step structure, DFS order, and absorbed flags."""
    passes = 0
    failures = 0
    test_num = 0

    def run(name, check_fn):
        nonlocal passes, failures, test_num
        test_num += 1
        try:
            issues = check_fn()
            if issues:
                print(f"[FAIL] Trace-Test {test_num}: {name}  ({'; '.join(issues)})")
                failures += 1
            else:
                print(f"[PASS] Trace-Test {test_num}: {name}")
                passes += 1
        except Exception as exc:  # noqa: BLE001
            import traceback
            print(f"[FAIL] Trace-Test {test_num}: {name} -> EXCEPTION: {exc}")
            traceback.print_exc()
            failures += 1

    # TR-1  Σ(0,1,2) n=2 — basic step structure + root match pair
    def tr1():
        t = build_trace([0, 1, 2], [], 2, overlapping=True)
        issues = []
        if len(t.tree_steps) != 3:
            issues.append(f"expected 3 tree steps, got {len(t.tree_steps)}")
        if t.tree_steps:
            r = t.tree_steps[0]
            if r.depth != 0:
                issues.append(f"step 1 depth expected 0, got {r.depth}")
            if len(r.input_rows) != 3:
                issues.append(f"root input expected 3 rows, got {len(r.input_rows)}")
            if len(r.match_pairs) != 1:
                issues.append(f"root expected 1 match pair, got {len(r.match_pairs)}")
            if r.match_pairs:
                mp = r.match_pairs[0]
                if mp.left_value != 0 or mp.right_value_original != 2:
                    issues.append(
                        f"root match: expected left_value=0, right_value_original=2, "
                        f"got ({mp.left_value},{mp.right_value_original})")
        return issues

    run("Σ(0,1,2) n=2 — step structure and root match pair", tr1)

    # TR-2  Σ(1,3,5,6,7) n=3 overlapping — 2 root match pairs → 2 selected PIs
    def tr2():
        t = build_trace([1, 3, 5, 6, 7], [], 3, overlapping=True)
        issues = []
        r = t.tree_steps[0]
        if len(r.input_rows) != 5:
            issues.append(f"root input expected 5, got {len(r.input_rows)}")
        if len(r.match_pairs) != 2:
            issues.append(f"root match pairs expected 2, got {len(r.match_pairs)}")
        if len(t.pi_chart.selected_pis) != 2:
            issues.append(f"expected 2 selected PIs, got {len(t.pi_chart.selected_pis)}")
        return issues

    run("Σ(1,3,5,6,7) n=3 overlapping — 2 root matches, 2 PIs", tr2)

    # TR-3  Σ(1,3,5,6,7) n=3 non-overlapping — left subtree of root pruned
    def tr3():
        t = build_trace([1, 3, 5, 6, 7], [], 3, overlapping=False)
        issues = []
        r = t.tree_steps[0]
        if len(r.match_pairs) != 2:
            issues.append(f"root match pairs expected 2, got {len(r.match_pairs)}")
        if r.child_step_left is not None:
            issues.append("root left subtree should be pruned (None) in non-overlapping mode")
        if len(t.pi_chart.selected_pis) != 2:
            issues.append(f"expected 2 selected PIs, got {len(t.pi_chart.selected_pis)}")
        return issues

    run("Σ(1,3,5,6,7) n=3 non-overlapping — left subtree pruned", tr3)

    # TR-4  Ancestor-absorbed flags present in child steps
    def tr4():
        t = build_trace([0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 15], [], 4, overlapping=True)
        issues = []
        if not t.tree_steps:
            return ["no tree steps produced"]
        found = any(
            row.absorbed_at_depth
            for step in t.tree_steps if step.depth > 0
            for row in step.input_rows
        )
        if not found:
            issues.append("no child step has an ancestor-absorbed item in input_rows")
        ok, msg = _verify_cover(t.pi_chart.selected_pis, [0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 15], 4, [])
        if not ok:
            issues.append(f"cover invalid: {msg}")
        if len(t.pi_chart.selected_pis) != 3:
            issues.append(f"expected 3 PIs (ĀD+BC+CD), got {len(t.pi_chart.selected_pis)}")
        return issues

    run("4-var 11-minterm — absorbed ancestor flags in child steps", tr4)

    # TR-5  DFS order: matched child step < left and right children
    def tr5():
        t = build_trace([0, 1, 2, 3, 4, 6, 7, 8, 9, 11, 15], [], 4, overlapping=True)
        issues = []
        for step in t.tree_steps:
            m, l, r = step.child_step_matched, step.child_step_left, step.child_step_right
            if m is not None and l is not None and m >= l:
                issues.append(f"step {step.step_index}: matched child {m} >= left child {l}")
            if m is not None and r is not None and m >= r:
                issues.append(f"step {step.step_index}: matched child {m} >= right child {r}")
        return issues

    run("4-var 11-minterm — DFS order (matched < left < right)", tr5)

    # TR-6  Don't-cares marked is_dc=True; cover correct (1.pdf p4 example)
    def tr6():
        t = build_trace([2, 3, 10, 11, 12, 13, 14, 15], [1, 6, 7], 4, overlapping=True)
        issues = []
        if not t.tree_steps:
            return ["no tree steps"]
        dc_rows = [r for r in t.tree_steps[0].input_rows if r.is_dc]
        if len(dc_rows) != 3:
            issues.append(f"root should have 3 dc rows (1,6,7), got {len(dc_rows)}")
        ok, msg = _verify_cover(t.pi_chart.selected_pis,
                                [2, 3, 10, 11, 12, 13, 14, 15], 4, [1, 6, 7])
        if not ok:
            issues.append(f"cover invalid: {msg}")
        if len(t.pi_chart.selected_pis) != 2:
            issues.append(f"expected 2 PIs (C+AB), got {len(t.pi_chart.selected_pis)}")
        return issues

    run("4-var don't-care example — dc rows flagged, cover correct", tr6)

    # TR-7  mask field populated on all rows
    def tr7():
        t = build_trace([1, 3, 5, 6, 7], [], 3, overlapping=True)
        issues = []
        for step in t.tree_steps:
            for row in step.input_rows + step.left_rows + step.right_rows:
                if row.mask == 0 and not row.is_dc:
                    issues.append(
                        f"step {step.step_index}: mask=0 for non-dc item value={row.display_value}")
        return issues

    run("Σ(1,3,5,6,7) n=3 — mask field populated on all rows", tr7)

    return passes, failures


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def _run_cli(argv: Optional[list[str]] = None) -> None:
    import argparse

    p = argparse.ArgumentParser(prog="scheinman", description="Scheinman Boolean minimization")

    # Multi-output mode detection: check for --multi before full parse
    raw = argv if argv is not None else sys.argv[1:]
    if "--multi" in raw:
        p.add_argument("--multi", type=int, metavar="N", required=True,
                       help="Number of variables (multi-output mode)")
        p.add_argument("--fn", action="append", default=[], metavar="MINTERMS",
                       help="Comma-separated minterms for one output function (repeat per output)")
        p.add_argument("--dont-cares", default="", dest="dont_cares")
        p.add_argument("--non-overlapping", action="store_true")
        args = p.parse_args(raw)

        n = args.multi
        functions = [
            [int(x) for x in fn_str.split(",") if x.strip()]
            for fn_str in args.fn
        ]
        dont_cares = [int(x) for x in args.dont_cares.split(",") if x.strip()]
        result = minimize_multi(functions, dont_cares, n, overlapping=not args.non_overlapping)

        for i in range(len(functions)):
            expr = format_expression(result.get(i, []), n)
            print(f"f{i}: {expr}")
    else:
        p.add_argument("n", type=int, help="Number of variables")
        p.add_argument("minterms", help="Comma-separated minterm indices, e.g. 1,2,3")
        p.add_argument("--dont-cares", default="", dest="dont_cares")
        p.add_argument("--non-overlapping", action="store_true")
        p.add_argument("--ascii", action="store_true")
        args = p.parse_args(raw)

        minterms = [int(x) for x in args.minterms.split(",") if x.strip()]
        dont_cares = [int(x) for x in args.dont_cares.split(",") if x.strip()]
        result_single = minimize(minterms, dont_cares, args.n, overlapping=not args.non_overlapping)
        print(format_expression(result_single, args.n, use_unicode=not args.ascii))


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    # If no positional args and no --multi, run the test suite
    if len(sys.argv) == 1:
        print("=== Single-output tests ===")
        so_passes, so_failures = _run_tests()
        print()
        print("=== Multi-output tests ===")
        mo_passes, mo_failures = _run_tests_multi()
        print()
        print("=== Trace tests ===")
        tr_passes, tr_failures = _run_trace_tests()
        print()
        total_passes = so_passes + mo_passes + tr_passes
        total_failures = so_failures + mo_failures + tr_failures
        print(f"{total_passes}/{total_passes + total_failures} tests passed.")
        if total_failures:
            sys.exit(1)
    else:
        _run_cli()
