2019-02-23
Parameter size of different network architecture in torchvision.models

There are some built in models in torchvision.models, such as vgg or resnet. The same page also record a table about the Top1 & Top5 errors of different architecture. However, the parameter size of different models is missing, so I write a program and complement the table.

Read More

2019-02-22
Load pytorch model trained with nn.DataParallel

I encountered a problem when loading pytorch model trained with nn.DataParallel, which gives some errors like follows:

1
2
Missing key(s) in state_dict: "embedding_net.conv1.weight" ...
Unexpected key(s) in state_dict: "module.embedding_net.conv1.weight" ...

It seems that pytorch automatically warp the model with module, after google it around I found this solution. Basically two workarounds:

  1. Remove module in the key manually

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    # original saved file with DataParallel
    state_dict = torch.load('myfile.pth.tar')
    # create new OrderedDict that does not contain `module.`
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
    model.load_state_dict(new_state_dict)
  2. When save parameters, don’t save original model, save the parameter of module

    1
    2
    torch.save(model.module.state_dict(), path_to_file) # saving this
    torch.save(model.state_dict(), path_to_file) # instead of this
Read More