Forum

PyTorch: save & loa...
 
Notifications
Clear all

PyTorch: save & load model with optimizer and epoch

1 Posty
1 Users
0 Likes
134 Widok
0
Topic starter

how to save model, optimizer and epoch checkpoint?

This topic was modified 2 miesiące temu by mrmucha
1 Answer
0
Topic starter

save function

    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, MODEL_PATH)

load 

    checkpoint = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer = optim.Adam(model.parameters(), lr=LR)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.cuda()
    
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

 

  • optimizer gpu tensor fix: as you see thereis simple hack that fix optimizer tensor, default it's loaded into CPU, we need to move it into gpu
  • important: before any load, remember to set your model layers and other stuff, then load from checkpoint, below is example how to load model (my example)
        model = models.resnet50(weights=None)
        model.cuda()
        model.fc = nn.Sequential(
            nn.Linear(model.fc.in_features, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024,512),
            nn.ReLU(inplace=True),
            nn.Linear(512,NUM_LABELS).cuda())
    
        for param in model.parameters():
            param.requires_grad = True
    
        checkpoint = torch.load(MODEL_PATH, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer = optim.Adam(model.parameters(), lr=LR)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()
        
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
    
        print("----- FINETUNED MODEL LOADED -----")

 

Odpowiedź

Author Name

Author Email

Your question *

 
Preview 0 Revisions Saved
Share: