-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils_bert.py
More file actions
106 lines (81 loc) · 3.85 KB
/
utils_bert.py
File metadata and controls
106 lines (81 loc) · 3.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import logging
import numpy as np
import torch.multiprocessing as mp
from typing import List, Dict, Tuple
from torch import Tensor
from typing import List, Dict, Union, Tuple
from abc import ABC, abstractmethod
from typing import Dict
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
class BaseSearch(ABC):
@abstractmethod
def search(self,
corpus: Dict[str, Dict[str, str]],
queries: Dict[str, str],
top_k: int,
**kwargs) -> Dict[str, Dict[str, float]]:
pass
class SentenceBERT:
def __init__(self, model_path: Union[str, Tuple] = None, sep: str = " ", **kwargs):
self.sep = sep
if isinstance(model_path, str):
self.q_model = SentenceTransformer(model_path)
self.doc_model = self.q_model
elif isinstance(model_path, tuple):
self.q_model = SentenceTransformer(model_path[0])
self.doc_model = SentenceTransformer(model_path[1])
def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[str, object]:
logger.info("Start multi-process pool on devices: {}".format(', '.join(map(str, target_devices))))
ctx = mp.get_context('spawn')
input_queue = ctx.Queue()
output_queue = ctx.Queue()
processes = []
for process_id, device_name in enumerate(target_devices):
p = ctx.Process(target=SentenceTransformer._encode_multi_process_worker, args=(process_id, device_name, self.doc_model, input_queue, output_queue), daemon=True)
p.start()
processes.append(p)
return {'input': input_queue, 'output': output_queue, 'processes': processes}
def stop_multi_process_pool(self, pool: Dict[str, object]):
output_queue = pool['output']
[output_queue.get() for _ in range(len(pool['processes']))]
return self.doc_model.stop_multi_process_pool(pool)
def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
return self.q_model.encode(queries, batch_size=batch_size, **kwargs)
def encode_corpus(self, corpus: Union[List[Dict[str, str]], Dict[str, List]], batch_size: int = 8, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
if type(corpus) is dict:
sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
else:
sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
return self.doc_model.encode(sentences, batch_size=batch_size, **kwargs)
def cos_sim(a: torch.Tensor, b: torch.Tensor):
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if len(a.shape) == 1:
a = a.unsqueeze(0)
if len(b.shape) == 1:
b = b.unsqueeze(0)
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1)) #TODO: this keeps allocating GPU memory
def dot_score(a: torch.Tensor, b: torch.Tensor):
"""
Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = dot_prod(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if len(a.shape) == 1:
a = a.unsqueeze(0)
if len(b.shape) == 1:
b = b.unsqueeze(0)
return torch.mm(a, b.transpose(0, 1))