pytorch-metric-learning

2.8.1last stable release 1 month ago
Complexity Score
High
Open Issues
N/A
Dependent Projects
18
Weekly Downloadsglobal
142,033

License

  • MIT
    • Yesattribution
    • Permissivelinking
    • Permissivedistribution
    • Permissivemodification
    • Nopatent grant
    • Yesprivate use
    • Permissivesublicensing
    • Notrademark grant

Downloads

Readme

News

December 11: v2.8.0

  • Added the Datasets module for easy downloading of common datasets:
    • CUB200
    • Cars196
    • INaturalist 2018
    • Stanford Online Products
  • Thank you ir2718.

November 2: v2.7.0

  • Added ThresholdConsistentMarginLoss.
  • Thank you ir2718.

Documentation

  • View the documentation here
  • View the installation instructions here
  • View the available losses, miners etc. here

Google Colab Examples

See the examples folder for notebooks you can download or run on Google Colab.

PyTorch Metric Learning Overview

This library contains 9 modules, each of which can be used independently within your existing codebase, or combined together for a complete train/test workflow.

How loss functions work

Using losses and miners in your training loop

Let’s initialize a plain TripletMarginLoss:

from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss()

To compute the loss in your training loop, pass in the embeddings computed by your model, and the corresponding labels. The embeddings should have size (N, embedding_size), and the labels should have size (N), where N is the batch size.

# your training loop
for i, (data, labels) in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = model(data)
    loss = loss_func(embeddings, labels)
    loss.backward()
    optimizer.step()

The TripletMarginLoss computes all possible triplets within the batch, based on the labels you pass into it. Anchor-positive pairs are formed by embeddings that share the same label, and anchor-negative pairs are formed by embeddings that have different labels.

Sometimes it can help to add a mining function:

from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner()
loss_func = losses.TripletMarginLoss()

# your training loop
for i, (data, labels) in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = model(data)
    hard_pairs = miner(embeddings, labels)
    loss = loss_func(embeddings, labels, hard_pairs)
    loss.backward()
    optimizer.step()

In the above code, the miner finds positive and negative pairs that it thinks are particularly difficult. Note that even though the TripletMarginLoss operates on triplets, it’s still possible to pass in pairs. This is because the library automatically converts pairs to triplets and triplets to pairs, when necessary.

Customizing loss functions

Loss functions can be customized using distances, reducers, and regularizers. In the diagram below, a miner finds the indices of hard pairs within a batch. These are used to index into the distance matrix, computed by the distance object. For this diagram, the loss function is pair-based, so it computes a loss per pair. In addition, a regularizer has been supplied, so a regularization loss is computed for each embedding in the batch. The per-pair and per-element losses are passed to the reducer, which (in this diagram) only keeps losses with a high value. The averages are computed for the high-valued pair and element losses, and are then added together to obtain the final loss.

Now here’s an example of a customized TripletMarginLoss:

from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.reducers import ThresholdReducer
from pytorch_metric_learning.regularizers import LpRegularizer
from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(distance = CosineSimilarity(), 
                     reducer = ThresholdReducer(high=0.3), 
             	     embedding_regularizer = LpRegularizer())

This customized triplet loss has the following properties:

  • The loss will be computed using cosine similarity instead of Euclidean distance.
  • All triplet losses that are higher than 0.3 will be discarded.
  • The embeddings will be L2 regularized.

Using loss functions for unsupervised / self-supervised learning

A SelfSupervisedLoss wrapper is provided for self-supervised learning:

from pytorch_metric_learning.losses import SelfSupervisedLoss
loss_func = SelfSupervisedLoss(TripletMarginLoss())

# your training for-loop
for i, data in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = your_model(data)
    augmented = your_model(your_augmentation(data))
    loss = loss_func(embeddings, augmented)
    loss.backward()
    optimizer.step()

If you’re interested in MoCo-style self-supervision, take a look at the MoCo on CIFAR10 notebook. It uses CrossBatchMemory to implement the momentum encoder queue, which means you can use any tuple loss, and any tuple miner to extract hard samples from the queue.

Highlights of the rest of the library

  • For a convenient way to train your model, take a look at the trainers.
  • Want to test your model’s accuracy on a dataset? Try the testers.
  • To compute the accuracy of an embedding space directly, use AccuracyCalculator.

If you’re short of time and want a complete train/test workflow, check out the example Google Colab notebooks.

To learn more about all of the above, see the documentation.

Installation

Required PyTorch version

  • pytorch-metric-learning >= v0.9.90 requires torch >= 1.6
  • pytorch-metric-learning < v0.9.90 doesn’t have a version requirement, but was tested with torch >= 1.2

Other dependencies: numpy, scikit-learn, tqdm, torchvision

Pip

pip install pytorch-metric-learning

To get the latest dev version:

pip install pytorch-metric-learning --pre

To install on Windows:

pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install pytorch-metric-learning

To install with evaluation and logging capabilities

(This will install the unofficial pypi version of faiss-gpu, plus record-keeper and tensorboard):

