[`scarlet`](https://pmelchior.github.io/scarlet/) works fine for deblending. It is, to my knowledge, the first method to realize the advantage of deblending in multiple imaging filters (in astronomy, at least). It's robust and fast, and currently the default deblender in the LSST pipeline. It performs really well on HSC data, our precursor data set before Rubin comes online.
But I have become dissatisfied. Scarlet1 (as a shall call it) is a pretty old code by now, the first commit being from May 8, 2017. Fred Moolekamp and I had been tinkering with the approach before it became `scarlet`. For reference, in the first working version we had to compute gradients _ourselves_, such were the times. We later switched to [`autograd`](https://github.com/HIPS/autograd), which served us fine, but it shows its age. In particular, it locks us in on CPUs, so we don't benefit from the performance gains on GPUs and TPUs, which we'll really want for the data volumes ahead.
It's not just the age. To make bulk processing robust, we designed classes to describe astronomical sources in a "best effort" way, geared towards the vast number of smallish, faintish galaxies and stars that dominate extragalactic surveys. Extending these source classes to other use cases (think e.g. strong lensing arcs, or dust lanes) is doable, but it's not easy to understand how. So, we packaged up the most common options in distinct source classes, which does work for bulk processing, but that made them convoluted and not really extendable. So, we didn't have the right "language" to specify source models. I want it to be flexible, modular, intuitive, and, ideally, concise.
Scarlet1 was designed for the traditional case of having one data source. But it's now quite common to have multiple instruments observing the same part of the sky, and the 2020s will bring much more overlapping data. We've been [arguing](http://ui.adsabs.harvard.edu/abs/2017ApJS..233...21R) [for](http://ui.adsabs.harvard.edu/abs/2019BAAS...51c.418E) [a](http://ui.adsabs.harvard.edu/abs/2019BAAS...51c.201R) [while](http://ui.adsabs.harvard.edu/abs/2019BAAS...51c..44C) that there are [substantial benefits](http://ui.adsabs.harvard.edu/abs/2020arXiv200810663C) in _jointly_ analyzing these datasets [at the pixel level](https://arxiv.org/abs/2201.03862). To support this, Remy Joseph [developed a method in scarlet](http://ui.adsabs.harvard.edu/abs/2021arXiv210706984J) for fast resampling and convolution. It works, but it could really use the boost from going to GPUs. With this use case in mind, we also need to redesign the models in _sky_ coordinates, not _image_ coordinates, as we have done so far.
Scarlet1 supports fitting only by proximal gradient descent to enforce any constraints on the source models, such as monotonicity of the shape. (I wrote a [note](https://pmelchior.net/blog/the-magic-of-proximal-operators.html) on proximal methods and what is so great about them). We even developed a [method](http://ui.adsabs.harvard.edu/abs/2019arXiv191010094M) for proximal optimization with the now ubiquitous Adam optimizer as its gradient update scheme. And while the constraints were necessary for stable and good fits, they did introduce some undesired behavior (e.g. radial "spikes" that _are_ monotonic but not realistic). With Francois Lanusse, we [replaced the constraints by a neural network](http://ui.adsabs.harvard.edu/abs/2019arXiv191203980L) that acts as a prior, steering the solutions towards more common galaxy shapes. Bolting tensorflow and numpy together was doable, but not fast. Having one coding framework for `scarlet` and neural nets is the way to go.
And finally, sampling. I didn't find a paper or any avenue for drawing samples of the posterior of a model with parameter constraints that are expressed by proximal operators. This is why `scarlet1` doesn't allow sampling. But maybe, if neural network priors are as useful as I believe them to be, we don't need these constraints anymore. Gradients of the likelihood and the prior: combine!
In short, here are my requirements for a successful `scarlet2` design. It needs to support
1. automatic differentiation, neural networks, and hardware acceleration
2. easy custom model specification
3. ingesting data from multiple instruments
4. fitting _and_ sampling
I'll address my thoughts about items 1 and 2 below. More posts about this redesign process will come later.
## 1. Torch vs JAX
I like Pytorch. It's modular and intuitive (see item 2). One thing that I really like is that source classes _have_ parameters. So, you can combine those parameters to do something. As an example, this is how to write a standard factorized source in torch:
```python
import torch.nn as nn
class Source(nn.Module):
def __init__(self, spectrum, morphology):
self.spectrum = spectrum
self.morphology = morphology
def forward(self):
return self.spectrum[:, None, None] * self.morphology[None, :, :]
```
That's how we think about generative models! It has also dynamic computing graphs, which is useful when the shape of the generated model needs to change, e.g. when we need to switch to a larger model type because we initially underestimate the size of an extended source.
But torch has downsides. It has explicit device placement, which is often a pain. And it is dynamic, i.e. it spends time on reevaluating the graph when it didn't need to. And while it is very actively used in the ML community, the astro community is moving to JAX.
JAX is closer to `numpy` (which we've been using so far), has automatic differentiation, just-in-time compilation, and vectorization. It translates natively, by virtue of the the ingenious XLA, to GPUs and TPUs. But it comes at a price: it only operates on pure functions. That makes generative modeling awkward:
```python
import jax.numpy as np
def model(source):
return source.spectrum[:, None, None] * source.morphology[None, :, :]
```
The `model` function is not part of `source`! Writing code is not very modular and hard to parse for any user who'd like to do custom source modeling.
This is why I was excited to find [`Equinox`](https://docs.kidger.site/equinox/). It's a torch-like API implemented in JAX. It treats the entire model, including callable functions, as a JAX pytree. So, you can write something like this:
```python
import equinox as eqx
import jax.numpy as jnp
class Source(eqx.Module):
spectrum: jnp.ndarray
morphology: jnp.ndarray
def __call__(self):
return self.spectrum[:, None, None] * self.morphology[None, :, :]
```
This is much more readable. It uses python's `dataclass` syntax, which makes it straightforward to declare and intialize the class members.
There will be kinks to iron out, including the lack of many methods in JAX or `Equinox` but I'm willing to bet on this combo for the redesign.
## 2. Model specification
Languages evolve to effectively communicate. We need to communicate how to specify what a source model is (what it does is quite clear: it contributes photon counts to a region of the modeled sky). So, here's a `scarlet1` setup from our [quickstart guide](https://pmelchior.github.io/scarlet/0-quickstart.html):
```python
source1 = scarlet.PointSource(model_frame, center1, observation)
source2 = scarlet.ExtendedSource(model_frame, center2, observation, K=2)
source3 = scarlet.ExtendedSource(model_frame, center3, observation, compact=True)
sources = [source1, source2, source3]
blend = scarlet.Blend(sources, observation)
```
That's not too bad, but some things are strange. Why do all sources need to know the `model_frame`, which is our description of the piece of the sky `blend` will represent? Well, because a source needs to know in what coordinates e.g. `center1` is specified. Also, why do sources need to have `observation` (i.e. the observational data) specified? It's nominally only relevant for the fitting procedure. But we're using it to _initialize_ all parameters of these sources. Without the user knowing anything about that, or even what those parameters are. Hmmm.
Here's a more verbose, but clearly more modular and transparent specification for `scarlet2` (not for the same example as above):
```python
import jax.numpy as jnp
import scarlet2 as sct
import distrax
with sct.Scene(frame=model_frame) as scene:
spectrum1 = sct.Spectrum(jnp.zeros(scene.C), constraint=sct.PositiveConstraint())
center1 = jnp.array([12., 13.])
sigma1 = jnp.array([0.1, 0.1])
center1 = sct.Parameter(center1, prior=distrax.MultivariateNormalDiag(center1, sigma1))
nu = sct.Parameter(1.5, prior=distrax.Uniform(low, high))
morph1 = sct.SpergelProfileMorphology(nu=nu)
source1 = sct.Source(pos1, spectrum1, morph1)
```
This API exposes all user-dependent choices to the user. All parameters are initialized explicitly, and it's clear how they are constrained or attached to a prior. For instance, it's clear that `center1` follows a Normal distribution with a width of 0.1. `source1` accepts what all sources do: a center, a spectrum, and a morphology. The user decides what these are. It has a steeper learning curve, but we can still package up some defaults like we used to.
Also note the convenience from the context manager that is invoked by `with` (see [here](https://book.pythontips.com/en/latest/context_managers.html) for an intro to context managers). This is inspired by how [PyMC adds random variables to their probability models](https://www.pymc.io/projects/examples/en/latest/howto/api_quickstart.html). It allows us to refer to the `model_frame` implicitly when initializing the sources. In detail, `Scene.__enter__` creates a *global* store, let's call `Scenery`, that stores the current `scene` and thus all info any source model or parameter could need during its `__init__`. `Scene.__exit__` then safely deletes this global store, so that other scenes could be created in the same process (just not at the same time). We can also automatically add a source to the `scene`, simply by having each `source` add _itself_ to `scene` as the last step of its initialization. And don't worry, we can add more sources later:
```python
with scene:
source2 = sct.Source...
```
Another detail: we're using a different mechanism for imposing constraints, namely by parameter transformation: there's a unconstrained parameter that is being optimized or sampled from but the generator sees it only in transformed form, e.g. $\sigma=\exp(\log(\sigma))$ for non-negative scalars. So, this keeps the problem entirely in the real of smooth optimization / sampling (re requirement 4 above), and thus easy to work with e.g. with [`optax`](https://github.com/deepmind/optax) or [`blackjax`](https://blackjax-devs.github.io/blackjax/).
Enough for now. If you want to contribute to the development of `scarlet2`, please head over to the [github repo](https://github.com/pmelchior/scarlet2) and peruse the [discussion forum](https://github.com/pmelchior/scarlet2/discussions) there.
Scarlet2 – Thoughts for a major redesign