-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
28 lines (19 loc) · 722 Bytes
/
main.py
File metadata and controls
28 lines (19 loc) · 722 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from models.model import load_model_and_processor
from preprocessing.data_loader import load_dataset, create_dataloader
from training.trainer import train_model
def main():
print("Is CUDA enabled?",torch.cuda.is_available())
# Załaduj model i procesor
model, processor = load_model_and_processor()
# Załaduj dane
dataset = load_dataset("./data/labels.csv")
# Przygotuj DataLoader
dataloader = create_dataloader(dataset, processor)
# Trenuj model
train_model(model, dataloader)
# Zapisz przetrenowany model
model.save_pretrained("models/fine_tuned_trocr")
processor.save_pretrained("models/fine_tuned_trocr")
if __name__ == "__main__":
main()