forked from LeonGuertler/SuperTinyLanguageModels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
35 lines (25 loc) · 883 Bytes
/
generate.py
File metadata and controls
35 lines (25 loc) · 883 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
The main generate code
"""
import hydra
import torch
from models.build_models import build_model
from models.generator import StandardGenerator
@hydra.main(config_path="configs", config_name="generate")
def main(cfg):
"""run the main eval loop"""
# set the checkpoint path to absolute path
cfg["model_ckpt"] = hydra.utils.to_absolute_path(cfg["model_ckpt"])
# load checkpoint from the path
model = build_model(checkpoint=torch.load(cfg["model_ckpt"]))
generator = StandardGenerator(model=model, generate_cfg=cfg["generator"])
while True:
# generate the text
generated_text = generator.default_generate(
input_text=input("Enter the input text: ")
)
print(generated_text)
if __name__ == "__main__":
# pylint: disable=no-value-for-parameter
main()
# pylint: enable=no-value-for-parameter