-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodule.py
More file actions
45 lines (36 loc) · 1.19 KB
/
module.py
File metadata and controls
45 lines (36 loc) · 1.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Base class that all dataset models will inherit from
from __future__ import print_function
import nd_aggregation
import mxnet as mx
from mxnet import nd, autograd, gluon
import numpy as np
import random
import argparse
import byzantine
import os
import json
import gluonnlp
from abc import ABC, abstractmethod
class Module(ABC):
@abstractmethod
def createModel(self):
return None
@abstractmethod
def loadTrainingData(self, ctx):
# Return training data and masks (currently masks only for REDDIT dataset)
return None, None
@abstractmethod
def loadTestingData(self):
return None
def getAccuracyMetric(self):
# Default accuracy if not otherwise specified
return mx.metric.Accuracy()
def getPredictionsFromNetworkOutput(self, output):
# Default argmax on axis 1
return nd.argmax(output, axis=1)
def initializeModel(self, net, ctx):
# Default model initialization with Xavier
net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), force_reinit=True, ctx=ctx)
def getLossFunction(self):
# Default cross entropy loss function
return gluon.loss.SoftmaxCrossEntropyLoss()