310 lines
12 KiB
Python
310 lines
12 KiB
Python
# model.py: Plot class for spectral fits and models
|
|
#
|
|
# Authors: William Cleveland (USRA),
|
|
# Adam Goldstein (USRA) and
|
|
# Daniel Kocevski (NASA)
|
|
#
|
|
# Portions of the code are Copyright 2020 William Cleveland and
|
|
# Adam Goldstein, Universities Space Research Association
|
|
# All rights reserved.
|
|
#
|
|
# Written for the Fermi Gamma-ray Burst Monitor (Fermi-GBM)
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
#
|
|
|
|
from .gbmplot import GbmPlot, Histo, ModelData, Collection, ModelSamples
|
|
from .lib import *
|
|
import warnings
|
|
|
|
class ModelFit(GbmPlot):
|
|
"""Class for plotting spectral fits.
|
|
|
|
Parameters:
|
|
fitter (:class:`~gbm.spectra.fitting.SpectralFitter`, optional):
|
|
The spectral fitter
|
|
view (str, optional): The plot view, one of 'counts', 'photon',
|
|
'energy' or 'nufnu'. Default is 'counts'
|
|
resid (bool, optional): If True, plots the residuals in counts view.
|
|
Default is True.
|
|
**kwargs: Options to pass to :class:`~.gbmplot.GbmPlot`
|
|
|
|
Attributes:
|
|
ax (:class:`matplotlib.axes`): The matplotlib axes object for the plot
|
|
canvas (Canvas Backend object): The plotting canvas, if set upon
|
|
initialization.
|
|
count_data (Collection of :class:`~.gbmplot.ModelData`):
|
|
The count data plot elements
|
|
count_models (Collection of :class:`~.gbmplot.Histo`):
|
|
The count model plot elements
|
|
fig (:class:`matplotlib.figure`): The matplotlib figure object
|
|
model_spectrum (Collection of :class:`~.gbmplot.ModelSamples`):
|
|
The model spectrum sample elements
|
|
residuals (Collection of :class:`~gbmplot.ModelData`):
|
|
The fit residual plot elements
|
|
view (str): The current plot view
|
|
xlim (float, float): The plotting range of the x axis.
|
|
This attribute can be set.
|
|
xscale (str): The scale of the x axis, either 'linear' or 'log'.
|
|
This attribute can be set.
|
|
ylim (float, float): The plotting range of the y axis.
|
|
This attribute can be set.
|
|
yscale (str): The scale of the y axis, either 'linear' or 'log'.
|
|
This attribute can be set.
|
|
"""
|
|
# Define a list of default plotting colors to cycle through
|
|
colors = '#7F3C8D,#11A579,#3969AC,#F2B701,#E73F74,#80BA5A,#E68310,#008695,#CF1C90,#f97b72,#4b4b8f,#A5AA99'.split(',')
|
|
_min_y = 1e-10
|
|
|
|
def __init__(self, fitter=None, canvas=None, view='counts', resid=True,
|
|
interactive=True):
|
|
|
|
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
|
|
|
|
self._figure, axes = plt.subplots(2, 1, sharex=True, sharey=False,
|
|
figsize=(5.7, 6.7), dpi=100,
|
|
gridspec_kw={'height_ratios': [3,1]})
|
|
plt.subplots_adjust(hspace=0)
|
|
self._ax = axes[0]
|
|
self._resid_ax = axes[1]
|
|
|
|
self._view = view
|
|
self._fitter = None
|
|
self._count_models = Collection()
|
|
self._count_data = Collection()
|
|
self._resids = Collection()
|
|
self._model_spectrum = None
|
|
|
|
|
|
# plot data and/or background if set on init
|
|
if fitter is not None:
|
|
self.set_fit(fitter, resid=resid)
|
|
if interactive:
|
|
plt.ion()
|
|
|
|
@property
|
|
def view(self):
|
|
return self._view
|
|
|
|
@property
|
|
def count_models(self):
|
|
return self._count_models
|
|
|
|
@property
|
|
def count_data(self):
|
|
return self._count_data
|
|
|
|
@property
|
|
def residuals(self):
|
|
return self._resids
|
|
|
|
@property
|
|
def model_spectrum(self):
|
|
return self._model_spectrum
|
|
|
|
def set_fit(self, fitter, resid=False):
|
|
"""Set the fitter. If a fitter already exists, this triggers a replot of
|
|
the fit.
|
|
|
|
Args:
|
|
fitter (:class:`~gbm.spectra.fitting.SpectralFitter`):
|
|
The spectral fitter for which a fit has been performed
|
|
resid (bool, optional): If True, plot the fit residuals
|
|
"""
|
|
self._fitter = fitter
|
|
|
|
if self._view == 'counts':
|
|
self.count_spectrum()
|
|
if resid:
|
|
self.show_residuals()
|
|
else:
|
|
self.hide_residuals()
|
|
|
|
elif self._view == 'photon':
|
|
self.photon_spectrum()
|
|
|
|
elif self._view == 'energy':
|
|
self.energy_spectrum()
|
|
|
|
elif self._view == 'nufnu':
|
|
self.nufnu_spectrum()
|
|
|
|
else:
|
|
pass
|
|
|
|
def count_spectrum(self):
|
|
"""Plot the count spectrum fit
|
|
"""
|
|
self._view = 'counts'
|
|
self._ax.clear()
|
|
|
|
model_counts = self._fitter.model_count_spectrum()
|
|
energy, chanwidths, data_counts, data_counts_err, ulmasks = \
|
|
self._fitter.data_count_spectrum()
|
|
|
|
for i in range(self._fitter.num_sets):
|
|
det = self._fitter.detectors[i]
|
|
self._count_models.insert(det, Histo(model_counts[i], self._ax,
|
|
edges_to_zero=False, color=self.colors[i],
|
|
alpha=1.0, label=det))
|
|
self._count_data.insert(det, ModelData(energy[i], data_counts[i],
|
|
chanwidths[i], data_counts_err[i],
|
|
self._ax, ulmask=ulmasks[i],
|
|
color=self.colors[i],
|
|
alpha=0.7, linewidth=0.9))
|
|
|
|
self._ax.set_ylabel(r'Rate [count s$^{-1}$ keV$^{-1}$]')
|
|
self._set_view()
|
|
self._ax.legend()
|
|
|
|
def photon_spectrum(self, **kwargs):
|
|
"""Plot the photon spectrum model
|
|
|
|
Args:
|
|
num_samples (int, optional): The number of sample spectra.
|
|
Default is 10.
|
|
"""
|
|
self._view = 'photon'
|
|
self._plot_spectral_model(**kwargs)
|
|
self._ax.set_ylabel(r'Photon Flux [ph cm$^{-2}$ s$^{-1}$ keV$^{-1}$]', fontsize=PLOTFONTSIZE)
|
|
|
|
def energy_spectrum(self, **kwargs):
|
|
"""Plot the energy spectrum model
|
|
|
|
Args:
|
|
num_samples (int, optional): The number of sample spectra.
|
|
Default is 100.
|
|
"""
|
|
self._view = 'energy'
|
|
self._plot_spectral_model(**kwargs)
|
|
self._ax.set_ylabel(r'Energy Flux [ph cm$^{-2}$ s$^{-1}$]', fontsize=PLOTFONTSIZE)
|
|
|
|
def nufnu_spectrum(self, **kwargs):
|
|
"""Plot the nuFnu spectrum model
|
|
|
|
Args:
|
|
num_samples (int, optional): The number of sample spectra.
|
|
Default is 100.
|
|
"""
|
|
self._view = 'nufnu'
|
|
self._plot_spectral_model(**kwargs)
|
|
self._ax.set_ylabel(r'$\nu F_\nu$ [keV ph cm$^{-2}$ s$^{-1}$]', fontsize=PLOTFONTSIZE)
|
|
|
|
def show_residuals(self, sigma=True):
|
|
"""Show the fit residuals
|
|
|
|
Args:
|
|
sigma (bool, optional): If True, plot the residuals in units of
|
|
model sigma, otherwise in units of counts.
|
|
Default is True.
|
|
"""
|
|
# if we don't already have residuals axis
|
|
if len(self._figure.axes) == 1:
|
|
self._figure.add_axes(self._resid_ax)
|
|
|
|
# get the residuals
|
|
energy, chanwidths, resid, resid_err = self._fitter.residuals(sigma=sigma)
|
|
|
|
# plot for each detector/dataset
|
|
ymin, ymax = ([], [])
|
|
for i in range(self._fitter.num_sets):
|
|
det = self._fitter.detectors[i]
|
|
self._resids.insert(det, ModelData(energy[i], resid[i], chanwidths[i],
|
|
resid_err[i], self._resid_ax, color=self.colors[i],
|
|
alpha=0.7, linewidth=0.9))
|
|
ymin.append((resid[i]-resid_err[i]).min())
|
|
ymax.append((resid[i]+resid_err[i]).max())
|
|
# the zero line
|
|
self._resid_ax.axhline(0.0, color='black')
|
|
self._resid_ax.set_xlabel('Energy [kev]', fontsize=PLOTFONTSIZE)
|
|
|
|
if sigma:
|
|
self._resid_ax.set_ylabel('Residuals [sigma]', fontsize=PLOTFONTSIZE)
|
|
else:
|
|
self._resid_ax.set_ylabel('Residuals [counts]', fontsize=PLOTFONTSIZE)
|
|
|
|
# we have to set the y-axis range manually, because the y-axis
|
|
# autoscale is broken (known issue) in matplotlib for this situation
|
|
ymin = np.min(ymin)
|
|
ymax = np.max(ymax)
|
|
self._resid_ax.set_ylim((1.0-np.sign(ymin)*0.1)*ymin,
|
|
(1.0+np.sign(ymax)*0.1)*ymax)
|
|
|
|
def hide_residuals(self):
|
|
"""Hide the fit residuals
|
|
"""
|
|
try:
|
|
self._figure.delaxes(self._resid_ax)
|
|
self._ax.xaxis.set_tick_params(which='both', labelbottom=True)
|
|
self._ax.set_xlabel('Energy (keV)', fontsize=PLOTFONTSIZE)
|
|
except:
|
|
print('Residuals already hidden')
|
|
|
|
def _set_view(self):
|
|
"""Set the view properties
|
|
"""
|
|
self._ax.set_xlim(self._fitter.energy_range)
|
|
self._ax.yaxis.set_tick_params(labelsize=PLOTFONTSIZE)
|
|
self._ax.set_xscale('log')
|
|
self._ax.set_yscale('log')
|
|
self._ax.set_xlabel('Energy [kev]', fontsize=PLOTFONTSIZE)
|
|
|
|
def _plot_spectral_model(self, num_samples=100, plot_components=True):
|
|
"""Plot the spectral model by sampling from the Gaussian approximation
|
|
to the parameters' posterior.
|
|
|
|
Args:
|
|
num_samples (int, optional): The number of sample spectra.
|
|
Default is 100.
|
|
"""
|
|
# clean plot and hide residuals if any
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
self._ax.clear()
|
|
self.hide_residuals()
|
|
|
|
|
|
num_comp = self._fitter.num_components
|
|
comps = self._fitter.function_components
|
|
name = self._fitter.function_name
|
|
|
|
# if the number of model components is > 1, plot each one
|
|
if (num_comp > 1) and (plot_components):
|
|
energies, samples = self._fitter.sample_spectrum(which=self._view,
|
|
num_samples=num_samples,
|
|
components=True)
|
|
self._spectrum_model = [ModelSamples(energies, samples[:,i,:], self._ax,
|
|
label=comps[i], color=self.colors[i+1],
|
|
alpha=0.1, lw=0.3) for i in range(num_comp)]
|
|
|
|
samples = samples.sum(axis=1)
|
|
else:
|
|
# or just plot the function
|
|
self._spectrum_model = []
|
|
energies, samples = self._fitter.sample_spectrum(which=self._view,
|
|
num_samples=num_samples)
|
|
y_max = samples.max(axis=(1,0))
|
|
self._spectrum_model.append(ModelSamples(energies, samples, self._ax,
|
|
label=name, color=self.colors[0],
|
|
alpha=0.1, lw=0.3))
|
|
self._set_view()
|
|
|
|
# fix the alphas for the legend
|
|
legend = self._ax.legend()
|
|
for lh in legend.legendHandles:
|
|
lh.set_alpha(1)
|
|
lh.set_linewidth(1.0)
|
|
|
|
if self._ax.get_ylim()[0] < self._min_y:
|
|
self._ax.set_ylim(self._min_y, 10.0*y_max)
|