Close
Sidebar
Search tutorials
Get Started
Documentation
XCurve.Metrics.AUPRC
Compute Area Under the PR Curve (AUPRC) for the retrieval task. Given the d-dim features $\{z_i\}_{i=1}^n$ and category labels $\{c_i\}_{i=1}^n$, the function takes an example with index $q$ as the query, and then compute the scores and targets: $$ s_i = z_q^\top z_i / (\|z_q\|\|z_i\|),~~ 1\leq i \leq n, i \neq q. $$ Denote $I^+ = \{i|c_i = c_q, 1\leq i \leq n, i \neq q\}$, $I^- = \{i|c_i \neq c_q, 1\leq i \leq n, i \neq q\}$. Afterward, the AUPRC of the query $q$ is defined as $$ \text{AUPRC} = \frac{1}{|I^+|} \sum_{i\\in I^+} \frac{\sum_{j\in I^+}\mathbb{I}[s_i \leq s_j]}{\sum_{j\in I^+ \cup I^-}\mathbb{I}[s_i \leq s_j]}. $$ The overall AUPRC is measures by the average value of all possible $q$: $$ \text{AUPRC}_q = \frac{1}{n} \sum_{q=1}^n \text{AUPRC}_q. $$
|
(torch.Tensor or np.ndarry): Input features of shape (N_samples, embedding_dim). (torch.Tensor or np.ndarry): Ground truth of shape (N_samples). (int or list[int]): List of k to computer Recall@k. |
---|---|
|
return corresponding AUPRC score. |
Example:
import torch
from XCurve.AUPRC import AUPRC, RecallAtK
feats = torch.randn((2**14, 128)).numpy()
targets = torch.randint(0, 1000, (2**14, 1)).numpy()
print(AUPRC(feats, targets))
print(RecallAtK(feats, targets, 1))
print(RecallAtK(feats, targets, [1,4,16]))