Bandit-v2 - Инференс: различия между версиями
Владимир (обсуждение | вклад) Новая страница: « = Запуск инференса модели Bandit-v2 для аудио сепарации = == Суть задачи == В данной статье описывается процесс настройки и запуска инференса нейросетевой модели '''Bandit-v2''' из репозитория [https://github.com/kwatcharasupat/bandit-v2 kwatcharasupat/bandit-v2] для задачи '''Cinematic Audio Source Separation'''...» |
Владимир (обсуждение | вклад) |
||
| (не показана 1 промежуточная версия этого же участника) | |||
| Строка 1: | Строка 1: | ||
= Запуск инференса модели Bandit-v2 для аудио сепарации = | = Запуск инференса модели Bandit-v2 для аудио сепарации = | ||
== Суть задачи == | == Суть задачи == | ||
В данной статье описывается процесс настройки и запуска инференса нейросетевой модели '''Bandit-v2''' из репозитория [https://github.com/kwatcharasupat/bandit-v2 kwatcharasupat/bandit-v2] для задачи '''Cinematic Audio Source Separation''' — разделения аудиодорожки на составляющие: речь (speech), музыку (music) и звуковые эффекты (sfx). | В данной статье описывается процесс настройки и запуска инференса нейросетевой модели '''Bandit-v2''' из репозитория [https://github.com/kwatcharasupat/bandit-v2 kwatcharasupat/bandit-v2] для задачи '''Cinematic Audio Source Separation''' — разделения аудиодорожки на составляющие: речь (speech), музыку (music) и звуковые эффекты (sfx). | ||
| Строка 48: | Строка 47: | ||
== Процесс формирования рабочего решения == | == Процесс формирования рабочего решения == | ||
=== Проблема 1: Netflix-специфичные зависимости === | === Проблема 1: Netflix-специфичные зависимости === | ||
Файл <code>requirements.txt</code> содержит ссылки на приватный PyPI сервер Netflix:<pre> | Файл <code>requirements.txt</code> содержит ссылки на приватный PyPI сервер Netflix:<pre> | ||
| Строка 87: | Строка 85: | ||
== Исходный код == | == Исходный код == | ||
{ | {| class="mw-collapsible mw-collapsed wikitable" | ||
# | !simple_inference.py | ||
</syntaxhighlight> | |- | ||
|<syntaxhighlight lang="python"> | |||
""" | |||
Memory-efficient inference for Bandit-v2 | |||
Processes chunks one at a time to minimize VRAM usage | |||
""" | |||
import os | |||
import sys | |||
import time | |||
import torch | |||
import torchaudio as ta | |||
from tqdm import tqdm | |||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |||
from src.models.bandit.bandit import Bandit | |||
def create_model(stems, fs=48000): | |||
"""Create Bandit model""" | |||
model = Bandit( | |||
in_channels=1, # Model trained on mono, but uses treat_channel_as_feature | |||
stems=stems, | |||
band_type="musical", | |||
n_bands=64, | |||
normalize_channel_independently=False, | |||
treat_channel_as_feature=True, | |||
n_sqm_modules=8, | |||
emb_dim=128, | |||
rnn_dim=256, | |||
bidirectional=True, | |||
rnn_type="GRU", | |||
mlp_dim=512, | |||
hidden_activation="Tanh", | |||
hidden_activation_kwargs=None, | |||
complex_mask=True, | |||
use_freq_weights=True, | |||
n_fft=2048, | |||
win_length=2048, | |||
hop_length=512, | |||
window_fn="hann_window", | |||
wkwargs=None, | |||
power=None, | |||
center=True, | |||
normalized=True, | |||
pad_mode="reflect", | |||
onesided=True, | |||
fs=fs, | |||
) | |||
return model | |||
def simple_chunked_inference( | |||
model, | |||
audio, # (channels, samples) | |||
fs=48000, | |||
chunk_seconds=30.0, # Much larger chunks, less overlap | |||
overlap_seconds=2.0, # Small overlap for crossfade | |||
device="cuda", | |||
use_half=True, | |||
): | |||
""" | |||
Simple chunk-based inference with crossfade overlap | |||
Processes ONE chunk at a time to minimize VRAM | |||
""" | |||
n_channels, n_samples = audio.shape | |||
chunk_samples = int(chunk_seconds * fs) | |||
overlap_samples = int(overlap_seconds * fs) | |||
hop_samples = chunk_samples - overlap_samples | |||
# Create output buffers (on CPU to save VRAM) | |||
stems = model.stems | |||
outputs = {stem: torch.zeros(n_channels, n_samples) for stem in stems} | |||
# Create crossfade window | |||
fade_in = torch.linspace(0, 1, overlap_samples) | |||
fade_out = torch.linspace(1, 0, overlap_samples) | |||
# Calculate chunks | |||
n_chunks = max(1, (n_samples - overlap_samples) // hop_samples + 1) | |||
print(f"Processing {n_chunks} chunks of {chunk_seconds}s with {overlap_seconds}s overlap") | |||
for i in tqdm(range(n_chunks), desc="Processing"): | |||
start = i * hop_samples | |||
end = min(start + chunk_samples, n_samples) | |||
# Get chunk | |||
chunk = audio[:, start:end] | |||
actual_len = chunk.shape[1] | |||
# Pad if needed | |||
if actual_len < chunk_samples: | |||
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - actual_len)) | |||
# Process on GPU | |||
chunk_gpu = chunk[None, :, :].to(device) | |||
with torch.inference_mode(): | |||
with torch.cuda.amp.autocast(enabled=use_half): | |||
batch = {"mixture": {"audio": chunk_gpu}} | |||
result = model(batch) | |||
# Get results back to CPU immediately | |||
for stem in stems: | |||
stem_audio = result["estimates"][stem]["audio"][0, :, :actual_len].float().cpu() | |||
# Apply crossfade for overlap regions | |||
if i == 0: | |||
# First chunk - no fade in | |||
outputs[stem][:, start:start+actual_len] = stem_audio | |||
else: | |||
# Apply crossfade in overlap region | |||
overlap_start = start | |||
overlap_end = start + overlap_samples | |||
if overlap_end <= n_samples: | |||
# Fade out previous, fade in current | |||
outputs[stem][:, overlap_start:overlap_end] *= fade_out | |||
outputs[stem][:, overlap_start:overlap_end] += stem_audio[:, :overlap_samples] * fade_in | |||
# Copy rest without fade | |||
if overlap_samples < actual_len: | |||
outputs[stem][:, overlap_end:start+actual_len] = stem_audio[:, overlap_samples:actual_len] | |||
else: | |||
outputs[stem][:, start:start+actual_len] = stem_audio | |||
# Clear GPU cache | |||
del chunk_gpu, result | |||
torch.cuda.empty_cache() | |||
return outputs | |||
def run_inference( | |||
checkpoint_path: str, | |||
audio_path: str, | |||
output_dir: str = None, | |||
stems: list = None, | |||
fs: int = 48000, | |||
device: str = "cuda", | |||
use_half: bool = True, | |||
chunk_seconds: float = 30.0, | |||
overlap_seconds: float = 2.0, | |||
): | |||
if stems is None: | |||
stems = ["speech", "music", "sfx"] | |||
if output_dir is None: | |||
output_dir = os.path.join(os.path.dirname(audio_path), "estimates") | |||
os.makedirs(output_dir, exist_ok=True) | |||
print(f"GPU: {torch.cuda.get_device_name(0)}") | |||
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB") | |||
print(f"\n=== Config ===") | |||
print(f"Chunk: {chunk_seconds}s, Overlap: {overlap_seconds}s") | |||
print(f"Half precision: {use_half}") | |||
# Load model | |||
print(f"\n=== Loading Model ===") | |||
model = create_model(stems=stems, fs=fs) | |||
checkpoint = torch.load(checkpoint_path, map_location="cpu") | |||
state_dict = checkpoint.get("state_dict", checkpoint) | |||
cleaned = {k.replace("model.", "", 1) if k.startswith("model.") else k: v | |||
for k, v in state_dict.items()} | |||
model.load_state_dict(cleaned, strict=False) | |||
model.to(device) | |||
# Note: Keep model in FP32, autocast handles mixed precision during inference | |||
model.eval() | |||
print("Model loaded") | |||
# Load audio | |||
print(f"\n=== Loading Audio ===") | |||
audio, audio_fs = ta.load(audio_path) | |||
duration = audio.shape[1] / audio_fs | |||
print(f"Duration: {duration:.1f}s, Channels: {audio.shape[0]}, SR: {audio_fs}") | |||
if audio_fs != fs: | |||
print(f"Resampling {audio_fs} -> {fs}") | |||
audio = ta.functional.resample(audio, audio_fs, fs) | |||
n_channels = audio.shape[0] | |||
# Run inference | |||
print(f"\n=== Inference ===") | |||
torch.cuda.reset_peak_memory_stats() | |||
t0 = time.time() | |||
# Process each channel separately to preserve stereo | |||
all_outputs = [] | |||
for ch in range(n_channels): | |||
if n_channels > 1: | |||
print(f"\nProcessing channel {ch+1}/{n_channels}...") | |||
audio_ch = audio[ch:ch+1, :] # Keep dim: (1, samples) | |||
ch_outputs = simple_chunked_inference( | |||
model, audio_ch, fs=fs, | |||
chunk_seconds=chunk_seconds, | |||
overlap_seconds=overlap_seconds, | |||
device=device, | |||
use_half=use_half, | |||
) | |||
all_outputs.append(ch_outputs) | |||
# Combine channels | |||
outputs = {} | |||
for stem in model.stems: | |||
outputs[stem] = torch.cat([ch_out[stem] for ch_out in all_outputs], dim=0) | |||
elapsed = time.time() - t0 | |||
peak_vram = torch.cuda.max_memory_allocated() / 1024**3 | |||
print(f"\nDone in {elapsed:.1f}s ({duration/elapsed:.1f}x realtime)") | |||
print(f"Peak VRAM: {peak_vram:.2f} GB") | |||
# Save | |||
print(f"\n=== Saving ===") | |||
for stem, audio_out in outputs.items(): | |||
path = os.path.join(output_dir, f"{stem}_estimate.wav") | |||
ta.save(path, audio_out, fs) | |||
print(f" {path}") | |||
print("\nComplete!") | |||
if __name__ == "__main__": | |||
import argparse | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument("--checkpoint", "-c", required=True) | |||
parser.add_argument("--audio", "-a", required=True) | |||
parser.add_argument("--output", "-o", default=None) | |||
parser.add_argument("--stems", nargs="+", default=["speech", "music", "sfx"]) | |||
parser.add_argument("--fs", type=int, default=48000) | |||
parser.add_argument("--device", default="cuda") | |||
parser.add_argument("--no-half", action="store_true") | |||
parser.add_argument("--chunk", type=float, default=30.0, help="Chunk size in seconds") | |||
parser.add_argument("--overlap", type=float, default=2.0, help="Overlap in seconds") | |||
args = parser.parse_args() | |||
run_inference( | |||
checkpoint_path=args.checkpoint, | |||
audio_path=args.audio, | |||
output_dir=args.output, | |||
stems=args.stems, | |||
fs=args.fs, | |||
device=args.device, | |||
use_half=not args.no_half, | |||
chunk_seconds=args.chunk, | |||
overlap_seconds=args.overlap, | |||
) | |||
</syntaxhighlight> | |||
|} | |||
== Описание simple_inference.py == | == Описание simple_inference.py == | ||
| Строка 127: | Строка 386: | ||
== Инструкция по установке == | == Инструкция по установке == | ||
=== Требования === | === Требования === | ||
| Строка 142: | Строка 400: | ||
=== Шаг 2: Установка PyTorch с CUDA 11.8 === | === Шаг 2: Установка PyTorch с CUDA 11.8 === | ||
'''Внимание:''' Версия PyTorch 2.0.0 из оригинального requirements.txt больше недоступна. Используйте 2.5.0 или новее.<pre> | |||
pip install torch==2.5.0+cu118 torchaudio==2.5.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 | pip install torch==2.5.0+cu118 torchaudio==2.5.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 | ||
</pre> | </pre> | ||
=== Шаг 3: Установка зависимостей === | === Шаг 3: Установка зависимостей === | ||
'''Внимание:''' Пакет <code>asteroid</code> требует сборки <code>pesq</code> с Visual C++ Build Tools. Для инференса он не нужен, поэтому не устанавливаем.<pre> | |||
pip install "pytorch_lightning>=2.3.0" hydra-core omegaconf librosa soundfile einops tqdm julius huggingface-hub pyyaml scipy pandas | pip install "pytorch_lightning>=2.3.0" hydra-core omegaconf librosa soundfile einops tqdm julius huggingface-hub pyyaml scipy pandas | ||
</pre> | </pre> | ||
| Строка 164: | Строка 422: | ||
== Инструкция по запуску == | == Инструкция по запуску == | ||
=== Активация окружения === | === Активация окружения === | ||
<pre> | <pre> | ||
Текущая версия от 07:39, 16 января 2026
Запуск инференса модели Bandit-v2 для аудио сепарации
Суть задачи
В данной статье описывается процесс настройки и запуска инференса нейросетевой модели Bandit-v2 из репозитория kwatcharasupat/bandit-v2 для задачи Cinematic Audio Source Separation — разделения аудиодорожки на составляющие: речь (speech), музыку (music) и звуковые эффекты (sfx).
Проект изначально разработан для внутренней инфраструктуры Netflix и содержит зависимости от закрытых пакетов, что делает прямой запуск невозможным. В статье представлено рабочее решение для запуска инференса на локальной машине с GPU.
Результат работы: скрипт simple_inference.py, позволяющий обрабатывать аудиофайлы любой длительности со скоростью ~17x realtime на RTX 4090.
О проекте Bandit-v2
Bandit (Band-Split RNN) — архитектура нейросети для разделения аудио на источники. Основные особенности:
- Band-Split подход — спектрограмма делится на 64 частотных диапазона (bands), каждый обрабатывается отдельно
- Dual-path RNN — последовательная обработка по временной и частотной осям с использованием GRU
- Mask Estimation — для каждого stem (речь/музыка/sfx) предсказывается комплексная маска
Технические характеристики модели:
| Параметр | Значение |
|---|---|
| Sample rate | 48000 Hz |
| n_bands | 64 |
| n_fft | 2048 |
| hop_length | 512 |
| emb_dim | 128 |
| rnn_dim | 256 |
| rnn_type | GRU (bidirectional) |
| n_sqm_modules | 8 |
Веса модели доступны на Zenodo: checkpoint-multi.ckpt
Процесс формирования рабочего решения
Проблема 1: Netflix-специфичные зависимости
Файл requirements.txt содержит ссылки на приватный PyPI сервер Netflix:
--index-url https://pypi.netflix.net/simple
И пакеты: nflx-manta, nflx-metaflow, jasper, storage, metatron и др.
Решение: Установка только минимально необходимых публичных пакетов для инференса.
Проблема 2: Ray зависимость
Модуль src/system/utils.py импортирует Ray (распределённые вычисления), который не нужен для локального инференса.
Решение: Создан отдельный скрипт simple_inference.py, который напрямую импортирует модель без системных обёрток.
Проблема 3: Устаревший PyTorch
В requirements указан torch==2.0.0+cu118, который больше недоступен в PyPI.
Решение: Установлен torch==2.5.0+cu118.
Проблема 4: Ошибка сборки pesq
Пакет asteroid тянет зависимость pesq, требующую компиляции C++.
Решение: Пакет asteroid не устанавливается — он нужен только для метрик обучения, не для инференса.
Проблема 5: Чрезмерное потребление VRAM
Оригинальный StandardTensorChunkedInferenceHandler с batch_size=32 и overlap 7 секунд (8с chunk, 1с hop) создаёт буфер на ~50GB.
Решение: Реализован простой chunked подход с последовательной обработкой чанков (30с chunk, 2с overlap), потребление VRAM снижено до ~2GB.
Проблема 6: Stereo обработка
Модель обучена с in_channels=1 (mono). При подаче stereo возникает ошибка LayerNorm.
Решение: Каждый канал stereo обрабатывается отдельно, результаты объединяются.
Проблема 7: Half precision (FP16)
Прямая конвертация модели в .half() вызывает ошибку типов в LayerNorm.
Решение: Используется torch.cuda.amp.autocast() для автоматического mixed precision.
Исходный код
| simple_inference.py |
|---|
"""
Memory-efficient inference for Bandit-v2
Processes chunks one at a time to minimize VRAM usage
"""
import os
import sys
import time
import torch
import torchaudio as ta
from tqdm import tqdm
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.models.bandit.bandit import Bandit
def create_model(stems, fs=48000):
"""Create Bandit model"""
model = Bandit(
in_channels=1, # Model trained on mono, but uses treat_channel_as_feature
stems=stems,
band_type="musical",
n_bands=64,
normalize_channel_independently=False,
treat_channel_as_feature=True,
n_sqm_modules=8,
emb_dim=128,
rnn_dim=256,
bidirectional=True,
rnn_type="GRU",
mlp_dim=512,
hidden_activation="Tanh",
hidden_activation_kwargs=None,
complex_mask=True,
use_freq_weights=True,
n_fft=2048,
win_length=2048,
hop_length=512,
window_fn="hann_window",
wkwargs=None,
power=None,
center=True,
normalized=True,
pad_mode="reflect",
onesided=True,
fs=fs,
)
return model
def simple_chunked_inference(
model,
audio, # (channels, samples)
fs=48000,
chunk_seconds=30.0, # Much larger chunks, less overlap
overlap_seconds=2.0, # Small overlap for crossfade
device="cuda",
use_half=True,
):
"""
Simple chunk-based inference with crossfade overlap
Processes ONE chunk at a time to minimize VRAM
"""
n_channels, n_samples = audio.shape
chunk_samples = int(chunk_seconds * fs)
overlap_samples = int(overlap_seconds * fs)
hop_samples = chunk_samples - overlap_samples
# Create output buffers (on CPU to save VRAM)
stems = model.stems
outputs = {stem: torch.zeros(n_channels, n_samples) for stem in stems}
# Create crossfade window
fade_in = torch.linspace(0, 1, overlap_samples)
fade_out = torch.linspace(1, 0, overlap_samples)
# Calculate chunks
n_chunks = max(1, (n_samples - overlap_samples) // hop_samples + 1)
print(f"Processing {n_chunks} chunks of {chunk_seconds}s with {overlap_seconds}s overlap")
for i in tqdm(range(n_chunks), desc="Processing"):
start = i * hop_samples
end = min(start + chunk_samples, n_samples)
# Get chunk
chunk = audio[:, start:end]
actual_len = chunk.shape[1]
# Pad if needed
if actual_len < chunk_samples:
chunk = torch.nn.functional.pad(chunk, (0, chunk_samples - actual_len))
# Process on GPU
chunk_gpu = chunk[None, :, :].to(device)
with torch.inference_mode():
with torch.cuda.amp.autocast(enabled=use_half):
batch = {"mixture": {"audio": chunk_gpu}}
result = model(batch)
# Get results back to CPU immediately
for stem in stems:
stem_audio = result["estimates"][stem]["audio"][0, :, :actual_len].float().cpu()
# Apply crossfade for overlap regions
if i == 0:
# First chunk - no fade in
outputs[stem][:, start:start+actual_len] = stem_audio
else:
# Apply crossfade in overlap region
overlap_start = start
overlap_end = start + overlap_samples
if overlap_end <= n_samples:
# Fade out previous, fade in current
outputs[stem][:, overlap_start:overlap_end] *= fade_out
outputs[stem][:, overlap_start:overlap_end] += stem_audio[:, :overlap_samples] * fade_in
# Copy rest without fade
if overlap_samples < actual_len:
outputs[stem][:, overlap_end:start+actual_len] = stem_audio[:, overlap_samples:actual_len]
else:
outputs[stem][:, start:start+actual_len] = stem_audio
# Clear GPU cache
del chunk_gpu, result
torch.cuda.empty_cache()
return outputs
def run_inference(
checkpoint_path: str,
audio_path: str,
output_dir: str = None,
stems: list = None,
fs: int = 48000,
device: str = "cuda",
use_half: bool = True,
chunk_seconds: float = 30.0,
overlap_seconds: float = 2.0,
):
if stems is None:
stems = ["speech", "music", "sfx"]
if output_dir is None:
output_dir = os.path.join(os.path.dirname(audio_path), "estimates")
os.makedirs(output_dir, exist_ok=True)
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
print(f"\n=== Config ===")
print(f"Chunk: {chunk_seconds}s, Overlap: {overlap_seconds}s")
print(f"Half precision: {use_half}")
# Load model
print(f"\n=== Loading Model ===")
model = create_model(stems=stems, fs=fs)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
cleaned = {k.replace("model.", "", 1) if k.startswith("model.") else k: v
for k, v in state_dict.items()}
model.load_state_dict(cleaned, strict=False)
model.to(device)
# Note: Keep model in FP32, autocast handles mixed precision during inference
model.eval()
print("Model loaded")
# Load audio
print(f"\n=== Loading Audio ===")
audio, audio_fs = ta.load(audio_path)
duration = audio.shape[1] / audio_fs
print(f"Duration: {duration:.1f}s, Channels: {audio.shape[0]}, SR: {audio_fs}")
if audio_fs != fs:
print(f"Resampling {audio_fs} -> {fs}")
audio = ta.functional.resample(audio, audio_fs, fs)
n_channels = audio.shape[0]
# Run inference
print(f"\n=== Inference ===")
torch.cuda.reset_peak_memory_stats()
t0 = time.time()
# Process each channel separately to preserve stereo
all_outputs = []
for ch in range(n_channels):
if n_channels > 1:
print(f"\nProcessing channel {ch+1}/{n_channels}...")
audio_ch = audio[ch:ch+1, :] # Keep dim: (1, samples)
ch_outputs = simple_chunked_inference(
model, audio_ch, fs=fs,
chunk_seconds=chunk_seconds,
overlap_seconds=overlap_seconds,
device=device,
use_half=use_half,
)
all_outputs.append(ch_outputs)
# Combine channels
outputs = {}
for stem in model.stems:
outputs[stem] = torch.cat([ch_out[stem] for ch_out in all_outputs], dim=0)
elapsed = time.time() - t0
peak_vram = torch.cuda.max_memory_allocated() / 1024**3
print(f"\nDone in {elapsed:.1f}s ({duration/elapsed:.1f}x realtime)")
print(f"Peak VRAM: {peak_vram:.2f} GB")
# Save
print(f"\n=== Saving ===")
for stem, audio_out in outputs.items():
path = os.path.join(output_dir, f"{stem}_estimate.wav")
ta.save(path, audio_out, fs)
print(f" {path}")
print("\nComplete!")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", "-c", required=True)
parser.add_argument("--audio", "-a", required=True)
parser.add_argument("--output", "-o", default=None)
parser.add_argument("--stems", nargs="+", default=["speech", "music", "sfx"])
parser.add_argument("--fs", type=int, default=48000)
parser.add_argument("--device", default="cuda")
parser.add_argument("--no-half", action="store_true")
parser.add_argument("--chunk", type=float, default=30.0, help="Chunk size in seconds")
parser.add_argument("--overlap", type=float, default=2.0, help="Overlap in seconds")
args = parser.parse_args()
run_inference(
checkpoint_path=args.checkpoint,
audio_path=args.audio,
output_dir=args.output,
stems=args.stems,
fs=args.fs,
device=args.device,
use_half=not args.no_half,
chunk_seconds=args.chunk,
overlap_seconds=args.overlap,
)
|
Описание simple_inference.py
Скрипт состоит из трёх основных функций:
create_model()
Создаёт экземпляр модели Bandit с параметрами, соответствующими checkpoint:
in_channels=1— mono (stereo обрабатывается поканально)n_bands=64,n_sqm_modules=8n_fft=2048,hop_length=512fs=48000
simple_chunked_inference()
Обрабатывает аудио чанками с crossfade:
- Аудио делится на чанки (по умолчанию 30 секунд)
- Соседние чанки перекрываются (по умолчанию 2 секунды)
- В зоне перекрытия применяется линейный crossfade
- Каждый чанк обрабатывается отдельно, результат сразу переносится на CPU
- GPU кэш очищается после каждого чанка
run_inference()
Основная функция:
- Загружает модель и веса из checkpoint
- Загружает аудио (с ресемплингом если нужно)
- Для stereo — обрабатывает каждый канал отдельно
- Сохраняет результаты в отдельные WAV файлы для каждого stem
Особенности реализации:
- Autocast mixed precision для ускорения без проблем с типами
- Минимальное потребление VRAM (~2GB на 30-секундный чанк)
- Поддержка произвольной длительности аудио
- Сохранение stereo (не конвертируется в mono)
Инструкция по установке
Требования
- Windows 10/11 или Linux
- NVIDIA GPU с поддержкой CUDA 11.8
- Anaconda или Miniconda
- ~3GB свободного места для PyTorch
Шаг 1: Создание conda окружения
conda create -n bandit-v2 python=3.10 -y conda activate bandit-v2
Шаг 2: Установка PyTorch с CUDA 11.8
Внимание: Версия PyTorch 2.0.0 из оригинального requirements.txt больше недоступна. Используйте 2.5.0 или новее.
pip install torch==2.5.0+cu118 torchaudio==2.5.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
Шаг 3: Установка зависимостей
Внимание: Пакет asteroid требует сборки pesq с Visual C++ Build Tools. Для инференса он не нужен, поэтому не устанавливаем.
pip install "pytorch_lightning>=2.3.0" hydra-core omegaconf librosa soundfile einops tqdm julius huggingface-hub pyyaml scipy pandas
Шаг 4: Клонирование репозитория
git clone https://github.com/kwatcharasupat/bandit-v2.git cd bandit-v2
Шаг 5: Скачивание весов модели
Скачайте checkpoint-multi.ckpt с Zenodo и поместите в удобное место.
Шаг 6: Добавление simple_inference.py
Скопируйте скрипт simple_inference.py (см. раздел "Исходный код") в корень репозитория bandit-v2/.
Инструкция по запуску
Активация окружения
conda activate bandit-v2 cd путь/к/bandit-v2
Базовый запуск
python simple_inference.py -c "путь/к/checkpoint-multi.ckpt" -a "путь/к/аудио.wav"
Полный синтаксис
python simple_inference.py \
--checkpoint "путь/к/checkpoint-multi.ckpt" \
--audio "путь/к/аудио.wav" \
--output "путь/к/выходной/папке" \
--chunk 30 \
--overlap 2 \
--stems speech music sfx
| Параметр | По умолчанию | Описание |
|---|---|---|
-c, --checkpoint
|
(обязательный) | Путь к файлу весов .ckpt |
-a, --audio
|
(обязательный) | Путь к входному аудиофайлу |
-o, --output
|
./estimates/ | Папка для выходных файлов |
--chunk
|
30.0 | Размер чанка в секундах |
--overlap
|
2.0 | Размер перекрытия в секундах |
--stems
|
speech music sfx | Список stem для разделения |
--fs
|
48000 | Sample rate (менять не рекомендуется) |
--no-half
|
False | Отключить mixed precision |
Пример вывода
GPU: NVIDIA GeForce RTX 4090 VRAM: 22.5 GB === Config === Chunk: 30.0s, Overlap: 2.0s Half precision: True === Loading Model === Model loaded === Loading Audio === Duration: 420.5s, Channels: 2, SR: 48000 === Inference === Processing channel 1/2... Processing: 100%|████████████████| 15/15 [00:12<00:00, 1.20it/s] Processing channel 2/2... Processing: 100%|████████████████| 15/15 [00:12<00:00, 1.21it/s] Done in 24.9s (16.9x realtime) Peak VRAM: 1.99 GB === Saving === speech_estimate.wav music_estimate.wav sfx_estimate.wav Complete!
Выходные файлы
В указанной папке создаются три файла:
speech_estimate.wav— речь/диалогиmusic_estimate.wav— музыкаsfx_estimate.wav— звуковые эффекты
Все файлы сохраняются с тем же sample rate (48000 Hz) и количеством каналов, что и входной файл.