-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpredictors.py
More file actions
28 lines (24 loc) · 1.03 KB
/
predictors.py
File metadata and controls
28 lines (24 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from typing import Tuple
from overrides import overrides
from allennlp.common.util import JsonDict
from allennlp.data import Instance
from allennlp.predictors import Predictor
@Predictor.register("slot_tagging_predictor")
class SlotTaggingPredictor(Predictor):
def predict(self, inputs: JsonDict) -> JsonDict:
instance = self._json_to_instance(inputs)
output_dict = self.predict_instance(instance)
outputs = {
"tokens": inputs["tokens"],
"predict_labels": [self._model.vocab.get_token_from_index(index, namespace="labels")
for index in output_dict["predicted_tags"]],
"predict_score": output_dict["predicted_score"]
}
if "true_labels" in inputs:
outputs["true_labels"] = inputs["true_labels"]
return outputs
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
tokens = json_dict["tokens"]
instance = self._dataset_reader.text_to_instance(tokens=tokens)
return instance