Trainer

Trainer for training tasks.

class mindnlp.engine.trainer.Trainer(network, args=None, loss_fn: Optional[Cell] = None, optimizer: Optional[Cell] = None, train_dataset: Optional[Dataset] = None, eval_dataset: Optional[Dataset] = None, metrics: Optional[Metric] = None, callbacks: Optional[Union[Callback, List]] = None, **kwargs)[source]

Bases: object

Trainer to train the model.

Parameters:
  • network (Cell) – A training network.

  • train_dataset (Dataset) – A training dataset iterator. If loss_fn is defined, the data and label will be passed to the network and the loss_fn respectively, so a tuple (data, label) should be returned from dataset. If there is multiple data or labels, set loss_fn to None and implement calculation of loss in network, then a tuple (data1, data2, data3, …) with all data returned from dataset will be passed to the network.

  • eval_dataset (Dataset) – A evaluating dataset iterator. If loss_fn is defined, the data and label will be passed to the network and the loss_fn respectively, so a tuple (data, label) should be returned from dataset. If there is multiple data or labels, set loss_fn to None and implement calculation of loss in network, then a tuple (data1, data2, data3, …) with all data returned from dataset will be passed to the network.

  • metrics (Optional[list[Metrics], Metrics]) – List of metrics objects which should be used while evaluating. Default:None.

  • epochs (int) – Total number of iterations on the data. Default: 10.

  • optimizer (Cell) – Optimizer for updating the weights. If optimizer is None, the network needs to do backpropagation and update weights. Default value: None.

  • loss_fn (Cell) – Objective function. If loss_fn is None, the network should contain the calculation of loss and parallel if needed. Default: None.

  • callbacks (Optional[list[Callback], Callback]) – List of callback objects which should be executed while training. Default: None.

  • jit (bool) – Whether use Just-In-Time compile.

add_callback()[source]

add callback

evaluate()[source]

evalute

evaluate_loop()[source]

evaluate loop

predict(test_dataset)[source]
predict_loop()[source]

predict loop

predict_step(inputs, return_loss_only=False)[source]

predict step

remove_callback(name_or_type)[source]

remove callback

run(tgt_columns=None)[source]

Training process entry.

Parameters:

tgt_columns (Optional[list[str], str]) – Target label column names for loss function.

save_model(output_dir, model_name=None)[source]

save model to specify dir.

set_amp(level='O1', loss_scaler=None)[source]

set amp

set_forward_fn(forward_fn)[source]

set forward function

set_optimizer(optimizer)[source]

set optimizer

set_step_fn(step_fn)[source]

set step function

train(target_columns)[source]
train_loop(train_dataset)[source]

train loop

train_step(inputs)[source]

train step