ModelTrain

Train the neural networks.

Created on Wed Jul 27 22:30:20 2022 @author: Yuxiao Yi

class ModelTrain.ModelTrain

Bases: object

buildModel(args)

Create a neural network.

loadModel(modelname, epoch)

Load a checkpoint on cpu.

trainMethod(args)

Initialize the optimizer and loss function.

getLoss(input, label, device, size=5000, count_eachloss=False)

Calculate the full loss on a given dataset. That is \(\sum_{i=1}^{n}Loss(f_{\theta}(x_i),y_i)\) where \(n\) is the size of dataset, \(f_{\theta}\) is DNN and \((x_i,y_i)\) is a data pair.

Parameters
  • input (torch.Tensor) – The input dataset.

  • label (torch.Tensor) – The label dataset.

  • device (str) – CPU or GPU device .e.g ‘cuda:0’ ‘cuda:1’ or ‘cpu’.

  • size (int) – Batch size used to calculate the loss which is different from training batch size. To avoid GPU out of memory, we use mini batches. Default 5000.

  • count_eachloss (bool) – Whether calculate the loss of each dimension respectively. Default False.

Returns

loss_value – The full loss on the given dataset.

Return type

torch.Tensor

trainingEntrance(input_train, label_train, input_valid, label_valid, args, norm, init_epoch=1)

The entrance to determine use which training method (single GPU, Data Parallel or DistributedDataParallel). If args.use_DDP=True, call trainDDP() else call train().

Parameters
  • input_train (torch.Tensor) – Input dataset for training.

  • label_train (torch.Tensor) – Label dataset for training.

  • input_valid (torch.Tensor) – Input dataset for validation

  • label_valid (torch.Tensor) – Label dataset for validation

  • args (argparse.Parser or easydict.EasyDict) – Hyper-parameters of the model.

  • norm (argparse.Parser or easydict.EasyDict) – Mean and standard deviation for normalization.

  • init_epoch (int) – Start epoch for training process. Default 1.

train(input_train, label_train, input_valid, label_valid, args, norm, init_epoch)

Train the model when using single GPU or torch.DataParallel.

trainDDP(input_train, label_train, input_valid, label_valid, args, norm, init_epoch)

Train the model when using torch.DistributedDataParallel .

_runnerForDDP(rank, input_train, label_train, input_valid, label_valid, args, norm, init_epoch)

The runner to perform torch.DistributedDataParallel .

saveModel(args, norm, epoch)

Save the model checkpoint, hyperparametrers, and normalization. Checkpoints will be saved in .pt format and the other will be saved in .json format.

Parameters
  • args (argparse.Parser or easydict.EasyDict) – Hyper-parameters of the model.

  • norm (argparse.Parser or easydict.EasyDict) – Mean and standard deviation for normalization.

  • epoch (int) – Current epoch from the begeining of the training.

static saveJson(json_path, args: dict)

Save dict into .json which will be used in SaveModel

Parameters
  • json_path (str) – The path to save json file.

  • args (dict) – Hyper-parameters dict.

saveLoss(train_loss, valid_loss, args)

Save training loss and validation loss and draw the loss curves.

Parameters
  • train_loss (numpy.array) – The traing loss.

  • valida_loss (numpy.array) – The validation loss.

  • args (argparse.Parser or easydict.EasyDict) – Hyper-parameters.

static plotLoss(args, axis='semilogy', dpi=200)

Draw and save training loss and validation loss curves.

Parameters
  • args (argparse.Parser or easydict.EasyDict) – Hyper-parameters.

  • axis (str, optional) – Plot function, could be ‘semilogy’ or ‘loglog’. Default ‘semilogy’.

  • dpi (int, optional) – The dpi used to save figure. Default 200.

static splitIndex(sample_num, rank, world_size)

The size will be splitted into world_size parts and the index of the rank part will be returned.

Parameters
  • sample_num (int) – The size of given dataset.

  • rank (int) – The id of a process.

  • world_size (int) – Total process size.

Returns

split_array_index – The index of the cpu_id-th part.

Return type

numpy.ndarray