Scarlet2 – Thoughts for a major redesign
Astronomical source modeling and separation, all new and shiny
scarlet works fine for deblending. It is, to my knowledge, the first method to realize the advantage of deblending in multiple imaging filters (in astronomy, at least). It's robust and fast, and currently the default deblender in the LSST pipeline. It performs really well on HSC data, our precursor data set before Rubin comes online.
But I have become dissatisfied. Scarlet1 (as I shall call it) is a pretty old code by now, the first commit being from May 8, 2017. Fred Moolekamp and I had been tinkering with the approach before it became scarlet. For reference, in the first working version we had to compute gradients ourselves, such were the times. We later switched to autograd, which served us fine, but it shows its age. In particular, it locks us in on CPUs, so we don't benefit from the performance gains on GPUs and TPUs, which we'll really want for the data volumes ahead.
It's not just the age. To make bulk processing robust, we designed classes to describe astronomical sources in a "best effort" way, geared towards the vast number of smallish, faintish galaxies and stars that dominate extragalactic surveys. Extending these source classes to other use cases (think e.g. strong lensing arcs, or dust lanes) is doable, but it's not easy to understand how. So, we packaged up the most common options in distinct source classes, which does work for bulk processing, but that made them convoluted and not really extendable. So, we didn't have the right "language" to specify source models. I want it to be flexible, modular, intuitive, and, ideally, concise.
Scarlet1 was designed for the traditional case of having one data source. But it's now quite common to have multiple instruments observing the same part of the sky, and the 2020s will bring much more overlapping data. We've been arguing for a while that there are substantial benefits in jointly analyzing these datasets at the pixel level. To support this, Remy Joseph developed a method in scarlet for fast resampling and convolution. It works, but it could really use the boost from going to GPUs. With this use case in mind, we also need to redesign the models in sky coordinates, not image coordinates, as we have done so far.
Scarlet1 supports fitting only by proximal gradient descent to enforce any constraints on the source models, such as monotonicity of the shape. (I wrote a note on proximal methods and what is so great about them). We even developed a method for proximal optimization with the now ubiquitous Adam optimizer as its gradient update scheme. And while the constraints were necessary for stable and good fits, they did introduce some undesired behavior (e.g. radial "spikes" that are monotonic but not realistic). With Francois Lanusse, we replaced the constraints by a neural network that acts as a prior, steering the solutions towards more common galaxy shapes. Bolting tensorflow and numpy together was doable, but not fast. Having one coding framework for scarlet and neural nets is the way to go.
And finally, sampling. I didn't find a paper or any avenue for drawing samples of the posterior of a model with parameter constraints that are expressed by proximal operators. This is why scarlet1 doesn't allow sampling. But maybe, if neural network priors are as useful as I believe them to be, we don't need these constraints anymore. Gradients of the likelihood and the prior: combine!
In short, here are my requirements for a successful scarlet2 design. It needs to support
- automatic differentiation, neural networks, and hardware acceleration
- easy custom model specification
- ingesting data from multiple instruments
- fitting and sampling
I'll address my thoughts about items 1 and 2 below. More posts about this redesign process will come later.
1. Torch vs JAX
I like Pytorch. It's modular and intuitive (see item 2). One thing that I really like is that source classes have parameters. So, you can combine those parameters to do something. As an example, this is how to write a standard factorized source in torch:
import torch.nn as nn
class Source(nn.Module):
def __init__(self, spectrum, morphology):
self.spectrum = spectrum
self.morphology = morphology
def forward(self):
return self.spectrum[:, None, None] * self.morphology[None, :, :]
That's how we think about generative models! It has also dynamic computing graphs, which is useful when the shape of the generated model needs to change, e.g. when we need to switch to a larger model type because we initially underestimate the size of an extended source.
But torch has downsides. It has explicit device placement, which is often a pain. And it is dynamic, i.e. it spends time on reevaluating the graph when it didn't need to. And while it is very actively used in the ML community, the astro community is moving to JAX.
JAX is closer to numpy (which we've been using so far), has automatic differentiation, just-in-time compilation, and vectorization. It translates natively, by virtue of the the ingenious XLA, to GPUs and TPUs. But it comes at a price: it only operates on pure functions. That makes generative modeling awkward:
import jax.numpy as np
def model(source):
return source.spectrum[:, None, None] * source.morphology[None, :, :]
The model function is not part of source! Functional programming is often not very modular and hard to parse for any user who'd like to do custom source modeling.
This is why I was excited to find Equinox. It's a torch-like API implemented in JAX. It treats the entire model, including callable functions, as a JAX pytree. So, you can write something like this:
import equinox as eqx
import jax.numpy as jnp
class Source(eqx.Module):
spectrum: jnp.ndarray
morphology: jnp.ndarray
def __call__(self):
return self.spectrum[:, None, None] * self.morphology[None, :, :]
This is much more readable. It uses python's dataclass syntax, which makes it straightforward to declare and intialize the class members.
There will be kinks to iron out, including the lack of many methods in JAX or Equinox but I'm willing to bet on this combo for the redesign.
2. Model specification
Languages evolve to effectively communicate. We need to communicate how to specify what a source model is (what it does is quite clear: it contributes photon counts to a region of the modeled sky). So, here's a scarlet1 setup from our quickstart guide:
source1 = scarlet.PointSource(model_frame, center1, observation)
source2 = scarlet.ExtendedSource(model_frame, center2, observation, K=2)
source3 = scarlet.ExtendedSource(model_frame, center3, observation, compact=True)
sources = [source1, source2, source3]
blend = scarlet.Blend(sources, observation)
That's not too bad, but some things are strange. Why do all sources need to know the model_frame, which is our description of the piece of the sky blend will represent? Well, because a source needs to know in what coordinates e.g. center1 is specified. Also, why do sources need to have observation (i.e. the observational data) specified? It's nominally only relevant for the fitting procedure. But we're using it to initialize all parameters of these sources. Without the user knowing anything about that, or even what those parameters are. Hmmm.
Here's a more verbose, but clearly more modular and transparent specification for scarlet2 (not for the same example as above):
import jax.numpy as jnp
import scarlet2 as sct
import distrax
with sct.Scene(frame=model_frame) as scene:
spectrum = sct.Spectrum(jnp.zeros(scene.C), constraint=sct.PositiveConstraint())
center = jnp.array([12., 13.])
sigma = jnp.array([0.1, 0.1])
center = sct.Parameter(center, prior=distrax.MultivariateNormalDiag(center, sigma))
nu = sct.Parameter(1.5, prior=distrax.Uniform(low, high))
morph = sct.SpergelProfileMorphology(nu=nu)
source = sct.Source(center, spectrum, morph)
This API exposes all user-dependent choices to the user. All parameters are initialized explicitly, and it's clear how they are constrained or attached to a prior. For instance, it's clear that center follows a Normal distribution with a width of 0.1. source accepts what all sources do: a center, a spectrum, and a morphology. The user decides what these are. It has a steeper learning curve, but we can still package up some defaults like we used to.
Also note the convenience from the context manager that is invoked by with (see here for an intro to context managers). This is inspired by how PyMC adds random variables to their probability models. It allows us to refer to the model_frame implicitly when initializing the sources. In detail, Scene.__enter__ creates a global store, let's call Scenery, that stores the current scene and thus all info any source model or parameter could need during its __init__. Scene.__exit__ then safely deletes this global store, so that other scenes could be created in the same process (just not at the same time). We can also automatically add a source to the scene, simply by having each source add itself to scene as the last step of its initialization. And don't worry, we can add more sources later:
with scene:
source2 = sct.Source...
Another detail: we're using a different mechanism for imposing constraints, namely by parameter transformation: there's a unconstrained parameter that is being optimized or sampled from but the generator sees it only in transformed form, e.g. for non-negative scalars. So, this keeps the problem entirely in the realm of smooth optimization / sampling (requirement 4 above) and thus easy to work with e.g. with optax or blackjax.
Enough for now. If you want to contribute to the development of scarlet2, please head over to the github repo and peruse the discussion forum there.