Bandit-v2 - Инференс

Материал из wolfram
Версия от 07:37, 16 января 2026; Владимир (обсуждение | вклад) (Новая страница: « = Запуск инференса модели Bandit-v2 для аудио сепарации = == Суть задачи == В данной статье описывается процесс настройки и запуска инференса нейросетевой модели '''Bandit-v2''' из репозитория [https://github.com/kwatcharasupat/bandit-v2 kwatcharasupat/bandit-v2] для задачи '''Cinematic Audio Source Separation'''...»)
(разн.) ← Предыдущая версия | Текущая версия (разн.) | Следующая версия → (разн.)
Перейти к навигации Перейти к поиску

Запуск инференса модели Bandit-v2 для аудио сепарации

Суть задачи

В данной статье описывается процесс настройки и запуска инференса нейросетевой модели Bandit-v2 из репозитория kwatcharasupat/bandit-v2 для задачи Cinematic Audio Source Separation — разделения аудиодорожки на составляющие: речь (speech), музыку (music) и звуковые эффекты (sfx).

Проект изначально разработан для внутренней инфраструктуры Netflix и содержит зависимости от закрытых пакетов, что делает прямой запуск невозможным. В статье представлено рабочее решение для запуска инференса на локальной машине с GPU.

Результат работы: скрипт simple_inference.py, позволяющий обрабатывать аудиофайлы любой длительности со скоростью ~17x realtime на RTX 4090.

О проекте Bandit-v2

Bandit (Band-Split RNN) — архитектура нейросети для разделения аудио на источники. Основные особенности:

  • Band-Split подход — спектрограмма делится на 64 частотных диапазона (bands), каждый обрабатывается отдельно
  • Dual-path RNN — последовательная обработка по временной и частотной осям с использованием GRU
  • Mask Estimation — для каждого stem (речь/музыка/sfx) предсказывается комплексная маска

Технические характеристики модели:

Параметр Значение
Sample rate 48000 Hz
n_bands 64
n_fft 2048
hop_length 512
emb_dim 128
rnn_dim 256
rnn_type GRU (bidirectional)
n_sqm_modules 8

Веса модели доступны на Zenodo: checkpoint-multi.ckpt

Процесс формирования рабочего решения

Проблема 1: Netflix-специфичные зависимости

Файл requirements.txt содержит ссылки на приватный PyPI сервер Netflix:

--index-url https://pypi.netflix.net/simple

И пакеты: nflx-manta, nflx-metaflow, jasper, storage, metatron и др.

Решение: Установка только минимально необходимых публичных пакетов для инференса.

Проблема 2: Ray зависимость

Модуль src/system/utils.py импортирует Ray (распределённые вычисления), который не нужен для локального инференса.

Решение: Создан отдельный скрипт simple_inference.py, который напрямую импортирует модель без системных обёрток.

Проблема 3: Устаревший PyTorch

В requirements указан torch==2.0.0+cu118, который больше недоступен в PyPI.

Решение: Установлен torch==2.5.0+cu118.

Проблема 4: Ошибка сборки pesq

Пакет asteroid тянет зависимость pesq, требующую компиляции C++.

Решение: Пакет asteroid не устанавливается — он нужен только для метрик обучения, не для инференса.

Проблема 5: Чрезмерное потребление VRAM

Оригинальный StandardTensorChunkedInferenceHandler с batch_size=32 и overlap 7 секунд (8с chunk, 1с hop) создаёт буфер на ~50GB.

Решение: Реализован простой chunked подход с последовательной обработкой чанков (30с chunk, 2с overlap), потребление VRAM снижено до ~2GB.

Проблема 6: Stereo обработка

Модель обучена с in_channels=1 (mono). При подаче stereo возникает ошибка LayerNorm.

Решение: Каждый канал stereo обрабатывается отдельно, результаты объединяются.

Проблема 7: Half precision (FP16)

Прямая конвертация модели в .half() вызывает ошибку типов в LayerNorm.

Решение: Используется torch.cuda.amp.autocast() для автоматического mixed precision.

Исходный код

Шаблон:Спойлер

Описание simple_inference.py

Скрипт состоит из трёх основных функций:

create_model()

Создаёт экземпляр модели Bandit с параметрами, соответствующими checkpoint:

  • in_channels=1 — mono (stereo обрабатывается поканально)
  • n_bands=64, n_sqm_modules=8
  • n_fft=2048, hop_length=512
  • fs=48000

simple_chunked_inference()

Обрабатывает аудио чанками с crossfade:

  1. Аудио делится на чанки (по умолчанию 30 секунд)
  2. Соседние чанки перекрываются (по умолчанию 2 секунды)
  3. В зоне перекрытия применяется линейный crossfade
  4. Каждый чанк обрабатывается отдельно, результат сразу переносится на CPU
  5. GPU кэш очищается после каждого чанка

run_inference()

Основная функция:

  1. Загружает модель и веса из checkpoint
  2. Загружает аудио (с ресемплингом если нужно)
  3. Для stereo — обрабатывает каждый канал отдельно
  4. Сохраняет результаты в отдельные WAV файлы для каждого stem

Особенности реализации:

  • Autocast mixed precision для ускорения без проблем с типами
  • Минимальное потребление VRAM (~2GB на 30-секундный чанк)
  • Поддержка произвольной длительности аудио
  • Сохранение stereo (не конвертируется в mono)

Инструкция по установке

Требования

  • Windows 10/11 или Linux
  • NVIDIA GPU с поддержкой CUDA 11.8
  • Anaconda или Miniconda
  • ~3GB свободного места для PyTorch

Шаг 1: Создание conda окружения

conda create -n bandit-v2 python=3.10 -y
conda activate bandit-v2

Шаг 2: Установка PyTorch с CUDA 11.8

Шаблон:Внимание

pip install torch==2.5.0+cu118 torchaudio==2.5.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118

Шаг 3: Установка зависимостей

Шаблон:Внимание

pip install "pytorch_lightning>=2.3.0" hydra-core omegaconf librosa soundfile einops tqdm julius huggingface-hub pyyaml scipy pandas

Шаг 4: Клонирование репозитория

git clone https://github.com/kwatcharasupat/bandit-v2.git
cd bandit-v2

Шаг 5: Скачивание весов модели

Скачайте checkpoint-multi.ckpt с Zenodo и поместите в удобное место.

Шаг 6: Добавление simple_inference.py

Скопируйте скрипт simple_inference.py (см. раздел "Исходный код") в корень репозитория bandit-v2/.

Инструкция по запуску

Активация окружения

conda activate bandit-v2
cd путь/к/bandit-v2

Базовый запуск

python simple_inference.py -c "путь/к/checkpoint-multi.ckpt" -a "путь/к/аудио.wav"

Полный синтаксис

python simple_inference.py \
    --checkpoint "путь/к/checkpoint-multi.ckpt" \
    --audio "путь/к/аудио.wav" \
    --output "путь/к/выходной/папке" \
    --chunk 30 \
    --overlap 2 \
    --stems speech music sfx
Параметр По умолчанию Описание
-c, --checkpoint (обязательный) Путь к файлу весов .ckpt
-a, --audio (обязательный) Путь к входному аудиофайлу
-o, --output ./estimates/ Папка для выходных файлов
--chunk 30.0 Размер чанка в секундах
--overlap 2.0 Размер перекрытия в секундах
--stems speech music sfx Список stem для разделения
--fs 48000 Sample rate (менять не рекомендуется)
--no-half False Отключить mixed precision

Пример вывода

GPU: NVIDIA GeForce RTX 4090
VRAM: 22.5 GB

=== Config ===
Chunk: 30.0s, Overlap: 2.0s
Half precision: True

=== Loading Model ===
Model loaded

=== Loading Audio ===
Duration: 420.5s, Channels: 2, SR: 48000

=== Inference ===
Processing channel 1/2...
Processing: 100%|████████████████| 15/15 [00:12<00:00, 1.20it/s]
Processing channel 2/2...
Processing: 100%|████████████████| 15/15 [00:12<00:00, 1.21it/s]

Done in 24.9s (16.9x realtime)
Peak VRAM: 1.99 GB

=== Saving ===
  speech_estimate.wav
  music_estimate.wav
  sfx_estimate.wav

Complete!

Выходные файлы

В указанной папке создаются три файла:

  • speech_estimate.wav — речь/диалоги
  • music_estimate.wav — музыка
  • sfx_estimate.wav — звуковые эффекты

Все файлы сохраняются с тем же sample rate (48000 Hz) и количеством каналов, что и входной файл.