skorch.callbacks¶
This module serves to elevate callbacks in submodules to the skorch.callback namespace. Remember to define __all__ in each submodule.
-
class
skorch.callbacks.Callback[source]¶ Base class for callbacks.
All custom callbacks should inherit from this class. The subclass may override any of the
on_...methods. It is, however, not necessary to override all of them, since it’s okay if they don’t have any effect.Classes that inherit from this also gain the
get_paramsandset_paramsmethod.Methods
initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, **kwargs)Called at the end of each batch. on_epoch_begin(net, **kwargs)Called at the beginning of each epoch. on_epoch_end(net, **kwargs)Called at the end of each epoch. on_grad_computed(net, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, **kwargs)Called at the beginning of training. on_train_end(net, **kwargs)Called at the end of training. get_params set_params -
initialize()[source]¶ (Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.
This method should return self.
-
-
class
skorch.callbacks.EpochTimer(**kwargs)[source]¶ Measures the duration of each epoch and writes it to the history with the name
dur.Methods
initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, **kwargs)Called at the end of each batch. on_epoch_begin(net, **kwargs)Called at the beginning of each epoch. on_epoch_end(net, **kwargs)Called at the end of each epoch. on_grad_computed(net, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, **kwargs)Called at the beginning of training. on_train_end(net, **kwargs)Called at the end of training. get_params set_params
-
class
skorch.callbacks.PrintLog(keys_ignored=None, sink=<built-in function print>, tablefmt='simple', floatfmt='.4f', stralign='right')[source]¶ Print useful information from the model’s history as a table.
By default,
PrintLogprints everything from the history except for'batches'.To determine the best loss,
PrintLoglooks for keys that end on'_best'and associates them with the corresponding loss. E.g.,'train_loss_best'will be matched with'train_loss'. TheScoringcallback takes care of creating those entries, which is whyPrintLogworks best in conjunction with that callback.PrintLogtreats keys with the'event_'prefix in a special way. They are assumed to contain information about occasionally occuring events. TheFalseorNoneentries (indicating that an event did not occur) are not printed, resulting in empty cells in the table, andTrueentries are printed with+symbol.PrintLoggroups all event columns together and pushes them to the right, just before the'dur'column.Note:
PrintLogwill not result in good outputs if the number of columns varies between epochs, e.g. if the valid loss is only present on every other epoch.Parameters: - keys_ignored : str or list of str (default=None)
Key or list of keys that should not be part of the printed table. Note that keys ending on ‘_best’ are also ignored.
- sink : callable (default=print)
The target that the output string is sent to. By default, the output is printed to stdout, but the sink could also be a logger, etc.
- tablefmt : str (default=’simple’)
The format of the table. See the documentation of the
tabulatepackage for more detail. Can be ‘plain’, ‘grid’, ‘pipe’, ‘html’, ‘latex’, among others.- floatfmt : str (default=’.4f’)
The number formatting. See the documentation of the
tabulatepackage for more details.- stralign : str (default=’right’)
The alignment of columns with strings. Can be ‘left’, ‘center’, ‘right’, or
None(disable alignment). Default is ‘right’ (to be consistent with numerical columns).
Methods
format_row(row, key, color)For a given row from the table, format it (i.e. initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, **kwargs)Called at the end of each batch. on_epoch_begin(net, **kwargs)Called at the beginning of each epoch. on_epoch_end(net, **kwargs)Called at the end of each epoch. on_grad_computed(net, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, **kwargs)Called at the beginning of training. on_train_end(net, **kwargs)Called at the end of training. get_params set_params table -
format_row(row, key, color)[source]¶ For a given row from the table, format it (i.e. floating points and color if applicable).
-
class
skorch.callbacks.ProgressBar(batches_per_epoch='count', detect_notebook=True, postfix_keys=None)[source]¶ Display a progress bar for each epoch including duration, estimated remaining time and user-defined metrics.
For jupyter notebooks a non-ASCII progress bar is printed instead. To use this feature, you need to have ipywidgets <https://ipywidgets.readthedocs.io/en/stable/user_install.html> installed.
Parameters: - batches_per_epoch : int, str (default=’count’)
The progress bar determines the number of batches per epoch by itself in
'count'mode where the number of iterations is determined after one epoch which will leave you without a progress bar at the first epoch. To fix that you can provide this number manually or set'auto'where the callback attempts to compute the number of batches per epoch beforehand.- detect_notebook : bool (default=True)
If enabled, the progress bar determines if its current environment is a jupyter notebook and switches to a non-ASCII progress bar.
- postfix_keys : list of str (default=[‘train_loss’, ‘valid_loss’])
You can use this list to specify additional info displayed in the progress bar such as metrics and losses. A prerequisite to this is that these values are residing in the history on batch level already, i.e. they must be accessible via
>>> net.history[-1, 'batches', -1, key]
Methods
initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, **kwargs)Called at the end of each batch. on_epoch_begin(net[, X, X_valid])Called at the beginning of each epoch. on_epoch_end(net, **kwargs)Called at the end of each epoch. on_grad_computed(net, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, **kwargs)Called at the beginning of training. on_train_end(net, **kwargs)Called at the end of training. get_params in_ipynb set_params
-
class
skorch.callbacks.LRScheduler(policy='WarmRestartLR', monitor='train_loss', **kwargs)[source]¶ Callback that sets the learning rate of each parameter group according to some policy.
Parameters: - policy : str or _LRScheduler class (default=’WarmRestartLR’)
Learning rate policy name or scheduler to be used.
- monitor : str or callable (default=None)
Value of the history to monitor or function/callable. In the latter case, the callable receives the net instance as argument and is expected to return the score (float) used to determine the learning rate adjustment.
- kwargs
Additional arguments passed to the lr scheduler.
Attributes: - kwargs
Methods
initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, **kwargs)Called at the end of each batch. on_epoch_begin(net, **kwargs)Called at the beginning of each epoch. on_epoch_end(net, **kwargs)Called at the end of each epoch. on_grad_computed(net, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, **kwargs)Called at the beginning of training. on_train_end(net, **kwargs)Called at the end of training. simulate(steps, initial_lr)Simulates the learning rate scheduler. get_params set_params
-
class
skorch.callbacks.WarmRestartLR(optimizer, min_lr=1e-06, max_lr=0.05, base_period=10, period_mult=2, last_epoch=-1)[source]¶ Stochastic Gradient Descent with Warm Restarts (SGDR) scheduler.
This scheduler sets the learning rate of each parameter group according to stochastic gradient descent with warm restarts (SGDR) policy. This policy simulates periodic warm restarts of SGD, where in each restart the learning rate is initialize to some value and is scheduled to decrease.
Parameters: - optimizer : torch.optimizer.Optimizer instance.
Optimizer algorithm.
- min_lr : float or list of float (default=1e-6)
Minimum allowed learning rate during each period for all param groups (float) or each group (list).
- max_lr : float or list of float (default=0.05)
Maximum allowed learning rate during each period for all param groups (float) or each group (list).
- base_period : int (default=10)
Initial restart period to be multiplied at each restart.
- period_mult : int (default=2)
Multiplicative factor to increase the period between restarts.
- last_epoch : int (default=-1)
The index of the last valid epoch.
References
[1] Ilya Loshchilov and Frank Hutter, 2017, “Stochastic Gradient Descent with Warm Restarts,”. “ICLR” https://arxiv.org/pdf/1608.03983.pdf Methods
load_state_dict(state_dict)Loads the schedulers state. state_dict()Returns the state of the scheduler as a dict.get_lr step
-
class
skorch.callbacks.CyclicLR(optimizer, base_lr=0.001, max_lr=0.006, step_size_up=2000, step_size_down=None, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', last_batch_idx=-1, step_size=None)[source]¶ Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). The policy cycles the learning rate between two boundaries with a constant frequency, as detailed in the paper. The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis.
Cyclical learning rate policy changes the learning rate after every batch.
batch_stepshould be called after a batch has been used for training. To resume training, save last_batch_idx and use it to instantiateCycleLR.This class has three built-in policies, as put forth in the paper:
- “triangular”:
- A basic triangular cycle w/ no amplitude scaling.
- “triangular2”:
- A basic triangular cycle that scales initial amplitude by half each cycle.
- “exp_range”:
- A cycle that scales initial amplitude by gamma**(cycle iterations) at each cycle iteration.
This implementation was adapted from the github repo: bckenstler/CLR
Parameters: - optimizer : torch.optimizer.Optimizer instance.
Optimizer algorithm.
- base_lr : float or list of float (default=1e-3)
Initial learning rate which is the lower boundary in the cycle for each param groups (float) or each group (list).
- max_lr : float or list of float (default=6e-3)
Upper boundaries in the cycle for each parameter group (float) or each group (list). Functionally, it defines the cycle amplitude (max_lr - base_lr). The lr at any cycle is the sum of base_lr and some scaling of the amplitude; therefore max_lr may not actually be reached depending on scaling function.
- step_size_up : int (default=2000)
Number of training iterations in the increasing half of a cycle.
- step_size_down : int (default=None)
Number of training iterations in the decreasing half of a cycle. If step_size_down is None, it is set to step_size_up.
- mode : str (default=’triangular’)
One of {triangular, triangular2, exp_range}. Values correspond to policies detailed above. If scale_fn is not None, this argument is ignored.
- gamma : float (default=1.0)
Constant in ‘exp_range’ scaling function: gamma**(cycle iterations)
- scale_fn : function (default=None)
Custom scaling policy defined by a single argument lambda function, where 0 <= scale_fn(x) <= 1 for all x >= 0. mode paramater is ignored.
- scale_mode : str (default=’cycle’)
One of {‘cycle’, ‘iterations’}. Defines whether scale_fn is evaluated on cycle number or cycle iterations (training iterations since start of cycle).
- last_batch_idx : int (default=-1)
The index of the last batch.
References
[1] Leslie N. Smith, 2017, “Cyclical Learning Rates for Training Neural Networks,”. “ICLR” https://arxiv.org/abs/1506.01186 Examples
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> scheduler = torch.optim.CyclicLR(optimizer) >>> data_loader = torch.utils.data.DataLoader(...) >>> for epoch in range(10): >>> for batch in data_loader: >>> scheduler.batch_step() >>> train_batch(...)
Methods
batch_step([batch_idx])Updates the learning rate for the batch index: batch_idx.get_lr()Calculates the learning rate at batch index: self.last_batch_idx.step([epoch])Not used by CyclicLR, use batch_step instead.
-
class
skorch.callbacks.GradientNormClipping(gradient_clip_value=None, gradient_clip_norm_type=2)[source]¶ Clips gradient norm of a module’s parameters.
The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
See
torch.nn.utils.clip_grad_norm_()for more information.Parameters: - gradient_clip_value : float (default=None)
If not None, clip the norm of all model parameter gradients to this value. The type of the norm is determined by the
gradient_clip_norm_typeparameter and defaults to L2.- gradient_clip_norm_type : float (default=2)
Norm to use when gradient clipping is active. The default is to use L2-norm. Can be ‘inf’ for infinity norm.
Methods
initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, **kwargs)Called at the end of each batch. on_epoch_begin(net, **kwargs)Called at the beginning of each epoch. on_epoch_end(net, **kwargs)Called at the end of each epoch. on_grad_computed(_, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, **kwargs)Called at the beginning of training. on_train_end(net, **kwargs)Called at the end of training. get_params set_params
-
class
skorch.callbacks.BatchScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]¶ Callback that performs generic scoring on batches.
This callback determines the score after each batch and stores it in the net’s history in the column given by
name. At the end of the epoch, the average of the scores are determined and also stored in the history. Furthermore, it is determined whether this average score is the best score yet and that information is also stored in the history.In contrast to
EpochScoring, this callback determines the score for each batch and then averages the score at the end of the epoch. This can be disadvantageous for some scores if the batch size is small – e.g. area under the ROC will return incorrect scores in this case. Therefore, it is recommnded to useEpochScoringunless you really need the scores for each batch.If
yis None, thescoringfunction with signature (model, X, y) must be able to handleXas aTensorandy=None.Parameters: - scoring : None, str, or callable
If None, use the
scoremethod of the model. If str, it should be a valid sklearn metric (e.g. “f1_score”, “accuracy_score”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to thescoringparameter in sklearn’sGridSearchCVet al.- lower_is_better : bool (default=True)
Whether lower (e.g. log loss) or higher (e.g. accuracy) scores are better
- on_train : bool (default=False)
Whether this should be called during train or validation.
- name : str or None (default=None)
If not an explicit string, tries to infer the name from the
scoringargument.- target_extractor : callable (default=to_numpy)
This is called on y before it is passed to scoring.
- use_caching : bool (default=True)
Re-use the model’s prediction for computing the loss to calculate the score. Turning this off will result in an additional inference step for each batch.
Methods
initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, X, y, training, **kwargs)Called at the end of each batch. on_epoch_begin(net, **kwargs)Called at the beginning of each epoch. on_epoch_end(net, **kwargs)Called at the end of each epoch. on_grad_computed(net, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, X, y, **kwargs)Called at the beginning of training. on_train_end(net, **kwargs)Called at the end of training. get_avg_score get_params set_params
-
class
skorch.callbacks.EpochScoring(scoring, lower_is_better=True, on_train=False, name=None, target_extractor=<function to_numpy>, use_caching=True)[source]¶ Callback that performs generic scoring on predictions.
At the end of each epoch, this callback makes a prediction on train or validation data, determines the score for that prediction and whether it is the best yet, and stores the result in the net’s history.
In case you already computed a score value for each batch you can omit the score computation step by return the value from the history. For example:
>>> def my_score(net, X=None, y=None): ... losses = net.history[-1, 'batches', :, 'my_score'] ... batch_sizes = net.history[-1, 'batches', :, 'valid_batch_size'] ... return np.average(losses, weights=batch_sizes) >>> net = MyNet(callbacks=[ ... ('my_score', Scoring(my_score, name='my_score'))
If you fit with a custom dataset, this callback should work as expected as long as
use_caching=Truewhich enables the collection ofyvalues from the dataset. If you decide to disable the caching of predictions andyvalues, you need to write your own scoring function that is able to deal with the dataset and returns a scalar, for example:>>> def ds_accuracy(net, ds, y=None): ... # assume ds yields (X, y), e.g. torchvision.datasets.MNIST ... y_true = [y for _, y in ds] ... y_pred = net.predict(ds) ... return sklearn.metrics.accuracy_score(y_true, y_pred) >>> net = MyNet(callbacks=[ ... EpochScoring(ds_accuracy, use_caching=False)]) >>> ds = torchvision.datasets.MNIST(root=mnist_path) >>> net.fit(ds)
Parameters: - scoring : None, str, or callable (default=None)
If None, use the
scoremethod of the model. If str, it should be a valid sklearn scorer (e.g. “f1”, “accuracy”). If a callable, it should have the signature (model, X, y), and it should return a scalar. This works analogously to thescoringparameter in sklearn’sGridSearchCVet al.- lower_is_better : bool (default=True)
Whether lower scores should be considered better or worse.
- on_train : bool (default=False)
Whether this should be called during train or validation data.
- name : str or None (default=None)
If not an explicit string, tries to infer the name from the
scoringargument.- target_extractor : callable (default=to_numpy)
This is called on y before it is passed to scoring.
- use_caching : bool (default=True)
Collect labels and predictions (
y_trueandy_pred) over the course of one epoch and use the cached values for computing the score. The cached values are shared between allEpochScoringinstances. Disabling this will result in an additional inference step for each epoch and an inability to use arbitrary datasets as input (since we don’t know how to extracty_truefrom an arbitrary dataset).
Methods
initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, y, y_pred, training, **kwargs)Called at the end of each batch. on_epoch_begin(net, dataset_train, …)Called at the beginning of each epoch. on_epoch_end(net, dataset_train, …)Called at the end of each epoch. on_grad_computed(net, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, X, y, **kwargs)Called at the beginning of training. on_train_end(*args, **kwargs)Called at the end of training. get_params set_params -
initialize()[source]¶ (Re-)Set the initial state of the callback. Use this e.g. if the callback tracks some state that should be reset when the model is re-initialized.
This method should return self.
-
class
skorch.callbacks.Checkpoint(target=None, monitor='valid_loss_best', f_params='params.pt', f_history=None, f_pickle=None, sink=<function noop>)[source]¶ Save the model during training if the given metric improved.
This callback works by default in conjunction with the validation scoring callback since it creates a
valid_loss_bestvalue in the history which the callback uses to determine if this epoch is save-worthy.You can also specify your own metric to monitor or supply a callback that dynamically evaluates whether the model should be saved in this epoch.
Some or all of the following can be saved:
- model parameters (see
f_paramsparameter); - training history (see
f_historyparameter); - entire model object (see
f_pickleparameter).
You can implement your own save protocol by subclassing
Checkpointand overridingsave_model().This callback writes a bool flag to the history column
event_cpindicating whether a checkpoint was created or not.Example:
>>> net = MyNet(callbacks=[Checkpoint()]) >>> net.fit(X, y)
Example using a custom monitor where models are saved only in epochs where the validation and the train losses are best:
>>> monitor = lambda net: all(net.history[-1, ( ... 'train_loss_best', 'valid_loss_best')]) >>> net = MyNet(callbacks=[Checkpoint(monitor=monitor)]) >>> net.fit(X, y)
Parameters: - target : deprecated
- monitor : str, function, None
Value of the history to monitor or callback that determines whether this epoch should lead to a checkpoint. The callback takes the network instance as parameter.
In case
monitoris set toNone, the callback will save the network at every epoch.Note: If you supply a lambda expression as monitor, you cannot pickle the wrapper anymore as lambdas cannot be pickled. You can mitigate this problem by using importable functions instead.
- f_params : file-like object, str, None (default=’params.pt’)
File path to the file or file-like object where the model parameters should be saved. Pass
Noneto disable saving model parameters.If the value is a string you can also use format specifiers to, for example, indicate the current epoch. Accessible format values are
net,last_epochandlast_batch. Example to include last epoch number in file name:>>> cb = Checkpoint(f_params="params_{last_epoch[epoch]}.pt")
- f_history : file-like object, str, None (default=None)
File path to the file or file-like object where the model training history should be saved. Pass
Noneto disable saving history.Supports the same format specifiers as
f_params.- f_pickle : file-like object, str, None (default=None)
File path to the file or file-like object where the entire model object should be pickled. Pass
Noneto disable pickling.Supports the same format specifiers as
f_params.- sink : callable (default=noop)
The target that the information about created checkpoints is sent to. This can be a logger or
printfunction (to send to stdout). By default the output is discarded.
Methods
initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, **kwargs)Called at the end of each batch. on_epoch_begin(net, **kwargs)Called at the beginning of each epoch. on_epoch_end(net, **kwargs)Called at the end of each epoch. on_grad_computed(net, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, **kwargs)Called at the beginning of training. on_train_end(net, **kwargs)Called at the end of training. save_model(net)Save the model. get_params set_params - model parameters (see
-
class
skorch.callbacks.EarlyStopping(monitor='valid_loss', patience=5, threshold=0.0001, threshold_mode='rel', lower_is_better=True, sink=<built-in function print>)[source]¶ Callback for stopping training when scores don’t improve.
Stop training early if a specified monitor metric did not improve in patience number of epochs by at least threshold.
Parameters: - monitor : str (default=’valid_loss’)
Value of the history to monitor to decide whether to stop training or not. The value is expected to be double and is commonly provided by scoring callbacks such as
skorch.callbacks.EpochScoring.- lower_is_better : bool (default=True)
Whether lower scores should be considered better or worse.
- patience : int (default=5)
Number of epochs to wait for improvement of the monitor value until the training process is stopped.
- threshold : int (default=1e-4)
Ignore score improvements smaller than threshold.
- threshold_mode : str (default=’rel’)
One of rel, abs. Decides whether the threshold value is interpreted in absolute terms or as a fraction of the best score so far (relative)
- sink : callable (default=print)
The target that the information about early stopping is sent to. By default, the output is printed to stdout, but the sink could also be a logger or
noop().
Methods
initialize()(Re-)Set the initial state of the callback. on_batch_begin(net, **kwargs)Called at the beginning of each batch. on_batch_end(net, **kwargs)Called at the end of each batch. on_epoch_begin(net, **kwargs)Called at the beginning of each epoch. on_epoch_end(net, **kwargs)Called at the end of each epoch. on_grad_computed(net, named_parameters, **kwargs)Called once per batch after gradients have been computed but before an update step was performed. on_train_begin(net, **kwargs)Called at the beginning of training. on_train_end(net, **kwargs)Called at the end of training. get_params set_params