forked from eriklindernoren/PyTorch-YOLOv3
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprune_std.py
More file actions
38 lines (32 loc) · 1.44 KB
/
prune_std.py
File metadata and controls
38 lines (32 loc) · 1.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn
import numpy as np
import pytorchyolo.train as train
import pytorchyolo.test as test
from pruning_modules import print_nonzeros
from pytorchyolo.utils.parse_config import parse_data_config
from pytorchyolo.models import load_model
import argparse
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Get threshold.")
parser.add_argument("--sen", type=float, default=0.25, help="Sensitivity for pruning.")
parser.add_argument("--train", type = bool, default = False, help = "Retrain Model.")
parser.add_argument("--prune", type = bool, default = False, help = "Prune Model.")
args = parser.parse_args()
print("Loading Model\n")
data = "config/coco.data"
model = "config/yolov3.cfg"
checkpoint_path = "weights/yolov3.weights"
data_config = parse_data_config(data)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(model,weights_path=checkpoint_path, pruning=args.prune)
print_nonzeros(model)
print("Pruned Test Outcome\n")
checkpoint_path = model.prune_by_std(checkpoint_path,s = args.sen)
print_nonzeros(model)
test.run(model =model,weights=checkpoint_path)
if (args.train==True):
print("Pruned and Retrained and Test Outcome\n")
model, checkpoint_path = train.run(model=model,pretrained_weights= checkpoint_path)
print_nonzeros(model)
test.run(model=model,weights =checkpoint_path)