import uuid, shutil, subprocess, sys, socket
from pathlib import Path
from flask import Flask, request, jsonify, send_from_directory

APP_DIR = Path(__file__).resolve().parent
RUNS_DIR = APP_DIR / "runs"
STATIC_DIR = APP_DIR / "static"
PORT_FILE = APP_DIR / "server_port.txt"
RUNS_DIR.mkdir(exist_ok=True, parents=True)
STATIC_DIR.mkdir(exist_ok=True, parents=True)

app = Flask(__name__, static_folder=str(STATIC_DIR), static_url_path="/static")

def ffmpeg_exe() -> str:
    import imageio_ffmpeg
    return imageio_ffmpeg.get_ffmpeg_exe()

def convert_to_wav(src: Path, dst: Path, log_path: Path) -> None:
    exe = ffmpeg_exe()
    cmd = [exe, "-y", "-i", str(src), "-ac", "2", "-ar", "44100", str(dst)]
    p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, errors="replace")
    log_path.write_text(
        "FFMPEG CMD:\n" + " ".join(cmd) + "\n\nSTDOUT:\n" + (p.stdout or "") + "\n\nSTDERR:\n" + (p.stderr or ""),
        encoding="utf-8",
        errors="replace",
    )
    if p.returncode != 0:
        raise RuntimeError("ffmpeg convert failed (see ffmpeg.log)")

def run_demucs(input_wav: Path, out_dir: Path, stdout_log: Path, stderr_log: Path) -> Path:
    cmd = [sys.executable, str(APP_DIR / "demucs_runner.py"), "-n", "htdemucs", "-o", str(out_dir), str(input_wav)]
    p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, errors="replace")
    stdout_log.write_text(p.stdout or "", encoding="utf-8", errors="replace")
    stderr_log.write_text(p.stderr or "", encoding="utf-8", errors="replace")

    if p.returncode != 0:
        tail_out = (p.stdout or "")[-4000:]
        tail_err = (p.stderr or "")[-4000:]
        raise RuntimeError(
            "demucs failed (return code %s)\n\n--- demucs stdout (tail) ---\n%s\n\n--- demucs stderr (tail) ---\n%s\n\n(Full logs saved in runs/<job>/demucs_stdout.log and demucs_stderr.log)"
            % (p.returncode, tail_out, tail_err)
        )

    model_dir = out_dir / "htdemucs"
    subdirs = [d for d in model_dir.iterdir() if d.is_dir()] if model_dir.exists() else []
    if not subdirs:
        raise RuntimeError("Demucs output folder not found in: " + str(out_dir))
    subdirs.sort(key=lambda p: p.stat().st_mtime, reverse=True)
    return subdirs[0]

@app.get("/")
def home():
    return send_from_directory(str(APP_DIR), "index.html")

@app.post("/api/separate")
def separate():
    if "file" not in request.files:
        return "No file uploaded", 400
    f = request.files["file"]
    if not f.filename:
        return "No filename", 400

    job_id = uuid.uuid4().hex[:10]
    job_dir = RUNS_DIR / job_id
    job_dir.mkdir(parents=True, exist_ok=True)

    src_path = job_dir / f.filename
    f.save(src_path)

    ffmpeg_log = job_dir / "ffmpeg.log"
    demucs_out_log = job_dir / "demucs_stdout.log"
    demucs_err_log = job_dir / "demucs_stderr.log"

    wav_path = job_dir / "input.wav"
    out_dir = job_dir / "out"
    out_dir.mkdir(parents=True, exist_ok=True)

    try:
        convert_to_wav(src_path, wav_path, ffmpeg_log)
        track_dir = run_demucs(wav_path, out_dir, demucs_out_log, demucs_err_log)

        expected = {
            "vocals": track_dir / "vocals.wav",
            "drums": track_dir / "drums.wav",
            "bass": track_dir / "bass.wav",
            "other": track_dir / "other.wav",
        }
        missing = [k for k, p in expected.items() if not p.exists()]
        if missing:
            found = [x.name for x in track_dir.glob("*.wav")]
            raise RuntimeError("Missing stems: " + ", ".join(missing) + "\nFound: " + ", ".join(found))

        public_dir = STATIC_DIR / "jobs" / job_id
        if public_dir.exists():
            shutil.rmtree(public_dir)
        public_dir.mkdir(parents=True, exist_ok=True)

        out_urls = {}
        for k, p in expected.items():
            dest = public_dir / f"{k}.wav"
            shutil.copy2(p, dest)
            out_urls[k] = f"/static/jobs/{job_id}/{k}.wav"

        return jsonify({"job_id": job_id, "stems": out_urls})

    except Exception as e:
        # Return details to UI, plus keep full logs on disk
        return str(e), 500

def pick_port(preferred=8080, tries=50) -> int:
    for port in range(preferred, preferred + tries):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                s.bind(("127.0.0.1", port))
                return port
            except OSError:
                continue
    raise RuntimeError("No free port found in range")

if __name__ == "__main__":
    port = pick_port(8080, tries=50)
    try:
        PORT_FILE.write_text(str(port), encoding="utf-8")
    except Exception:
        pass
    print(f"Starting AI Stem Mixer on http://localhost:{port}")
    app.run(host="127.0.0.1", port=port, debug=False, use_reloader=False)
