Close
Sidebar
Search tutorials
Get Started
Documentation
XCurve.Losses.AUPRC
The following functions build the class to compute the List-stable AUPRC loss: CLASS XCurve.AUPRC.RetrievalDataset(data_dir, list_dir, subset, input_size, batchsize, num_sample_per_id, normal_mean=[0.485, 0.456, 0.406], normal_std=[0.229, 0.224, 0.225], split='train') [SOURCE]
|
Control the surrogate loss of pos-neg pairs. See \tau_1 in Eq.(7). Control the surrogate loss of pos-pos pairs. See \tau_2 in Eq.(7). Control the exponential moving average. See \beta in Eq.(10). Imbalance ratio of the id with most positive examples. Number of examples for each id. Weight of the variance regular term w.r.t. positive examples. Weight of the variance regular term w.r.t. negative examples. |
---|
Example:
import torch.nn.functional as F
from XCurve.AUPRC import (ListStableAUPRC, DefaultLossCfg, \
RetrievalDataset, DefaultInatDatasetCfg)
dataset = RetrievalDataset(**DefaultInatDatasetCfg, split='train')
criterion = ListStableAUPRC(**DefaultLossCfg)
criterion.update_cnt_per_id(dataset.get_cnt_per_id())
feats = F.normalize(torch.randn((16, 128)).cuda(), dim=1, p=2)
targets = torch.tensor([5, 5, 5, 5, 3, 3, 3, 3, 1, 1, 1, 1, 9, 9, 9, 9]).cuda()
loss = criterion(feats, targets)
print(loss.item())