pip install pytorch-metric-learning[with-hooks]

To install with evaluation and logging capabilities (CPU)

(This will install the unofficial pypi version of faiss-cpu, plus record-keeper and tensorboard):

pip install pytorch-metric-learning[with-hooks-cpu]

Conda

conda install -c conda-forge pytorch-metric-learning

To use the testing module, you’ll need faiss, which can be installed via conda as well. See the installation instructions for faiss.

Benchmark results

See powerful-benchmarker to view benchmark results and to use the benchmarking tool.

Development

Development is done on the dev branch:

git checkout dev

Unit tests can be run with the default unittest library:

python -m unittest discover

You can specify the test datatypes and test device as environment variables. For example, to test using float32 and float64 on the CPU:

TEST_DTYPES=float32,float64 TEST_DEVICE=cpu python -m unittest discover

To run a single test file instead of the entire test suite, specify the file name:

python -m unittest tests/losses/test_angular_loss.py

Code is formatted using black and isort:

pip install black isort
./format_code.sh

Acknowledgements

Contributors

Thanks to the contributors who made pull requests!

Contributor Highlights domenicoMuscill0 - ManifoldLoss
- P2SGradLoss
- HistogramLoss
- DynamicSoftMarginLoss
- RankedListLoss mlopezantequera - Made the testers work on any combination of query and reference sets
- Made AccuracyCalculator work with arbitrary label comparisons cwkeam - SelfSupervisedLoss
- VICRegLoss
- Added mean reciprocal rank accuracy to AccuracyCalculator
- BaseLossWrapper ir2718 - ThresholdConsistentMarginLoss
- The Datasets module marijnl - BatchEasyHardMiner
- TwoStreamMetricLoss
- GlobalTwoStreamEmbeddingSpaceTester
- Example using trainers.TwoStreamMetricLoss chingisooinar SubCenterArcFaceLoss elias-ramzi HierarchicalSampler fjsj SupConLoss AlenUbuntu CircleLoss interestingzhuo PNPLoss wconnell Learning a scRNAseq Metric Embedding mkmenta Improved get_all_triplets_indices (fixed the INT_MAX error) AlexSchuy optimized utils.loss_and_miner_utils.get_random_triplet_indices JohnGiorgi all_gather in utils.distributed Hummer12007 utils.key_checker vltanh Made InferenceModel.train_indexer accept datasets btseytlin get_nearest_neighbors in InferenceModel mlw214 Added return_per_class to AccuracyCalculator layumi InstanceLoss NoTody Helped add ref_emb and ref_labels to the distributed wrappers. ElisonSherton Fixed an edge case in ArcFaceLoss. stompsjo Improved documentation for NTXentLoss. Puzer Bug fix for PNPLoss. elisim Developer improvements to DistributedLossWrapper. GaetanLepage z1w thinline72 tpanum fralik joaqo JoOkuma gkouros yutanakamura-tky KinglittleQ martin0258 michaeldeyzel HSinger04 rheum bot66

Facebook AI

Thank you to Ser-Nam Lim at Facebook AI, and my research advisor, Professor Serge Belongie. This project began during my internship at Facebook AI where I received valuable feedback from Ser-Nam, and his team of computer vision and machine learning engineers and research scientists. In particular, thanks to Ashish Shah and Austin Reiter for reviewing my code during its early stages of development.

Open-source repos

This library contains code that has been adapted and modified from the following great open-source repos:

  • https://github.com/bnu-wangxun/Deep_Metric
  • https://github.com/chaoyuaw/incubator-mxnet/blob/master/example/gluon/embedding_learning
  • https://github.com/facebookresearch/deepcluster
  • https://github.com/geonm/proxy-anchor-loss
  • https://github.com/idstcv/SoftTriple
  • https://github.com/kunhe/FastAP-metric-learning
  • https://github.com/ronekko/deep_metric_learning
  • https://github.com/tjddus9597/Proxy-Anchor-CVPR2020
  • http://kaizhao.net/regularface
  • https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts

Logo

Thanks to Jeff Musgrave for designing the logo.

Citing this library

If you’d like to cite pytorch-metric-learning in your paper, you can use this bibtex:

@article{Musgrave2020PyTorchML,
  title={PyTorch Metric Learning},
  author={Kevin Musgrave and Serge J. Belongie and Ser-Nam Lim},
  journal={ArXiv},
  year={2020},
  volume={abs/2008.09164}
}

Dependencies

CVE IssuesActive
0
Scorecards Score
4.40
Test Coverage
No Data
Follows Semver
No
Github Stars
6,041
Dependenciestotal
15
DependenciesOutdated
0
DependenciesDeprecated
0
Threat Modelling
No Data
Repo Audits
No Data

Learn how to distribute pytorch-metric-learning in your own private PyPI registry

pip install pytorch-metric-learning
Processing...
Done

211 Releases

PyPI on Cloudsmith

Getting started with PyPI on Cloudsmith is fast and easy.