Audio Classification: Training a Drum Kit Classifier

Audio Classification: Training a Drum Kit Classifier
This post describes a first attempt at an AI project I have been working on: audio classification of single drum hits — collecting one-shots from sample packs, building a labelled dataset, fine-tuning Wav2Vec2, and publishing both dataset and model to Huggingface.
The finished artefacts are public here:
- Dataset: airasoul/drum-kit
- Model: airasoul/wav2vec2-base-drum-kit
The original source Jupyter notebooks can be found in the following github repository:
In this post
- What was built (dataset, augmentation, training)
- Results
- Using the dataset and model
- Limitations
What was built
The goal was a classifier that takes a short WAV of a single drum or percussion hit and returns one of ten instrument labels. Everything below was done in a sequence of Jupyter notebooks; the early work is summarised here as outcomes rather than a tutorial.
Dataset. Roughly 3,000 one-shot WAV files were gathered from several open source sample packs. Files were sorted into ten classes by matching filenames against keyword rules — kick, bd, snare, hat, tom, and so on — then renamed to a consistent pattern (kick-000001.wav, etc.). All audio was resampled to 16 kHz mono, which matches facebook/wav2vec2-base and keeps training practical.
Classes: clap, conga, crash, cymbal, hat, kick, ride, rim, snare, tom.
Augmentation. To stretch the data without recording more hits, each original file was left in place and three variants were written alongside it: a slight time stretch, low-level noise, and gain change (±20%). Augmented files use suffixes -stretch, -noise, and -gain in the filename. That roughly quadruples the material the model sees while keeping labels tied to the parent folder.
Hub dataset. The folder tree was pushed to Hugging Face as airasoul/drum-kit. Labels follow the usual Hub convention: the parent directory name is the class. The split is 90% train / 10% test, fixed with seed 42. Each row exposes an audio feature (path, array, sampling rate) and a string label.
Model. facebook/wav2vec2-base was fine-tuned for ten-way classification on the Hub dataset. Clips are truncated or padded to 3 seconds at 16 kHz. Training ran for 10 epochs with evaluation each epoch; a weighted loss was used in places to reduce bias toward over-represented classes (e.g. the model defaulting to “snare”). The best checkpoint was published as airasoul/wav2vec2-base-drum-kit (~360 MB weights).
Results
On the held-out test split (300 samples), the published model reaches 97.0% accuracy.
Validation during training reached ~95.7% by epoch 10.
| Class | Precision | Recall | F1-score |
|---|---|---|---|
| clap | 1.00 | 1.00 | 1.00 |
| conga | 0.96 | 0.93 | 0.95 |
| crash | 0.97 | 0.97 | 0.97 |
| cymbal | 1.00 | 0.91 | 0.95 |
| hat | 1.00 | 0.97 | 0.98 |
| kick | 1.00 | 0.94 | 0.97 |
| ride | 0.94 | 1.00 | 0.97 |
| rim | 1.00 | 1.00 | 1.00 |
| snare | 0.89 | 0.96 | 0.93 |
| tom | 0.92 | 1.00 | 0.96 |
Load the dataset
Load the drum-kit dataset from the Hugging Face Hub using datasets.load_dataset. The dataset is already split into train (and optionally test) via the deployment script, so no manual split is needed. This cell also inspects the dataset features (audio, label) and overall structure.
from datasets import load_dataset, Audio
# load dataset
dataset = load_dataset("airasoul/drum-kit")
features = dataset['train'].features
# print('Features', features)
# print('Audio Features', features["audio"])
# print('Label Features', features["label"])
# print('Dataset', dataset)
Get labels from the dataset
Extract the unique label values from both the train and test splits, merge them into a single sorted list, and count the number of distinct classes. This gives you the full set of drum/instrument labels (e.g. kick, snare) used in the dataset.
# Get the labels for test
labels_test = dataset['train'].unique("label")
print("Available labels in test:", labels_test)
# Get the labels for train
labels_train = dataset['test'].unique("label")
print("Available labels in train:", labels_train)
# Merge the labels
labels = sorted(list(set(labels_train) | set(labels_test)))
print('Merged labels', labels)
# Get the number of labels
num_labels = len(labels)
print('Number of labels:', num_labels)
Index labels
Build label2id and id2label mappings so each class name maps to a numeric ID (and vice versa). These mappings are required for training or evaluating models that expect integer class indices (e.g. many audio classification models).
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
print('label2id, id2label ', label2id, id2label )
Display sample audio clips
Pick 5 random examples from the training set and display each one with IPython’s Audio widget. For each clip, show the filename, label, array shape, and sampling rate so you can quickly verify the data and listen to representative samples.
import random
from IPython.display import Audio, display
for _ in range(5):
rand_idx = random.randint(0, len(dataset['train'])-1)
example = dataset["train"][rand_idx]
audio = example["audio"]
labelt = example["label"]
print(f'Filename: {audio["path"]}')
print(f'Label: {labelt}')
print(f'Shape: {audio["array"].shape}')
print(f'Sampling rate: {audio["sampling_rate"]}')
display(Audio(audio["array"], rate=audio["sampling_rate"]))
Single-file inference
By setting AUDIO_PATH below, the model will predict the drum class.
import torch
import librosa
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
MODEL_PATH = "wav2vec2-base-finetuned-ks"
AUDIO_PATH = "clap-000001.wav" # change to your WAV
MAX_DURATION = 3.0
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_PATH)
model = AutoModelForAudioClassification.from_pretrained(MODEL_PATH)
model.to(device)
model.eval()
sr = feature_extractor.sampling_rate
audio, _ = librosa.load(AUDIO_PATH, sr=sr, mono=True)
max_length = int(sr * MAX_DURATION)
inputs = feature_extractor(
audio, sampling_rate=sr, max_length=max_length,
truncation=True, return_tensors="pt", padding=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
pred_id = model(**inputs).logits[0].argmax().item()
id2label = model.config.id2label
label = id2label.get(pred_id) or id2label.get(str(pred_id))
print(label)
Limitations
The model is trained on short, single-hit drum sounds. Performance will drop on full mixes, overlapping hits, or recordings very different from the training packs. It suits sample browsers, tagging workflows, or experimentation — not live kit identification in a band mix.
Notebook index
If you want to rebuild from scratch rather than consume the Hub artefacts:
| Notebook | What it does |
|---|---|
0-setup.ipynb |
Dependencies, sample cleanup |
1-working-with-audio.ipynb |
Sample rate, bit depth, amplitude notes |
2-creating-a-dataset.ipynb |
Organise and resample samples |
3-audio-augmentation.ipynb |
Stretch, noise, gain |
4-deploying-dataset.ipynb |
Push dataset to Hub |
5-dataset-usage.ipynb |
Load and inspect dataset |
6-train-with-transformers.ipynb |
Fine-tune Wav2Vec2 |
7-evaluate.ipynb |
Test metrics |
8-test.ipynb |
Local single-file inference |
9-publish.ipynb |
Push model to Hub |
Summary
Scattered one-shots from sample packs became airasoul/drum-kit and a Wav2Vec2 classifier at 97% test accuracy on held-out hits. The sections above on using the dataset and using the model are enough to download both and run classification on your own WAVs without touching the training notebooks.