Bandit-v2 - Инференс
Запуск инференса модели Bandit-v2 для аудио сепарации
Суть задачи
В данной статье описывается процесс настройки и запуска инференса нейросетевой модели Bandit-v2 из репозитория kwatcharasupat/bandit-v2 для задачи Cinematic Audio Source Separation — разделения аудиодорожки на составляющие: речь (speech), музыку (music) и звуковые эффекты (sfx).
Проект изначально разработан для внутренней инфраструктуры Netflix и содержит зависимости от закрытых пакетов, что делает прямой запуск невозможным. В статье представлено рабочее решение для запуска инференса на локальной машине с GPU.
Результат работы: скрипт simple_inference.py, позволяющий обрабатывать аудиофайлы любой длительности со скоростью ~17x realtime на RTX 4090.
О проекте Bandit-v2
Bandit (Band-Split RNN) — архитектура нейросети для разделения аудио на источники. Основные особенности:
- Band-Split подход — спектрограмма делится на 64 частотных диапазона (bands), каждый обрабатывается отдельно
- Dual-path RNN — последовательная обработка по временной и частотной осям с использованием GRU
- Mask Estimation — для каждого stem (речь/музыка/sfx) предсказывается комплексная маска
Технические характеристики модели:
| Параметр | Значение |
|---|---|
| Sample rate | 48000 Hz |
| n_bands | 64 |
| n_fft | 2048 |
| hop_length | 512 |
| emb_dim | 128 |
| rnn_dim | 256 |
| rnn_type | GRU (bidirectional) |
| n_sqm_modules | 8 |
Веса модели доступны на Zenodo: checkpoint-multi.ckpt
Процесс формирования рабочего решения
Проблема 1: Netflix-специфичные зависимости
Файл requirements.txt содержит ссылки на приватный PyPI сервер Netflix:
--index-url https://pypi.netflix.net/simple
И пакеты: nflx-manta, nflx-metaflow, jasper, storage, metatron и др.
Решение: Установка только минимально необходимых публичных пакетов для инференса.
Проблема 2: Ray зависимость
Модуль src/system/utils.py импортирует Ray (распределённые вычисления), который не нужен для локального инференса.
Решение: Создан отдельный скрипт simple_inference.py, который напрямую импортирует модель без системных обёрток.
Проблема 3: Устаревший PyTorch
В requirements указан torch==2.0.0+cu118, который больше недоступен в PyPI.
Решение: Установлен torch==2.5.0+cu118.
Проблема 4: Ошибка сборки pesq
Пакет asteroid тянет зависимость pesq, требующую компиляции C++.
Решение: Пакет asteroid не устанавливается — он нужен только для метрик обучения, не для инференса.
Проблема 5: Чрезмерное потребление VRAM
Оригинальный StandardTensorChunkedInferenceHandler с batch_size=32 и overlap 7 секунд (8с chunk, 1с hop) создаёт буфер на ~50GB.
Решение: Реализован простой chunked подход с последовательной обработкой чанков (30с chunk, 2с overlap), потребление VRAM снижено до ~2GB.
Проблема 6: Stereo обработка
Модель обучена с in_channels=1 (mono). При подаче stereo возникает ошибка LayerNorm.
Решение: Каждый канал stereo обрабатывается отдельно, результаты объединяются.
Проблема 7: Half precision (FP16)
Прямая конвертация модели в .half() вызывает ошибку типов в LayerNorm.
Решение: Используется torch.cuda.amp.autocast() для автоматического mixed precision.
Исходный код
| simple_inference.py |
|---|
"""
Memory-efficient inference for Bandit-v2
Processes chunks one at a time to minimize VRAM usage
"""
import os
import sys
import time
import torch
import torchaudio as ta
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.models.bandit.bandit import Bandit
def create_model(stems, fs=48000):
"""Create Bandit model"""
model = Bandit(
in_channels=1, # Model trained on mono, but uses treat_channel_as_feature
stems=stems,
band_type="musical",
n_bands=64,
normalize_channel_independently=False,
treat_channel_as_feature=True,
n_sqm_modules=8,
emb_dim=128,
rnn_dim=256,
bidirectional=True,
rnn_type="GRU",
mlp_dim=512,
hidden_activation="Tanh",
hidden_activation_kwargs=None,
complex_mask=True,
use_freq_weights=True,
n_fft=2048,
win_length=2048,
hop_length=512,
window_fn="hann_window",
wkwargs=None,
power=None,
center=True,
normalized=True,
pad_mode="reflect",
onesided=True,
fs=fs,
)
return model
def simple_chunked_inference(
model,
audio, # (channels, samples)
fs=48000,
chunk_seconds=30.0, # Much larger chunks, less overlap
overlap_seconds=2.0, # Small overlap for crossfade
device="cuda",
use_half=True,
):
"""
Simple chunk-based inference with crossfade overlap
Processes ONE chunk at a time to minimize VRAM
"""
n_channels, n_samples = audio.shape
chunk_samples = int(chunk_seconds * fs)
overlap_samples = int(overlap_seconds * fs)
hop_samples = chunk_samples - overlap_samples
# Create output buffers (on CPU to save VRAM)
stems = model.stems
outputs = {stem: torch.zeros(n_channels, n_samples) for stem in stems}
# Create crossfade window
fade_in = torch.linspace(0, 1, overlap_samples)
fade_out = torch.linspace(1, 0, overlap_samples)
# Calculate chunks
n_chunks = max(1, (n_samples - overlap_samples) // hop_samples + 1)
print(f"Processing {n_chunks} chunks of {chunk_seconds}s with {overlap_seconds}s overlap")
for i in tqdm(range(n_chunks), desc="Processing"):
start = i * hop_samples
end = min(start + chunk_samples, n_samples)
# Get chunk
chunk = audio[:, start:end]
actual_len = chunk.shape[1]
# Pad if needed
if actual_len < chunk_samples:
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - actual_len))
# Process on GPU
chunk_gpu = chunk[None, :, :].to(device)
with torch.inference_mode():
with torch.cuda.amp.autocast(enabled=use_half):
batch = {"mixture": {"audio": chunk_gpu}}
result = model(batch)
# Get results back to CPU immediately
for stem in stems:
stem_audio = result["estimates"][stem]["audio"][0, :, :actual_len].float().cpu()
# Apply crossfade for overlap regions
if i == 0:
# First chunk - no fade in
outputs[stem][:, start:start+actual_len] = stem_audio
else:
# Apply crossfade in overlap region
overlap_start = start
overlap_end = start + overlap_samples
if overlap_end <= n_samples:
# Fade out previous, fade in current
outputs[stem][:, overlap_start:overlap_end] *= fade_out
outputs[stem][:, overlap_start:overlap_end] += stem_audio[:, :overlap_samples] * fade_in
# Copy rest without fade
if overlap_samples < actual_len:
outputs[stem][:, overlap_end:start+actual_len] = stem_audio[:, overlap_samples:actual_len]
else:
outputs[stem][:, start:start+actual_len] = stem_audio
# Clear GPU cache
del chunk_gpu, result
torch.cuda.empty_cache()
return outputs
def run_inference(
checkpoint_path: str,
audio_path: str,
output_dir: str = None,
stems: list = None,
fs: int = 48000,
device: str = "cuda",
use_half: bool = True,
chunk_seconds: float = 30.0,
overlap_seconds: float = 2.0,
):
if stems is None:
stems = ["speech", "music", "sfx"]
if output_dir is None:
output_dir = os.path.join(os.path.dirname(audio_path), "estimates")
os.makedirs(output_dir, exist_ok=True)
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print(f"\n=== Config ===")
print(f"Chunk: {chunk_seconds}s, Overlap: {overlap_seconds}s")
print(f"Half precision: {use_half}")
# Load model
print(f"\n=== Loading Model ===")
model = create_model(stems=stems, fs=fs)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
cleaned = {k.replace("model.", "", 1) if k.startswith("model.") else k: v
for k, v in state_dict.items()}
model.load_state_dict(cleaned, strict=False)
model.to(device)
# Note: Keep model in FP32, autocast handles mixed precision during inference
model.eval()
print("Model loaded")
# Load audio
print(f"\n=== Loading Audio ===")
audio, audio_fs = ta.load(audio_path)
duration = audio.shape[1] / audio_fs
print(f"Duration: {duration:.1f}s, Channels: {audio.shape[0]}, SR: {audio_fs}")
if audio_fs != fs:
print(f"Resampling {audio_fs} -> {fs}")
audio = ta.functional.resample(audio, audio_fs, fs)
n_channels = audio.shape[0]
# Run inference
print(f"\n=== Inference ===")
torch.cuda.reset_peak_memory_stats()
t0 = time.time()
# Process each channel separately to preserve stereo
all_outputs = []
for ch in range(n_channels):
if n_channels > 1:
print(f"\nProcessing channel {ch+1}/{n_channels}...")
audio_ch = audio[ch:ch+1, :] # Keep dim: (1, samples)
ch_outputs = simple_chunked_inference(
model, audio_ch, fs=fs,
chunk_seconds=chunk_seconds,
overlap_seconds=overlap_seconds,
device=device,
use_half=use_half,
)
all_outputs.append(ch_outputs)
# Combine channels
outputs = {}
for stem in model.stems:
outputs[stem] = torch.cat([ch_out[stem] for ch_out in all_outputs], dim=0)
elapsed = time.time() - t0
peak_vram = torch.cuda.max_memory_allocated() / 1024**3
print(f"\nDone in {elapsed:.1f}s ({duration/elapsed:.1f}x realtime)")
print(f"Peak VRAM: {peak_vram:.2f} GB")
# Save
print(f"\n=== Saving ===")
for stem, audio_out in outputs.items():
path = os.path.join(output_dir, f"{stem}_estimate.wav")
ta.save(path, audio_out, fs)
print(f" {path}")
print("\nComplete!")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", "-c", required=True)
parser.add_argument("--audio", "-a", required=True)
parser.add_argument("--output", "-o", default=None)
parser.add_argument("--stems", nargs="+", default=["speech", "music", "sfx"])
parser.add_argument("--fs", type=int, default=48000)
parser.add_argument("--device", default="cuda")
parser.add_argument("--no-half", action="store_true")
parser.add_argument("--chunk", type=float, default=30.0, help="Chunk size in seconds")
parser.add_argument("--overlap", type=float, default=2.0, help="Overlap in seconds")
args = parser.parse_args()
run_inference(
checkpoint_path=args.checkpoint,
audio_path=args.audio,
output_dir=args.output,
stems=args.stems,
fs=args.fs,
device=args.device,
use_half=not args.no_half,
chunk_seconds=args.chunk,
overlap_seconds=args.overlap,
)
|
Описание simple_inference.py
Скрипт состоит из трёх основных функций:
create_model()
Создаёт экземпляр модели Bandit с параметрами, соответствующими checkpoint:
in_channels=1— mono (stereo обрабатывается поканально)n_bands=64,n_sqm_modules=8n_fft=2048,hop_length=512fs=48000
simple_chunked_inference()
Обрабатывает аудио чанками с crossfade:
- Аудио делится на чанки (по умолчанию 30 секунд)
- Соседние чанки перекрываются (по умолчанию 2 секунды)
- В зоне перекрытия применяется линейный crossfade
- Каждый чанк обрабатывается отдельно, результат сразу переносится на CPU
- GPU кэш очищается после каждого чанка
run_inference()
Основная функция:
- Загружает модель и веса из checkpoint
- Загружает аудио (с ресемплингом если нужно)
- Для stereo — обрабатывает каждый канал отдельно
- Сохраняет результаты в отдельные WAV файлы для каждого stem
Особенности реализации:
- Autocast mixed precision для ускорения без проблем с типами
- Минимальное потребление VRAM (~2GB на 30-секундный чанк)
- Поддержка произвольной длительности аудио
- Сохранение stereo (не конвертируется в mono)
Инструкция по установке
Требования
- Windows 10/11 или Linux
- NVIDIA GPU с поддержкой CUDA 11.8
- Anaconda или Miniconda
- ~3GB свободного места для PyTorch
Шаг 1: Создание conda окружения
conda create -n bandit-v2 python=3.10 -y conda activate bandit-v2
Шаг 2: Установка PyTorch с CUDA 11.8
Внимание: Версия PyTorch 2.0.0 из оригинального requirements.txt больше недоступна. Используйте 2.5.0 или новее.
pip install torch==2.5.0+cu118 torchaudio==2.5.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
Шаг 3: Установка зависимостей
Внимание: Пакет asteroid требует сборки pesq с Visual C++ Build Tools. Для инференса он не нужен, поэтому не устанавливаем.
pip install "pytorch_lightning>=2.3.0" hydra-core omegaconf librosa soundfile einops tqdm julius huggingface-hub pyyaml scipy pandas
Шаг 4: Клонирование репозитория
git clone https://github.com/kwatcharasupat/bandit-v2.git cd bandit-v2
Шаг 5: Скачивание весов модели
Скачайте checkpoint-multi.ckpt с Zenodo и поместите в удобное место.
Шаг 6: Добавление simple_inference.py
Скопируйте скрипт simple_inference.py (см. раздел "Исходный код") в корень репозитория bandit-v2/.
Инструкция по запуску
Активация окружения
conda activate bandit-v2 cd путь/к/bandit-v2
Базовый запуск
python simple_inference.py -c "путь/к/checkpoint-multi.ckpt" -a "путь/к/аудио.wav"
Полный синтаксис
python simple_inference.py \
--checkpoint "путь/к/checkpoint-multi.ckpt" \
--audio "путь/к/аудио.wav" \
--output "путь/к/выходной/папке" \
--chunk 30 \
--overlap 2 \
--stems speech music sfx
| Параметр | По умолчанию | Описание |
|---|---|---|
-c, --checkpoint
|
(обязательный) | Путь к файлу весов .ckpt |
-a, --audio
|
(обязательный) | Путь к входному аудиофайлу |
-o, --output
|
./estimates/ | Папка для выходных файлов |
--chunk
|
30.0 | Размер чанка в секундах |
--overlap
|
2.0 | Размер перекрытия в секундах |
--stems
|
speech music sfx | Список stem для разделения |
--fs
|
48000 | Sample rate (менять не рекомендуется) |
--no-half
|
False | Отключить mixed precision |
Пример вывода
GPU: NVIDIA GeForce RTX 4090 VRAM: 22.5 GB === Config === Chunk: 30.0s, Overlap: 2.0s Half precision: True === Loading Model === Model loaded === Loading Audio === Duration: 420.5s, Channels: 2, SR: 48000 === Inference === Processing channel 1/2... Processing: 100%|████████████████| 15/15 [00:12<00:00, 1.20it/s] Processing channel 2/2... Processing: 100%|████████████████| 15/15 [00:12<00:00, 1.21it/s] Done in 24.9s (16.9x realtime) Peak VRAM: 1.99 GB === Saving === speech_estimate.wav music_estimate.wav sfx_estimate.wav Complete!
Выходные файлы
В указанной папке создаются три файла:
speech_estimate.wav— речь/диалогиmusic_estimate.wav— музыкаsfx_estimate.wav— звуковые эффекты
Все файлы сохраняются с тем же sample rate (48000 Hz) и количеством каналов, что и входной файл.