Simple L2/L1 Regularization in Torch 7
A few days ago, I was trying to improve the generalization ability of my neural networks. One popular approach to improve performance is to introduce a regularization term during training on network parameters, so that the space of possible solutions is constrained to plausible values. One popular method is to use a p-norm, which is defined as:
, where \(x\) belongs to an \(N\)-dimensional vector space and \(i\) indexes elements from \(x\). Popular choices for \(p\) are \(p=1\) and \(p=2\). \(p=1\) results in the L1 norm, which is known to induce sparsity. For \(p=2\), p-norm translates to the famous Euclidean norm. When L1/L2 regularization is properly used, networks parameters tend to stay small during training.
When I was trying to introduce L1/L2 penalization for my network, I was surprised to see that the stochastic gradient descent (SGDC) optimizer in the Torch nn package does not support regularization out-of-the-box. Thankfully, you can easily add regularization using the callback.
Adding Regularization to SGDC
Torch’s implementation of SGDC is simple to follow. The relevant part of the optimizer is the following three lines of code:
currentError = currentError + criterion:forward(module:forward(input), target) module:updateGradInput(input, criterion:updateGradInput(module.output, target)) module:accUpdateGradParameters(input, criterion.gradInput, currentLearningRate)
The first line calculates the loss using the forward pass of the network (
module) given the input and current network parameters. The second line calculates the gradient of the model with respect to parameters. The third line updates network parameters using the
currentLearningRate. In order to add regularization, we need to modify the
currentError to reflect L1/L2 regularization penalty and also modify the update rule for network parameters. This can be achieved using the following callback function in SGDC:
if self.hookIteration then self.hookIteration(self, iteration, currentError) end
hookIteration function is defined and passed to SGDC, it is called at every iteration. We can define a suitable
callback function to implement regularization and pass it onto SGDC. One implementation can be the following:
local function callback(trainer, iteration, currentError) currentError = currentError + regularization_penalty(trainer.module, l1_weight, l2_weight) regularize_parameters(trainer.module, l1_weight, l2_weight) end
trainer is an instance of SGDC. We can add our callback to SGDC by overriding the
trainer.hookIteration = callback function of SGDC.
currentError is a reference to the trainer
currentError so updating it also updates current optimizer error.
regularize_parameters functions is easy. The first one is not strictly necessary, given that we analytically know how to differentiate L1/L2 norms, but it might be useful to implement them so that we can visualize the total loss during optimization. This can be achieved in the following manner:
function regularization_penalty(network, l1_weight, l2_weight) local parameters, _ = network:parameters() local penalty = 0 for i=1, table.getn(parameters) do penalty = penalty + l1_weight * parameters[i]:norm(1) + l2_weight * parameters[i]:norm(2) ^ 2 end return penalty end
The only ambiguous line might be the iteration over parameters. This is actually not an iteration over individual network parameters, but over network layers, i.e.
i goes from 1 to number-of-layers. Now, lets’ move to updating network parameters:
function regularize_params(network, l1_weight, l2_weight) local parameters, _ = network:parameters() for i=1, table.getn(parameters) do local update = torch.clamp(parameters[i], -l1_weight, l1_weight) update:add(parameters[i]:mul(-l2_weight)) parameters[i]:csub(update) end end
parameters is a reference network parameters, so updating it affects the state of the network. The only ambiguous part is the
clamp function on the parameters to the
l1_weight. By doing this we are effectively reducing the step size when
parameters are close to zero, thereby reducing oscillatory movement around the origin.