diff --git a/load_t7.py b/load_t7.py index 3665d3c..954d9fb 100644 --- a/load_t7.py +++ b/load_t7.py @@ -41,7 +41,8 @@ def load(o, param_list): param_list.append(temp) # batch norm elif o['modules'][i]._typename == 'nn.SpatialBatchNormalization' or \ - o['modules'][i]._typename == 'nn.VolumetricBatchNormalization': + o['modules'][i]._typename == 'cudnn.SpatialBatchNormalization' or \ + o['modules'][i]._typename == 'nn.VolumetricBatchNormalization': param_list[-1]['gamma'] = o['modules'][i]['weight'] param_list[-1]['beta'] = o['modules'][i]['bias'] param_list[-1]['mean'] = o['modules'][i]['running_mean']