"""Run Demucs while bypassing torchaudio's torchcodec-based load/save.

Newer torchaudio versions route both torchaudio.load() and torchaudio.save() through torchcodec.
On some Windows/Python setups torchcodec is unavailable or its DLLs fail to load, which breaks Demucs.

We patch:
- torchaudio.load  -> soundfile.read (returns torch.Tensor [channels, frames], sr)
- torchaudio.save  -> soundfile.write (accepts torch.Tensor [channels, frames])

Since our app pre-converts inputs to WAV and we output WAV stems, soundfile is sufficient.

Usage:
  python demucs_runner.py -n htdemucs -o <out_dir> <input.wav>
"""

import runpy
import sys

def _patch_torchaudio():
    try:
        import torchaudio  # noqa
        import soundfile as sf
        import torch
        import numpy as np

        def sf_load(path, *args, **kwargs):
            data, sr = sf.read(path, always_2d=True, dtype="float32")
            data = np.transpose(data, (1, 0))  # [frames, ch] -> [ch, frames]
            wav = torch.from_numpy(data)
            return wav, sr

        def sf_save(path, src, sample_rate, **kwargs):
            # src: torch.Tensor [channels, frames] or [frames] mono
            if hasattr(src, "detach"):
                src = src.detach().cpu()
            if isinstance(src, torch.Tensor):
                arr = src.numpy()
            else:
                arr = np.asarray(src)

            if arr.ndim == 1:
                arr = arr.reshape(1, -1)
            # [ch, frames] -> [frames, ch]
            arr = np.transpose(arr, (1, 0)).astype("float32", copy=False)

            # Map bits_per_sample when possible
            bits = kwargs.get("bits_per_sample", None)
            subtype = None
            if bits == 16:
                subtype = "PCM_16"
            elif bits == 24:
                subtype = "PCM_24"
            elif bits == 32:
                # if encoding says "PCM_S", use PCM_32; otherwise float
                subtype = "PCM_32" if kwargs.get("encoding") in ("PCM_S", "PCM_U") else "FLOAT"
            # Default: let soundfile choose (usually FLOAT for float32 input)
            sf.write(path, arr, sample_rate, subtype=subtype)

        torchaudio.load = sf_load
        torchaudio.save = sf_save
    except Exception:
        pass

_patch_torchaudio()

sys.argv = ["demucs.separate"] + sys.argv[1:]
runpy.run_module("demucs.separate", run_name="__main__")
