-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtemp.py
More file actions
65 lines (40 loc) · 1.72 KB
/
temp.py
File metadata and controls
65 lines (40 loc) · 1.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from numpy.core import function_base
import tvm
from tvm import relay
from tvm.relay import GlobalVar
from tvm.relay.op.nn.nn import batch_flatten
from tvm.relay.op.tensor import cos
from tvm.relay.testing import run_infer_type
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient, PartialEvaluate, DeadCodeElimination
from tvm.relay.testing import run_infer_type, rand, check_grad, create_workload
import numpy as np
from tvm.relay.transform.transform import ForwardFoldScaleAxis
from tvm.contrib import graph_runtime as runtime
def normal_mlp(data_shape, label_shape):
data = relay.var("data", shape=data_shape, dtype="float32")
label = relay.var("data", shape=label_shape, dtype="float32")
fc1 = relay.nn.dense(data, relay.var("fc1_weight"), units=128)
fc1 = relay.nn.bias_add(fc1, relay.var("fc1_bias"), axis=-1)
act1 = relay.nn.relu(fc1)
fc2 = relay.nn.dense(act1, relay.var("fc2_weight"), units=64)
fc2 = relay.nn.bias_add(fc2, relay.var("fc2_bias"), axis=-1)
act2 = relay.nn.relu(fc2)
fc3 = relay.nn.dense(act2, relay.var("fc3_weight"), units=10)
fc3 = relay.nn.bias_add(fc3, relay.var("fc3_bias"), axis=-1)
mlp = relay.nn.softmax(data=fc3)
mlp = relay.nn.cross_entropy(mlp, label)
args = relay.analysis.free_vars(mlp)
return relay.Function(args, mlp)
batch_size = 10
data_shape = (batch_size, 784)
label_shape = (batch_size, 10)
valid_shape = (100, 784)
dtype = "float32"
func = normal_mlp(data_shape, label_shape)
func = run_infer_type(func)
back = run_infer_type(gradient(func))
mod, params = create_workload(back)
opt_level = 0
with tvm.transform.PassContext(opt_level=opt_level):
lib = relay.build(mod, target='llvm', params=params)