Skip to content

Questions about train.py and test.py #13

@BugMaker2002

Description

@BugMaker2002

Hello, can you tell me what this code does in the test.py file?(您好,请问您能否告诉我在test.py文件当中这段代码的作用是什么?)
image
And why are the model architectures used in train.py and test.py different? In train.py, you used a network architecture of ndim=2048, but in test.py, you seem to have changed to a network architecture of n_class=1, so if you load the pre-trained model directly, you will get an error because the dimensions of the parameters do not match. I think the normal approach would be to use the same network architecture on the training set and the validation set, the model on the validation set loads the trained model weights directly on the training set, and then we train an MLP with nn.Linear(2048, 1), but why don't you do that?(而且为什么在train.py和test.py当中使用的模型架构也不同?在train.py文件当中,您使用的是ndim=2048的网络架构,但是在test.py文件当中,您似乎又换成了n_class=1的网络架构,这样的话如果直接加载预训练模型是会报错的,因为参数的维度对应不上。我觉得正常的做法应该是:在训练集和验证集上使用相同的网络架构,验证集上的模型直接载入训练集训练好的模型权重,然后我们再去训练一个nn.Linear(2048, 1)的MLP,但是为什么您不选择这么做呢?)
image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions