Proximal matrix factorization in pytorch
Constrained optimization with autograd
This is a best of 2018 post, of sorts. I decided to combine the two most important techniques I've worked on in the last year: proximal algorithms and pytorch. In particular, I have asked myself how hard it would be to implement matrix factorization with proximal constraints (which we have used for image analysis here and for the SCARLET method here) in pytorch. Turns out, it's pretty straightforward.
As I've said before, proximal algorithms provide an efficient and elegant approach for constrained optimization. (Side remark: I'm a great fan of constrained optimization because it allows me to impose structure onto a data analysis problem; this may be some intuition how the data ought to behave, but in the physical sciences we may even have a law that needs to be obeyed.) As a recap, the basic algorithm is a one-liner called Proximal Gradient Method (PGM) because it performs a gradient step of the function followed by the application of a proximal operator, which minimizes a penalty/regularization function : Many proximal operators are analytic, chief among them are those that constrain a solution to some submanifold, in which case is simply a projection onto that manifold. A fair selection of proximal operators is coded up in my proxmin package.
But note the two other elements of PGM: the gradient of and the stepsize . This is where pytorch will come in handy. I'm paraphrasing Dan Foreman-Mackey:
Even if you don't use neural networks, you will like automatic differentiation!
pytorch inherits autograd from tensorflow. So you can write down an objective function in terms of some model parameters, and like and pytorch will automatically know how to compute the gradients of with respect to these parameters. There's a minor price to pay: all functions need to be expressed using tensorflow/pytorch ingredients, which look a lot like numpy but plain numpy will not do.
Let's look at my favorite model these days, the simplest non-linear model imaginable: matrix factorization. It's a problem of factorizing a matrix into two (simpler) matrices and . The idea here is that the matrix gets expressed by a finite number of components (the rows of ) with amplitudes stored in the columns of . Typically one tries to reduce the number o those components and thus the dimensionality of the problem. Matrix factorization is incredibly useful and employed widely for problems like text topic analysis, audio speaker recognition, and hyperspectral imaging.
Under Gaussian noise in the data we have a standard quadratic objective function which I should like to minimize by calculating the gradients of the matrix factors. Sure, I could do this analytically (and I have), but bear with me, this will get interesting. The syntax of pytorch is pretty concise, so we only need code like this:
import torch
import torch.nn as nn
class NMF(nn.Module):
def __init__(self, B, N, K):
super(NMF, self).__init__()
self.A = nn.Parameter(torch.rand(B, K, requires_grad=True))
self.S = nn.Parameter(torch.rand(K, N, requires_grad=True))
def forward(self):
return torch.matmul(self.A, self.S)
# some data cube Y: B x N and we want to factor it into K components
Y = torch.rand(B,N)
nmf = NMF(B, N, K)
Y_ = nmf()
loss_fn = nn.MSELoss(reduction='sum')
loss = loss_fn(Y_, Y)
loss.backward()
All we did here is to create a class NMF with two parameters A and S (initialized at random) and told it that the model is the matrix product of these two matrices. We then defined a loss function (mean squared error of data Y and model prediction Y_, but we sum up all elements to get a proper norm) and call the magic backward() method. At this point, the gradients of loss with respect to the model parameters are computed and can be accessed asnmf.A.gradand nmf.S.grad.
The standard gradient method the proceeds like this:
for param, stepsize in zip(nmf.parameters(), stepsizes):
param.data = param.data - stepsize * param.grad
Note that we have to access the data portion of the parameter because we're directly meddling with the contents here. For a fully functional implementation, we need to repeat it several times or until some form of convergence has been achieved, so the code will look like this:
for epoch in range(n_epoch):
Y_ = nmf()
loss = loss_fn(Y_, Y)
nmf.zero_grad() # need to clear the old gradients
loss.backward()
for param,stepsize in zip(nmf.parameters(), stepsizes):
param.data = param.data - stepsize * param.grad
As promised, that wasn't hard.
But, there's obviously no proximal trickery going on here. In fact, the model above is the plain, unconstrained version of the matrix factorization, which works, but isn't the one most people use. If we have some insight in the problem, we'd like to constrain the solution. The most common form is to require that all elements of and be non-negative. That makes a lot of sense in many applications, e.g. when the amplitudes in are considered positive contributions to some mixed signal, for instance the rate of occurrence of a word or a topic in a document cannot be below zero. This is why the problem is often called non-negative matrix factorization (NMF), but that's just one very simple type of constraint. I therefore prefer the term constrained matrix factorization.
Anyway, if you look at the fist equation for the PGM and the last piece of code, the only thing missing here is the proximal operator. For non-negativity, the proximal operator is the projection onto the non-negative numbers (duh!):
which has a pytorch equivalent: prox_plus = torch.Threshold(0,0). All I have to do for constraining the solution to by non-negative is to alter line 8 above to this:
param.data = prox_plus(param.data - stepsize * param.grad)
Et voilĂ , matrix factorization with proximal constraint and gradients calculated by pytorch.
Because that was so easy, we should also add another constraint that makes the NMF much better behaved: normalizing one of the matrix factors. If that's not done, the optimizer will often just shuffle power from the to the matrix and back with no improvement in the loss function. How you normalize depends on your problem, but here we simply require that the rows of, say, sum up to one:
def prox_unity(X):
return X / X.sum(dim=0).unsqueeze(dim=0)
Note that since the proximal operations are outside of the gradient calculation, pytorch doesn't need to know about it because it doesn't affect autograd. That means:
Any proximal operator is acceptable with autograd
We can use any operation we want. If you want a solution with small norm, use the soft thresholding operator proxmin.operators.prox_soft or for one that maximizes entropy, use the operator proxmin.operator.prox_max_entropy, etc. The code now looks like this:
for epoch in range(n_epoch):
Y_ = nmf()
nmf.zero_grad() # need to clear the old gradients
loss = loss_fn(Y_, Y)
loss.backward()
for param,stepsize in zip(nmf.parameters(), stepsizes):
param.data = prox_plus(param.data - stepsize * param.grad)
if param is nmf.A:
param.data = prox_unity(param.data)
if param is nmf.S:
# 3. argument of prox_max_entropy is regularization strength
param.data = proxmin.operator.prox_max_entropy(param.data, 0, 1)
PGM for pytorch
To make if even easier to use, I've created a functional proper pytorch optimimizer, poaching most of the logic of the SGD implementation. For some reason this class is called "Stochastic Gradient Descent", but there's nothing at all stochastic about it, it's simple gradient descent, so perfect for our purposes. It also has acceleration techniques with momentum and Nesterov's method, which will speed up the optimization.
from torch.optim.sgd import SGD
from torch.optim.optimizer import required
class PGM(SGD):
def __init__(self, params, proxs, lr=required, momentum=0, dampening=0, nesterov=False):
kwargs = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=0, nesterov=nesterov)
super().__init__(params, **kwargs)
if len(proxs) != len(self.param_groups):
raise ValueError("Invalid length of argument proxs: {} instead of {}".format(len(proxs), len(self.param_groups)))
for group, prox in zip(self.param_groups, list(proxs)):
group.setdefault('prox', prox)
def step(self, closure=None):
# perform a gradient step
# optionally with momentum or nesterov acceleration
super().step(closure=closure)
for group in self.param_groups:
prox = group['prox']
# apply the proximal operator to each parameter in a group
for p in group['params']:
p.data = prox(p.data)
The code above is also on gist.
Now you use the customary pytorch optimizer incantation (define an optimizer and loss, compute gradients, call optimizer.step) with proximal operations:
nmf = NMF(B, N, K)
prox = prox_plus
optimizer = PGM(nmf.parameters(), prox, lr=0.01, momentum=0.5)
for epoch in range(n_epoch):
Y_ = nmf()
loss = loss_fn(Y_, Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
For better control, you can explicitly order different parameters into groups and define per-group optimization settings. This is particularly useful for the stepsizes (learning rate in pytorch) and proximal operators:
param_list = [{'params': nmf.A, 'lr': 0.01},
{'params': nmf.S, 'lr': 10}]
prox_list = [prox_plus, prox_max_entropy]
optimizer = PGM(param_list, prox_list, momentum=0.5)
Next steps
This works just fine and can be swapped to get constrained optimization instead of the canonical pytorch optimizers (especially SGD). But, as an outlook, you can see the one piece I'm not happy with: we still have to set explicitly the step sizes / learning rates. In the NMF model, we actually know the step sizes analytically, but often the model is so complicated that finding a good learning rate is the main challenge. There's extensive literature on the topic, but even more importantly, there are several optimizers in pytorch (e.g. Adam, RMSpropand Adagrad) that adjust the learning rates. However, the way they do that is not compatible with proximal updates. At least not directly. So, I'm looking for a way to tune the learning rates for PGM...