Source code for mindnlp.engine.callbacks.earlystop_callback

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Callback for Early Stop.
"""
from mindspore import log
from mindnlp.abc import Callback

[docs]class EarlyStopCallback(Callback): """ Stop training without getting better after n epochs. Args: patience (int): Numbers of epochs evaluations without raising. Default:10. larger_better (bool): Whether the larger value of the metric is better. Default:True. """ def __init__(self, patience=10, larger_better=True): self.wait = 0 self.patience = patience self.best_metrics_values = [] self.larger_better = larger_better
[docs] def evaluate_end(self, run_context): """ Called after evaluating. Args: run_context (RunContext): Information about the model. """ metrics_values = run_context.metrics_values if metrics_values is None: return if self.is_better_metric_value(metrics_values): self.wait = 0 self.best_metrics_values = metrics_values else: self.wait += 1 if self.wait >= self.patience: run_context.earlystop = True log.warning(f"After {self.wait} Evaluations, no improvement for " f"metric `{run_context.metrics_names}`(best value: {self.best_metrics_values})")
[docs] def is_better_metric_value(self, metrics_values): """ Compare each metrics values with the best metrics values. Args: metrics_values (float): metrics values used to compared with the best metrics values so far. """ if self.best_metrics_values == {}: return True values_larger = metrics_values > self.best_metrics_values better_or_not = values_larger & self.larger_better return better_or_not