ModelTrain¶
Train the neural networks.
Created on Wed Jul 27 22:30:20 2022 @author: Yuxiao Yi
- class ModelTrain.ModelTrain¶
Bases:
object- buildModel(args)¶
Create a neural network.
- loadModel(modelname, epoch)¶
Load a checkpoint on cpu.
- trainMethod(args)¶
Initialize the optimizer and loss function.
- getLoss(input, label, device, size=5000, count_eachloss=False)¶
Calculate the full loss on a given dataset. That is \(\sum_{i=1}^{n}Loss(f_{\theta}(x_i),y_i)\) where \(n\) is the size of dataset, \(f_{\theta}\) is DNN and \((x_i,y_i)\) is a data pair.
- Parameters
input (torch.Tensor) – The input dataset.
label (torch.Tensor) – The label dataset.
device (str) – CPU or GPU device .e.g ‘cuda:0’ ‘cuda:1’ or ‘cpu’.
size (int) – Batch size used to calculate the loss which is different from training batch size. To avoid GPU out of memory, we use mini batches. Default 5000.
count_eachloss (bool) – Whether calculate the loss of each dimension respectively. Default False.
- Returns
loss_value – The full loss on the given dataset.
- Return type
torch.Tensor
- trainingEntrance(input_train, label_train, input_valid, label_valid, args, norm, init_epoch=1)¶
The entrance to determine use which training method (single GPU, Data Parallel or DistributedDataParallel). If
args.use_DDP=True, calltrainDDP()else calltrain().- Parameters
input_train (torch.Tensor) – Input dataset for training.
label_train (torch.Tensor) – Label dataset for training.
input_valid (torch.Tensor) – Input dataset for validation
label_valid (torch.Tensor) – Label dataset for validation
args (argparse.Parser or easydict.EasyDict) – Hyper-parameters of the model.
norm (argparse.Parser or easydict.EasyDict) – Mean and standard deviation for normalization.
init_epoch (int) – Start epoch for training process. Default 1.
- train(input_train, label_train, input_valid, label_valid, args, norm, init_epoch)¶
Train the model when using single GPU or
torch.DataParallel.
- trainDDP(input_train, label_train, input_valid, label_valid, args, norm, init_epoch)¶
Train the model when using
torch.DistributedDataParallel.
- _runnerForDDP(rank, input_train, label_train, input_valid, label_valid, args, norm, init_epoch)¶
The runner to perform
torch.DistributedDataParallel.
- saveModel(args, norm, epoch)¶
Save the model checkpoint, hyperparametrers, and normalization. Checkpoints will be saved in .pt format and the other will be saved in .json format.
- Parameters
args (argparse.Parser or easydict.EasyDict) – Hyper-parameters of the model.
norm (argparse.Parser or easydict.EasyDict) – Mean and standard deviation for normalization.
epoch (int) – Current epoch from the begeining of the training.
- saveLoss(train_loss, valid_loss, args)¶
Save training loss and validation loss and draw the loss curves.
- Parameters
train_loss (numpy.array) – The traing loss.
valida_loss (numpy.array) – The validation loss.
args (argparse.Parser or easydict.EasyDict) – Hyper-parameters.
- static plotLoss(args, axis='semilogy', dpi=200)¶
Draw and save training loss and validation loss curves.
- static splitIndex(sample_num, rank, world_size)¶
The size will be splitted into world_size parts and the index of the rank part will be returned.
- Parameters
- Returns
split_array_index – The index of the cpu_id-th part.
- Return type