From 0ec8d2fbf616df756ceeb0f6f1d5935ce4eed9d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maik=20Fr=C3=B6be?= Date: Thu, 5 Dec 2024 08:38:48 +0100 Subject: [PATCH] proposal for Recall@k --- trectools/trec_eval.py | 59 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/trectools/trec_eval.py b/trectools/trec_eval.py index 8683ad1..9a51609 100644 --- a/trectools/trec_eval.py +++ b/trectools/trec_eval.py @@ -402,6 +402,65 @@ def get_map(self, depth=1000, per_query=False, trec_eval=True): return (map_per_query.sum() / nqueries)[label] + def get_recall(self, depth=1000, per_query=False, trec_eval=True, removeUnjudged=False): + """ + Calculates the Recall. + + Params + ------- + depth: the evaluation depth. Default = 1000 + trec_eval: set to True if result should be the same as trec_eval, e.g., sort documents by score first. Default = True. + per_query: If True, runs the evaluation per query. Default = False + removeUnjudged: set to True if you want to remove the unjudged documents before calculating this metric. + + Returns + -------- + if per_query == True: returns a pandas dataframe with two cols (query, Recall@d) + else: returns a float value representing the Recall. + """ + label = "Recall@%d" % (depth) + + run = self.run.run_data + qrels = self.qrels.qrels_data + + if removeUnjudged: + onlyjudged = pd.merge(run, qrels[["query","docid","rel"]], how="left") + onlyjudged = onlyjudged[~onlyjudged["rel"].isnull()] + run = onlyjudged[["query","q0","docid","rank","score","system"]] + + # Select only topX documents per query + topX = run.groupby("query")[["query","docid","score"]].head(depth) + + # Make sure that rank position starts by 1 + topX["rank"] = 1 + topX["rank"] = topX.groupby("query")["rank"].cumsum() + topX["discount"] = 1. / np.log2(topX["rank"]+1) + + # Keep only documents that are relevant (rel > 0) + relevant_docs = qrels[qrels.rel > 0] + selection = pd.merge(topX, relevant_docs[["query","docid","rel"]], how="left") + selection = selection[~selection["rel"].isnull()] + + relevant_per_query = {} + recall_per_query = [] + for _, i in relevant_docs.iterrows(): + if i['query'] not in relevant_per_query: + relevant_per_query[i['query']] = set() + + relevant_per_query[i['query']].add(i['docid']) + + for query in relevant_per_query.keys(): + retrieved_docs = selection[selection['query'] == query]['docid'].unique() + retrieved_relevant_docs = [i for i in retrieved_docs if i in relevant_per_query[query]] + recall_per_query += [{'query': query, label: len(retrieved_relevant_docs)/ len(relevant_per_query[query])}] + + recall_per_query = pd.DataFrame(recall_per_query) + if per_query: + return recall_per_query + else: + return recall_per_query[label].mean() + + def get_rprec(self, depth=1000, per_query=False, trec_eval=True, removeUnjudged=False): """ The Precision at R, where R is the number of relevant documents for a topic.