Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions audio_separator/separator/architectures/mdx_separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ def load_model(self):
ort_session_options.log_severity_level = 0

ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
session_providers = ort_inference_session.get_providers()

requested_provider = self.onnx_execution_provider[0] if self.onnx_execution_provider else None
if requested_provider and requested_provider not in session_providers:
self.logger.warning(
f"ONNX Runtime could not activate requested provider {requested_provider}; "
f"session is using {session_providers}. This usually means required CUDA/cuDNN "
f"runtime libraries are not visible to the dynamic loader."
)
else:
self.logger.debug(f"ONNX Runtime session providers: {session_providers}")

self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.")
else:
Expand Down
18 changes: 18 additions & 0 deletions audio_separator/separator/separator.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,26 @@ def setup_accelerated_inferencing_device(self):
system_info = self.get_system_info()
self.check_ffmpeg_installed()
self.log_onnxruntime_packages()
self.preload_onnxruntime_dependencies()
self.setup_torch_device(system_info)

def preload_onnxruntime_dependencies(self):
"""
Preload ONNX Runtime shared library dependencies when supported by the installed package.

This helps pip-installed CUDA/cuDNN runtime wheels become visible to ONNX Runtime before
the first CUDAExecutionProvider session is created.
"""
if not hasattr(ort, "preload_dlls"):
self.logger.debug("Installed ONNX Runtime does not provide preload_dlls(); skipping dependency preload.")
return

try:
ort.preload_dlls()
self.logger.debug("Preloaded ONNX Runtime shared library dependencies.")
except Exception as exc:
self.logger.warning(f"Unable to preload ONNX Runtime shared library dependencies: {exc}")

def get_system_info(self):
"""
This method logs the system information, including the operating system, CPU archutecture and Python version
Expand Down
8 changes: 6 additions & 2 deletions audio_separator/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,13 @@ def main():
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)

from audio_separator.separator import Separator

if args.env_info:
from audio_separator.separator import Separator
separator = Separator()
sys.exit(0)

if args.list_models:
from audio_separator.separator import Separator
separator = Separator(info_only=True)

if args.list_format == "json":
Expand Down Expand Up @@ -190,6 +190,7 @@ def main():
sys.exit(0)

if args.list_presets:
from audio_separator.separator import Separator
separator = Separator(info_only=True)
presets = separator.list_ensemble_presets()

Expand Down Expand Up @@ -217,6 +218,7 @@ def main():
sys.exit(0)

if args.download_model_only:
from audio_separator.separator import Separator
models_to_download = [args.model_filename] + (args.extra_models or [])
separator = Separator(log_formatter=log_formatter, log_level=log_level, model_file_dir=args.model_file_dir)
for model in models_to_download:
Expand All @@ -233,6 +235,8 @@ def main():

logger.info(f"Separator version {package_version} beginning with input path(s): {', '.join(audio_files)}")

from audio_separator.separator import Separator

separator = Separator(
log_formatter=log_formatter,
log_level=log_level,
Expand Down
11 changes: 9 additions & 2 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,15 @@ def test_cli_version_subprocess():

# Test the CLI with no arguments
def test_cli_no_args(capsys):
# Skip subprocess CLI tests - require proper CLI installation
pytest.skip("CLI subprocess tests require proper installation")
test_args = ["cli.py"]

with patch("sys.argv", test_args), patch.dict("sys.modules", {"audio_separator.separator": None}):
with pytest.raises(SystemExit) as exc_info:
main()

assert exc_info.value.code == 1
captured = capsys.readouterr()
assert "Separate audio file into different stems." in captured.out


# Test with multiple filename arguments
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/test_gpu_runtime_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from unittest.mock import patch

from audio_separator.separator.separator import Separator


def test_setup_accelerated_inferencing_device_preloads_onnxruntime_dependencies():
separator = Separator(info_only=True)
system_info = object()

with patch.object(separator, "get_system_info", return_value=system_info), patch.object(separator, "check_ffmpeg_installed"), patch.object(
separator, "log_onnxruntime_packages"
), patch("audio_separator.separator.separator.ort.preload_dlls", create=True) as mock_preload, patch.object(separator, "setup_torch_device") as mock_setup:
separator.setup_accelerated_inferencing_device()

mock_preload.assert_called_once_with()
mock_setup.assert_called_once_with(system_info)


def test_setup_accelerated_inferencing_device_continues_when_preload_fails():
separator = Separator(info_only=True)
system_info = object()

with patch.object(separator, "get_system_info", return_value=system_info), patch.object(separator, "check_ffmpeg_installed"), patch.object(
separator, "log_onnxruntime_packages"
), patch("audio_separator.separator.separator.ort.preload_dlls", side_effect=RuntimeError("boom"), create=True), patch.object(
separator, "setup_torch_device"
) as mock_setup, patch.object(separator.logger, "warning") as mock_warning:
separator.setup_accelerated_inferencing_device()

mock_setup.assert_called_once_with(system_info)
mock_warning.assert_called_once()
Loading