"""
Flask web application for the Scheinman Boolean minimization tool.

NOTE: This app binds to port 80, which requires elevated privileges on Linux.
      Run with: sudo python web/app.py
      For development without sudo: python web/app.py --dev (still port 80, use 8080 if needed)
      To use a different port, set the PORT environment variable.
"""

import sys
import os
import logging
import argparse
from typing import Any

# ---------------------------------------------------------------------------
# Path fix: import scheinman from ../python/
# ---------------------------------------------------------------------------
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'python'))
from scheinman import minimize, minimize_multi, format_expression  # noqa: E402

from flask import Flask, render_template, request, jsonify

# Lazy import — visualize pulls in matplotlib which is heavy
_visualize = None

def _get_visualize():
    global _visualize
    if _visualize is None:
        sys.path.insert(0, os.path.dirname(__file__))
        import visualize as _viz
        _visualize = _viz
    return _visualize

app = Flask(__name__)
log = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _parse_int_list(raw) -> list[int]:
    """Accept a JSON list of ints, or a comma-separated string. Returns [] for blank."""
    if isinstance(raw, list):
        return [int(x) for x in raw]
    raw = str(raw).strip()
    if not raw:
        return []
    parts = [p.strip() for p in raw.split(',') if p.strip()]
    return [int(p) for p in parts]


def _validate_n(n_raw: Any) -> int:
    n = int(n_raw)
    if not (1 <= n <= 20):
        raise ValueError(f"Number of variables must be between 1 and 20, got {n}")
    return n


# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------

@app.route('/')
def index() -> str:
    return render_template('index.html')


@app.route('/api/minimize', methods=['POST'])
def api_minimize():
    """
    Single-output minimization endpoint.

    JSON body:
        n           : int
        minterms    : str  (comma-separated)
        dont_cares  : str  (comma-separated, optional)
        overlapping : bool (default true)

    Returns:
        { expression, term_count, variables } on success
        { error } on failure
    """
    data = request.get_json(force=True)
    try:
        n = _validate_n(data.get('n', ''))
        minterms = _parse_int_list(data.get('minterms', []))
        dont_cares = _parse_int_list(data.get('dont_cares', []))
        overlapping = bool(data.get('overlapping', True))

        selected = minimize(minterms, dont_cares, n, overlapping=overlapping)
        expression = format_expression(selected, n, use_unicode=True)

        return jsonify({
            'expression': expression,
            'term_count': len(selected),
            'variables': n,
        })
    except (ValueError, TypeError) as exc:
        log.warning("minimize error: %s", exc)
        return jsonify({'error': str(exc)}), 400
    except Exception as exc:
        log.exception("Unexpected error in /api/minimize")
        return jsonify({'error': f'Internal error: {exc}'}), 500


@app.route('/api/minimize-multi', methods=['POST'])
def api_minimize_multi():
    """
    Multi-output minimization endpoint.

    JSON body:
        n           : int
        functions   : list[str]  (each a comma-separated minterm list)
        dont_cares  : str        (comma-separated shared don't-cares, optional)
        overlapping : bool       (default true)

    Returns:
        { outputs: [{ expression, term_count }], unique_pi_count, variables } on success
        { error } on failure
    """
    data = request.get_json(force=True)
    try:
        n = _validate_n(data.get('n', ''))
        raw_functions: list[str] = data.get('functions', [])
        if not raw_functions:
            raise ValueError("At least one output function is required")
        functions = [_parse_int_list(f) for f in raw_functions]
        dont_cares = _parse_int_list(data.get('dont_cares', []))
        overlapping = bool(data.get('overlapping', True))

        result = minimize_multi(functions, dont_cares, n, overlapping=overlapping)

        outputs = []
        all_pis: set = set()
        for i in range(len(functions)):
            cover = result.get(i, [])
            all_pis.update(id(pi) for pi in cover)
            # Use object identity for unique count across outputs
            outputs.append({
                'expression': format_expression(cover, n, use_unicode=True),
                'term_count': len(cover),
            })

        # Count unique prime implicants across all outputs by value+mask identity
        unique_pis: set = set()
        for i in range(len(functions)):
            for pi in result.get(i, []):
                unique_pis.add((pi.value & pi.mask, pi.mask))

        return jsonify({
            'outputs': outputs,
            'unique_pi_count': len(unique_pis),
            'variables': n,
        })
    except (ValueError, TypeError) as exc:
        log.warning("minimize_multi error: %s", exc)
        return jsonify({'error': str(exc)}), 400
    except Exception as exc:
        log.exception("Unexpected error in /api/minimize-multi")
        return jsonify({'error': f'Internal error: {exc}'}), 500


# ---------------------------------------------------------------------------
# Visualization endpoint
# ---------------------------------------------------------------------------

@app.route('/api/visualize', methods=['POST'])
def api_visualize():
    """
    Generate step-by-step PNG images for a minimization run.

    JSON body: same as /api/minimize or /api/minimize-multi, plus:
        mode : "single" | "multi"  (default "single")

    Guard: n <= 4 and (for multi) <= 3 output functions.
    Returns: { steps: [{ title, image }] }  where image is base64 PNG.
    """
    data = request.get_json(force=True)
    try:
        n = _validate_n(data.get('n', ''))
        if n > 4:
            return jsonify({'error': 'Visualization is only available for n ≤ 4 variables'}), 400

        mode       = data.get('mode', 'single')
        dont_cares = _parse_int_list(data.get('dont_cares', []))
        overlapping = bool(data.get('overlapping', True))
        viz        = _get_visualize()

        if mode == 'multi':
            raw_functions = data.get('functions', [])
            if len(raw_functions) > 3:
                return jsonify(
                    {'error': 'Visualization is only available for ≤ 3 output functions'}
                ), 400
            functions = [_parse_int_list(f) for f in raw_functions]
            payload = viz.build_visualize_payload(
                [], dont_cares, n, overlapping, functions=functions
            )
        else:
            minterms = _parse_int_list(data.get('minterms', []))
            payload  = viz.build_visualize_payload(minterms, dont_cares, n, overlapping)

        return jsonify(payload)
    except (ValueError, TypeError) as exc:
        log.warning("visualize error: %s", exc)
        return jsonify({'error': str(exc)}), 400
    except Exception as exc:
        log.exception("Unexpected error in /api/visualize")
        return jsonify({'error': f'Internal error: {exc}'}), 500


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Scheinman Boolean Minimizer web server')
    parser.add_argument('--dev', action='store_true',
                        help='Enable Flask debug mode (development only)')
    parser.add_argument('--port', type=int,
                        default=int(os.environ.get('PORT', 80)),
                        help='Port to bind (default: 80, requires sudo on Linux)')
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.DEBUG if args.dev else logging.INFO,
        format='%(asctime)s %(levelname)s %(name)s: %(message)s',
    )

    if args.port == 80 and not args.dev:
        log.warning(
            "Binding to port 80. On Linux this requires root privileges. "
            "Run with: sudo python web/app.py"
        )

    app.run(host='0.0.0.0', port=args.port, debug=args.dev)
