From 86ce0593d7de438b4c9a283d32afd60ba15f7d21 Mon Sep 17 00:00:00 2001 From: Misty-Star <144589976+Misty-Star@users.noreply.github.com> Date: Wed, 25 Mar 2026 16:02:43 +0800 Subject: [PATCH 1/2] fix: preload onnxruntime cuda dependencies --- .../separator/architectures/mdx_separator.py | 12 +++++++ audio_separator/separator/separator.py | 18 +++++++++++ audio_separator/utils/cli.py | 8 +++-- tests/unit/test_cli.py | 11 +++++-- tests/unit/test_gpu_runtime_setup.py | 31 +++++++++++++++++++ 5 files changed, 76 insertions(+), 4 deletions(-) create mode 100644 tests/unit/test_gpu_runtime_setup.py diff --git a/audio_separator/separator/architectures/mdx_separator.py b/audio_separator/separator/architectures/mdx_separator.py index 0516d3bc..88717896 100644 --- a/audio_separator/separator/architectures/mdx_separator.py +++ b/audio_separator/separator/architectures/mdx_separator.py @@ -120,6 +120,18 @@ def load_model(self): ort_session_options.log_severity_level = 0 ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options) + session_providers = ort_inference_session.get_providers() + + requested_provider = self.onnx_execution_provider[0] if self.onnx_execution_provider else None + if requested_provider and requested_provider not in session_providers: + self.logger.warning( + f"ONNX Runtime could not activate requested provider {requested_provider}; " + f"session is using {session_providers}. This usually means required CUDA/cuDNN " + f"runtime libraries are not visible to the dynamic loader." + ) + else: + self.logger.debug(f"ONNX Runtime session providers: {session_providers}") + self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0] self.logger.debug("Model loaded successfully using ONNXruntime inferencing session.") else: diff --git a/audio_separator/separator/separator.py b/audio_separator/separator/separator.py index 83024881..0f7d82e8 100644 --- a/audio_separator/separator/separator.py +++ b/audio_separator/separator/separator.py @@ -321,8 +321,26 @@ def setup_accelerated_inferencing_device(self): system_info = self.get_system_info() self.check_ffmpeg_installed() self.log_onnxruntime_packages() + self.preload_onnxruntime_dependencies() self.setup_torch_device(system_info) + def preload_onnxruntime_dependencies(self): + """ + Preload ONNX Runtime shared library dependencies when supported by the installed package. + + This helps pip-installed CUDA/cuDNN runtime wheels become visible to ONNX Runtime before + the first CUDAExecutionProvider session is created. + """ + if not hasattr(ort, "preload_dlls"): + self.logger.debug("Installed ONNX Runtime does not provide preload_dlls(); skipping dependency preload.") + return + + try: + ort.preload_dlls() + self.logger.debug("Preloaded ONNX Runtime shared library dependencies.") + except Exception as exc: + self.logger.warning(f"Unable to preload ONNX Runtime shared library dependencies: {exc}") + def get_system_info(self): """ This method logs the system information, including the operating system, CPU archutecture and Python version diff --git a/audio_separator/utils/cli.py b/audio_separator/utils/cli.py index c45fcaf2..dbf8f515 100755 --- a/audio_separator/utils/cli.py +++ b/audio_separator/utils/cli.py @@ -150,13 +150,13 @@ def main(): log_level = getattr(logging, args.log_level.upper()) logger.setLevel(log_level) - from audio_separator.separator import Separator - if args.env_info: + from audio_separator.separator import Separator separator = Separator() sys.exit(0) if args.list_models: + from audio_separator.separator import Separator separator = Separator(info_only=True) if args.list_format == "json": @@ -190,6 +190,7 @@ def main(): sys.exit(0) if args.list_presets: + from audio_separator.separator import Separator separator = Separator(info_only=True) presets = separator.list_ensemble_presets() @@ -217,6 +218,7 @@ def main(): sys.exit(0) if args.download_model_only: + from audio_separator.separator import Separator models_to_download = [args.model_filename] + (args.extra_models or []) separator = Separator(log_formatter=log_formatter, log_level=log_level, model_file_dir=args.model_file_dir) for model in models_to_download: @@ -233,6 +235,8 @@ def main(): logger.info(f"Separator version {package_version} beginning with input path(s): {', '.join(audio_files)}") + from audio_separator.separator import Separator + separator = Separator( log_formatter=log_formatter, log_level=log_level, diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index cae45184..72bb610a 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -60,8 +60,15 @@ def test_cli_version_subprocess(): # Test the CLI with no arguments def test_cli_no_args(capsys): - # Skip subprocess CLI tests - require proper CLI installation - pytest.skip("CLI subprocess tests require proper installation") + test_args = ["cli.py"] + + with patch("sys.argv", test_args), patch.dict("sys.modules", {"audio_separator.separator": None}): + with pytest.raises(SystemExit) as exc_info: + main() + + assert exc_info.value.code == 1 + captured = capsys.readouterr() + assert "Separate audio file into different stems." in captured.out # Test with multiple filename arguments diff --git a/tests/unit/test_gpu_runtime_setup.py b/tests/unit/test_gpu_runtime_setup.py new file mode 100644 index 00000000..b42402ea --- /dev/null +++ b/tests/unit/test_gpu_runtime_setup.py @@ -0,0 +1,31 @@ +from unittest.mock import patch + +from audio_separator.separator.separator import Separator + + +def test_setup_accelerated_inferencing_device_preloads_onnxruntime_dependencies(): + separator = Separator(info_only=True) + system_info = object() + + with patch.object(separator, "get_system_info", return_value=system_info), patch.object(separator, "check_ffmpeg_installed"), patch.object( + separator, "log_onnxruntime_packages" + ), patch("audio_separator.separator.separator.ort.preload_dlls") as mock_preload, patch.object(separator, "setup_torch_device") as mock_setup: + separator.setup_accelerated_inferencing_device() + + mock_preload.assert_called_once_with() + mock_setup.assert_called_once_with(system_info) + + +def test_setup_accelerated_inferencing_device_continues_when_preload_fails(): + separator = Separator(info_only=True) + system_info = object() + + with patch.object(separator, "get_system_info", return_value=system_info), patch.object(separator, "check_ffmpeg_installed"), patch.object( + separator, "log_onnxruntime_packages" + ), patch("audio_separator.separator.separator.ort.preload_dlls", side_effect=RuntimeError("boom")), patch.object( + separator, "setup_torch_device" + ) as mock_setup, patch.object(separator.logger, "warning") as mock_warning: + separator.setup_accelerated_inferencing_device() + + mock_setup.assert_called_once_with(system_info) + mock_warning.assert_called_once() From 51f5d6a2402e2c32558d148498e64551cd351c5a Mon Sep 17 00:00:00 2001 From: Misty-Star <144589976+Misty-Star@users.noreply.github.com> Date: Wed, 25 Mar 2026 19:48:18 +0800 Subject: [PATCH 2/2] test: allow preload_dlls patch on older ort --- tests/unit/test_gpu_runtime_setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_gpu_runtime_setup.py b/tests/unit/test_gpu_runtime_setup.py index b42402ea..276c9193 100644 --- a/tests/unit/test_gpu_runtime_setup.py +++ b/tests/unit/test_gpu_runtime_setup.py @@ -9,7 +9,7 @@ def test_setup_accelerated_inferencing_device_preloads_onnxruntime_dependencies( with patch.object(separator, "get_system_info", return_value=system_info), patch.object(separator, "check_ffmpeg_installed"), patch.object( separator, "log_onnxruntime_packages" - ), patch("audio_separator.separator.separator.ort.preload_dlls") as mock_preload, patch.object(separator, "setup_torch_device") as mock_setup: + ), patch("audio_separator.separator.separator.ort.preload_dlls", create=True) as mock_preload, patch.object(separator, "setup_torch_device") as mock_setup: separator.setup_accelerated_inferencing_device() mock_preload.assert_called_once_with() @@ -22,7 +22,7 @@ def test_setup_accelerated_inferencing_device_continues_when_preload_fails(): with patch.object(separator, "get_system_info", return_value=system_info), patch.object(separator, "check_ffmpeg_installed"), patch.object( separator, "log_onnxruntime_packages" - ), patch("audio_separator.separator.separator.ort.preload_dlls", side_effect=RuntimeError("boom")), patch.object( + ), patch("audio_separator.separator.separator.ort.preload_dlls", side_effect=RuntimeError("boom"), create=True), patch.object( separator, "setup_torch_device" ) as mock_setup, patch.object(separator.logger, "warning") as mock_warning: separator.setup_accelerated_inferencing_device()